Skip to content

Commit

Permalink
Added ResidualBlock; updated train() to save the model state rather t…
Browse files Browse the repository at this point in the history
…han the parameters, as per latest Flux guidelines
  • Loading branch information
msainsburydale committed Sep 1, 2024
1 parent 04ba9a7 commit 98b216b
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 180 deletions.
9 changes: 5 additions & 4 deletions docs/src/API/architectures.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Modules

The following high-level modules are often used when constructing a neural-network architecture. In particular, the [`DeepSet`](@ref) is the building block for most classes of [Estimators](@ref) in the package.
The following high-level modules are often used when constructing a neural-network architecture. In particular, the [`DeepSet`](@ref) is the building block for most classes of [Estimators](@ref) in the package.

```@docs
DeepSet
Expand Down Expand Up @@ -32,13 +32,15 @@ NeighbourhoodVariogram

## Layers

In addition to the [built-in layers](https://fluxml.ai/Flux.jl/stable/reference/models/layers/) provided by Flux, the following layers may be used when constructing a neural-network architecture.
In addition to the [built-in layers](https://fluxml.ai/Flux.jl/stable/reference/models/layers/) provided by Flux, the following layers may be used when constructing a neural-network architecture.

```@docs
DensePositive
PowerDifference
ResidualBlock
SpatialGraphConv
```

Expand All @@ -50,7 +52,7 @@ Order = [:type, :function]
Pages = ["activationfunctions.md"]
```

In addition to the [standard activation functions](https://fluxml.ai/Flux.jl/stable/models/activation/) provided by Flux, the following structs can be used at the end of an architecture to act as output activation functions that ensure valid estimates for certain models. **NB:** Although we refer to the following objects as "activation functions", they should be treated as layers that are included in the final stage of a Flux `Chain()`.
In addition to the [standard activation functions](https://fluxml.ai/Flux.jl/stable/models/activation/) provided by Flux, the following structs can be used at the end of an architecture to act as output activation functions that ensure valid estimates for certain models. **NB:** Although we refer to the following objects as "activation functions", they should be treated as layers that are included in the final stage of a Flux `Chain()`.

```@docs
Compress
Expand All @@ -59,4 +61,3 @@ CorrelationMatrix
CovarianceMatrix
```

68 changes: 30 additions & 38 deletions docs/src/workflow/advancedusage.md

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions src/Architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,57 @@ PowerDifference(a::Number, b::AbstractArray) = PowerDifference([a], b)
PowerDifference(a::AbstractArray, b::Number) = PowerDifference(a, [b])
(f::PowerDifference)(x, y) = (abs.(f.a .* x - (1 .- f.a) .* y)).^f.b
(f::PowerDifference)(tup::Tuple) = f(tup[1], tup[2])


#TODO add further details
"""
ResidualBlock(filter, in => out; stride = 1)
Basic residual block (see [here](https://en.wikipedia.org/wiki/Residual_neural_network#Basic_block)),
consisting of two sequential convolutional layers and a skip (shortcut) connection
that connects the input of the block directly to the output,
facilitating the training of deep networks.
# Examples
```
using NeuralEstimators
z = rand(16, 16, 1, 1)
b = ResidualBlock((3, 3), 1 => 32)
b(z)
```
"""
struct ResidualBlock{B}
block::B
end
Flux.@functor ResidualBlock
(b::ResidualBlock)(x) = relu.(b.block(x))
function ResidualBlock(filter, channels; stride = 1)

layer = Chain(
Conv(filter, channels; stride = stride, pad=1, bias=false),
BatchNorm(channels[2], relu),
Conv(filter, channels[2]=>channels[2]; pad=1, bias=false),
BatchNorm(channels[2])
)

if stride == 1 && channels[1] == channels[2]
# dimensions match, can add input directly to output
connection = +
else
#TODO options for different dimension matching (padding vs. projection)
# Projection connection using 1x1 convolution
connection = Shortcut(
Chain(
Conv((1, 1), channels; stride = stride, bias=false),
BatchNorm(channels[2])
)
)
end

ResidualBlock(SkipConnection(layer, connection))
end
struct Shortcut{S}
s::S
end
Flux.@functor Shortcut
(s::Shortcut)(mx, x) = mx + s.s(x)
55 changes: 16 additions & 39 deletions src/Estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -891,17 +891,15 @@ when applied to data `Z`, returns the median
The ensemble can be initialised with a collection of trained `estimators` and then
applied immediately to observed data. Alternatively, the ensemble can be
initialised with a collection of untrained `estimators`
(or a function defining the architecture of each estimator, and the number of estimators in the ensemble),
trained with `train()`, and then applied to observed data. In the latter case, where the ensemble is trained directly,
if `savepath` is specified both the ensemble and component estimators will be saved.
initialised with a collection of untrained `estimators`
(or a function defining the architecture of each estimator, and the number of estimators in the ensemble),
trained with `train()`, and then applied to observed data. In the latter case, where the ensemble is trained directly,
if `savepath` is specified both the ensemble and component estimators will be saved.
Note that the training of ensemble components can be done in parallel; however,
currently this needs to be done manually by the user, since `train()` currently
trains the ensemble components sequentially.
Note that `train()` currently acts sequentially on the component estimators.
The ensemble components can be accessed by indexing the ensemble directly; the number
of component estimators can be obtained using `length()`.
The ensemble components can be accessed by indexing the ensemble directly; the number
of component estimators can be obtained using `length()`.
# Examples
```
Expand All @@ -923,12 +921,12 @@ function architecture()
end
# Ensemble size
J = 5
J = 3
# Initialise ensemble
ensemble = Ensemble(architecture, J)
ensemble[1] # access component estimators by indexing
length(ensemble) # number of component estimators
ensemble[1] # access component estimators by indexing
length(ensemble) # number of component estimators
# Training
ensemble = train(ensemble, sampler, simulator, m = m, epochs = 5)
Expand All @@ -941,26 +939,6 @@ rmse(assessment)
# Apply to data
ensemble(Z)
# Testing
J = 5 # ensemble size
ensemble = Ensemble(architecture, J)
train(ensemble, sampler, simulator, m = m, epochs = 5, savepath="testing-path")
ensemble = Ensemble(architecture, J)
ensemble(Z)
loadpath = joinpath(pwd(), "testing-path", "ensemble.bson")
Flux.loadparams!(ensemble, load(loadpath, @__MODULE__)[:weights])
ensemble(Z)
# Testing
J = 5 # ensemble size
ensemble = Ensemble(architecture, J)
trainx(ensemble, sampler, simulator, [30, 50], epochs = 5, savepath="testing-path")
ensemble = Ensemble(architecture, J)
ensemble(Z)
loadpath = joinpath(pwd(), "testing-path_m50", "ensemble.bson")
Flux.loadparams!(ensemble, load(loadpath, @__MODULE__)[:weights])
ensemble(Z)
```
"""
struct Ensemble <: NeuralEstimator
Expand All @@ -969,7 +947,6 @@ end
Ensemble(architecture::Function, J::Integer) = Ensemble([architecture() for j in 1:J])
@layer Ensemble

# NB would be great to have parallel version of this as a package extension
function train(ensemble::Ensemble, args...; kwargs...)
kwargs = (;kwargs...)
savepath = haskey(kwargs, :savepath) ? kwargs.savepath : ""
Expand All @@ -981,14 +958,14 @@ function train(ensemble::Ensemble, args...; kwargs...)
train(estimator, args...; kwargs...)
end
ensemble = Ensemble(estimators)
if savepath != "" # save ensemble

if savepath != "" # save ensemble
if !ispath(savepath) mkpath(savepath) end
weights = Flux.params(cpu(ensemble)) # ensure we are on the cpu before serialization
@save joinpath(savepath, "ensemble.bson") weights
end

return ensemble
return ensemble
end

function (ensemble::Ensemble)(Z; aggr = median)
Expand All @@ -1003,8 +980,8 @@ function (ensemble::Ensemble)(Z; aggr = median)
return θ̂
end

# Overload Base functions
# Overload Base functions
Base.getindex(e::Ensemble, i::Integer) = e.estimators[i]
Base.length(e::Ensemble) = length(e.estimators)
Base.length(e::Ensemble) = length(e.estimators)
Base.eachindex(e::Ensemble) = eachindex(e.estimators)
Base.show(io::IO, ensemble::Ensemble) = print(io, "\nEnsemble with $(length(ensemble.estimators)) component estimators")
Base.show(io::IO, ensemble::Ensemble) = print(io, "\nEnsemble with $(length(ensemble.estimators)) component estimators")
12 changes: 7 additions & 5 deletions src/NeuralEstimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using DataFrames
using Distances
using Distributions: Poisson, Bernoulli, product_distribution
using Flux
using Flux: ofeltype, params, DataLoader, update!, glorot_uniform, onehotbatch, _match_eltype # @layer
using Flux: ofeltype, DataLoader, update!, glorot_uniform, onehotbatch, _match_eltype # @layer
using Flux: @functor; var"@layer" = var"@functor" # NB did this because even semi-recent versions of Flux do not include @layer
using Folds
using Graphs
Expand All @@ -21,7 +21,7 @@ using InvertedIndices
using LinearAlgebra
using NamedArrays
using NearestNeighbors: KDTree, knn
using Optim # needed to obtain the MAP with neural ratio
using Optim # needed to obtain the ML/MAP with neural ratio (at least via gradient descent... NB could make it a package extension)
using Random: randexp, shuffle
using RecursiveArrayTools: VectorOfArray, convert
using SparseArrays
Expand All @@ -37,7 +37,7 @@ include("loss.jl")
export ParameterConfigurations, subsetparameters
include("Parameters.jl")

export DeepSet, summarystatistics, Compress, CovarianceMatrix, CorrelationMatrix
export DeepSet, summarystatistics, Compress, CovarianceMatrix, CorrelationMatrix, ResidualBlock
export vectotril, vectotriu
include("Architectures.jl")

Expand All @@ -63,7 +63,7 @@ include("train.jl")
export assess, Assessment, merge, join, risk, bias, rmse, coverage, intervalscore, empiricalprob
include("assess.jl")

export stackarrays, expandgrid, loadbestweights, loadweights, numberreplicates, nparams, samplesize, drop, containertype, estimateinbatches, rowwisenorm
export stackarrays, expandgrid, numberreplicates, nparams, samplesize, drop, containertype, estimateinbatches, rowwisenorm
include("utility.jl")

export samplesize, samplecorrelation, samplecovariance, NeighbourhoodVariogram
Expand All @@ -72,8 +72,10 @@ include("summarystatistics.jl")
export EM, removedata, encodedata
include("missingdata.jl")

# Backwards compatability:
# Backwards compatability and deprecations:
simulategaussianprocess = simulategaussian; export simulategaussianprocess
export loadbestweights, loadweights
include("deprecated.jl")

end

Expand Down
Loading

0 comments on commit 98b216b

Please sign in to comment.