Skip to content

Commit

Permalink
Helper initialiser for Ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
msainsburydale committed Aug 28, 2024
1 parent 932ce23 commit 04ba9a7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
51 changes: 37 additions & 14 deletions src/Estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ coercetotuple(x) = (x...,)

"""
Ensemble(estimators)
Ensemble(architecture::Function, J::Integer)
(ensemble::Ensemble)(Z; aggr = median)
Defines an ensemble based on a collection of `estimators` which,
Expand All @@ -890,9 +891,10 @@ when applied to data `Z`, returns the median
The ensemble can be initialised with a collection of trained `estimators` and then
applied immediately to observed data. Alternatively, the ensemble can be
initialised with a collection of untrained `estimators`,
trained with `train()`, and then applied to observed data. In the latter case,
if `savepath` is specified, both the ensemble and component estimators will be saved.
initialised with a collection of untrained `estimators`
(or a function defining the architecture of each estimator, and the number of estimators in the ensemble),
trained with `train()`, and then applied to observed data. In the latter case, where the ensemble is trained directly,
if `savepath` is specified both the ensemble and component estimators will be saved.
Note that the training of ensemble components can be done in parallel; however,
currently this needs to be done manually by the user, since `train()` currently
Expand All @@ -913,18 +915,19 @@ sampler(K) = randn32(p, K)
simulator(θ, m) = [μ .+ randn32(d, m) for μ ∈ eachcol(θ)]
# Architecture of each ensemble component
function estimator()
function architecture()
ψ = Chain(Dense(d, 64, relu), Dense(64, 64, relu))
ϕ = Chain(Dense(64, 64, relu), Dense(64, p))
deepset = DeepSet(ψ, ϕ)
PointEstimator(deepset)
end
# Ensemble size
J = 5
# Initialise ensemble
J = 5 # ensemble size
estimators = [estimator() for j in 1:J]
ensemble = Ensemble(estimators)
ensemble[1] # can access component estimators by indexing
ensemble = Ensemble(architecture, J)
ensemble[1] # access component estimators by indexing
length(ensemble) # number of component estimators
# Training
Expand All @@ -937,15 +940,34 @@ assessment = assess(ensemble, θ, Z)
rmse(assessment)
# Apply to data
Z = Z[1]
ensemble(Z)
# Testing
J = 5 # ensemble size
ensemble = Ensemble(architecture, J)
train(ensemble, sampler, simulator, m = m, epochs = 5, savepath="testing-path")
ensemble = Ensemble(architecture, J)
ensemble(Z)
loadpath = joinpath(pwd(), "testing-path", "ensemble.bson")
Flux.loadparams!(ensemble, load(loadpath, @__MODULE__)[:weights])
ensemble(Z)
# Testing
J = 5 # ensemble size
ensemble = Ensemble(architecture, J)
trainx(ensemble, sampler, simulator, [30, 50], epochs = 5, savepath="testing-path")
ensemble = Ensemble(architecture, J)
ensemble(Z)
loadpath = joinpath(pwd(), "testing-path_m50", "ensemble.bson")
Flux.loadparams!(ensemble, load(loadpath, @__MODULE__)[:weights])
ensemble(Z)
```
"""
struct Ensemble <: NeuralEstimator
estimators
end
Ensemble(architecture::Function, J::Integer) = Ensemble([architecture() for j in 1:J])
@layer Ensemble
Base.show(io::IO, ensemble::Ensemble) = print(io, "\nEnsemble with $(length(ensemble.estimators)) component estimators")

# NB would be great to have parallel version of this as a package extension
function train(ensemble::Ensemble, args...; kwargs...)
Expand Down Expand Up @@ -981,7 +1003,8 @@ function (ensemble::Ensemble)(Z; aggr = median)
return θ̂
end

# Overload getindex to enable indexing
getindex(e::Ensemble, i::Integer) = e.estimators[i]
# Overload length to obtain number of component estimators
length(e::Ensemble) = length(e.estimators)
# Overload Base functions
Base.getindex(e::Ensemble, i::Integer) = e.estimators[i]
Base.length(e::Ensemble) = length(e.estimators)
Base.eachindex(e::Ensemble) = eachindex(e.estimators)
Base.show(io::IO, ensemble::Ensemble) = print(io, "\nEnsemble with $(length(ensemble.estimators)) component estimators")
2 changes: 1 addition & 1 deletion src/NeuralEstimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module NeuralEstimators

using Base: @propagate_inbounds, @kwdef
using Base.GC: gc
import Base: join, merge, show, size, summary, getindex, length
import Base: join, merge, show, size, summary, getindex, length, eachindex
using BSON: @save, load
using ChainRulesCore: @non_differentiable, @ignore_derivatives
using CSV
Expand Down

0 comments on commit 04ba9a7

Please sign in to comment.