Skip to content

Commit

Permalink
make rollout_iterations and lead_time_hours params
Browse files Browse the repository at this point in the history
  • Loading branch information
14renus committed Jan 28, 2025
1 parent df3efe5 commit 3fd8093
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 8 deletions.
7 changes: 4 additions & 3 deletions geoarches/configs/dataloader/era5.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
dataset:
_target_: geoarches.dataloaders.era5.Era5Forecast
path: data/era5_240/full/
lead_time_hours: 24 # mixed
lead_time_hours: 24
multistep: ${oc.select:module.rollout_iterations,1}
norm_scheme: pangu
load_prev: True

validation_args:
multistep: 2
multistep: ${oc.select:module.validation.rollout_iterations,1}

test_args:
multistep: 2
multistep: ${module.inference.rollout_iterations}
5 changes: 5 additions & 0 deletions geoarches/configs/module/archesweather.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ module:
use_graphcast_coeffs: True
use_prev: True
loss_delta_normalization: True
lead_time_hours: ${dataloader.dataset.lead_time_hours}
rollout_iterations: 1

inference:
rollout_iterations: 2 # number of rollouts/multistep

backbone:
# default backbone
Expand Down
8 changes: 8 additions & 0 deletions geoarches/configs/module/archesweathergen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ module:
num_warmup_steps: 5000
num_training_steps: ${max_steps}

# for logging metrics
lead_time_hours: ${dataloader.dataset.lead_time_hours}
rollout_iterations: 1 # For training, number of rollouts

validation:
num_members: 5
rollout_iterations: 2 # number of rollouts

inference:
num_steps: 25
num_members: 10
Expand Down
11 changes: 7 additions & 4 deletions geoarches/lightning_modules/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
num_cycles=0.5,
learn_residual=False,
sd3_timestep_sampling=True,
lead_time_hours=24,
**kwargs,
):
"""
Expand Down Expand Up @@ -93,9 +94,11 @@ def __init__(

# set up metrics
save_memory = cfg.inference.num_members > 10
val_kwargs = dict(lead_time_hours=24, rollout_iterations=1, save_memory=save_memory)
val_kwargs = dict(
lead_time_hours=lead_time_hours, rollout_iterations=1, save_memory=save_memory
)
test_kwargs = dict(
lead_time_hours=24,
lead_time_hours=lead_time_hours,
rollout_iterations=cfg.inference.rollout_iterations,
save_memory=save_memory,
)
Expand Down Expand Up @@ -380,8 +383,8 @@ def sample_rollout(

def validation_step(self, batch, batch_nb):
# for the validation, we make some generations and log them
val_num_members = 5
val_rollout_iterations = 2
val_num_members = self.cfg.validation.num_steps
val_rollout_iterations = self.cfg.validation.rollout_iterations
samples = [
self.sample_rollout(
batch,
Expand Down
3 changes: 2 additions & 1 deletion geoarches/lightning_modules/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
add_input_state=False,
save_test_outputs=False,
use_weatherbench_lat_coeffs=True,
lead_time_hours=24,
rollout_iterations=1,
test_filename_suffix="",
**kwargs,
Expand Down Expand Up @@ -96,7 +97,7 @@ def __init__(
compute_lat_weights_fn=compute_lat_weights_weatherbench
if use_weatherbench_lat_coeffs
else compute_lat_weights,
lead_time_hours=24,
lead_time_hours=lead_time_hours,
rollout_iterations=rollout_iterations,
)

Expand Down

0 comments on commit 3fd8093

Please sign in to comment.