From 389fdd80cffe846d60bd7ef7ba24327d9aa9f269 Mon Sep 17 00:00:00 2001 From: pratikunterwegs Date: Mon, 12 Feb 2024 16:18:10 +0000 Subject: [PATCH] Default model C++ handles lists of vax and populations, WIP #167 --- R/model_default.R | 98 ++++++++++++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 39 deletions(-) diff --git a/R/model_default.R b/R/model_default.R index 8fedfc06..b8c85129 100644 --- a/R/model_default.R +++ b/R/model_default.R @@ -145,7 +145,7 @@ model_default_cpp <- function(population, # check the time end and increment # restrict increment to lower limit of 1e-6 checkmate::assert_integerish(time_end, lower = 0) - checkmate::assert_number(increment, lower = 1e-6, finite = TRUE) + checkmate::assert_number(increment, lower = 1e-3, finite = TRUE) # check all vector lengths are equal or 1L params <- list( @@ -154,51 +154,65 @@ model_default_cpp <- function(population, recovery_rate = recovery_rate, time_end = time_end ) - stopifnot( - "All parameters must be of the same length, or must have length 1" = - test_recyclable(params) - ) + # take parameter names here as names(DT) updates by reference! + param_names <- names(params) - # first check if `intervention` is a list of interventions or a list-of-lists + # Check if parameters can be recycled; + # Check if `population` is a single population or a list of such + # and convert to list for a data.table list column; + # also check if `intervention` is a list of interventions or a list-of-lists + # and convert to a list for a data.table list column. NULL is allowed; + # Check if `vaccination` is a single vaccination or a list + # and convert to a list for a data.table list column is_lofints <- checkmate::test_list( intervention, "intervention", - any.missing = FALSE, null.ok = TRUE + all.missing = FALSE, null.ok = TRUE ) + # allow some NULLs (a valid no intervention scenario) but not all NULLs is_lofls <- checkmate::test_list( intervention, - types = "list", null.ok = FALSE, any.missing = FALSE + types = c("list", "null"), all.missing = FALSE ) && all( - vapply(unlist(intervention, recursive = FALSE), is_intervention, TRUE) + vapply( + unlist(intervention, recursive = FALSE), + FUN = function(x) { + is_intervention(x) || is.null(x) + }, TRUE + ) ) + stopifnot( + "All parameters must be of the same length, or must have length 1" = + test_recyclable(params), + "`population` must be a or a list of s" = + is_population(population) || checkmate::test_list( + population, + types = "population" + ), "`intervention` must be a list of s or a list of such lists" = - is_lofints || is_lofls + is_lofints || is_lofls, + "`vaccination` must be a or a list of s" = + is_vaccination(vaccination) || checkmate::test_list( + vaccination, + type = c("vaccination", "null"), null.ok = TRUE + ) ) - # standardise and prepare intervention lists + + # make lists if not lists + if (is_population(population)) { + population <- list(population) + } if (is_lofints) { - intervention <- .cross_check_intervention( - intervention, population, - c("contacts", "transmissibility", "infectiousness_rate", "recovery_rate") - ) - # convert to list for data.table cross join intervention <- list(intervention) - } else { - intervention <- lapply( - intervention, .cross_check_intervention, population, - c("contacts", "transmissibility", "infectiousness_rate", "recovery_rate") - ) } - - # check the vaccination class - checkmate::assert_class(vaccination, "vaccination", null.ok = TRUE) - # make list after cross-checking - vaccination <- list( - .cross_check_vaccination(vaccination, population, doses = 1L) - ) + if (is_vaccination(vaccination) || is.null(vaccination)) { + vaccination <- list(vaccination) + } # check that time-dependence functions are passed as a list with at least the - # arguments `time` and `x` - # time must be before x, and they must be first two args + # arguments `time` and `x`, in order as the first two args + # NOTE: this functionality is not vectorised; + # convert to list for data.table list column checkmate::assert_list( time_dependence, "function", null.ok = TRUE, @@ -219,31 +233,37 @@ model_default_cpp <- function(population, # collect parameters and add a parameter set identifier params <- data.table::as.data.table(params) - params[, param_set := .I] + params[, "param_set" := .I] # this nested data.table will be returned model_output <- data.table::CJ( - population = list(population), + population = population, intervention = intervention, vaccination = vaccination, time_dependence = time_dependence, increment = increment, sorted = FALSE ) - model_output[, scenario := .I] + + # process the population, interventions, and vaccinations, after + # cross-checking them agains the relevant population + model_output[, args := apply(model_output, 1, function(x) { + .check_prepare_args_default(c(x)) + })] + model_output[, "scenario" := .I] # combine infection parameters and scenarios # NOTE: join X[Y] must have params as X as list cols not supported for X model_output <- params[, as.list(model_output), by = names(params)] # collect model arguments in column data, then overwrite - model_output[, data := apply(model_output, 1, c)] - model_output[, data := lapply(data, function(l) { - tmp_pop_ <- l[["population"]] - l <- .prepare_args_model_default(l) - output_ <- .output_to_df( + model_output[, args := apply(model_output, 1, function(x) { + c(x[["args"]], x[param_names]) # avoid including col "param_set" + })] + model_output[, data := Map(population, args, f = function(p, l) { + .output_to_df( do.call(.model_default_cpp, l), - population = tmp_pop_, # taken from local scope/env + population = p, # taken from local scope/env compartments = compartments ) })]