-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.R
220 lines (201 loc) · 8.59 KB
/
utils.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
#' Make `tsibble`
#'
#' @description
#'
#' This function converts an input `tibble` with columns for \link[lubridate]{epiyear} and \link[lubridate]{epiweek} into a \link[tsibble]{tsibble} object. The `tsibble` has columns specifying indices for the time series as well as a date for the Monday of the epiyear/epiweek combination at each row. Users can optionally ignore the current week when generating the `tsibble` via the "chop" argument.
#'
#' @param df A `tibble` containing columns `epiyear` and `epiweek`.
#' @param chop Logical indicating whether or not to remove the most current week (default `TRUE`).
#' @return A `tsibble` containing additional columns `monday` indicating the date
#' for the Monday of that epiweek, and `yweek` (a yearweek vctr class object)
#' that indexes the `tsibble` in 1 week increments.
#' @export
#' @md
make_tsibble <- function(df, chop=TRUE) {
out <- df %>%
# get the monday that starts the MMWRweek
dplyr::mutate(monday=MMWRweek::MMWRweek2Date(MMWRyear=epiyear, MMWRweek=epiweek, MMWRday=2), .after="epiweek") %>%
# convert represent as yearweek (see package?tsibble)
dplyr::mutate(yweek=tsibble::yearweek(monday), .after="monday") %>%
# convert to tsibble
tsibble::as_tsibble(index=yweek, key=location)
# Remove the incomplete week
if (chop) out <- utils::head(out, -1)
return(out)
}
#' Get Monday
#'
#' @description
#'
#' This function is a helper to get the date for the Monday of the current week.
#'
#' @return Date for the Monday of the current week. For more details see \link[lubridate]{floor_date}.
#' @export
#' @md
#'
this_monday <- function() {
lubridate::floor_date(lubridate::today(), "weeks", week_start = 1)
}
#' Check Monday
#'
#' @description
#'
#' This is a helper function to see if today is Monday.
#
#' @return Logical indicating whether or not today is Monday
#' @export
#' @md
is_monday <- function() {
lubridate::wday(lubridate::today(), label=TRUE) %in% c("Mon")
}
#' Visualize forecast output
#'
#' @description
#'
#' This function serves as a plotting mechanism for prepped forecast submission data (see \link[focustools]{format_for_submission}). Using truth data supplied, the plots show the historical trajectory of each outcome along with the point estimates for forecasts. Optionally, the user can include 50% prediction interval as well. Plots include trajectories of incident cases, incident deaths, and cumulative deaths faceted by location.
#'
#'
#' @param .data Historical truth data for all locations and outcomes in submission targets
#' @param submission Formatted submission
#' @param location Vector specifying locations to filter to; `'US'` by default.
#' @param target Vector specifying target(s) to plot; default is `c('Incident Cases','Incident Deaths','Cumulative Deaths')`
#' @param pi Logical as to whether or not the plot should include 50% prediction interval; default is `TRUE`
#'
#' @return A `ggplot2` plot object with line plots for outcome trajectories faceted by location
#'
#' @md
#' @export
#'
plot_forecast <- function(.data, submission, target=c('Incident Cases','Incident Deaths','Cumulative Deaths'), location="US", pi = TRUE) {
## pretty sure we need to add an intermediary variable for the filter below
## otherwise the condition will interpret as the column name not the vector ... i think?
loc <- location
tmp_target <- target
# Check that the specified location is in the data and submission.
stopifnot("Specified location is not in recorded data" = loc %in% unique(.data$location))
stopifnot("Specified location is not in forecast data" = loc %in% unique(submission$location))
# Grab the real data
real <-
.data %>%
tibble::as_tibble() %>%
dplyr::filter(location %in% loc) %>%
tidyr::gather(target, value, -epiyear,-epiweek,-monday,-yweek,-location) %>%
# tidyr::gather(target, value, icases, ccases, ideaths, cdeaths) %>%
dplyr::mutate(target = target %>% stringr::str_remove_all("s$") %>% stringr::str_replace_all(c("^i"="inc ", "^c"="cum "))) %>%
dplyr::select(location, date=monday, target, point=value) %>%
dplyr::mutate(type="recorded") %>%
dplyr::filter(type!="cum case")
# Grab the forecasted data
forecasted <-
submission %>%
dplyr::filter(type=="point" | quantile==.25 | quantile==.75) %>%
dplyr::filter(location %in% loc) %>%
dplyr::mutate(quantile=tidyr::replace_na(quantile, "point")) %>%
dplyr::select(-type) %>%
tidyr::separate(target, into=c("nwk", "target"), sep=" wk ahead ") %>%
dplyr::select(location, date=target_end_date, target, quantile, value) %>%
tidyr::spread(quantile, value) %>%
dplyr::mutate(type="forecast")
# Bind them
bound <-
dplyr::bind_rows(real, forecasted) %>%
dplyr::arrange(date, location) %>%
dplyr::filter(location %in% loc) %>%
dplyr::mutate(target =
dplyr::case_when(target == "inc case" ~ "Incident Cases",
target == "inc death" ~ "Incident Deaths",
target == "cum case" ~ "Cumulative Cases",
target == "cum death" ~ "Cumulative Deaths")
)
## get location *names* rather than code
bound <-
bound %>%
dplyr::left_join(dplyr::select(locations, location, location_name), by = "location") %>%
dplyr::select(-location) %>%
dplyr::rename(location = location_name)
# Plot
p <-
bound %>%
## exclude cumulative cases from plot
# dplyr::filter(target != "Cumulative Cases") %>%
## filter to only include targets passed in
dplyr::filter(target %in% tmp_target) %>%
ggplot2::ggplot(ggplot2::aes(date, point)) +
ggplot2::geom_line(ggplot2::aes(col=type)) +
ggplot2::scale_y_continuous(labels = scales::number_format(big.mark = ",")) +
ggplot2::facet_wrap(~location + target, scales="free", ncol = 3) +
ggplot2::theme_bw() +
ggplot2::labs(x = "Date", y = NULL) +
ggplot2::theme(legend.position = "Bottom", legend.title = ggplot2::element_blank())
if(pi) {
p <-
p +
ggplot2::geom_ribbon(ggplot2::aes(fill = type, ymin = `0.25`, ymax = `0.75`),
alpha = 0.5, color="lightpink", data=dplyr::filter(bound, type == "forecast"))
}
return(p)
}
#' Reshape data for submission summary
#'
#' @description
#'
#' This unexported helper function is used in \link[focustools]{submission_summary}. It spreads forecast targets to a wide format and forces "US" locations to be at the top of the resulting `tibble`.
#'
#' @param .data Tibble with submission data
#' @param ... Additional arguments passed to \link[tidyr]{spread}
#'
#' @return A `tibble` with wide summary data.
#'
#' @md
#'
spread_value <- function(.data, ...) {
## quietly ...
suppressMessages({
tmp <-
## spread the data
tidyr::spread(.data, ...) %>%
## then get the location names
dplyr::left_join(dplyr::select(locations, location, location_name)) %>%
dplyr::select(-location)
})
## one more piece of logic to get "Previous" column before w ahead columns if need be
if("Previous" %in% names(tmp)) {
tmp <-
tmp %>%
dplyr::select(location = location_name, Previous, dplyr::everything())
} else {
tmp <-
tmp %>%
dplyr::select(location = location_name, dplyr::everything())
}
## if US is in there put it on top
if("US" %in% tmp$location) {
tmp <-
dplyr::bind_rows(dplyr::filter(tmp, location == "US"), dplyr::filter(tmp, location !="US"))
}
}
#' Extract ARIMA parameters
#'
#' @description
#'
#' Extracts ARIMA model parameters, including p, d, q, P, D, Q, and results from \link[broom]{tidy} and \link[broom]{glance} on an ARIMA model object.
#'
#' @param arimafit A single-row mable (`mdl_df`) from `fabeltools::model(arima=ARIMA(...))`.
#'
#' @return A single-row `tibble` containing ARIMA model parameter and diagnostic information.
#' @md
#' @export
extract_arima_params <- function(arimafit) {
if (!("mdl_df" %in% class(arimafit))) stop("Input must be a mdl_df (mable) from fabletools::model().")
if (nrow(arimafit)>1) stop("Input mdl_df must have only one row (one location).")
if (names(arimafit)[1]!="location") stop("Input mdl_df must have location.")
if (class(arimafit[[2]][[1]]$fit)!="ARIMA") stop("Model must be ARIMA.")
.tidy <- fabletools::tidy(arimafit)
.glance <- fabletools::glance(arimafit)
.broom <- dplyr::inner_join(.tidy, .glance, by=c("location", ".model"))
.params <- arimafit[[2]][[1]]$fit$spec
dplyr::bind_cols(.params, .broom) %>%
dplyr::select(location, .model, dplyr::everything()) %>%
dplyr::select_if(is.atomic) %>%
dplyr::inner_join(locations %>% dplyr::select(location, abbreviation, location_name), ., by="location")
}