Skip to content

Commit

Permalink
Now import scatter() and gather() from NNlib; convenience functions f…
Browse files Browse the repository at this point in the history
…or EM()
  • Loading branch information
msainsburydale authored and msainsburydale committed Dec 9, 2024
1 parent f788541 commit 285cc74
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 37 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NeuralEstimators"
uuid = "38f6df31-6b4a-4144-b2af-7ace2da57606"
authors = ["Matthew Sainsbury-Dale <msainsburydale@gmail.com> and contributors"]
version = "0.2.0"
version = "0.2.1"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand All @@ -14,6 +14,7 @@ GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 2 additions & 0 deletions docs/src/API/utility.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ expandgrid
IndicatorWeights
KernelWeights
initialise_estimator
loadbestweights
Expand Down
3 changes: 3 additions & 0 deletions src/Estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ end
# Initialise ensemble with three components
ensemble = Ensemble(architecture, 3)
ensemble[1] # access component estimators by indexing
ensemble[1:2] # indexing with an iterable collection returns the corresponding ensemble
length(ensemble) # number of component estimators
# Training
Expand Down Expand Up @@ -983,6 +984,8 @@ end

# Overload Base functions
Base.getindex(e::Ensemble, i::Integer) = e.estimators[i]
Base.getindex(e::Ensemble, indices::AbstractVector{<:Integer}) = Ensemble(e.estimators[indices])
Base.getindex(e::Ensemble, indices::UnitRange{<:Integer}) = Ensemble(e.estimators[indices])
Base.length(e::Ensemble) = length(e.estimators)
Base.eachindex(e::Ensemble) = eachindex(e.estimators)
Base.show(io::IO, ensemble::Ensemble) = print(io, "\nEnsemble with $(length(ensemble.estimators)) component estimators")
34 changes: 16 additions & 18 deletions src/Graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ end
@doc raw"""
SpatialGraphConv(in => out, g=relu; args...)
Implements a spatial graph convolution for isotropic processes,
Implements a spatial graph convolution for isotropic spatial processes [(Sainsbury-Dale et al., 2025)](https://arxiv.org/abs/2310.02600),
```math
\boldsymbol{h}^{(l)}_{j} =
Expand Down Expand Up @@ -580,18 +580,13 @@ on either the `k`-nearest neighbours of each location; all nodes within a disc o
or, if both `r` and `k` are provided, a subset of `k` neighbours within a disc
of fixed radius `r`.
Several subsampling strategies are possible when choosing a subset of `k` neighbours within
a disc of fixed radius `r`. If `random=true` (default), the neighbours are randomly selected from
within the disc (note that this also approximately preserves the distribution of
distances within the neighbourhood set). If `random=false`, a deterministic algorithm is used
that aims to preserve the distribution of distances within the neighbourhood set, by choosing
those nodes with distances to the central node corresponding to the
$\{0, \frac{1}{k}, \frac{2}{k}, \dots, \frac{k-1}{k}, 1\}$ quantiles of the empirical
distribution function of distances within the disc.
(This algorithm in fact yields $k+1$ neighbours, since both the closest and furthest nodes are always included.)
Otherwise,
If `S` is a square matrix, it is treated as a distance matrix; otherwise, it
should be an $n$ x $d$ matrix, where $n$ is the number of spatial locations
and $d$ is the spatial dimension (typically $d$ = 2). In the latter case,
the distance metric is taken to be the Euclidean distance. Note that use of a
maxmin ordering currently requires a matrix of spatial locations (not a distance matrix).
If `maxmin=false` (default) the `k`-nearest neighbours are chosen based on all points in
When using the `k` nearest neighbours, if `maxmin=false` (default) the neighbours are chosen based on all points in
the graph. If `maxmin=true`, a so-called maxmin ordering is applied,
whereby an initial point is selected, and each subsequent point is selected to
maximise the minimum distance to those points that have already been selected.
Expand All @@ -600,11 +595,14 @@ amongst the points that have already appeared in the ordering. If `combined=true
neighbours are defined to be the union of the `k`-nearest neighbours and the
`k`-nearest neighbours subject to a maxmin ordering.
If `S` is a square matrix, it is treated as a distance matrix; otherwise, it
should be an $n$ x $d$ matrix, where $n$ is the number of spatial locations
and $d$ is the spatial dimension (typically $d$ = 2). In the latter case,
the distance metric is taken to be the Euclidean distance. Note that use of a
maxmin ordering currently requires a matrix of spatial locations (not a distance matrix).
Two subsampling strategies are implemented when choosing a subset of `k` neighbours within
a disc of fixed radius `r`. If `random=true` (default), the neighbours are randomly selected from
within the disc. If `random=false`, a deterministic algorithm is used
that aims to preserve the distribution of distances within the neighbourhood set, by choosing
those nodes with distances to the central node corresponding to the
$\{0, \frac{1}{k}, \frac{2}{k}, \dots, \frac{k-1}{k}, 1\}$ quantiles of the empirical
distribution function of distances within the disc (this in fact yields up to $k+1$ neighbours,
since both the closest and furthest nodes are always included).
By convention with the functionality in `GraphNeuralNetworks.jl` which is based on directed graphs,
the neighbours of location `i` are stored in the column `A[:, i]` where `A` is the
Expand Down Expand Up @@ -658,7 +656,7 @@ function adjacencymatrix(M::Mat, r::F, k::Integer; random::Bool = true) where Ma
@assert k > 0
@assert r > 0

if random == false
if !random
A = adjacencymatrix(M, r)
A = subsetneighbours(A, k)
A = dropzeros!(A) # remove self loops
Expand Down
11 changes: 5 additions & 6 deletions src/NeuralEstimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ using Flux: @functor; var"@layer" = var"@functor" # NB did this because even sem
using Folds
using Graphs
using GraphNeuralNetworks
using GraphNeuralNetworks: check_num_nodes, scatter, gather
using GraphNeuralNetworks: check_num_nodes
import GraphNeuralNetworks: GraphConv
using InvertedIndices
using LinearAlgebra
using NamedArrays
using NearestNeighbors: KDTree, knn
using NNlib: scatter, gather
using Random: randexp, shuffle
using RecursiveArrayTools: VectorOfArray, convert
using SparseArrays
Expand Down Expand Up @@ -75,9 +76,10 @@ include("deprecated.jl")

end

#TODO makes sense to have m=1 as default in train() (it is often the case that m=1, especially in spatial statistics)

# ---- longer term/lower priority:
# - Amortised posterior approximation (https://github.com/slimgroup/InvertibleNetworks.jl). Also allow for conditioning.
# - Amortised likelihood approximation (https://github.com/slimgroup/InvertibleNetworks.jl)
# - Extension: Incorporate the following package to greatly expand bootstrap functionality: https://github.com/juliangehring/Bootstrap.jl. Note also the "straps()" method that allows one to obtain the bootstrap distribution. I think what I can do is define a method of interval(bs::BootstrapSample). Maybe one difficulty will be how to re-sample... Not sure how the bootstrap method will know to sample from the independent replicates dimension (the last dimension) of each array.
# - Add NeuralEstimators.jl to the list of packages that use Documenter: see https://documenter.juliadocs.org/stable/man/examples/
# - Add NeuralEstimators.jl to https://github.com/smsharma/awesome-neural-sbi#code-packages-and-benchmarks
# - Ensemble: make it “play well” throughout the package. For example, assess() with other kinds of neural estimators (e.g., quantile estimators), and ml/mapestimate() with RatioEstimators.
Expand All @@ -91,11 +93,8 @@ end
# - Sequence (e.g., time-series) input: https://jldc.ch/post/seq2one-flux/
# - Precompile NeuralEstimators.jl to reduce latency: See https://julialang.org/blog/2021/01/precompile_tutorial/. Seems easy, just need to add precompile(f, (arg_types…)) to whichever methods we want to precompile
# - Examples: data plots within each example. Can show a histogram for univariate data; a scatterplot for bivariate data; a heatmap for gridded data; and scatterplot for irregular spatial data.
# - Extension: Incorporate the following package to greatly expand bootstrap functionality: https://github.com/juliangehring/Bootstrap.jl. Note also the "straps()" method that allows one to obtain the bootstrap distribution. I think what I can do is define a method of interval(bs::BootstrapSample). Maybe one difficulty will be how to re-sample... Not sure how the bootstrap method will know to sample from the independent replicates dimension (the last dimension) of each array.
# - GPU on MacOS with Metal.jl (already have extension written, need to wait until Metal.jl is further developed; in particular, need convolution layers to be implemented)
# - Explicit learning of summary statistics
# - Amortised posterior approximation (https://github.com/slimgroup/InvertibleNetworks.jl)
# - Amortised likelihood approximation (https://github.com/slimgroup/InvertibleNetworks.jl)
# - Functionality for storing and plotting the training-validation risk in the NeuralEstimator. This will involve changing _train() to return both the estimator and the risk, and then defining train(::NeuralEstimator) to update the slot containing the risk. We will also need _train() to take the argument "loss_vs_epoch", so that we can "continue training"
# - Separate GNN functionality (tried this with package extensions but not possible currently because we need to define custom structs)
# - SpatialPyramidPool for CNNs
Expand Down
13 changes: 8 additions & 5 deletions src/missingdata.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#TODO think it's better if this is kept simple, and designed only for neural EM...

@doc raw"""
EM(simulateconditional::Function, MAP::Union{Function, NeuralEstimator}, θ₀ = nothing)
Implements the (Bayesian) Monte Carlo expectation-maximisation (EM) algorithm,
Expand Down Expand Up @@ -69,8 +67,10 @@ struct EM{F,T,S}
end
EM(simulateconditional, MAP) = EM(simulateconditional, MAP, nothing)
EM(em::EM, θ₀) = EM(em.simulateconditional, em.MAP, θ₀)
EM(simulateconditional, MAP, θ₀::Number) = EM(simulateconditional, MAP, [θ₀])
#TODO think it's better if this is kept simple, and designed only for neural EM...

function (em::EM)(Z::A, θ₀ = nothing; args...) where {A <: AbstractArray{T, N}} where {T, N}
function (em::EM)(Z::A, θ₀ = nothing; kwargs...) where {A <: AbstractArray{T, N}} where {T, N}
@warn "Data has been passed to the EM algorithm that contains no missing elements... the MAP estimator will be applied directly to the data"
em.MAP(Z)
end
Expand All @@ -81,7 +81,6 @@ function (em::EM)(
niterations::Integer = 50,
nsims::Integer = 1,
nconsecutive::Integer = 3,
#nensemble::Integer = 5, # TODO implement and document
ϵ = 0.01,
ξ = nothing,
use_ξ_in_simulateconditional::Bool = false,
Expand Down Expand Up @@ -147,13 +146,17 @@ function (em::EM)(
return_iterates ? θ_all : θₗ
end

function (em::EM)(Z::V, θ₀::Union{Vector, Matrix, Nothing} = nothing; args...) where {V <: AbstractVector{A}} where {A <: AbstractArray{Union{Missing, T}, N}} where {T, N}
function (em::EM)(Z::V, θ₀::Union{Number, Vector, Matrix, Nothing} = nothing; args...) where {V <: AbstractVector{A}} where {A <: AbstractArray{Union{Missing, T}, N}} where {T, N}

if isnothing(θ₀)
@assert !isnothing(em.θ₀) "Please provide initial estimates `θ₀` in the function call or in the `EM` object."
θ₀ = em.θ₀
end

if isa(θ₀, Number)
θ₀ = [θ₀]
end

if isa(θ₀, Vector)
θ₀ = repeat(θ₀, 1, length(Z))
end
Expand Down
65 changes: 60 additions & 5 deletions src/simulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,47 @@ function simulatepotts(grid::AbstractMatrix{Int}, β; nsims::Int = 1, num_iterat
#TODO sum(chequerboard1) == 0 (easy workaround in this case, just iterate over chequerboard2)
#TODO sum(chequerboard2) == 0 (easy workaround in this case, just iterate over chequerboard1)

# Define neighbours offsets (assuming 4-neighbour connectivity)
# Define neighbours offsets based on 4-neighbour connectivity
neighbour_offsets = [(0, 1), (1, 0), (0, -1), (-1, 0)]

# Gibbs sampling iterations
for _ in 1:num_iterations
for chequerboard in (chequerboard1, chequerboard2)

# ---- Vectorised version (this implementation doesn't seem to save any time) ----

# # Compute neighbour counts for each state
# # NB some wasted computations because we are comuting the neighbour counts for all grid cells, not just the current chequerboard
# padded_grid = padarray(grid, 1, minimum(states) - 1) # pad grid with label that is outside support
# neighbor_counts = zeros(nrows, ncols, num_states)
# for (di, dj) in neighbour_offsets
# row_indices = 2+di:nrows+di+1
# col_indices = 2+dj:ncols+dj+1
# shifted_grid = padded_grid[row_indices, col_indices]
# for k in 1:num_states
# neighbor_counts[:, :, k] .+= (shifted_grid .== states[k])
# end
# end

# # Calculate conditional probabilities
# probs = exp.(β .* neighbor_counts)
# probs ./= sum(probs, dims=3)


# # Sample new states for chequerboard cells
# cumulative_probs = cumsum(probs, dims=3)
# rand_matrix = rand(nrows, ncols)

# # Indices of chequerboard cells
# chequerboard_indices = findall(chequerboard)

# # Sample new states and update grid
# sampled_indices = map(i -> findfirst(cumulative_probs[Tuple(i)..., :] .> rand_matrix[Tuple(i)...]), chequerboard_indices)
# grid[chequerboard] .= states[sampled_indices]


# ---- Simple version ----


for ci in findall(chequerboard)

# Get cartesian coordinates of current pixel
Expand Down Expand Up @@ -482,10 +517,29 @@ function simulatepotts(grid::AbstractMatrix{Int}, β; nsims::Int = 1, num_iterat
return grid
end

# function padarray(grid, pad_size, pad_value)
# padded_grid = fill(pad_value, size(grid)[1] + 2*pad_size, size(grid)[2] + 2*pad_size)
# padded_grid[pad_size+1:end-pad_size, pad_size+1:end-pad_size] .= grid
# return padded_grid
# end

function simulatepotts(nrows::Int, ncols::Int, num_states::Int, β; kwargs...)
grid = rand(1:num_states, nrows, ncols)
simulatepotts(grid, β; kwargs...)
@assert length(β) == 1
β = β[1]
β_crit = log(1 + sqrt(num_states))
if β < β_crit
# Random initialization for high temperature
grid = rand(1:num_states, nrows, ncols)
else
# Clustered initialization for low temperature
cluster_size = max(1, min(nrows, ncols) ÷ 4)
clustered_rows = ceil(Int, nrows / cluster_size)
clustered_cols = ceil(Int, ncols / cluster_size)
base_grid = rand(1:num_states, clustered_rows, clustered_cols)
grid = repeat(base_grid, inner=(cluster_size, cluster_size))
grid = grid[1:nrows, 1:ncols] # Trim to exact dimensions
end
simulatepotts(grid, β; kwargs...)
end

function simulatepotts(grid::AbstractMatrix{Union{Missing, I}}, β; kwargs...) where I <: Integer
Expand All @@ -499,7 +553,8 @@ function simulatepotts(grid::AbstractMatrix{Union{Missing, I}}, β; kwargs...) w
# Compute the mask
mask = ismissing.(grid)

# Replace missing entries with random states # TODO might converge faster with a better initialisation
# Replace missing entries with random states
# TODO might converge faster with a better initialisation
grid[mask] .= rand(states, sum(mask))

# Convert eltype of grid to Int
Expand Down
3 changes: 1 addition & 2 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,7 @@ function _train(θ̂, θ_train::P, θ_val::P, Z_train::T, Z_val::T;
# save the loss every epoch in case training is prematurely halted
savebool && @save loss_path loss_per_epoch

# If the current loss is better than the previous best, save θ̂ and
# update the minimum validation risk
# If the current loss is better than the previous best, save θ̂ and update the minimum validation risk
if val_risk <= min_val_risk
savebool && _savestate(θ̂, savepath, epoch)
min_val_risk = val_risk
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ end

## Potts model
β = 0.7
complete_grid = simulatepotts(n, n, 2, 0.99) # simulate marginally from the Ising model
complete_grid = simulatepotts(n, n, 2, β) # simulate marginally from the Ising model
@test size(complete_grid) == (n, n)
@test length(unique(complete_grid)) == 2
Expand Down

0 comments on commit 285cc74

Please sign in to comment.