Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 105: Generate test forecasts and scoring output #148

Merged
merged 25 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f1de36c
ignore *.pickle
SamuelBrand1 Nov 12, 2024
d40c81c
Create README.md
SamuelBrand1 Nov 12, 2024
59e138c
Create data_for_model_fit.json
SamuelBrand1 Nov 12, 2024
d0b1b9d
Merge branch 'main' into add-demo-data
SamuelBrand1 Nov 12, 2024
abe1aff
catch bad default n-chains
SamuelBrand1 Nov 12, 2024
eeb09ae
Merge branch 'main' into add-demo-data
SamuelBrand1 Nov 13, 2024
c18807d
reorg dir structure
SamuelBrand1 Nov 13, 2024
ffa6809
add priors.py to the test folder
SamuelBrand1 Nov 18, 2024
3b88d5c
Merge branch 'main' into add-demo-data
SamuelBrand1 Nov 19, 2024
d9cf985
Merge branch 'main' into add-demo-data
SamuelBrand1 Nov 19, 2024
0646a73
Update priors.py
SamuelBrand1 Nov 19, 2024
817f6c4
rename test folder to match name pattern
SamuelBrand1 Nov 19, 2024
e2bf19f
add inference data -> parquet step
SamuelBrand1 Nov 19, 2024
7801236
Delete inference_data.nc
SamuelBrand1 Nov 19, 2024
a6f4ba5
add test dir data
SamuelBrand1 Nov 19, 2024
3b0ee90
catch wrong format
SamuelBrand1 Nov 19, 2024
e69c167
reformat eval_data.tsv
SamuelBrand1 Nov 20, 2024
d106910
full test run bash script
SamuelBrand1 Nov 20, 2024
7644efa
reformat to long format
SamuelBrand1 Nov 20, 2024
6c9222f
change to common R dep pattern
SamuelBrand1 Nov 20, 2024
f0be95a
Merge branch 'main' into add-scoring-data
SamuelBrand1 Nov 20, 2024
59804bb
rename dirs
SamuelBrand1 Nov 20, 2024
a1d576a
Merge branch 'main' into add-scoring-data
SamuelBrand1 Nov 21, 2024
aae759b
reorg test folders
SamuelBrand1 Nov 21, 2024
e4076a3
update README for test directories
SamuelBrand1 Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
*.xls
*.xlsx
*.rds
*.pickle
*.nc

# Documents
*.doc
Expand Down Expand Up @@ -395,3 +397,7 @@ notebooks/*.md
private_data/*
*_files/
.vscode/settings.json

# Test data exceptions to the general data exclusion
!pipelines/tests/covid-19_r_run_test_inference/TestDir/data.csv
!pipelines/tests/covid-19_r_run_test_inference/TestDir/eval_data.tsv
29 changes: 21 additions & 8 deletions pipelines/convert_inferencedata_to_parquet.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
library(forecasttools)
library(readr)
library(arrow)
library(fs)
library(argparser)
library(dplyr)
library(stringr)
library(tidyr)
script_packages <- c(
"argparser",
"arrow",
"dplyr",
"forecasttools",
"fs",
"ggplot2",
"lubridate",
"readr",
"scoringutils",
"stringr",
"tidyr"
)

## load in packages without messages
purrr::walk(script_packages, \(pkg) {
suppressPackageStartupMessages(
library(pkg, character.only = TRUE)
)
})


tidy_and_save_mcmc <- function(model_run_dir,
file_name_prefix = "",
Expand Down
1 change: 1 addition & 0 deletions pipelines/iteration_helpers/loop_postprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ BASE_DIR="$1"
for SUBDIR in "$BASE_DIR"/*/; do
# Run the R script with the current subdirectory as the model_dir argument
echo "$SUBDIR"
Rscript convert_inferencedata_to_parquet.R "$SUBDIR"
Rscript postprocess_state_forecast.R "$SUBDIR"
done
10 changes: 10 additions & 0 deletions pipelines/tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Test data folder

This folder is aimed at running test-mode scripts for validating the inference
pipeline on the test data. The test data is stored in subdirectories.

To run the test scripts, execute the following command from the `pipelines` directory:

```bash
% bash ./tests/test_run.sh ./tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs 1000 28
```

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax.numpy as jnp
import numpyro.distributions as dist
import pyrenew.transformation as transformation
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

i0_first_obs_n_rv = DistributionalVariable(
"i0_first_obs_n_rv",
dist.Beta(1, 10),
)

initialization_rate_rv = DistributionalVariable(
"rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0)
)

r_logmean = jnp.log(1)
r_logsd = jnp.log(jnp.sqrt(2))

log_r_mu_intercept_rv = DistributionalVariable(
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd)
)

eta_sd_rv = DistributionalVariable(
"eta_sd", dist.TruncatedNormal(0.04, 0.02, low=0)
)

autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40))


inf_feedback_strength_rv = TransformedVariable(
"inf_feedback",
DistributionalVariable(
"inf_feedback_raw",
dist.LogNormal(jnp.log(50), jnp.log(2)),
),
transforms=transformation.AffineTransform(loc=0, scale=-1),
)
# Could be reparameterized?

p_ed_visit_mean_rv = DistributionalVariable(
"p_ed_visit_mean",
dist.Normal(
transformation.SigmoidTransform().inv(0.005),
0.3,
),
) # logit scale


p_ed_visit_w_sd_rv = DistributionalVariable(
"p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0)
)


autoreg_p_ed_visit_rv = DistributionalVariable(
"autoreg_p_ed_visit_rv", dist.Beta(1, 100)
)

ed_visit_wday_effect_rv = TransformedVariable(
"ed_visit_wday_effect",
DistributionalVariable(
"ed_visit_wday_effect_raw",
dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])),
),
transformation.AffineTransform(loc=0, scale=7),
)

# Based on looking at some historical posteriors.
phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1))
36 changes: 36 additions & 0 deletions pipelines/tests/test_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash

# Check if the base directory is provided as an argument
if [ -z "$1" ]; then
echo "Usage: $0 <base_dir>"
exit 1
fi

# Base directory containing subdirectories
BASE_DIR="$1"
N_SAMPLES=$2
N_AHEAD=$3

# Iterate over each subdirectory in the base directory
echo "TEST-MODE: Running loop over subdirectories in $BASE_DIR"
echo "For $N_SAMPLES samples on 1 chain, and $N_AHEAD forecast points"
for SUBDIR in "$BASE_DIR"/*/; do
echo "TEST-MODE: Inference for $SUBDIR"
python fit_model.py "$SUBDIR" --n-chains 1 --n-samples $N_SAMPLES
echo "TEST-MODE: Finished inference"
echo "TEST-MODE: Generating posterior predictions for $SUBDIR"
python generate_predictive.py "$SUBDIR" --n-forecast-points $N_AHEAD
echo "TEST-MODE: Finished generating posterior predictions"
echo "TEST-MODE: Converting inferencedata to parquet for $SUBDIR"
Rscript convert_inferencedata_to_parquet.R "$SUBDIR"
echo "TEST-MODE: Finished converting inferencedata to parquet"
echo "TEST-MODE: Forecasting baseline models for $SUBDIR"
Rscript timeseries_forecasts.R "$SUBDIR" --n-forecast-days $N_AHEAD --n-samples $N_SAMPLES
echo "TEST-MODE: Finished forecasting baseline models"
echo "TEST-MODE: Postprocessing state forecast for $SUBDIR"
Rscript postprocess_state_forecast.R "$SUBDIR"
echo "TEST-MODE: Finished postprocessing state forecast"
echo "TEST-MODE: Scoring forecast for $SUBDIR"
Rscript score_forecast.R "$SUBDIR"
echo "TEST-MODE: Finished scoring forecast"
done
2 changes: 1 addition & 1 deletion pipelines/timeseries_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ p <- arg_parser(
"Forecast other (non-target-disease) ED visits for a given location."
) |>
add_argument(
"model-run-dir",
"model_run_dir",
help = "Directory containing the model data and output.",
) |>
add_argument(
Expand Down
Loading