Skip to content

Commit

Permalink
Merge pull request #293 from reichlab/issue-291
Browse files Browse the repository at this point in the history
make load_forecasts() results locale-independent
  • Loading branch information
Serena-Wang authored Aug 13, 2021
2 parents 0d7954d + f313d33 commit 793c460
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 14 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## Changes since last release
- sort `models` parameter in `load_forecasts()` and `load_latest_forecasts()` so that the resulting data frame is
locale-independent
- add `hub` parameter in `get_all_models()`. It does not support loading model names for ECDC hub from remote hub repo for now.

## covidHubUtils 0.1.6

This is a release focusing on new features in most of the major functions.
Expand Down
32 changes: 19 additions & 13 deletions R/get_all_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
#'
#' @param source string specifying where to get all valid model names
#' Currently support `"local_hub_repo"`, `"remote_hub_repo"` and `"zoltar"`.
#' @param hub character vector, where the first element indicates the hub
#' from which to load forecasts. Possible options are "US" and "ECDC"
#' @param hub_repo_path path to local clone of the `reichlab/covid19-forecast-hub`
#' repository
#' @importFrom httr GET stop_for_status content
#' @return a list of valid model names
#'
#' @export
get_all_models <- function(source = "zoltar", hub_repo_path) {

get_all_models <- function(source = "zoltar", hub = c("US", "ECDC"), hub_repo_path) {
# validate source
source <- match.arg(source,
choices = c("remote_hub_repo", "local_hub_repo", "zoltar"),
Expand All @@ -26,10 +27,16 @@ get_all_models <- function(source = "zoltar", hub_repo_path) {
stop("Error in get_all_models: data-processed subdirectory does not exist.")
}
models <- list.dirs(data_processed, full.names = FALSE)
models <- models[nchar(models) > 0]
models <- unique(models[nchar(models) > 0])
models <- sort(models, method = "radix")
} else if (source == "remote_hub_repo") {
if (hub[1] == "US") {
req_url <- "https://api.github.com/repos/reichlab/covid19-forecast-hub/git/trees/master?recursive=1"
} else if (hub[1] == "ECDC") {
stop("Error in get_all_models: loading all model names for ECDC hub from remote hub repo is not supported now.")
}
# set up remote hub repo request
req <- httr::GET("https://api.github.com/repos/reichlab/covid19-forecast-hub/git/trees/master?recursive=1")
req <- httr::GET(req_url)
httr::stop_for_status(req)

# get all files in data-processed/ from tree structure
Expand All @@ -45,21 +52,20 @@ get_all_models <- function(source = "zoltar", hub_repo_path) {
models <- sapply(folders, function(filename) {
unlist(strsplit(filename, "/"))[2]
})
models <- sort(unique(models), method = "radix")
} else if (source == "zoltar") {
# set up Zoltar connection
zoltar_connection <- zoltr::new_connection()
if (Sys.getenv("Z_USERNAME") == "" | Sys.getenv("Z_PASSWORD") == "") {
zoltr::zoltar_authenticate(zoltar_connection, "zoltar_demo", "Dq65&aP0nIlG")
} else {
zoltr::zoltar_authenticate(zoltar_connection, Sys.getenv("Z_USERNAME"), Sys.getenv("Z_PASSWORD"))
}
zoltar_connection <- setup_zoltar_connection(staging = FALSE)

# construct Zoltar project url
the_projects <- zoltr::projects(zoltar_connection)
project_url <- the_projects[the_projects$name == "COVID-19 Forecasts", "url"]
project_url <- get_zoltar_project_url(
hub = hub,
zoltar_connection = zoltar_connection
)

models <- zoltr::models(zoltar_connection, project_url)$model_abbr
models <- sort(unique(models), method = "radix")
}

return(unique(models))
return(models)
}
3 changes: 3 additions & 0 deletions R/load_forecasts_repo.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ load_forecasts_repo <- function(
# validate models
all_valid_models <- list.dirs(file_path, full.names = FALSE)
all_valid_models <- all_valid_models[nchar(all_valid_models) > 0]

if (!is.null(models)) {
models <- unlist(purrr::map(models, function(model) {
match.arg(model, choices = all_valid_models)
}))
} else {
models <- all_valid_models
}

models <- sort(models, method = "radix")

# get valid location codes
if (hub[1] == "US") {
Expand Down
1 change: 1 addition & 0 deletions R/load_forecasts_zoltar.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ load_forecasts_zoltar <- function(models = NULL,

if (is.null(models)){
models <- all_models$model_abbr
models <- sort(models, method = "radix")
}

# set 2 workers
Expand Down
3 changes: 3 additions & 0 deletions R/load_latest_forecasts_repo.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ load_latest_forecasts_repo <- function(file_path,
# validate models
all_valid_models <- list.dirs(file_path, full.names = FALSE)
all_valid_models <- all_valid_models[nchar(all_valid_models) > 0]

if (!is.null(models)) {
models <- unlist(purrr::map(models, function(model) {
match.arg(model, choices = all_valid_models)
}))
} else {
models <- all_valid_models
}

models <- sort(models, method = "radix")

# get valid location codes
if (hub[1] == "US") {
Expand Down
5 changes: 4 additions & 1 deletion man/get_all_models.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 793c460

Please sign in to comment.