Skip to content

Commit

Permalink
Default model C++ handles lists of vax and populations, WIP #167
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikunterwegs committed Mar 14, 2024
1 parent af9105f commit 389fdd8
Showing 1 changed file with 59 additions and 39 deletions.
98 changes: 59 additions & 39 deletions R/model_default.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ model_default_cpp <- function(population,
# check the time end and increment
# restrict increment to lower limit of 1e-6
checkmate::assert_integerish(time_end, lower = 0)
checkmate::assert_number(increment, lower = 1e-6, finite = TRUE)
checkmate::assert_number(increment, lower = 1e-3, finite = TRUE)

# check all vector lengths are equal or 1L
params <- list(
Expand All @@ -154,51 +154,65 @@ model_default_cpp <- function(population,
recovery_rate = recovery_rate,
time_end = time_end
)
stopifnot(
"All parameters must be of the same length, or must have length 1" =
test_recyclable(params)
)
# take parameter names here as names(DT) updates by reference!
param_names <- names(params)

# first check if `intervention` is a list of interventions or a list-of-lists
# Check if parameters can be recycled;
# Check if `population` is a single population or a list of such
# and convert to list for a data.table list column;
# also check if `intervention` is a list of interventions or a list-of-lists
# and convert to a list for a data.table list column. NULL is allowed;
# Check if `vaccination` is a single vaccination or a list
# and convert to a list for a data.table list column
is_lofints <- checkmate::test_list(
intervention, "intervention",
any.missing = FALSE, null.ok = TRUE
all.missing = FALSE, null.ok = TRUE
)
# allow some NULLs (a valid no intervention scenario) but not all NULLs
is_lofls <- checkmate::test_list(
intervention,
types = "list", null.ok = FALSE, any.missing = FALSE
types = c("list", "null"), all.missing = FALSE
) && all(
vapply(unlist(intervention, recursive = FALSE), is_intervention, TRUE)
vapply(
unlist(intervention, recursive = FALSE),
FUN = function(x) {
is_intervention(x) || is.null(x)
}, TRUE
)
)

stopifnot(
"All parameters must be of the same length, or must have length 1" =
test_recyclable(params),
"`population` must be a <population> or a list of <population>s" =
is_population(population) || checkmate::test_list(
population,
types = "population"
),
"`intervention` must be a list of <intervention>s or a list of such lists" =
is_lofints || is_lofls
is_lofints || is_lofls,
"`vaccination` must be a <vaccination> or a list of <vaccination>s" =
is_vaccination(vaccination) || checkmate::test_list(
vaccination,
type = c("vaccination", "null"), null.ok = TRUE
)
)
# standardise and prepare intervention lists

# make lists if not lists
if (is_population(population)) {
population <- list(population)
}
if (is_lofints) {
intervention <- .cross_check_intervention(
intervention, population,
c("contacts", "transmissibility", "infectiousness_rate", "recovery_rate")
)
# convert to list for data.table cross join
intervention <- list(intervention)
} else {
intervention <- lapply(
intervention, .cross_check_intervention, population,
c("contacts", "transmissibility", "infectiousness_rate", "recovery_rate")
)
}

# check the vaccination class
checkmate::assert_class(vaccination, "vaccination", null.ok = TRUE)
# make list after cross-checking
vaccination <- list(
.cross_check_vaccination(vaccination, population, doses = 1L)
)
if (is_vaccination(vaccination) || is.null(vaccination)) {
vaccination <- list(vaccination)
}

# check that time-dependence functions are passed as a list with at least the
# arguments `time` and `x`
# time must be before x, and they must be first two args
# arguments `time` and `x`, in order as the first two args
# NOTE: this functionality is not vectorised;
# convert to list for data.table list column
checkmate::assert_list(
time_dependence, "function",
null.ok = TRUE,
Expand All @@ -219,31 +233,37 @@ model_default_cpp <- function(population,

# collect parameters and add a parameter set identifier
params <- data.table::as.data.table(params)
params[, param_set := .I]
params[, "param_set" := .I]

# this nested data.table will be returned
model_output <- data.table::CJ(
population = list(population),
population = population,
intervention = intervention,
vaccination = vaccination,
time_dependence = time_dependence,
increment = increment,
sorted = FALSE
)
model_output[, scenario := .I]

# process the population, interventions, and vaccinations, after
# cross-checking them agains the relevant population
model_output[, args := apply(model_output, 1, function(x) {
.check_prepare_args_default(c(x))
})]
model_output[, "scenario" := .I]

# combine infection parameters and scenarios
# NOTE: join X[Y] must have params as X as list cols not supported for X
model_output <- params[, as.list(model_output), by = names(params)]

# collect model arguments in column data, then overwrite
model_output[, data := apply(model_output, 1, c)]
model_output[, data := lapply(data, function(l) {
tmp_pop_ <- l[["population"]]
l <- .prepare_args_model_default(l)
output_ <- .output_to_df(
model_output[, args := apply(model_output, 1, function(x) {
c(x[["args"]], x[param_names]) # avoid including col "param_set"
})]
model_output[, data := Map(population, args, f = function(p, l) {
.output_to_df(
do.call(.model_default_cpp, l),
population = tmp_pop_, # taken from local scope/env
population = p, # taken from local scope/env
compartments = compartments
)
})]
Expand Down

0 comments on commit 389fdd8

Please sign in to comment.