Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Functionality for Linearized Optimal Transport Computations #187

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions examples/LinearizedTransport/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Optimal transport costs are expensive to compute in general, so scaling can be quite bad if we need to, say, compute
the OT cost pairwise for a reasonably sized family of measures. When this is the situation, it may be beneficial to
linearize the OT distance using the manifold-like structure induced by the Wasserstein cost. Fix μ, and consider the transformation
ν → T_ν, where T_ν is the optimal transport map pushing μ forward to ν. Now fix two other measures ν, ρ, not equal to μ.
We may approximate OT(ν, ρ) via OT(ν, ρ) ≈ ||T_ν - T_ρ||_L^2(μ). If μ, ν, and ρ are "nice" (i.e. have smooth and accessible densities
w.r.t to the Lebesgue measure), then the right hand side is easy to approximate well via standard numerical methods.

Now, it is a sad fact that recovering the maps T_ν is generally no easy task itself. But in the case of entropically regularized
transport, there exists a very nice entropic approximation to the transport map, which depends only on the measure ν and
a family of N i.i.d samples Y_i ∼ ν.

The following example is rather contrived, since if we only wanted to compute one distance, we're actually doing much more work than we
need to by computing 2 Sinkhorn problems and an integral on top of that, but again the main application here would be when
we have O(n^2) distances to compute

Note that the choice of reference measure can significantly affect the quality of the approximation, and as of writing there is
no non-heauristic method for choosing a "good" reference.

Relevant sources:

Moosmüller, Caroline, and Alexander Cloninger. "Linear optimal transport embedding: provable Wasserstein classification for certain rigid transformations and perturbations." Information and Inference: A Journal of the IMA 12.1 (2023): 363-389.
Pooladian, A.-A. and Niles-Weed, J. Entropic estimation of optimal transport maps. arXiv: 2109.12004, 2021

"""

using Distances
using Distributions
using OptimalTransport

N = 1000 # number of samples

# sample some points according to our chosen reference and target distributions
μ = rand(Normal(1,1), N)
ν = rand(Normal(0,1), N)
ρ = rand(Normal(2,1), N)

# set the weights on the samples to be uniform
a = fill(1/N, N)

# compute the cost matrices for the two pairs of distributions
C = pairwise(SqEuclidean(), μ', ν')
D = pairwise(SqEuclidean(), μ', ρ')
E = pairwise(SqEuclidean(), ν', ρ')

# get the entropic transport maps
T_ν = entropic_transport_map(a, a, ν, C, 0.1, SinkhornGibbs())
T_ρ = entropic_transport_map(a, a, ρ, D, 0.1, SinkhornGibbs())

# integrand for the linearization
f(x) = (T_ν([x]) - T_ρ([x]))^2

# convert target distributions to dirac clouds
ν_dist = DiscreteNonParametric(ν, a)
ρ_dist = DiscreteNonParametric(ρ, a)

# compute and compare
I = (sum(f.(μ)) / N)^0.5 # naive Monte Carlo approximation of the L2 distance between the entropic maps
J = ot_cost(sqeuclidean, ν_dist, ρ_dist)

println("Linear approximation of the distance: $I; True OT distance: $J")
2 changes: 1 addition & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
export SinkhornBarycenterGibbs
export QuadraticOTNewton

export sinkhorn, sinkhorn2
export sinkhorn, sinkhorn2, sinkhorn_potentials, entropic_transport_map
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
export sinkhorn_divergence, sinkhorn_divergence_unbalanced
Expand Down
88 changes: 88 additions & 0 deletions src/entropic/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,94 @@ function sinkhorn(μ, ν, C, ε, alg::Sinkhorn; kwargs...)
return γ
end

"""
sinkhorn_potentials(
μ, ν, C, ε, alg=SinkhornGibbs();
atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000,
)

Compute the dual potentials for the entropically regularized optimal transport
problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
`(length(μ), length(ν))`, and entropic regularization parameter `ε`.

Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
the iterate of the transport plan `G` satisfies
```julia
isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
```
The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
the computation is stopped.

Batch computations for multiple histograms with a common cost matrix `C` can be performed by
passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that
the number of source and target marginals is equal or that a single source or single target
marginal is provided (either as matrix or as vector). The optimal transport plans are
returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
`i`th pair of source and target marginals.

See also: [`sinkhorn2`](@ref)
"""

function sinkhorn_potentials(μ, ν, C, ε, alg::Sinkhorn; kwargs...)
# build solver
solver = build_solver(μ, ν, C, ε, alg; kwargs...)

# perform Sinkhorn algorithm
solve!(solver)

# compute optimal transport plan
u = solver.cache.u
v = solver.cache.v
f = ε * log.(u)
g = ε * log.(v)

return (f, g)
end


"""
sinkhorn_potentials(
μ, ν, C, ε, alg=SinkhornGibbs();
atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000,
)

Compute the entropic transport plan estimator for the entropically regularized optimal transport
problem with source and target marginals `μ` and `ν`, cost matrix `C` of size
`(length(μ), length(ν))`, and entropic regularization parameter `ε`.

Every `check_convergence` steps it is assessed if the algorithm is converged by checking if
the iterate of the transport plan `G` satisfies
```julia
isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1))
```
The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations,
the computation is stopped.

Batch computations for multiple histograms with a common cost matrix `C` can be performed by
passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that
the number of source and target marginals is equal or that a single source or single target
marginal is provided (either as matrix or as vector). The optimal transport plans are
returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the
`i`th pair of source and target marginals.

See also: [`sinkhorn2`](@ref)
"""

function entropic_transport_map(μ, ν, samples_ν, C, ε, alg::Sinkhorn; kwargs...)
_, g = sinkhorn_potentials(μ, ν, C, ε, alg; kwargs...)
N = size(ν, 1)
function T(x::AbstractVecOrMat)
b = zeros(N)
for i in 1:N
y = x .- samples_ν[i,:]
b[i] = exp(1/ε * (g[i] - 0.5 * sum(y .* y)))
end
return sum(b .* samples_ν, dims=1) / sum(b)
end
return T
end


function sinkhorn_cost_from_plan(γ, C, ε; regularization=false)
cost = if regularization
dot_matwise(γ, C) .+
Expand Down