diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 6bbc31f6..9cdd02aa 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -1,137 +1,115 @@ -#' The Scalable Highly Adaptive Lasso +#' Scalable Highly Adaptive Lasso (HAL) #' -#' The Highly Adaptive Lasso is an estimation procedure that generates a design -#' matrix consisting of basis functions corresponding to covariates and -#' interactions of covariates and fits Lasso regression to this (usually) very -#' wide matrix, recovering a nonparametric functional form that describes the -#' target prediction function as a composition of subset functions with finite -#' variation norm. This implementation uses \pkg{hal9001}, which provides both -#' a custom implementation (based on \pkg{origami}) of the cross-validated -#' lasso as well the standard call to \code{\link[glmnet]{cv.glmnet}} from the -#' \pkg{glmnet}. +#' The Highly Adaptive Lasso (HAL) is a nonparametric regression function that +#' has been demonstrated to optimally estimate functions with bounded (finite) +#' variation norm. The algorithm proceeds by first building an adaptive basis +#' (i.e., the HAL basis) based on indicator basis functions (or higher-order +#' spline basis functions) representing covariates and interactions of the +#' covariates up to a pre-specified degree. The fitting procedures included in +#' this learner use \code{\link[hal9001]{fit_hal}} from the \pkg{hal9001} +#' package. For details on HAL regression, consider consulting the following +#' \insertCite{benkeser2016hal;textual}{sl3}), +#' \insertCite{coyle2020hal9001-rpkg;textual}{sl3}), +#' \insertCite{hejazi2020hal9001-joss;textual}{sl3}). #' #' @docType class +#' #' @importFrom R6 R6Class +#' @importFrom origami folds2foldvec +#' @importFrom stats predict quasibinomial #' #' @export #' #' @keywords data #' -#' @return Learner object with methods for training and prediction. See -#' \code{\link{Lrnr_base}} for documentation on learners. +#' @return A learner object inheriting from \code{\link{Lrnr_base}} with +#' methods for training and prediction. For a full list of learner +#' functionality, see the complete documentation of \code{\link{Lrnr_base}}. #' -#' @format \code{\link{R6Class}} object. +#' @format An \code{\link[R6]{R6Class}} object inheriting from +#' \code{\link{Lrnr_base}}. #' #' @family Learners #' #' @section Parameters: -#' \describe{ -#' \item{\code{max_degree=3}}{ The highest order of interaction -#' terms for which the basis functions ought to be generated. The default -#' corresponds to generating basis functions up to all 3-way interactions of -#' covariates in the input matrix, matching the default in \pkg{hal9001}. -#' } -#' \item{\code{fit_type="glmnet"}}{The specific routine to be called when -#' fitting the Lasso regression in a cross-validated manner. Choosing the -#' \code{"glmnet"} option calls either \code{\link[glmnet]{cv.glmnet}} or -#' \code{\link[glmnet]{glmnet}}. -#' } -#' \item{\code{n_folds=10}}{Integer for the number of folds to be used -#' when splitting the data for cross-validation. This defaults to 10 as this -#' is the convention for V-fold cross-validation. -#' } -#' \item{\code{use_min=TRUE}}{Determines which lambda is selected from -#' \code{\link[glmnet]{cv.glmnet}}. \code{TRUE} corresponds to -#' \code{"lambda.min"} and \code{FALSE} corresponds to \code{"lambda.1se"}. -#' } -#' \item{\code{reduce_basis=NULL}}{A \code{numeric} value bounded in the open -#' interval (0,1) indicating the minimum proportion of ones in a basis -#' function column needed for the basis function to be included in the -#' procedure to fit the Lasso. Any basis functions with a lower proportion -#' of 1's than the specified cutoff will be removed. This argument defaults -#' to \code{NULL}, in which case all basis functions are used in the Lasso -#' stage of HAL. -#' } -#' \item{\code{return_lasso=TRUE}}{A \code{logical} indicating whether or not -#' to return the \code{\link[glmnet]{glmnet}} fit of the Lasso model. -#' } -#' \item{\code{return_x_basis=FALSE}}{A \code{logical} indicating whether or -#' not to return the matrix of (possibly reduced) basis functions used in -#' the HAL Lasso fit. -#' } -#' \item{\code{basis_list=NULL}}{The full set of basis functions generated -#' from the input data (from \code{\link[hal9001]{enumerate_basis}}). The -#' dimensionality of this structure is roughly (n * 2^(d - 1)), where n is -#' the number of observations and d is the number of columns in the input. -#' } -#' \item{\code{cv_select=TRUE}}{A \code{logical} specifying whether the array -#' of values specified should be passed to \code{\link[glmnet]{cv.glmnet}} -#' in order to pick the optimal value (based on cross-validation) (when set -#' to \code{TRUE}) or to fit along the sequence of values (or a single value -#' using \code{\link[glmnet]{glmnet}} (when set to \code{FALSE}). -#' } -#' \item{\code{...}}{Other parameters passed directly to -#' \code{\link[hal9001]{fit_hal}}. See its documentation for details. -#' } -#' } -# +#' - \code{...}: Arguments passed to \code{\link[hal9001]{fit_hal}}. See +#' it's documentation for details. +#' +#' @examples +#' data(cpp_imputed) +#' covs <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs") +#' task <- sl3_Task$new(cpp_imputed, covariates = covs, outcome = "haz") +#' +#' # instantiate with max 2-way interactions, 0-order splines, and binning +#' # (i.e., num_knots) that decreases with increasing interaction degree +#' hal_lrnr <- Lrnr_hal9001$new( +#' max_degree = 2, num_knots = c(20, 10), smoothness_orders = 0 +#' ) +#' hal_fit <- hal_lrnr$train(task) +#' hal_preds <- hal_fit$predict() Lrnr_hal9001 <- R6Class( - classname = "Lrnr_hal9001", inherit = Lrnr_base, - portable = TRUE, class = TRUE, + classname = "Lrnr_hal9001", + inherit = Lrnr_base, portable = TRUE, class = TRUE, public = list( - initialize = function(max_degree = 3, - fit_type = "glmnet", - n_folds = 10, - use_min = TRUE, - reduce_basis = NULL, - return_lasso = TRUE, - return_x_basis = FALSE, - basis_list = NULL, - cv_select = TRUE, - ...) { + initialize = function(...) { params <- args_to_list() super$initialize(params = params, ...) } ), private = list( .properties = c("continuous", "binomial", "weights", "ids"), - .train = function(task) { args <- self$params + args$X <- as.matrix(task$X) + outcome_type <- self$get_outcome_type(task) + args$Y <- outcome_type$format(task$Y) if (is.null(args$family)) { - args$family <- args$family <- outcome_type$glm_family() + args$family <- outcome_type$glm_family() } - args$X <- as.matrix(task$X) - args$Y <- outcome_type$format(task$Y) - args$yolo <- FALSE + if (!any(grepl("fit_control", names(args)))) { + args$fit_control <- list() + } + args$fit_control$foldid <- origami::folds2foldvec(task$folds) + + if (task$has_node("id")) { + args$id <- task$id + } if (task$has_node("weights")) { - args$weights <- task$weights + args$fit_control$weights <- task$weights } if (task$has_node("offset")) { args$offset <- task$offset } - if (task$has_node("id")) { - args$id <- task$id - } + # fit HAL, allowing glmnet-fitting arguments + other_valid <- c( + names(formals(glmnet::cv.glmnet)), names(formals(glmnet::glmnet)) + ) + + fit_object <- call_with_args( + hal9001::fit_hal, args, + other_valid = other_valid + ) - fit_object <- call_with_args(hal9001::fit_hal, args) return(fit_object) }, .predict = function(task = NULL) { - predictions <- predict(self$fit_object, new_data = as.matrix(task$X)) + predictions <- stats::predict( + self$fit_object, + new_data = data.matrix(task$X) + ) if (!is.na(safe_dim(predictions)[2])) { p <- ncol(predictions) colnames(predictions) <- sprintf("lambda_%0.3e", self$params$lambda) } return(predictions) }, - .required_packages = c("hal9001") + .required_packages = c("hal9001", "glmnet") ) ) diff --git a/R/survival_utils.R b/R/survival_utils.R index dd11429c..65fb0dec 100644 --- a/R/survival_utils.R +++ b/R/survival_utils.R @@ -32,11 +32,17 @@ pooled_hazard_task <- function(task, trim = TRUE) { repeated_data <- underlying_data[index, ] new_folds <- origami::id_folds_to_folds(task$folds, index) - repeated_task <- task$next_in_chain( - column_names = column_names, - data = repeated_data, id = "id", - folds = new_folds - ) + nodes <- task$nodes + nodes$id <- "id" + repeated_task <- sl3_Task$new(repeated_data, column_names = column_names, nodes = task$nodes, folds = new_folds, outcome_levels = outcome_levels, outcome_type = task$outcome_type$type) + # If "task" has a non-null row_index then this will fail. + # The next_in_chain function does not reset the row_index if data is passed in. + # So CV learners and pooled hazards don't work + # repeated_task <- task$next_in_chain( + # column_names = column_names, + # data = repeated_data, id = "id", + # folds = new_folds, row_index = NULL + # ) # make bin indicators bin_number <- rep(level_index, each = task$nrow) diff --git a/man/Lrnr_hal9001.Rd b/man/Lrnr_hal9001.Rd index f66d40d4..9d8c8203 100644 --- a/man/Lrnr_hal9001.Rd +++ b/man/Lrnr_hal9001.Rd @@ -69,6 +69,8 @@ in order to pick the optimal value (based on cross-validation) (when set to \code{TRUE}) or to fit along the sequence of values (or a single value using \code{\link[glmnet]{glmnet}} (when set to \code{FALSE}). } +\item{\code{squash=TRUE}}{A \code{logical} specifying whether to call \code{\link[hal9001]{squash_hal_fit}} on the returned hal9001 fit object. +} \item{\code{...}}{Other parameters passed directly to \code{\link[hal9001]{fit_hal}}. See its documentation for details. } diff --git a/vignettes/testing.Rmd b/vignettes/testing.Rmd new file mode 100644 index 00000000..2bc812f5 --- /dev/null +++ b/vignettes/testing.Rmd @@ -0,0 +1,89 @@ +--- +title: "Untitled" +output: html_document +--- + +```{r setup, include=FALSE} +knitr::opts_chunk$set(echo = TRUE) +``` + +```{r} +library(sl3) + + +``` + +```{r} +n <- 2500 +library(simcausal) +#library(sl3) +D <- DAG.empty() + + +D <- D + + node("W1f", distr = "runif", min = -1, max = 1) + + node("W2f", distr = "runif", min = -1, max = 1) + + node("W3f", distr = "runif", min = -1, max = 1) + + node("W1", distr = "rconst", const = W1f) + + node("W2", distr = "rconst", const = W2f) + + node("W3", distr = "rconst", const = W3f) + + node("g", distr = "rconst", const = 0.2 + 0.65*plogis(sin(W1*5) + W1*sin(W1*5) + cos(W2*5) + 2*W1*W2 - sin(W3*5) + sin(5*W1*W3) + 2*W1*W2*W3 + W3*sin(W1*5) + cos(W2*4)*sin(W1*5) ) ) + + node("A", distr = "rbinom", size = 1, prob = g )+ + node("gR", distr = "rconst", const = 2*(W1 + W2 + W3) + A*(W1 + W2 + W3 + W1*W2 + W2*W3 + W1*W3 ) + W1*W2 + W2*W3 + W1*W3 + W1^2 -W2^2 + W3^2 ) + + node("R", distr = "rnorm", mean = gR, sd = 1) + +setD <- set.DAG(D) +data <- sim(setD, n = n) +data + +``` + + +```{r} +#call_with_args <- sl3:::call_with_args +library(R6) +task <- sl3_Task$new(data, covariates = c("W1", "W2", "W3", "A"), outcome = "R") +task$data + +lrnr_ranger <- Lrnr_ranger$new(num.trees = 50, predict.all = TRUE ) +lrnr_ranger <- lrnr_ranger$train(task) +data.table::as.data.table(lrnr_ranger$predict(task)) +``` + + + +```{r} +lrnr_xgboost <- Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE ) +lrnr_xgboost <- lrnr_xgboost$train(task) +data.table::as.data.table(lrnr_xgboost$predict(task)) + +lrnr_xgboost_stacked <- make_learner(Pipeline, Lrnr_cv$new(Lrnr_xgboost$new(nrounds = 20, predict.all.trees = TRUE )), Lrnr_nnls$new(convex = FALSE)) +lrnr_xgboost_stacked <- lrnr_xgboost_stacked$train(task) +data.table::as.data.table(lrnr_xgboost_stacked$predict(task)) +``` + + +```{r} + +lrnr_xg_stack <- make_learner(Stack, Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 3 ), Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 5 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 7 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 10 )) +lrnr_xg_stack <- Lrnr_sl$new(lrnr_xg_stack, metalearner = Lrnr_$new()) +``` +```{r} +lrnr_stack <- make_learner(Stack, + Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 3 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 5 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 7 ), Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 10 ), lrnr_xg_stack) +lrnr_stack <- lrnr_stack$train(task) +preds <- lrnr_stack$predict(task) +as.data.frame(apply(preds - data$gR, 2, function(v) {mean(v^2)})) +#lrnr_cv <- Lrnr_cv$new(lrnr_stack) +#lrnr_cv <- lrnr_cv$train(task) +#lrnr_cv$cv_risk(loss_squared_error) +``` + + + + +