-
Notifications
You must be signed in to change notification settings - Fork 23.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Function Request: np.interp #50334
Comments
I would not add an |
A basic approximation (without periodic boundaries): import torch
from torch import Tensor
def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
"""One-dimensional linear interpolation for monotonically increasing sample
points.
Returns the one-dimensional piecewise linear interpolant to a function with
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
Args:
x: the :math:`x`-coordinates at which to evaluate the interpolated
values.
xp: the :math:`x`-coordinates of the data points, must be increasing.
fp: the :math:`y`-coordinates of the data points, same length as `xp`.
Returns:
the interpolated values, same size as `x`.
"""
m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
b = fp[:-1] - (m * xp[:-1])
indicies = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1
indicies = torch.clamp(indicies, 0, len(m) - 1)
return m[indicies] * x + b[indicies] import math
import matplotlib.pyplot
import torch
xp = torch.linspace(0, math.tau, 10)
fp = torch.sin(xp)
x = torch.linspace(0, math.tau, 50)
y = interp(x, xp, fp)
matplotlib.pyplot.plot(x, y, "x") |
Hey! Thanks to @0x00b1 who leaves us a simplest version of the linear interpolation.
You can test this func with the code below.
And if you prefer to use the method based on torch.seachsorted, you can try modify this line of codes:
|
Test code:
|
Here's a version that
¹Of course it still requires def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor, dim: int=-1, extrapolate: str='constant') -> torch.Tensor:
"""One-dimensional linear interpolation between monotonically increasing sample
points, with extrapolation beyond sample points.
Returns the one-dimensional piecewise linear interpolant to a function with
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
Args:
x: The :math:`x`-coordinates at which to evaluate the interpolated
values.
xp: The :math:`x`-coordinates of the data points, must be increasing.
fp: The :math:`y`-coordinates of the data points, same shape as `xp`.
dim: Dimension across which to interpolate.
extrapolate: How to handle values outside the range of `xp`. Options are:
- 'linear': Extrapolate linearly beyond range of xp values.
- 'constant': Use the boundary value of `fp` for `x` values outside `xp`.
Returns:
The interpolated values, same size as `x`.
"""
# Move the interpolation dimension to the last axis
x = x.movedim(dim, -1)
xp = xp.movedim(dim, -1)
fp = fp.movedim(dim, -1)
m = torch.diff(fp) / torch.diff(xp) # slope
b = fp[..., :-1] - m * xp[..., :-1] # offset
indices = torch.searchsorted(xp, x, right=False)
if extrapolate == 'constant':
# Pad m and b to get constant values outside of xp range
m = torch.cat([torch.zeros_like(m)[..., :1], m, torch.zeros_like(m)[..., :1]], dim=-1)
b = torch.cat([fp[..., :1], b, fp[..., -1:]], dim=-1)
else: # extrapolate == 'linear'
indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1)
values = m.gather(-1, indices) * x + b.gather(-1, indices)
return values.movedim(-1, dim) Note: If there might be duplicates of (xp, fp) points in the data that's ok, but then you need to use some offset in the division, such as Usage/test (click to unfold)>>> import torch
>>> import interp
########### 2d data, dim=1 ###########
>>> fp = torch.tensor([[1, 2, 5, 9], [-1, 0, 2, -1]])
>>> xp = torch.tensor([[0, 2, 4, 10], [3, 5, 9, 11]])
>>> interpolated_vals = interp(torch.stack(2 * [torch.arange(15)]), xp, fp, extrapolate='linear')
>>> print(interpolated_vals)
tensor([[ 1.0000, 1.5000, 2.0000, 3.5000, 5.0000, 5.6667, 6.3333, 7.0000,
7.6667, 8.3333, 9.0000, 9.6667, 10.3333, 11.0000, 11.6667],
[-2.5000, -2.0000, -1.5000, -1.0000, -0.5000, 0.0000, 0.5000, 1.0000,
1.5000, 2.0000, 0.5000, -1.0000, -2.5000, -4.0000, -5.5000]])
>>> interpolated_vals = interp(torch.stack(2 * [torch.arange(15)]), xp, fp, extrapolate='constant')
>>> print(interpolated_vals)
tensor([[ 1.0000, 1.5000, 2.0000, 3.5000, 5.0000, 5.6667, 6.3333, 7.0000,
7.6667, 8.3333, 9.0000, 9.0000, 9.0000, 9.0000, 9.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -0.5000, 0.0000, 0.5000, 1.0000,
1.5000, 2.0000, 0.5000, -1.0000, -1.0000, -1.0000, -1.0000]])
########### 3d data, dim=1 ###########
>>> fp = torch.arange(0, 24).reshape((2, 4, 3))
>>> print(fp)
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23]]])
>>> xp = torch.tile(torch.tensor([1, 3, 4, 8]), (2,3,1)).swapaxes(1, 2) + torch.tensor([0, 1, -1])
>>> print(xp)
tensor([[[1, 2, 0],
[3, 4, 2],
[4, 5, 3],
[8, 9, 7]],
[[1, 2, 0],
[3, 4, 2],
[4, 5, 3],
[8, 9, 7]]])
>>> x = torch.tile(torch.arange(0, 8), (2, 3, 1)).swapaxes(1, 2)
>>> print(x)
tensor([[[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4],
[5, 5, 5],
[6, 6, 6],
[7, 7, 7]],
[[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4],
[5, 5, 5],
[6, 6, 6],
[7, 7, 7]]])
>>> interpolated_vals = interp(x, xp, fp, dim=1, extrapolate='constant')
>>> print(interpolated_vals)
tensor([[[ 0.0000, 1.0000, 2.0000],
[ 0.0000, 1.0000, 3.5000],
[ 1.5000, 1.0000, 5.0000],
[ 3.0000, 2.5000, 8.0000],
[ 6.0000, 4.0000, 8.7500],
[ 6.7500, 7.0000, 9.5000],
[ 7.5000, 7.7500, 10.2500],
[ 8.2500, 8.5000, 11.0000]],
[[12.0000, 13.0000, 14.0000],
[12.0000, 13.0000, 15.5000],
[13.5000, 13.0000, 17.0000],
[15.0000, 14.5000, 20.0000],
[18.0000, 16.0000, 20.7500],
[18.7500, 19.0000, 21.5000],
[19.5000, 19.7500, 22.2500],
[20.2500, 20.5000, 23.0000]]]) |
An implementation of NumPy's np.interp.
This was first requested on #1552.
cc @mruberry @rgommers @heitorschueroff
The text was updated successfully, but these errors were encountered: