Skip to content

Commit

Permalink
merge with devel
Browse files Browse the repository at this point in the history
  • Loading branch information
rachaelvp committed Jun 23, 2021
2 parents 3aab316 + 14b1630 commit b41ed5d
Show file tree
Hide file tree
Showing 24 changed files with 349 additions and 68 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ As of April 2021:
* Support for the custom lasso procedure implemented in `Rcpp` has been
discontinued. Accordingly, the `"lassi"` option and argument `fit_type` have
been removed from `fit_hal`.
* Re-added `lambda.min.ratio` as a `fit_control` argument to `fit_hal`. We've
seen that not setting `lambda.min.ratio` in `glmnet` can lead to no `lambda`
values that fit the data sufficiently well, so it seems appropriate to
override the `glmnet` default.

# hal9001 0.3.0

Expand Down
22 changes: 16 additions & 6 deletions R/hal.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@
#' \code{\link[glmnet]{cv.glmnet}}. When \code{TRUE}, \code{"lambda.min"} is
#' used; otherwise, \code{"lambda.1se"}. Only used when
#' \code{cv_select = TRUE}.
#' - \code{lambda.min.ratio}: A \code{\link[glmnet]{glmnet}} argument specifying
#' the smallest value for \code{lambda}, as a fraction of \code{lambda.max},
#' the (data derived) entry value (i.e. the smallest value for which all
#' coefficients are zero). We've seen that not setting \code{lambda.min.ratio}
#' can lead to no \code{lambda} values that fit the data sufficiently well.
#' - \code{prediction_bounds}: A vector of size two that provides the lower and
#' upper bounds for predictions. When \code{prediction_bounds = "default"},
#' the predictions are bounded between \code{min(Y) - sd(Y)} and
Expand Down Expand Up @@ -173,8 +178,8 @@ fit_hal <- function(X,
X_unpenalized = NULL,
max_degree = ifelse(ncol(X) >= 20, 2, 3),
smoothness_orders = 1,
num_knots = sapply(seq_len(max_degree),
num_knots_generator,
num_knots = num_knots_generator(
max_degree = max_degree,
smoothness_orders = smoothness_orders,
base_num_knots_0 = 500,
base_num_knots_1 = 200
Expand All @@ -189,6 +194,7 @@ fit_hal <- function(X,
n_folds = 10,
foldid = NULL,
use_min = TRUE,
lambda.min.ratio = 1e-4,
prediction_bounds = "default"
),
formula_control = list(
Expand All @@ -206,7 +212,7 @@ fit_hal <- function(X,
# errors when a supplied control list is missing arguments
defaults <- list(
cv_select = TRUE, n_folds = 10, foldid = NULL, use_min = TRUE,
prediction_bounds = "default"
lambda.min.ratio = 1e-4, prediction_bounds = "default"
)
if (any(!names(defaults) %in% names(fit_control))) {
fit_control <- c(
Expand Down Expand Up @@ -472,12 +478,16 @@ fit_hal <- function(X,
#' the basis function.
#'
#' @keywords internal
num_knots_generator <- function(d, smoothness_orders, base_num_knots_0 = 500,
num_knots_generator <- function(max_degree, smoothness_orders, base_num_knots_0 = 500,
base_num_knots_1 = 200) {
if (all(smoothness_orders > 0)) {
return(round(base_num_knots_1 / 2^(d - 1)))
return(sapply(seq_len(max_degree), function(d) {
round(base_num_knots_1 / 2^(d - 1))
}))
}
else {
return(round(base_num_knots_0 / 2^(d - 1)))
return(sapply(seq_len(max_degree), function(d) {
round(base_num_knots_0 / 2^(d - 1))
}))
}
}
4 changes: 2 additions & 2 deletions R/summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ summary.hal9001 <- function(object,
if (lambda != object$lambda_star) {
if (is.null(object$lasso_fit)) {
stop(
"Coefficients for specified lamdba do not exist, or are not ",
"Coefficients for specified lambda do not exist, or are not ",
"accessible since the fit of the lasso model was not returned ",
"(i.e., return_lasso was set to FALSE in `hal_fit()`)."
)
Expand All @@ -85,7 +85,7 @@ summary.hal9001 <- function(object,
coefs <- object$coefs
if (length(lambda) > 1) {
warning(
"Coefficients for many lamdba exist --\n",
"Coefficients for many lambda exist --\n",
"Summarizing coefficients corresponding to minimum lambda."
)
lambda_idx <- which.min(lambda)
Expand Down
1 change: 1 addition & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ knitr::opts_chunk$set(
[![Coverage Status](https://img.shields.io/codecov/c/github/tlverse/hal9001/master.svg)](https://codecov.io/github/tlverse/hal9001?branch=master)
[![CRAN](https://www.r-pkg.org/badges/version/hal9001)](https://www.r-pkg.org/pkg/hal9001)
[![CRAN downloads](https://cranlogs.r-pkg.org/badges/hal9001)](https://CRAN.R-project.org/package=hal9001)
[![CRAN total downloads](http://cranlogs.r-pkg.org/badges/grand-total/hal9001)](https://CRAN.R-project.org/package=hal9001)
[![Project Status: Active – The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
[![License: GPL v3](https://img.shields.io/badge/License-GPL%20v3-blue.svg)](http://www.gnu.org/licenses/gpl-3.0)
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3558313.svg)](https://doi.org/10.5281/zenodo.3558313)
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Status](https://img.shields.io/codecov/c/github/tlverse/hal9001/master.svg)](htt
[![CRAN](https://www.r-pkg.org/badges/version/hal9001)](https://www.r-pkg.org/pkg/hal9001)
[![CRAN
downloads](https://cranlogs.r-pkg.org/badges/hal9001)](https://CRAN.R-project.org/package=hal9001)
[![CRAN total
downloads](http://cranlogs.r-pkg.org/badges/grand-total/hal9001)](https://CRAN.R-project.org/package=hal9001)
[![Project Status: Active – The project has reached a stable, usable
state and is being actively
developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
Expand Down Expand Up @@ -111,7 +113,7 @@ hal_fit$times
# training sample prediction
preds <- predict(hal_fit, new_data = x)
mean(hal_mse <- (preds - y)^2)
#> [1] 0.0357806
#> [1] 0.03481173
```

-----
Expand Down Expand Up @@ -167,7 +169,7 @@ See file `LICENSE` for details.

## References

<div id="refs" class="references hanging-indent">
<div id="refs" class="references">

<div id="ref-benkeser2016hal">

Expand Down
Loading

0 comments on commit b41ed5d

Please sign in to comment.