Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pooled_hazard_task bug #333

Open
wants to merge 10 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 65 additions & 87 deletions R/Lrnr_hal9001.R
Original file line number Diff line number Diff line change
@@ -1,137 +1,115 @@
#' The Scalable Highly Adaptive Lasso
#' Scalable Highly Adaptive Lasso (HAL)
#'
#' The Highly Adaptive Lasso is an estimation procedure that generates a design
#' matrix consisting of basis functions corresponding to covariates and
#' interactions of covariates and fits Lasso regression to this (usually) very
#' wide matrix, recovering a nonparametric functional form that describes the
#' target prediction function as a composition of subset functions with finite
#' variation norm. This implementation uses \pkg{hal9001}, which provides both
#' a custom implementation (based on \pkg{origami}) of the cross-validated
#' lasso as well the standard call to \code{\link[glmnet]{cv.glmnet}} from the
#' \pkg{glmnet}.
#' The Highly Adaptive Lasso (HAL) is a nonparametric regression function that
#' has been demonstrated to optimally estimate functions with bounded (finite)
#' variation norm. The algorithm proceeds by first building an adaptive basis
#' (i.e., the HAL basis) based on indicator basis functions (or higher-order
#' spline basis functions) representing covariates and interactions of the
#' covariates up to a pre-specified degree. The fitting procedures included in
#' this learner use \code{\link[hal9001]{fit_hal}} from the \pkg{hal9001}
#' package. For details on HAL regression, consider consulting the following
#' \insertCite{benkeser2016hal;textual}{sl3}),
#' \insertCite{coyle2020hal9001-rpkg;textual}{sl3}),
#' \insertCite{hejazi2020hal9001-joss;textual}{sl3}).
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom origami folds2foldvec
#' @importFrom stats predict quasibinomial
#'
#' @export
#'
#' @keywords data
#'
#' @return Learner object with methods for training and prediction. See
#' \code{\link{Lrnr_base}} for documentation on learners.
#' @return A learner object inheriting from \code{\link{Lrnr_base}} with
#' methods for training and prediction. For a full list of learner
#' functionality, see the complete documentation of \code{\link{Lrnr_base}}.
#'
#' @format \code{\link{R6Class}} object.
#' @format An \code{\link[R6]{R6Class}} object inheriting from
#' \code{\link{Lrnr_base}}.
#'
#' @family Learners
#'
#' @section Parameters:
#' \describe{
#' \item{\code{max_degree=3}}{ The highest order of interaction
#' terms for which the basis functions ought to be generated. The default
#' corresponds to generating basis functions up to all 3-way interactions of
#' covariates in the input matrix, matching the default in \pkg{hal9001}.
#' }
#' \item{\code{fit_type="glmnet"}}{The specific routine to be called when
#' fitting the Lasso regression in a cross-validated manner. Choosing the
#' \code{"glmnet"} option calls either \code{\link[glmnet]{cv.glmnet}} or
#' \code{\link[glmnet]{glmnet}}.
#' }
#' \item{\code{n_folds=10}}{Integer for the number of folds to be used
#' when splitting the data for cross-validation. This defaults to 10 as this
#' is the convention for V-fold cross-validation.
#' }
#' \item{\code{use_min=TRUE}}{Determines which lambda is selected from
#' \code{\link[glmnet]{cv.glmnet}}. \code{TRUE} corresponds to
#' \code{"lambda.min"} and \code{FALSE} corresponds to \code{"lambda.1se"}.
#' }
#' \item{\code{reduce_basis=NULL}}{A \code{numeric} value bounded in the open
#' interval (0,1) indicating the minimum proportion of ones in a basis
#' function column needed for the basis function to be included in the
#' procedure to fit the Lasso. Any basis functions with a lower proportion
#' of 1's than the specified cutoff will be removed. This argument defaults
#' to \code{NULL}, in which case all basis functions are used in the Lasso
#' stage of HAL.
#' }
#' \item{\code{return_lasso=TRUE}}{A \code{logical} indicating whether or not
#' to return the \code{\link[glmnet]{glmnet}} fit of the Lasso model.
#' }
#' \item{\code{return_x_basis=FALSE}}{A \code{logical} indicating whether or
#' not to return the matrix of (possibly reduced) basis functions used in
#' the HAL Lasso fit.
#' }
#' \item{\code{basis_list=NULL}}{The full set of basis functions generated
#' from the input data (from \code{\link[hal9001]{enumerate_basis}}). The
#' dimensionality of this structure is roughly (n * 2^(d - 1)), where n is
#' the number of observations and d is the number of columns in the input.
#' }
#' \item{\code{cv_select=TRUE}}{A \code{logical} specifying whether the array
#' of values specified should be passed to \code{\link[glmnet]{cv.glmnet}}
#' in order to pick the optimal value (based on cross-validation) (when set
#' to \code{TRUE}) or to fit along the sequence of values (or a single value
#' using \code{\link[glmnet]{glmnet}} (when set to \code{FALSE}).
#' }
#' \item{\code{...}}{Other parameters passed directly to
#' \code{\link[hal9001]{fit_hal}}. See its documentation for details.
#' }
#' }
#
#' - \code{...}: Arguments passed to \code{\link[hal9001]{fit_hal}}. See
#' it's documentation for details.
#'
#' @examples
#' data(cpp_imputed)
#' covs <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs")
#' task <- sl3_Task$new(cpp_imputed, covariates = covs, outcome = "haz")
#'
#' # instantiate with max 2-way interactions, 0-order splines, and binning
#' # (i.e., num_knots) that decreases with increasing interaction degree
#' hal_lrnr <- Lrnr_hal9001$new(
#' max_degree = 2, num_knots = c(20, 10), smoothness_orders = 0
#' )
#' hal_fit <- hal_lrnr$train(task)
#' hal_preds <- hal_fit$predict()
Lrnr_hal9001 <- R6Class(
classname = "Lrnr_hal9001", inherit = Lrnr_base,
portable = TRUE, class = TRUE,
classname = "Lrnr_hal9001",
inherit = Lrnr_base, portable = TRUE, class = TRUE,
public = list(
initialize = function(max_degree = 3,
fit_type = "glmnet",
n_folds = 10,
use_min = TRUE,
reduce_basis = NULL,
return_lasso = TRUE,
return_x_basis = FALSE,
basis_list = NULL,
cv_select = TRUE,
...) {
initialize = function(...) {
params <- args_to_list()
super$initialize(params = params, ...)
}
),
private = list(
.properties = c("continuous", "binomial", "weights", "ids"),

.train = function(task) {
args <- self$params

args$X <- as.matrix(task$X)

outcome_type <- self$get_outcome_type(task)
args$Y <- outcome_type$format(task$Y)

if (is.null(args$family)) {
args$family <- args$family <- outcome_type$glm_family()
args$family <- outcome_type$glm_family()
}

args$X <- as.matrix(task$X)
args$Y <- outcome_type$format(task$Y)
args$yolo <- FALSE
if (!any(grepl("fit_control", names(args)))) {
args$fit_control <- list()
}
args$fit_control$foldid <- origami::folds2foldvec(task$folds)

if (task$has_node("id")) {
args$id <- task$id
}

if (task$has_node("weights")) {
args$weights <- task$weights
args$fit_control$weights <- task$weights
}

if (task$has_node("offset")) {
args$offset <- task$offset
}

if (task$has_node("id")) {
args$id <- task$id
}
# fit HAL, allowing glmnet-fitting arguments
other_valid <- c(
names(formals(glmnet::cv.glmnet)), names(formals(glmnet::glmnet))
)

fit_object <- call_with_args(
hal9001::fit_hal, args,
other_valid = other_valid
)

fit_object <- call_with_args(hal9001::fit_hal, args)
return(fit_object)
},
.predict = function(task = NULL) {
predictions <- predict(self$fit_object, new_data = as.matrix(task$X))
predictions <- stats::predict(
self$fit_object,
new_data = data.matrix(task$X)
)
if (!is.na(safe_dim(predictions)[2])) {
p <- ncol(predictions)
colnames(predictions) <- sprintf("lambda_%0.3e", self$params$lambda)
}
return(predictions)
},
.required_packages = c("hal9001")
.required_packages = c("hal9001", "glmnet")
)
)
16 changes: 11 additions & 5 deletions R/survival_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ pooled_hazard_task <- function(task, trim = TRUE) {
repeated_data <- underlying_data[index, ]
new_folds <- origami::id_folds_to_folds(task$folds, index)

repeated_task <- task$next_in_chain(
column_names = column_names,
data = repeated_data, id = "id",
folds = new_folds
)
nodes <- task$nodes
nodes$id <- "id"
repeated_task <- sl3_Task$new(repeated_data, column_names = column_names, nodes = task$nodes, folds = new_folds, outcome_levels = outcome_levels, outcome_type = task$outcome_type$type)
# If "task" has a non-null row_index then this will fail.
# The next_in_chain function does not reset the row_index if data is passed in.
# So CV learners and pooled hazards don't work
# repeated_task <- task$next_in_chain(
# column_names = column_names,
# data = repeated_data, id = "id",
# folds = new_folds, row_index = NULL
# )

# make bin indicators
bin_number <- rep(level_index, each = task$nrow)
Expand Down
2 changes: 2 additions & 0 deletions man/Lrnr_hal9001.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

89 changes: 89 additions & 0 deletions vignettes/testing.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
---
title: "Untitled"
output: html_document
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```

```{r}
library(sl3)


```

```{r}
n <- 2500
library(simcausal)
#library(sl3)
D <- DAG.empty()


D <- D +
node("W1f", distr = "runif", min = -1, max = 1) +
node("W2f", distr = "runif", min = -1, max = 1) +
node("W3f", distr = "runif", min = -1, max = 1) +
node("W1", distr = "rconst", const = W1f) +
node("W2", distr = "rconst", const = W2f) +
node("W3", distr = "rconst", const = W3f) +
node("g", distr = "rconst", const = 0.2 + 0.65*plogis(sin(W1*5) + W1*sin(W1*5) + cos(W2*5) + 2*W1*W2 - sin(W3*5) + sin(5*W1*W3) + 2*W1*W2*W3 + W3*sin(W1*5) + cos(W2*4)*sin(W1*5) ) ) +
node("A", distr = "rbinom", size = 1, prob = g )+
node("gR", distr = "rconst", const = 2*(W1 + W2 + W3) + A*(W1 + W2 + W3 + W1*W2 + W2*W3 + W1*W3 ) + W1*W2 + W2*W3 + W1*W3 + W1^2 -W2^2 + W3^2 ) +
node("R", distr = "rnorm", mean = gR, sd = 1)

setD <- set.DAG(D)
data <- sim(setD, n = n)
data

```


```{r}
#call_with_args <- sl3:::call_with_args
library(R6)
task <- sl3_Task$new(data, covariates = c("W1", "W2", "W3", "A"), outcome = "R")
task$data

lrnr_ranger <- Lrnr_ranger$new(num.trees = 50, predict.all = TRUE )
lrnr_ranger <- lrnr_ranger$train(task)
data.table::as.data.table(lrnr_ranger$predict(task))
```



```{r}
lrnr_xgboost <- Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE )
lrnr_xgboost <- lrnr_xgboost$train(task)
data.table::as.data.table(lrnr_xgboost$predict(task))

lrnr_xgboost_stacked <- make_learner(Pipeline, Lrnr_cv$new(Lrnr_xgboost$new(nrounds = 20, predict.all.trees = TRUE )), Lrnr_nnls$new(convex = FALSE))
lrnr_xgboost_stacked <- lrnr_xgboost_stacked$train(task)
data.table::as.data.table(lrnr_xgboost_stacked$predict(task))
```


```{r}

lrnr_xg_stack <- make_learner(Stack, Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 3 ), Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 5 ),
Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 7 ),
Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 10 ))
lrnr_xg_stack <- Lrnr_sl$new(lrnr_xg_stack, metalearner = Lrnr_$new())
```
```{r}
lrnr_stack <- make_learner(Stack,
Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 3 ),
Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 5 ),
Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 7 ), Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 10 ), lrnr_xg_stack)
lrnr_stack <- lrnr_stack$train(task)
preds <- lrnr_stack$predict(task)
as.data.frame(apply(preds - data$gR, 2, function(v) {mean(v^2)}))
#lrnr_cv <- Lrnr_cv$new(lrnr_stack)
#lrnr_cv <- lrnr_cv$train(task)
#lrnr_cv$cv_risk(loss_squared_error)
```