Skip to content

Commit

Permalink
(0.91.5) Tweak initialization procedure so that callback schedules wo…
Browse files Browse the repository at this point in the history
…rk with checkpointed simulations (#3660)

* Add a test for initialization of callback schedules

* initialize schedules separately from callbacks

* New simulations test for initialiation

* Bump to 0.91.5

* Fix initialization bug

* Fix checkpointer test

* Update test/test_simulations.jl

* code cleanup

* Update test_checkpointer.jl

---------

Co-authored-by: Navid C. Constantinou <navidcy@users.noreply.github.com>
  • Loading branch information
glwagner and navidcy authored Aug 5, 2024
1 parent feaa386 commit bf767af
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Oceananigans"
uuid = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09"
authors = ["Climate Modeling Alliance and contributors"]
version = "0.91.4"
version = "0.91.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
13 changes: 7 additions & 6 deletions src/Simulations/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ end
@inline (callback::Callback)(sim) = callback.func(sim, callback.parameters)
@inline (callback::Callback{<:Nothing})(sim) = callback.func(sim)

# Fallback initialization: initialize the schedule.
# Then, if the schedule calls for it, execute the callback.
function initialize!(callback::Callback, sim)
initialize!(callback.schedule, sim.model) && callback(sim)
return nothing
end
"""
initialize!(callback::Callback, sim)
Initialize `callback`. By default, this does nothing, but
can be optionally specialized on the type parameters of `Callback`.
"""
initialize!(callback::Callback, sim) = nothing

"""
Callback(func, schedule=IterationInterval(1);
Expand Down
12 changes: 10 additions & 2 deletions src/Simulations/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,26 @@ function initialize!(sim::Simulation)
# Output and diagnostics initialization
[add_dependencies!(sim.diagnostics, writer) for writer in values(sim.output_writers)]

# Initialize schedules
scheduled_activities = Iterators.flatten((values(sim.diagnostics),
values(sim.callbacks),
values(sim.output_writers)))

for activity in scheduled_activities
initialize!(activity.schedule, sim.model)
end

# Reset! the model time-stepper, evaluate all diagnostics, and write all output at first iteration
if clock.iteration == 0
reset!(timestepper(sim.model))

# Initialize schedules and run diagnostics, callbacks, and output writers
for diag in values(sim.diagnostics)
diag.schedule(sim.model)
run_diagnostic!(diag, model)
end

for callback in values(sim.callbacks)
callback.callsite isa TimeStepCallsite && initialize!(callback, sim)
callback.callsite isa TimeStepCallsite && callback(sim)
end

for writer in values(sim.output_writers)
Expand Down
19 changes: 10 additions & 9 deletions src/Utils/schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ schedule_aligned_time_step(schedule, clock, Δt) = Δt
function initialize!(schedule::AbstractSchedule, model)
schedule(model)

# `return true` indicates that the schedule
# "actuates" at initial call.
# the default behavior `return true` dictates that by default,
# schedules actuate at the initial call.
return true
end

Expand Down Expand Up @@ -47,17 +47,19 @@ on a `interval` of simulation time, as kept by `model.clock`.
"""
TimeInterval(interval) = TimeInterval(convert(Float64, interval), 0.0, 0)

function initialize!(schedule::TimeInterval, model)
schedule.first_actuation_time = model.clock.time
schedule(model)
function initialize!(schedule::TimeInterval, first_actuation_time::Number)
schedule.first_actuation_time = first_actuation_time
schedule.actuations = 0
return true
end

initialize!(schedule::TimeInterval, model) = initialize!(schedule, model.clock.time)

function next_actuation_time(schedule::TimeInterval)
t₀ = schedule.first_actuation_time
N = schedule.actuations
T = schedule.interval
return t₀ + N * T
return t₀ + (N + 1) * T
end

function (schedule::TimeInterval)(model)
Expand All @@ -67,9 +69,8 @@ function (schedule::TimeInterval)(model)
if t >= t★
if schedule.actuations < typemax(Int)
schedule.actuations += 1
else
schedule.first_actuation_time = t★
schedule.actuations = 1
else # re-initialize the schedule to t★
initialize!(schedule, t★)
end
return true
else
Expand Down
50 changes: 37 additions & 13 deletions test/test_checkpointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ function test_model_equality(test_model, true_model)
return nothing
end

""" Set up a simple simulation to test picking up from a checkpoint. """
function initialization_test_simulation(arch, stop_time, Δt=1, δt=2)
grid = RectilinearGrid(arch, size=(), topology=(Flat, Flat, Flat))
model = NonhydrostaticModel(; grid)
simulation = Simulation(model; Δt, stop_time)

progress_message(sim) = @info string("Iter: ", iteration(sim), ", time: ", prettytime(sim))
simulation.callbacks[:progress] = Callback(progress_message, TimeInterval(δt))

checkpointer = Checkpointer(model,
schedule = TimeInterval(stop_time),
prefix = "initialization_test",
cleanup = false)

simulation.output_writers[:checkpointer] = checkpointer

return simulation
end

"""
Run two coarse rising thermal bubble simulations and make sure
Expand All @@ -34,10 +53,7 @@ Run two coarse rising thermal bubble simulations and make sure
3. run!(test_model, pickup) works as expected
"""
function test_thermal_bubble_checkpointer_output(arch)
#####
##### Create and run "true model"
#####

# Create and run "true model"
Nx, Ny, Nz = 16, 16, 16
Lx, Ly, Lz = 100, 100, 100
Δt = 6
Expand All @@ -58,10 +74,7 @@ function test_thermal_bubble_checkpointer_output(arch)
end

function test_hydrostatic_splash_checkpointer(arch, free_surface)
#####
##### Create and run "true model"
#####

# Create and run "true model"
Nx, Ny, Nz = 16, 16, 4
Lx, Ly, Lz = 1, 1, 1

Expand All @@ -78,7 +91,6 @@ function test_hydrostatic_splash_checkpointer(arch, free_surface)
end

function run_checkpointer_tests(true_model, test_model, Δt)

true_simulation = Simulation(true_model, Δt=Δt, stop_iteration=5)

checkpointer = Checkpointer(true_model, schedule=IterationInterval(5), overwrite_existing=true)
Expand Down Expand Up @@ -162,10 +174,7 @@ end

function run_checkpointer_cleanup_tests(arch)
grid = RectilinearGrid(arch, size=(1, 1, 1), extent=(1, 1, 1))
model = NonhydrostaticModel(grid=grid,
buoyancy=SeawaterBuoyancy(), tracers=(:T, :S)
)

model = NonhydrostaticModel(; grid, buoyancy=SeawaterBuoyancy(), tracers=(:T, :S))
simulation = Simulation(model, Δt=0.2, stop_iteration=10)

simulation.output_writers[:checkpointer] = Checkpointer(model, schedule=IterationInterval(3), cleanup=true)
Expand All @@ -191,5 +200,20 @@ for arch in archs
end

run_checkpointer_cleanup_tests(arch)

# Run a simulation that saves data to a checkpoint
rm("initialization_test_iteration*.jld2", force=true)
simulation = initialization_test_simulation(arch, 4)
run!(simulation)

# Now try again, but picking up from the previous checkpoint
N = iteration(simulation)
checkpoint = "initialization_test_iteration$N.jld2"
simulation = initialization_test_simulation(arch, 8)
run!(simulation, pickup=checkpoint)

progress_cb = simulation.callbacks[:progress]
progress_cb.schedule.first_actuation_time
@test progress_cb.schedule.first_actuation_time == 4
end
end
2 changes: 1 addition & 1 deletion test/test_schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using Oceananigans: initialize!
ti = TimeInterval(2)
initialize!(ti, fake_model_at_iter_0)

@test ti.actuations == 1
@test ti.actuations == 0
@test ti.interval == 2.0
@test ti(fake_model_at_time_2)
@test !(ti(fake_model_at_time_3))
Expand Down
16 changes: 14 additions & 2 deletions test/test_simulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ function wall_time_step_wizard_tests(arch)
Δt = new_time_step(Δt, wizard, model)
@test Δt diff_CFL * Δx^2 / model.closure.ν


grid_stretched = RectilinearGrid(arch,
size = (1, 1, 1),
size = (1, 1, 1),
x = (0, 1),
y = (0, 1),
z = z -> z,
Expand Down Expand Up @@ -207,6 +206,19 @@ end
@info "Testing simulations [$(typeof(arch))]..."
run_basic_simulation_tests(arch)

# Test initialization for simulations started with iteration ≠ 0
grid = RectilinearGrid(arch, size=(), topology=(Flat, Flat, Flat))
model = NonhydrostaticModel(; grid)
simulation = Simulation(model; Δt=1, stop_time=6)

progress_message(sim) = @info string("Iter: ", iteration(sim), ", time: ", prettytime(sim))
progress_cb = Callback(progress_message, TimeInterval(2))
simulation.callbacks[:progress] = progress_cb

model.clock.iteration = 1 # we want to start here for some reason
run!(simulation)
@test progress_cb.schedule.actuations == 3

@testset "NaN Checker [$(typeof(arch))]" begin
@info " Testing NaN Checker [$(typeof(arch))]..."
run_nan_checker_test(arch, erroring=true)
Expand Down

2 comments on commit bf767af

@glwagner
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/112440

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.91.5 -m "<description of version>" bf767af3c40c049ebc9499b7553065a9ef350178
git push origin v0.91.5

Please sign in to comment.