Skip to content

Commit

Permalink
Merge pull request #24 from signaturescience/glm
Browse files Browse the repository at this point in the history
glm
  • Loading branch information
vpnagraj authored Dec 21, 2021
2 parents dbbf3a0 + 3f8565f commit 15af0d1
Show file tree
Hide file tree
Showing 12 changed files with 543 additions and 2 deletions.
9 changes: 7 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Authors@R:
comment = c(ORCID = "0000-0001-9140-9028"))
)
Description: Miscellaneous functions for retrieving data, creating and evaluating time series
forecasting models for influenza-like illness (ILI) cases, deaths, and hospitalizations in
forecasting models for influenza-like illness (ILI) and influenza hospitalizations in
the United States.
License: GPL (>= 3)
Encoding: UTF-8
Expand All @@ -39,6 +39,11 @@ Imports:
RSocrata,
tibble,
tidyr,
tsibble
tsibble,
trending,
trendeval,
ggplot2,
evalcast,
yardstick
Depends:
R (>= 2.10)
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ export(get_cdc_hosp)
export(get_cdc_ili)
export(get_cdc_vax)
export(get_hdgov_hosp)
export(glm_fit)
export(glm_forecast)
export(glm_quibble)
export(glm_wrap)
export(is_monday)
export(make_tsibble)
export(plot_forc)
export(this_monday)
export(wis_score)
importFrom(magrittr,"%>%")
15 changes: 15 additions & 0 deletions R/fiphde.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,20 @@ if(getRversion() >= "2.15.1") utils::globalVariables(c(".",
"sea_label",
"monday",
"yweek",
"0.025",
"0.975",
"<NA>",
"error",
"estimate",
"model",
"lower",
"upper",
"lower_pi",
"upper_pi",
"flu.admits",
"quantile",
"truth",
"rmse",
"value",
"."))

208 changes: 208 additions & 0 deletions R/glm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#' Fit glm models
#'
#' This helper function is used in \link[fiphde]{glm_wrap} to fit a list of models and select the best one.
#'
#' @param .data Data including all explanatory and outcome variables needed for modeling; must include column for "location"
#' @param .models List of models defined as \link[trending]{trending_model} objects
#'
#' @return A `tibble` containing characteristics from the "best" `glm` model (i.e., the model from ".models" list with lowest RMSE). The columns in this `tibble` include:
#'
#'- model_class: The "type" of model for the best fit
#'- fit: The fitted model object for the best fit
#'- location: The geographic
#'- data: Original model fit data as a `tibble` in a list column
#'
#' @export
#'
#' @md
#'
glm_fit <- function(.data,
.models) {

dat <- .data

message(paste0("Location ... "))
message(unique(dat$location))

## evaluate models and use metrics provided in arg
res <-
trendeval::evaluate_models(
.models,
dat,
method = trendeval::evaluate_resampling,
metrics = list(yardstick::rmse)
)

## pull the best model by rmse
best_by_rmse <-
res %>%
# dplyr::filter(purrr::map_lgl(warning, is.null)) %>% # remove models that gave warnings
dplyr::filter(purrr::map_lgl(error, is.null)) %>% # remove models that errored
dplyr::slice_min(rmse) %>%
dplyr::select(model) %>%
purrr::pluck(1,1)

## fit the model
tmp_fit <-
best_by_rmse %>%
trending::fit(dat)

## construct tibble with model type, actual fit, and the location
ret <- dplyr::tibble(model_class = best_by_rmse$model_class,
fit = tmp_fit,
location = unique(dat$location),
data = tidyr::nest(dat, fit_data = dplyr::everything()))

message("Selected model ...")
message(as.character(ret$fit$fitted_model$family)[1])

message("Variables ...")
message(paste0(names(ret$fit$fitted_model$coefficients), collapse = " + "))
return(ret)

}


#' Get quantiles from prediction intervals
#'
#' This function runs the \link[trending]{predict.trending_model_fit} method on a fitted model at specified values of "alpha" in order to create a range of prediction intervals. The processing also includes steps to convert the alpha to corresponding quantile values at upper and lower bounds.
#'
#' @param fit Fitted model object from \link[fiphde]{glm_fit}
#' @param new_data Tibble with new data on which the \link[trending]{predict.trending_model_fit} method should run
#' @param alpha Vector specifying the threshold(s) to be used for prediction intervals; alpha of `0.05` would correspond to 95% PI; default is `c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2` to range of intervals
#'
#' @return A tibble with predicted values at each quantile (lower and upper bound for each value of "alpha")
#' @export
#' @md
#'
glm_quibble <- function(fit, new_data, alpha = c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2) {

## get the quantiles from the alpha
q_lower <- alpha/2
q_upper <- 1 - q_lower

## run the predict method on the fitted model
## use the given alpha
fit %>%
stats::predict(new_data, alpha = alpha) %>%
## get just the prediction interval bounds ...
## index (time column must be named index) ...
dplyr::select(epiyear, epiweek, lower_pi, upper_pi) %>%
## reshape so that its in long format
tidyr::gather(quantile, value, lower_pi:upper_pi) %>%
## and subout out lower_pi/upper_pi for the appropriate quantile
dplyr::mutate(quantile = ifelse(quantile == "lower_pi", q_lower, q_upper))
}


#' Forecast glm models
#'
#' This function uses fitted model object from \link[fiphde]{glm_fit} and future covariate data to create probablistic forecasts at specific quantiles derived from the "alpha" parameter.
#'
#' @param .data Data including all explanatory and outcome variables needed for modeling
#' @param new_covariates Tibble with one column per covariate, and n rows for n horizons being forecasted
#' @param fit Fitted model object from \link[fiphde]{glm_fit}
#' @param alpha Vector specifying the threshold(s) to be used for prediction intervals; alpha of `0.05` would correspond to 95% PI; default is `c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2` to range of intervals
#'
#' @return Tibble with forecasts (quantiles and point estimates)
#' @export
#' @md
#'
glm_forecast <- function(.data, new_covariates = NULL, fit, alpha = c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2) {

## get the last date from the data provided
last_date <-
.data %>%
dplyr::arrange(location, epiweek, epiyear) %>%
dplyr::mutate(date = MMWRweek::MMWRweek2Date(epiyear, epiweek)) %>%
dplyr::pull(date) %>%
utils::tail(1)

tmp <- .data
## set up "new data" with the epiweek/epiyear week being forecasted ...
## and calculation of lagged flu admissions (1 to 4 weeks back)
new_data <-
dplyr::tibble(lag_1 = utils::tail(tmp$flu.admits, 4)[4],
lag_2 = utils::tail(tmp$flu.admits, 4)[3],
lag_3 = utils::tail(tmp$flu.admits, 4)[2],
lag_4 = utils::tail(tmp$flu.admits, 4)[1],
epiweek = lubridate::epiweek(last_date + 7),
epiyear = lubridate::epiyear(last_date + 7))

## this should allow for a constant
if(!is.null(new_covariates)) {
new_data <-
cbind(new_data,new_covariates)
}
# ## take the fit object provided and use predict
point_estimates <-
fit %>%
stats::predict(new_data) %>%
dplyr::select(epiweek, epiyear, estimate) %>%
dplyr::mutate(estimate = round(estimate)) %>%
dplyr::mutate(quantile = NA) %>%
dplyr::select(epiweek, epiyear, quantile, value = estimate)

## map the quibble function over the alphas
quants <- purrr::map_df(alpha, .f = function(x) glm_quibble(fit = fit, new_data = new_data, alpha = x))

## prep data
dplyr::bind_rows(point_estimates,quants) %>%
dplyr::arrange(epiyear,epiweek, quantile) %>%
dplyr::left_join(new_data, by = c("epiweek","epiyear")) %>%
dplyr::select(epiweek,epiyear,quantile,value)
}

#' Run glm modeling and forecasting
#'
#' This is a wrapper function that pipelines influenza hospitalization modeling (\link[fiphde]{glm_fit}) and forecasting (\link[fiphde]{glm_forecast}).
#'
#' @param .data Data including all explanatory and outcome variables needed for modeling
#' @param .models List of models defined as \link[trending]{trending_model} objects
#' @param new_covariates Tibble with one column per covariate, and n rows for n horizons being forecasted
#' @param horizon Number of weeks ahead for forecasting
#' @param alpha Vector specifying the threshold(s) to be used for prediction intervals; alpha of `0.05` would correspond to 95% PI; default is `c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2` to range of intervals
#'
#' @return Named list with two elements:
#'
#' - model: Output from \link[fiphde]{glm_fit} with selected model fit
#' - forecasts: Output from \link[fiphde]{glm_forecast} with forecasts from each horizon combined as a single tibble
#'
#' @export
#' @md
glm_wrap <- function(.data, .models, new_covariates = NULL, horizon = 4, alpha = c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2) {

tmp_fit <- glm_fit(.data, .models = .models)

stopifnot(nrow(new_covariates) == horizon)
message("Forecasting 1 week ahead")
tmp_forc <-
glm_forecast(.data = .data, new_covariates = new_covariates[1,], fit = tmp_fit$fit, alpha = alpha)

forc_list <- list()
forc_list[[1]] <- tmp_forc

if(horizon > 1) {

for(i in 2:horizon) {

message(sprintf("Forecasting %d week ahead",i))

prev_weeks <-
do.call("rbind", forc_list) %>%
## get point estimate ... which will have NA quantile value
dplyr::filter(is.na(quantile)) %>%
dplyr::rename(flu.admits = value) %>%
dplyr::select(-quantile) %>%
dplyr::mutate(flu.admits.cov = NA, location = "US")

forc_list[[i]] <- glm_forecast(.data = dplyr::bind_rows(.data, prev_weeks),
new_covariates = new_covariates[i,],
fit = tmp_fit$fit,
alpha = alpha)

}
}
forc_res <- do.call("rbind", forc_list)
return(list(model = tmp_fit, forecasts = forc_res))
}
58 changes: 58 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,61 @@ this_monday <- function() {
is_monday <- function() {
lubridate::wday(lubridate::today(), label=TRUE) %in% c("Mon")
}

#' Calculate WIS score
#'
#' Helper function to calculate weighted interval score (WIS) for prepped forecasts
#'
#' @param .forecasts Tibble with prepped foreacsts
#' @param .test Tibble with test data including observed value for flu admissions stored in "flu.admits" column
#'
#' @return Tibble with the WIS for each combination of epiweek and epiyear
#' @export
#'
wis_score <- function(.forecasts, .test) {
.forecasts %>%
dplyr::left_join(.test) %>%
dplyr::select(epiweek,epiyear,quantile,value,flu.admits) %>%
dplyr::group_by(epiweek, epiyear) %>%
dplyr::summarise(wis = evalcast::weighted_interval_score(quantile = quantile, value = value, actual_value = flu.admits))
}

#' Plot forecasts against observed data
#'
#' This helper function creates a plot to visualize the forecasted point estimates (and 95% prediction interval) alongside the observed data.
#'
#' @param .forecasts Tibble with prepped forecasts
#' @param .train Tibble with data used for modeling
#' @param .test Tibble with observed data held out from modeling
#'
#' @return A `ggplot2` plot object
#' @export
#'
#'
plot_forc <- function(.forecasts, .train, .test) {

forc_dat <-
.forecasts %>%
dplyr::filter(quantile %in% c(NA,0.025,0.975)) %>%
tidyr::spread(quantile,value) %>%
dplyr::rename(lower = `0.025`, upper = `0.975`, mean = `<NA>`)

.test %>%
dplyr::bind_rows(.train) %>%
dplyr::select(epiweek,epiyear, truth = flu.admits, location) %>%
dplyr::left_join(forc_dat) %>%
dplyr::mutate(date = MMWRweek::MMWRweek2Date(epiyear, epiweek)) %>%
ggplot2::ggplot() +
ggplot2::geom_line(ggplot2::aes(date,truth), lwd = 2, col = "black") +
ggplot2::geom_line(ggplot2::aes(date,mean), lwd = 2, alpha = 0.5, lty = "solid", col = "firebrick") +
ggplot2::geom_ribbon(ggplot2::aes(date, ymin = lower, ymax = upper), alpha = 0.25, fill = "firebrick") +
## get an upper limit from whatever the max of observed or forcasted hospitalizations is
ggplot2::scale_y_continuous(limits = c(0,max(c(.test$flu.admits, .train$flu.admits, forc_dat$upper)))) +
ggplot2::scale_x_date(date_labels = "%Y-%m", date_breaks = "month") +
ggplot2::labs(x = "Date", y = "Count", title = "Influenza hospitalizations") +
ggplot2::theme_minimal() +
ggplot2::facet_wrap(~ location)

}


25 changes: 25 additions & 0 deletions man/glm_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions man/glm_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 15af0d1

Please sign in to comment.