Skip to content

Commit

Permalink
Saving ensemble during training
Browse files Browse the repository at this point in the history
  • Loading branch information
msainsburydale committed Aug 26, 2024
1 parent 2f72c78 commit 932ce23
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 46 deletions.
32 changes: 28 additions & 4 deletions src/Estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -891,12 +891,16 @@ when applied to data `Z`, returns the median
The ensemble can be initialised with a collection of trained `estimators` and then
applied immediately to observed data. Alternatively, the ensemble can be
initialised with a collection of untrained `estimators`,
trained with `train()`, and then applied to observed data.
trained with `train()`, and then applied to observed data. In the latter case,
if `savepath` is specified, both the ensemble and component estimators will be saved.
Note that the training of ensemble components can be done in parallel; however,
currently this needs to be done manually by the user, since `train()` currently
trains the ensemble components sequentially.
The ensemble components can be accessed by indexing the ensemble directly; the number
of component estimators can be obtained using `length()`.
# Examples
```
using NeuralEstimators, Flux
Expand All @@ -920,6 +924,8 @@ end
J = 5 # ensemble size
estimators = [estimator() for j in 1:J]
ensemble = Ensemble(estimators)
ensemble[1] # can access component estimators by indexing
length(ensemble) # number of component estimators
# Training
ensemble = train(ensemble, sampler, simulator, m = m, epochs = 5)
Expand All @@ -941,13 +947,26 @@ end
@layer Ensemble
Base.show(io::IO, ensemble::Ensemble) = print(io, "\nEnsemble with $(length(ensemble.estimators)) component estimators")

# TODO parallel version of this as a package extension
# NB would be great to have parallel version of this as a package extension
function train(ensemble::Ensemble, args...; kwargs...)
kwargs = (;kwargs...)
savepath = haskey(kwargs, :savepath) ? kwargs.savepath : ""
estimators = map(enumerate(ensemble.estimators)) do (i, estimator)
@info "Training estimator $i"
@info "Training estimator $i of $(length(ensemble))"
if savepath != "" # modify the savepath before passing it onto train
kwargs = merge(kwargs, (savepath = joinpath(savepath, "estimator$i"),))
end
train(estimator, args...; kwargs...)
end
Ensemble(estimators)
ensemble = Ensemble(estimators)

if savepath != "" # save ensemble
if !ispath(savepath) mkpath(savepath) end
weights = Flux.params(cpu(ensemble)) # ensure we are on the cpu before serialization
@save joinpath(savepath, "ensemble.bson") weights
end

return ensemble
end

function (ensemble::Ensemble)(Z; aggr = median)
Expand All @@ -961,3 +980,8 @@ function (ensemble::Ensemble)(Z; aggr = median)
θ̂ = dropdims(θ̂; dims = 3)
return θ̂
end

# Overload getindex to enable indexing
getindex(e::Ensemble, i::Integer) = e.estimators[i]
# Overload length to obtain number of component estimators
length(e::Ensemble) = length(e.estimators)
2 changes: 1 addition & 1 deletion src/NeuralEstimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module NeuralEstimators

using Base: @propagate_inbounds, @kwdef
using Base.GC: gc
import Base: join, merge, show, size, summary
import Base: join, merge, show, size, summary, getindex, length
using BSON: @save, load
using ChainRulesCore: @non_differentiable, @ignore_derivatives
using CSV
Expand Down
58 changes: 29 additions & 29 deletions src/assess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ with columns:
If `estimator` is an `IntervalEstimator`, the column `estimate` will be replaced by the columns `lower` and `upper`, containing the lower and upper bounds of the interval, respectively.
If `estimator` is a `QuantileEstimator`, the `df` will also contain a column `prob` indicating the probability level of the corresponding quantile estimate.
If `estimator` is a `QuantileEstimator`, the `df` will also contain a column `prob` indicating the probability level of the corresponding quantile estimate.
Multiple `Assessment` objects can be combined with `merge()`
(used for combining assessments from multiple point estimators) or `join()`
Expand Down Expand Up @@ -169,9 +169,9 @@ end
"""
coverage(assessment::Assessment; ...)
Computes a Monte Carlo approximation of an interval estimator's expected coverage,
as defined in [Hermans et al. (2022, Definition 2.1)](https://arxiv.org/abs/2110.06581),
and the proportion of parameters below and above the lower and upper bounds, respectively.
Computes a Monte Carlo approximation of an interval estimator's expected coverage,
as defined in [Hermans et al. (2022, Definition 2.1)](https://arxiv.org/abs/2110.06581),
and the proportion of parameters below and above the lower and upper bounds, respectively.
# Keyword arguments
- `average_over_parameters::Bool = false`: if true, the coverage is averaged over all parameters; otherwise (default), it is computed over each parameter separately.
Expand Down Expand Up @@ -202,21 +202,21 @@ function coverage(assessment::Assessment;
return df
end

#TODO bootstrap sampling for bounds on this diagnostic
#TODO bootstrap sampling for bounds on this diagnostic
function empiricalprob(assessment::Assessment;
average_over_parameters::Bool = false,
average_over_sample_sizes::Bool = true)

df = assessment.df

@assert all(["prob", "estimate", "truth"] .∈ Ref(names(df)))
@assert all(["prob", "estimate", "truth"] .∈ Ref(names(df)))

grouping_variables = [:prob]
if "estimator" names(df) push!(grouping_variables, :estimator) end
if "estimator" names(df) push!(grouping_variables, :estimator) end
if !average_over_parameters push!(grouping_variables, :parameter) end
if !average_over_sample_sizes push!(grouping_variables, :m) end
df = groupby(df, grouping_variables)
df = combine(df,
df = combine(df,
[:estimate, :truth] => ((x, y) -> x .> y) => :below,
ungroup = false)
df = combine(df, :below => mean => :empirical_prob)
Expand Down Expand Up @@ -405,7 +405,7 @@ function assess(
if boot == true
verbose && println(" Computing $((probs[2] - probs[1]) * 100)% non-parametric bootstrap intervals...")
# bootstrap estimates
@assert !(typeof(Z) <: Tuple) "bootstrap() is not currently set up for dealing with set-level information; please contact the package maintainer"
@assert !(typeof(Z) <: Tuple) "bootstrap() is not currently set up for dealing with set-level information; please contact the package maintainer"
bs = bootstrap.(Ref(estimator), Z, use_gpu = use_gpu, B = B)
else # if boot is not a Bool, we will assume it is a bootstrap data set. # TODO probably should add some checks on boot in this case (length should be equal to K, for example)
verbose && println(" Computing $((probs[2] - probs[1]) * 100)% parametric bootstrap intervals...")
Expand Down Expand Up @@ -437,11 +437,11 @@ end

function assess(
estimator::Union{QuantileEstimatorContinuous, QuantileEstimatorDiscrete}, θ::P, Z;
parameter_names::Vector{String} = ["θ$i" for i 1:size(θ, 1)],
parameter_names::Vector{String} = ["θ$i" for i 1:size(θ, 1)],
estimator_name::Union{Nothing, String} = nothing,
estimator_names::Union{Nothing, String} = nothing, # for backwards compatibility
use_gpu::Bool = true,
probs = Float32.(range(0.01, stop=0.99, length=100))
probs = Float32.(range(0.01, stop=0.99, length=100))
) where {P <: Union{AbstractMatrix, ParameterConfigurations}}

# Extract the matrix of parameters
Expand All @@ -467,46 +467,46 @@ function assess(
# If the estimator is a QuantileEstimatorDiscrete, then we use its probability levels
if typeof(estimator) <: QuantileEstimatorDiscrete
probs = estimator.probs
else
τ = [permutedims(probs) for _ in eachindex(Z)] # convert from vector to vector of matrices
end
else
τ = [permutedims(probs) for _ in eachindex(Z)] # convert from vector to vector of matrices
end
n_probs = length(probs)

# Construct input set
i = estimator.i
if isnothing(i)
if typeof(estimator) <: QuantileEstimatorDiscrete
set_info = nothing
else
set_info = nothing
else
set_info = τ
end
else
end
else
θ₋ᵢ = θ[Not(i), :]
if typeof(estimator) <: QuantileEstimatorDiscrete
set_info = eachcol(θ₋ᵢ)
else
# Combine each θ₋ᵢ with the corresponding vector of
else
# Combine each θ₋ᵢ with the corresponding vector of
# probability levels, which requires repeating θ₋ᵢ appropriately
set_info = map(1:K) do k
set_info = map(1:K) do k
θ₋ᵢₖ = repeat(θ₋ᵢ[:, k:k], inner = (1, n_probs))
vcat(θ₋ᵢₖ, probs')
end
end
end
end
θ = θ[i:i, :]
parameter_names = parameter_names[i:i]
end

# Compute estimates using memory-safe version of estimator((Z, set_info))
runtime = @elapsed θ̂ = estimateinbatches(estimator, Z, set_info, use_gpu = use_gpu)
runtime = @elapsed θ̂ = estimateinbatches(estimator, Z, set_info, use_gpu = use_gpu)

# Convert to DataFrame and add information
p = size(θ, 1)
runtime = DataFrame(runtime = runtime)
df = DataFrame(
parameter = repeat(repeat(parameter_names, inner = n_probs), K),
truth = repeat(vec(θ), inner = n_probs),
prob = repeat(repeat(probs, outer = p), K),
estimate = vec(θ̂),
truth = repeat(vec(θ), inner = n_probs),
prob = repeat(repeat(probs, outer = p), K),
estimate = vec(θ̂),
m = repeat(m, inner = n_probs*p),
k = repeat(1:K, inner = n_probs*p),
j = 1 # just for consistency with other methods
Expand Down Expand Up @@ -549,7 +549,7 @@ function assess(

# run the estimators
assessments = map(1:E) do i
verbose && println(" Running estimator $(estimator_names[i])...")
verbose && println(" Running $(estimator_names[i])...")
if use_ξ[i]
assess(estimators[i], θ, Z, ξ = ξ; use_gpu = use_gpu[i], estimator_name = estimator_names[i], kwargs...)
else
Expand Down
24 changes: 12 additions & 12 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function sampler(K)
return θ
end
function simulator(θ_matrix, m)
function simulator(θ_matrix, m)
[θ[1] .+ θ[2] * randn(1, m) for θ ∈ eachcol(θ_matrix)]
end
Expand Down Expand Up @@ -215,7 +215,7 @@ function _train(θ̂, sampler, simulator;
savebool && _saveinfo(loss_per_epoch, train_time, savepath, verbose = verbose)
savebool && _savebestweights(savepath)

# TODO if the user has relied on using train() as a mutating function, the optimal estimator will not be returned. Can I set θ̂ = θ̂_best to fix this? This also ties in with the other TODO down below above trainx(), regarding which device the estimator is on at the end of training.
# TODO if the user has relied on using train() as a mutating function, the optimal estimator will not be returned. Can I set θ̂ = θ̂_best to fix this? This also ties in with the other TODO down below above trainx(), regarding which device the estimator is on at the end of training.

return θ̂_best
end
Expand Down Expand Up @@ -291,7 +291,7 @@ function _train(θ̂, θ_train::P, θ_val::P, simulator;
# Update θ̂ and compute the training risk
epoch_time = 0.0
train_risk = []

for θ _ParameterLoader(θ_train, batchsize = batchsize)
sim_time += @elapsed set = _constructset(θ̂, simulator, θ, m, batchsize)
epoch_time += @elapsed rsk = _risk(θ̂, loss, set, device, optimiser)
Expand Down Expand Up @@ -487,7 +487,7 @@ function _constructset(θ̂, simulator::Function, θ::P, m, batchsize) where {P
_constructset(θ̂, Z, θ, batchsize)
end
function _constructset(θ̂, Z, θ::P, batchsize) where {P <: Union{AbstractMatrix, ParameterConfigurations}}
Z = ZtoFloat32(Z)
Z = ZtoFloat32(Z)
θ = θtoFloat32(_extractθ(θ))
_DataLoader((Z, θ), batchsize)
end
Expand Down Expand Up @@ -543,7 +543,7 @@ function _constructset(θ̂::QuantileEstimatorContinuous, Zτ, θ::P, batchsize)
θ = θtoFloat32(_extractθ(θ))
Z, τ =
Z = ZtoFloat32(Z)
τ = ZtoFloat32.(τ)
τ = ZtoFloat32.(τ)

i = θ̂.i
if isnothing(i)
Expand All @@ -553,13 +553,13 @@ function _constructset(θ̂::QuantileEstimatorContinuous, Zτ, θ::P, batchsize)
@assert size(θ, 1) >= i "The number of parameters in the model (size(θ, 1) = $(size(θ, 1))) must be at least as large as the value of i stored in the estimator (θ̂.i = $(θ̂.i))"
θᵢ = θ[i:i, :]
θ₋ᵢ = θ[Not(i), :]
# Combine each θ₋ᵢ with the corresponding vector of
# Combine each θ₋ᵢ with the corresponding vector of
# probability levels, which requires repeating θ₋ᵢ appropriately
θ₋ᵢτ = map(eachindex(τ)) do k
θ₋ᵢτ = map(eachindex(τ)) do k
τₖ = τ[k]
θ₋ᵢₖ = repeat(θ₋ᵢ[:, k:k], inner = (1, length(τₖ)))
vcat(θ₋ᵢₖ, τₖ')
end
end
input = (Z, θ₋ᵢτ) # "Tupleise" the input
output = θᵢ
end
Expand Down Expand Up @@ -620,19 +620,19 @@ function _risk(θ̂::QuantileEstimatorContinuous, loss, set::DataLoader, device,
input1 = Z
input2 = permutedims.(τ)
input = (input1, input2)
τ = reduce(hcat, τ) # reduce from vector of vectors to matrix
τ = reduce(hcat, τ) # reduce from vector of vectors to matrix
else
Z, θ₋ᵢτ = input
τ = [x[end, :] for x θ₋ᵢτ] # extract probability levels
τ = reduce(hcat, τ) # reduce from vector of vectors to matrix
τ = reduce(hcat, τ) # reduce from vector of vectors to matrix
end

# repeat τ and θ to facilitate broadcasting and indexing
# note that repeat() cannot be differentiated by Zygote
p = size(output, 1)
@ignore_derivatives τ = repeat(τ, inner = (p, 1))
@ignore_derivatives output = repeat(output, inner = (size(τ, 1) ÷ p, 1))

if !isnothing(optimiser)

# "Implicit" style used by Flux <= 0.14.
Expand All @@ -655,7 +655,7 @@ end

# ---- Wrapper function for training multiple estimators over a range of sample sizes ----

#TODO (not sure what we want do about the following behaviour, need to think about it): If called as est = trainx(est) then est will be on the GPU; if called as trainx(est) then est will not be on the GPU. Note that the same thing occurs for train(). That is, when the function is treated as mutating, then the estimator will be on the same device that was used during training; otherwise, it will be on whichever device it was when input to the function. Need consistency to improve user experience.
#TODO (not sure what we want do about the following behaviour, need to think about it): If called as est = trainx(est) then est will be on the GPU; if called as trainx(est) then est will not be on the GPU. Note that the same thing occurs for train(). That is, when the function is treated as mutating, then the estimator will be on the same device that was used during training; otherwise, it will be on whichever device it was when input to the function. Need consistency to improve user experience.
"""
trainx(θ̂, sampler::Function, simulator::Function, m::Vector{Integer}; ...)
trainx(θ̂, θ_train, θ_val, simulator::Function, m::Vector{Integer}; ...)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ end
J = 2 # ensemble size
estimators = [estimator() for j in 1:J]
ensemble = Ensemble(estimators)
ensemble[1] # can be indexed
@test length(ensemble) == J # number of component estimators

# Training
ensemble = train(ensemble, sampler, simulator, m = m, epochs = 2, verbose = false)
Expand Down

0 comments on commit 932ce23

Please sign in to comment.