Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor simulation architecture #66

Merged
merged 9 commits into from
Feb 15, 2024
14 changes: 4 additions & 10 deletions R/sim_contacts.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,18 @@ sim_contacts <- function(contact_distribution,
population_age = population_age
)

chain <- .sim_bp_linelist(
contacts <- .sim_internal(
sim_type = "contacts",
contact_distribution = contact_distribution,
contact_interval = contact_interval,
prob_infect = prob_infect,
outbreak_start_date = outbreak_start_date,
min_outbreak_size = min_outbreak_size,
population_age = population_age,
contact_tracing_status_probs = contact_tracing_status_probs,
config = config
)

chain <- .add_names(.data = chain)

contacts <- .sim_contacts_tbl(
.data = chain,
contact_tracing_status_probs = contact_tracing_status_probs
)
row.names(chain) <- NULL

# return line list
# return contacts table
contacts
}
46 changes: 0 additions & 46 deletions R/sim_contacts_tbl.R

This file was deleted.

204 changes: 204 additions & 0 deletions R/sim_internal.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#' Internal simulation function called by the exported simulation functions
#' within \pkg{simulist}
#'
#' @description This internal function simulates a line list, and
#' when `sim_type` is `"contacts"` or `"outbreak"` a contacts table as well.
#'
#' @inheritParams sim_linelist
#' @inheritParams sim_contacts
#' @inheritParams .check_sim_input
#'
#' @return A `<data.frame>` if `sim_type` is `"linelist"` or `"contacts"`, or a
#' list of two `<data.frame>`s if `sim_type` is `"outbreak"`.
#' @keywords internal
.sim_internal <- function(sim_type = c("linelist", "contacts", "outbreak"),
contact_distribution,
contact_interval,
prob_infect,
onset_to_hosp = NULL,
onset_to_death = NULL,
hosp_risk = NULL,
hosp_death_risk = NULL,
non_hosp_death_risk = NULL,
outbreak_start_date,
add_names = NULL,
add_ct = NULL,
min_outbreak_size,
population_age,
case_type_probs = NULL,
contact_tracing_status_probs = NULL,
config) {
sim_type <- match.arg(sim_type)

outbreak_size <- 0
max_iter <- 0L
# condition on a minimum chain size
while (outbreak_size < min_outbreak_size) {
.data <- .sim_network_bp(
contact_distribution = contact_distribution,
contact_interval = contact_interval,
prob_infect = prob_infect,
config = config
)
outbreak_size <- sum(.data$infected == "infected")
max_iter <- max_iter + 1L
if (max_iter >= 1e3) {
stop(
"Exceeded maximum number of iterations for simulating outbreak. \n",
"Change input parameters or min_outbreak_size.",
call. = FALSE
)
}
}

names(.data)[names(.data) == "ancestor"] <- "infector"

# add delays dates
.data$date_onset <- .data$time + outbreak_start_date

# add exposure date for cases
id_time <- data.frame(infector = .data$id, infector_time = .data$time)

# left join infector time to data preserving column and row order
col_order <- c(colnames(.data), "infector_time")
.data <- merge(.data, id_time, by = "infector", all.x = TRUE)
.data <- .data[order(is.na(.data$infector), decreasing = TRUE), ]
.data <- .data[col_order]
row.names(.data) <- NULL

.data <- .add_date_contact(
.data = .data,
contact_type = "last",
distribution = config$last_contact_distribution,
config$last_contact_distribution_params,
outbreak_start_date = outbreak_start_date
)
.data <- .add_date_contact(
.data = .data,
contact_type = "first",
distribution = config$first_contact_distribution,
config$first_contact_distribution_params
)

# add random age and gender
.data$gender <- sample(c("m", "f"), replace = TRUE, size = nrow(.data))
if (is.data.frame(population_age)) {
age_groups <- apply(population_age, MARGIN = 1, function(x) x[1]:x[2])
sample_weight <- rep(population_age$proportion, times = lengths(age_groups))
# normalise for vector length
sample_weight <- sample_weight /
rep(lengths(age_groups), times = lengths(age_groups))
.data$age <- sample(
unlist(age_groups),
size = nrow(.data),
replace = TRUE,
prob = sample_weight
)
} else {
.data$age <- sample(
population_age[1]:population_age[2],
size = nrow(.data),
replace = TRUE
)
}

if (sim_type %in% c("linelist", "outbreak")) {
.data <- .add_hospitalisation(
.data = .data,
onset_to_hosp = onset_to_hosp,
hosp_risk = hosp_risk
)
.data <- .add_deaths(
.data = .data,
onset_to_death = onset_to_death,
hosp_death_risk = hosp_death_risk,
non_hosp_death_risk = non_hosp_death_risk
)

# add hospitalisation and death dates
.data$date_admission <- .data$hospitalisation + outbreak_start_date
.data$date_death <- .data$deaths + outbreak_start_date

linelist_cols <- c(
"id", "case_type", "gender", "age", "date_onset", "date_admission",
"date_death", "date_first_contact", "date_last_contact"
)

if (add_names) {
.data <- .add_names(.data = .data)
linelist_cols <- append(linelist_cols, "case_name", after = 1)
}

# add confirmed, probable, suspected case types
.data$case_type[.data$infected == "infected"] <- sample(
x = names(case_type_probs),
size = sum(.data$infected == "infected"),
replace = TRUE,
prob = case_type_probs
)

# add Ct if confirmed
if (add_ct) {
.data <- .add_ct(
.data = .data,
distribution = config$ct_distribution,
config$ct_distribution_params
)
linelist_cols <- c(linelist_cols, "ct_value")
}
}

if (sim_type %in% c("contacts", "outbreak")) {
if (!"infector_name" %in% colnames(.data)) {
.data <- .add_names(.data = .data)
}

contacts_tbl <- subset(
.data,
select = c(
"infector_name", "case_name", "age", "gender",
"date_first_contact", "date_last_contact"
)
)
colnames(contacts_tbl) <- c(
"from", "to", "age", "gender", "date_first_contact",
"date_last_contact"
)
contacts_tbl$was_case <- ifelse(
test = .data$infected == "infected",
yes = "Y",
no = "N"
)

# add contact tracing status
pick_N <- which(contacts_tbl$was_case == "N")
status_type <- sample(
names(contact_tracing_status_probs),
size = length(pick_N),
replace = TRUE,
prob = contact_tracing_status_probs
)
contacts_tbl$status <- "case"
contacts_tbl$status[pick_N] <- status_type

contacts_tbl <- contacts_tbl[-1, ]

row.names(contacts_tbl) <- NULL
}

if (sim_type == "contacts") {
return(contacts_tbl)
} else {
.data <- .data[.data$infected == "infected", ]
.data <- .data[, linelist_cols]
row.names(.data) <- NULL

switch(sim_type,
linelist = return(.data),
outbreak = return(list(
linelist = .data,
contacts = contacts_tbl
))
)
}
}
19 changes: 5 additions & 14 deletions R/sim_linelist.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,11 @@ sim_linelist <- function(contact_distribution,
)
}

chain <- .sim_bp_linelist(
linelist <- .sim_internal(
sim_type = "linelist",
contact_distribution = contact_distribution,
contact_interval = contact_interval,
prob_infect = prob_infect,
outbreak_start_date = outbreak_start_date,
min_outbreak_size = min_outbreak_size,
population_age = population_age,
config = config
)

linelist <- .sim_clinical_linelist(
chain = chain,
onset_to_hosp = onset_to_hosp,
onset_to_death = onset_to_death,
hosp_risk = hosp_risk,
Expand All @@ -234,14 +227,12 @@ sim_linelist <- function(contact_distribution,
outbreak_start_date = outbreak_start_date,
add_names = add_names,
add_ct = add_ct,
min_outbreak_size = min_outbreak_size,
population_age = population_age,
case_type_probs = case_type_probs,
config = config
)

linelist$chain <- linelist$chain[linelist$chain$infected == "infected", ]
chain <- linelist$chain[, linelist$cols]
row.names(chain) <- NULL

# return line list
chain
linelist
}
30 changes: 7 additions & 23 deletions R/sim_outbreak.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,11 @@ sim_outbreak <- function(contact_distribution,
)
}

chain <- .sim_bp_linelist(
outbreak <- .sim_internal(
sim_type = "outbreak",
contact_distribution = contact_distribution,
contact_interval = contact_interval,
prob_infect = prob_infect,
outbreak_start_date = outbreak_start_date,
min_outbreak_size = min_outbreak_size,
population_age = population_age,
config = config
)

linelist <- .sim_clinical_linelist(
chain = chain,
onset_to_hosp = onset_to_hosp,
onset_to_death = onset_to_death,
hosp_risk = hosp_risk,
Expand All @@ -159,22 +152,13 @@ sim_outbreak <- function(contact_distribution,
outbreak_start_date = outbreak_start_date,
add_names = add_names,
add_ct = add_ct,
min_outbreak_size = min_outbreak_size,
population_age = population_age,
case_type_probs = case_type_probs,
contact_tracing_status_probs = contact_tracing_status_probs,
config = config
)

contacts <- .sim_contacts_tbl(
.data = linelist$chain,
contact_tracing_status_probs = contact_tracing_status_probs
)

linelist$chain <- linelist$chain[linelist$chain$infected == "infected", ]
chain <- linelist$chain[, linelist$cols]
row.names(chain) <- NULL

# return outbreak data
list(
linelist = chain,
contacts = contacts
)
# return list of line list and contacts table
outbreak
}
Loading
Loading