Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

glm #24

Merged
merged 6 commits into from
Dec 21, 2021
Merged

glm #24

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Authors@R:
comment = c(ORCID = "0000-0001-9140-9028"))
)
Description: Miscellaneous functions for retrieving data, creating and evaluating time series
forecasting models for influenza-like illness (ILI) cases, deaths, and hospitalizations in
forecasting models for influenza-like illness (ILI) and influenza hospitalizations in
the United States.
License: GPL (>= 3)
Encoding: UTF-8
Expand All @@ -39,6 +39,11 @@ Imports:
RSocrata,
tibble,
tidyr,
tsibble
tsibble,
trending,
trendeval,
ggplot2,
evalcast,
yardstick
Depends:
R (>= 2.10)
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ export(get_cdc_hosp)
export(get_cdc_ili)
export(get_cdc_vax)
export(get_hdgov_hosp)
export(glm_fit)
export(glm_forecast)
export(glm_quibble)
export(glm_wrap)
export(is_monday)
export(make_tsibble)
export(plot_forc)
export(this_monday)
export(wis_score)
importFrom(magrittr,"%>%")
15 changes: 15 additions & 0 deletions R/fiphde.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,20 @@ if(getRversion() >= "2.15.1") utils::globalVariables(c(".",
"sea_label",
"monday",
"yweek",
"0.025",
"0.975",
"<NA>",
"error",
"estimate",
"model",
"lower",
"upper",
"lower_pi",
"upper_pi",
"flu.admits",
"quantile",
"truth",
"rmse",
"value",
"."))

208 changes: 208 additions & 0 deletions R/glm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#' Fit glm models
#'
#' This helper function is used in \link[fiphde]{glm_wrap} to fit a list of models and select the best one.
#'
#' @param .data Data including all explanatory and outcome variables needed for modeling; must include column for "location"
#' @param .models List of models defined as \link[trending]{trending_model} objects
#'
#' @return A `tibble` containing characteristics from the "best" `glm` model (i.e., the model from ".models" list with lowest RMSE). The columns in this `tibble` include:
#'
#'- model_class: The "type" of model for the best fit
#'- fit: The fitted model object for the best fit
#'- location: The geographic
#'- data: Original model fit data as a `tibble` in a list column
#'
#' @export
#'
#' @md
#'
glm_fit <- function(.data,
.models) {

dat <- .data

message(paste0("Location ... "))
message(unique(dat$location))

## evaluate models and use metrics provided in arg
res <-
trendeval::evaluate_models(
.models,
dat,
method = trendeval::evaluate_resampling,
metrics = list(yardstick::rmse)
)

## pull the best model by rmse
best_by_rmse <-
res %>%
# dplyr::filter(purrr::map_lgl(warning, is.null)) %>% # remove models that gave warnings
dplyr::filter(purrr::map_lgl(error, is.null)) %>% # remove models that errored
dplyr::slice_min(rmse) %>%
dplyr::select(model) %>%
purrr::pluck(1,1)

## fit the model
tmp_fit <-
best_by_rmse %>%
trending::fit(dat)

## construct tibble with model type, actual fit, and the location
ret <- dplyr::tibble(model_class = best_by_rmse$model_class,
fit = tmp_fit,
location = unique(dat$location),
data = tidyr::nest(dat, fit_data = dplyr::everything()))

message("Selected model ...")
message(as.character(ret$fit$fitted_model$family)[1])

message("Variables ...")
message(paste0(names(ret$fit$fitted_model$coefficients), collapse = " + "))
return(ret)

}


#' Get quantiles from prediction intervals
#'
#' This function runs the \link[trending]{predict.trending_model_fit} method on a fitted model at specified values of "alpha" in order to create a range of prediction intervals. The processing also includes steps to convert the alpha to corresponding quantile values at upper and lower bounds.
#'
#' @param fit Fitted model object from \link[fiphde]{glm_fit}
#' @param new_data Tibble with new data on which the \link[trending]{predict.trending_model_fit} method should run
#' @param alpha Vector specifying the threshold(s) to be used for prediction intervals; alpha of `0.05` would correspond to 95% PI; default is `c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2` to range of intervals
#'
#' @return A tibble with predicted values at each quantile (lower and upper bound for each value of "alpha")
#' @export
#' @md
#'
glm_quibble <- function(fit, new_data, alpha = c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2) {

## get the quantiles from the alpha
q_lower <- alpha/2
q_upper <- 1 - q_lower

## run the predict method on the fitted model
## use the given alpha
fit %>%
stats::predict(new_data, alpha = alpha) %>%
## get just the prediction interval bounds ...
## index (time column must be named index) ...
dplyr::select(epiyear, epiweek, lower_pi, upper_pi) %>%
## reshape so that its in long format
tidyr::gather(quantile, value, lower_pi:upper_pi) %>%
## and subout out lower_pi/upper_pi for the appropriate quantile
dplyr::mutate(quantile = ifelse(quantile == "lower_pi", q_lower, q_upper))
}


#' Forecast glm models
#'
#' This function uses fitted model object from \link[fiphde]{glm_fit} and future covariate data to create probablistic forecasts at specific quantiles derived from the "alpha" parameter.
#'
#' @param .data Data including all explanatory and outcome variables needed for modeling
#' @param new_covariates Tibble with one column per covariate, and n rows for n horizons being forecasted
#' @param fit Fitted model object from \link[fiphde]{glm_fit}
#' @param alpha Vector specifying the threshold(s) to be used for prediction intervals; alpha of `0.05` would correspond to 95% PI; default is `c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2` to range of intervals
#'
#' @return Tibble with forecasts (quantiles and point estimates)
#' @export
#' @md
#'
glm_forecast <- function(.data, new_covariates = NULL, fit, alpha = c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2) {

## get the last date from the data provided
last_date <-
.data %>%
dplyr::arrange(location, epiweek, epiyear) %>%
dplyr::mutate(date = MMWRweek::MMWRweek2Date(epiyear, epiweek)) %>%
dplyr::pull(date) %>%
utils::tail(1)

tmp <- .data
## set up "new data" with the epiweek/epiyear week being forecasted ...
## and calculation of lagged flu admissions (1 to 4 weeks back)
new_data <-
dplyr::tibble(lag_1 = utils::tail(tmp$flu.admits, 4)[4],
lag_2 = utils::tail(tmp$flu.admits, 4)[3],
lag_3 = utils::tail(tmp$flu.admits, 4)[2],
lag_4 = utils::tail(tmp$flu.admits, 4)[1],
epiweek = lubridate::epiweek(last_date + 7),
epiyear = lubridate::epiyear(last_date + 7))

## this should allow for a constant
if(!is.null(new_covariates)) {
new_data <-
cbind(new_data,new_covariates)
}
# ## take the fit object provided and use predict
point_estimates <-
fit %>%
stats::predict(new_data) %>%
dplyr::select(epiweek, epiyear, estimate) %>%
dplyr::mutate(estimate = round(estimate)) %>%
dplyr::mutate(quantile = NA) %>%
dplyr::select(epiweek, epiyear, quantile, value = estimate)

## map the quibble function over the alphas
quants <- purrr::map_df(alpha, .f = function(x) glm_quibble(fit = fit, new_data = new_data, alpha = x))

## prep data
dplyr::bind_rows(point_estimates,quants) %>%
dplyr::arrange(epiyear,epiweek, quantile) %>%
dplyr::left_join(new_data, by = c("epiweek","epiyear")) %>%
dplyr::select(epiweek,epiyear,quantile,value)
}

#' Run glm modeling and forecasting
#'
#' This is a wrapper function that pipelines influenza hospitalization modeling (\link[fiphde]{glm_fit}) and forecasting (\link[fiphde]{glm_forecast}).
#'
#' @param .data Data including all explanatory and outcome variables needed for modeling
#' @param .models List of models defined as \link[trending]{trending_model} objects
#' @param new_covariates Tibble with one column per covariate, and n rows for n horizons being forecasted
#' @param horizon Number of weeks ahead for forecasting
#' @param alpha Vector specifying the threshold(s) to be used for prediction intervals; alpha of `0.05` would correspond to 95% PI; default is `c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2` to range of intervals
#'
#' @return Named list with two elements:
#'
#' - model: Output from \link[fiphde]{glm_fit} with selected model fit
#' - forecasts: Output from \link[fiphde]{glm_forecast} with forecasts from each horizon combined as a single tibble
#'
#' @export
#' @md
glm_wrap <- function(.data, .models, new_covariates = NULL, horizon = 4, alpha = c(0.01, 0.025, seq(0.05, 0.45, by = 0.05)) * 2) {

tmp_fit <- glm_fit(.data, .models = .models)

stopifnot(nrow(new_covariates) == horizon)
message("Forecasting 1 week ahead")
tmp_forc <-
glm_forecast(.data = .data, new_covariates = new_covariates[1,], fit = tmp_fit$fit, alpha = alpha)

forc_list <- list()
forc_list[[1]] <- tmp_forc

if(horizon > 1) {

for(i in 2:horizon) {

message(sprintf("Forecasting %d week ahead",i))

prev_weeks <-
do.call("rbind", forc_list) %>%
## get point estimate ... which will have NA quantile value
dplyr::filter(is.na(quantile)) %>%
dplyr::rename(flu.admits = value) %>%
dplyr::select(-quantile) %>%
dplyr::mutate(flu.admits.cov = NA, location = "US")

forc_list[[i]] <- glm_forecast(.data = dplyr::bind_rows(.data, prev_weeks),
new_covariates = new_covariates[i,],
fit = tmp_fit$fit,
alpha = alpha)

}
}
forc_res <- do.call("rbind", forc_list)
return(list(model = tmp_fit, forecasts = forc_res))
}
58 changes: 58 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,61 @@ this_monday <- function() {
is_monday <- function() {
lubridate::wday(lubridate::today(), label=TRUE) %in% c("Mon")
}

#' Calculate WIS score
#'
#' Helper function to calculate weighted interval score (WIS) for prepped forecasts
#'
#' @param .forecasts Tibble with prepped foreacsts
#' @param .test Tibble with test data including observed value for flu admissions stored in "flu.admits" column
#'
#' @return Tibble with the WIS for each combination of epiweek and epiyear
#' @export
#'
wis_score <- function(.forecasts, .test) {
.forecasts %>%
dplyr::left_join(.test) %>%
dplyr::select(epiweek,epiyear,quantile,value,flu.admits) %>%
dplyr::group_by(epiweek, epiyear) %>%
dplyr::summarise(wis = evalcast::weighted_interval_score(quantile = quantile, value = value, actual_value = flu.admits))
}

#' Plot forecasts against observed data
#'
#' This helper function creates a plot to visualize the forecasted point estimates (and 95% prediction interval) alongside the observed data.
#'
#' @param .forecasts Tibble with prepped forecasts
#' @param .train Tibble with data used for modeling
#' @param .test Tibble with observed data held out from modeling
#'
#' @return A `ggplot2` plot object
#' @export
#'
#'
plot_forc <- function(.forecasts, .train, .test) {

forc_dat <-
.forecasts %>%
dplyr::filter(quantile %in% c(NA,0.025,0.975)) %>%
tidyr::spread(quantile,value) %>%
dplyr::rename(lower = `0.025`, upper = `0.975`, mean = `<NA>`)

.test %>%
dplyr::bind_rows(.train) %>%
dplyr::select(epiweek,epiyear, truth = flu.admits, location) %>%
dplyr::left_join(forc_dat) %>%
dplyr::mutate(date = MMWRweek::MMWRweek2Date(epiyear, epiweek)) %>%
ggplot2::ggplot() +
ggplot2::geom_line(ggplot2::aes(date,truth), lwd = 2, col = "black") +
ggplot2::geom_line(ggplot2::aes(date,mean), lwd = 2, alpha = 0.5, lty = "solid", col = "firebrick") +
ggplot2::geom_ribbon(ggplot2::aes(date, ymin = lower, ymax = upper), alpha = 0.25, fill = "firebrick") +
## get an upper limit from whatever the max of observed or forcasted hospitalizations is
ggplot2::scale_y_continuous(limits = c(0,max(c(.test$flu.admits, .train$flu.admits, forc_dat$upper)))) +
ggplot2::scale_x_date(date_labels = "%Y-%m", date_breaks = "month") +
ggplot2::labs(x = "Date", y = "Count", title = "Influenza hospitalizations") +
ggplot2::theme_minimal() +
ggplot2::facet_wrap(~ location)

}


25 changes: 25 additions & 0 deletions man/glm_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions man/glm_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading