Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG(shells): fix and improve partition() with NNLS #145

Merged
merged 1 commit into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions glass/core/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def nnls(
a: ArrayLike,
b: ArrayLike,
*,
tol: float = 1e-10,
tol: float = 0.0,
maxiter: int | None = None,
) -> ArrayLike:
"""Compute a non-negative least squares solution.
Expand All @@ -39,7 +39,7 @@ def nnls(
if a.shape[0] != b.shape[0]:
raise ValueError("the shapes of `a` and `b` do not match")

m, n = a.shape
_, n = a.shape

if maxiter is None:
maxiter = 3 * n
Expand Down
80 changes: 66 additions & 14 deletions glass/shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,36 @@ def partition(z: ArrayLike,
redshift arrays of all window functions. Intermediate function
values are found by linear interpolation.

When partitioning a density function, it is usually desirable to
keep the normalisation fixed. In that case, the problem can be
enhanced with the further constraint that the sum of the solution
equals the integral of the target function,

.. math::
\\begin{pmatrix}
w_1(z_1) \\Delta z_1 & w_2(z_1) \\, \\Delta z_1 & \\cdots \\\\
w_1(z_2) \\Delta z_2 & w_2(z_2) \\, \\Delta z_2 & \\cdots \\\\
\\vdots & \\vdots & \\ddots \\\\
\\hline
\\lambda & \\lambda & \\cdots
\\end{pmatrix} \\, \\begin{pmatrix}
x_1 \\\\ x_2 \\\\ \\vdots
\\end{pmatrix} = \\begin{pmatrix}
f(z_1) \\, \\Delta z_1 \\\\ f(z_2) \\, \\Delta z_2 \\\\ \\vdots
\\\\ \\hline \\lambda \\int \\! f(z) \\, dz
\\end{pmatrix} \\;,

where :math:`\\lambda` is a multiplier to enforce the integral
contraints.

The :func:`partition()` function implements a number of methods to
obtain a solution:

If ``method="nnls"`` (the default), obtain a partition from a
non-negative least-squares solution. This will match the shape of
the input function well, but the overall normalisation might be
differerent. The contribution from each shell is a positive number,
which is required to partition e.g. density distributions.
non-negative least-squares solution. This will usually match the
shape of the input function closely. The contribution from each
shell is a positive number, which is required to partition e.g.
density functions.

If ``method="lstsq"``, obtain a partition from an unconstrained
least-squares solution. This will more closely match the shape of
Expand All @@ -402,8 +427,8 @@ def partition(z: ArrayLike,

If ``method="restrict"``, obtain a partition by integrating the
restriction (using :func:`restrict`) of the function :math:`f` to
each window. This will more closely match the normalisation of the
input function, but the shape might differ.
each window. For overlapping shells, this method might produce
results which are far from the input function.

"""
try:
Expand All @@ -417,9 +442,15 @@ def partition_lstsq(
z: ArrayLike,
fz: ArrayLike,
shells: Sequence[RadialWindow],
*,
sumtol: float = 0.01,
) -> ArrayLike:
"""Least-squares partition."""

# make sure nothing breaks
if sumtol < 1e-4:
sumtol = 1e-4

# compute the union of all given redshift grids
zp = z
for w in shells:
Expand All @@ -440,11 +471,16 @@ def partition_lstsq(
b = ndinterp(zp, z, fz, left=0., right=0.)
b = b*dz

# now a is a matrix of shape (len(shells), len(zp))
# and b is a matrix of shape (*dims, len(zp))
# append a constraint for the integral
mult = 1/sumtol
a = np.concatenate([a, mult * np.ones((len(shells), 1))], axis=-1)
b = np.concatenate([b, mult * np.reshape(np.trapz(fz, z), (*dims, 1))], axis=-1)

# now a is a matrix of shape (len(shells), len(zp) + 1)
# and b is a matrix of shape (*dims, len(zp) + 1)
# need to find weights x such that b == x @ a over all axes of b
# do the least-squares fit over partially flattened b, then reshape
x = np.linalg.lstsq(a.T, b.reshape(-1, zp.size).T, rcond=None)[0]
x = np.linalg.lstsq(a.T, b.reshape(-1, zp.size + 1).T, rcond=None)[0]
x = x.T.reshape(*dims, len(shells))
# roll the last axis of size len(shells) to the front
x = np.moveaxis(x, -1, 0)
Expand All @@ -456,6 +492,8 @@ def partition_nnls(
z: ArrayLike,
fz: ArrayLike,
shells: Sequence[RadialWindow],
*,
sumtol: float = 0.01,
) -> ArrayLike:
"""Non-negative least-squares partition.

Expand All @@ -466,6 +504,10 @@ def partition_nnls(

from .core.algorithm import nnls

# make sure nothing breaks
if sumtol < 1e-4:
sumtol = 1e-4

# compute the union of all given redshift grids
zp = z
for w in shells:
Expand All @@ -486,13 +528,23 @@ def partition_nnls(
b = ndinterp(zp, z, fz, left=0., right=0.)
b = b*dz

# now a is a matrix of shape (len(shells), len(zp))
# and b is a matrix of shape (*dims, len(zp))
# for each dim, find weights x such that b == a.T @ x
# the output is more conveniently shapes with len(shells) first
# append a constraint for the integral
mult = 1/sumtol
a = np.concatenate([a, mult * np.ones((len(shells), 1))], axis=-1)
b = np.concatenate([b, mult * np.reshape(np.trapz(fz, z), (*dims, 1))], axis=-1)

# now a is a matrix of shape (len(shells), len(zp) + 1)
# and b is a matrix of shape (*dims, len(zp) + 1)
# for each dim, find non-negative weights x such that b == a.T @ x

# reduce the dimensionality of the problem using a thin QR decomposition
q, r = np.linalg.qr(a.T)
y = np.einsum('ji,...j', q, b)

# for each dim, find non-negative weights x such that y == r @ x
x = np.empty([len(shells)] + dims)
for i in np.ndindex(*dims):
x[(slice(None),) + i] = nnls(a.T, b[i])[0]
x[(...,) + i] = nnls(r, y[i])

# all done
return x
Expand Down
2 changes: 2 additions & 0 deletions glass/test/test_shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,5 @@ def test_partition(method):
part = partition(z, fz, shells, method=method)

assert part.shape == (len(shells), 3, 2)

assert np.allclose(part.sum(axis=0), np.trapz(fz, z))
Loading