Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function Request: np.interp #50334

Open
Tracked by #50341
mruberry opened this issue Jan 10, 2021 · 5 comments
Open
Tracked by #50341

Function Request: np.interp #50334

mruberry opened this issue Jan 10, 2021 · 5 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: interpolation module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mruberry
Copy link
Collaborator

mruberry commented Jan 10, 2021

An implementation of NumPy's np.interp.

This was first requested on #1552.

cc @mruberry @rgommers @heitorschueroff

@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: numpy Related to numpy support, and also numpy compatibility of our operators function request A request for a new function or the addition of new arguments/modes to an existing function. module: interpolation labels Jan 10, 2021
@rgommers
Copy link
Collaborator

I would not add an interp function. NumPy has very limited interpolation functionality, it'd be better to be scipy.interpolate compatible. And the interp1d function (also requested in gh-50335) is essentially the same.

@0x00b1
Copy link
Contributor

0x00b1 commented Dec 24, 2021

A basic approximation (without periodic boundaries):

import torch
from torch import Tensor


def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
    """One-dimensional linear interpolation for monotonically increasing sample
    points.

    Returns the one-dimensional piecewise linear interpolant to a function with
    given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.

    Args:
        x: the :math:`x`-coordinates at which to evaluate the interpolated
            values.
        xp: the :math:`x`-coordinates of the data points, must be increasing.
        fp: the :math:`y`-coordinates of the data points, same length as `xp`.

    Returns:
        the interpolated values, same size as `x`.
    """
    m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
    b = fp[:-1] - (m * xp[:-1])

    indicies = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1
    indicies = torch.clamp(indicies, 0, len(m) - 1)

    return m[indicies] * x + b[indicies]
import math

import matplotlib.pyplot
import torch


xp = torch.linspace(0, math.tau, 10)
fp = torch.sin(xp)

x = torch.linspace(0, math.tau, 50)
y = interp(x, xp, fp)

matplotlib.pyplot.plot(x, y, "x")

image

@yiyuzhuang
Copy link

yiyuzhuang commented Sep 15, 2022

Hey! Thanks to @0x00b1 who leaves us a simplest version of the linear interpolation.
I slightly extend this code into a version which supports the 1-D data of batches, like (N, D).


def interpolate(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor:
    """One-dimensional linear interpolation for monotonically increasing sample
    points.

    Returns the one-dimensional piecewise linear interpolant to a function with
    given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.

    Args:
        x: the :math:`x`-coordinates at which to evaluate the interpolated
            values.
        xp: the :math:`x`-coordinates of the data points, must be increasing.
        fp: the :math:`y`-coordinates of the data points, same length as `xp`.

    Returns:
        the interpolated values, same size as `x`.
    """
    m = (fp[:,1:] - fp[:,:-1]) / (xp[:,1:] - xp[:,:-1])  #slope
    b = fp[:, :-1] - (m.mul(xp[:, :-1]) )

    indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), -1) - 1  #torch.ge:  x[i] >= xp[i] ? true: false
    indicies = torch.clamp(indicies, 0, m.shape[-1] - 1)

    line_idx = torch.linspace(0, indicies.shape[0], 1, device=indicies.device).to(torch.long)
    line_idx = line_idx.expand(indicies.shape)
    # idx = torch.cat([line_idx, indicies] , 0)
    return m[line_idx, indicies].mul(x) + b[line_idx, indicies]

You can test this func with the code below.

x = torch.linspace(0, 10*3.14, 5).expand(3, -1)
y = torch.sin(x)

sc = torch.linspace(x.min(), x.max(), 50).expand(3, -1)
new = interpolate(sc, x, y)

import matplotlib.pyplot as plt
plt.plot(x[0].numpy().squeeze(), y[0].numpy().squeeze(), color='blue')
plt.plot(sc[0].numpy().squeeze(), new[0].numpy().squeeze(), color='orange')
plt.show()

And if you prefer to use the method based on torch.seachsorted, you can try modify this line of codes:

 # indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), -1) - 1  #torch.ge:  x[i] >= xp[i] ? true: false
 # change it into
    indicies = torch.searchsorted(xp.contiguous(), x.contiguous(), right=False) -1  # verson 2 : searchsorted

@logchan
Copy link

logchan commented Jan 8, 2023

line_idx = torch.linspace(0, indicies.shape[0], 1, device=indicies.device).to(torch.long)

line_idx is [0] (steps is 1) so all samples are from the first line. I'd use:

line_idx = torch.arange(len(indicies), device=indicies.device).view(-1, 1)

Test code:

x = torch.linspace(0, 10*3.14, 100).repeat(3, 1)
y = torch.sin(x)
y[1] *= 2
y[2] *= 3

sc = torch.linspace(x.min(), x.max(), 10).repeat(3, 1)
sc[1] *= 0.8
sc[2] *= 0.5
new = interpolate(sc, x, y)

plt.plot(x[0].numpy().squeeze(), y[0].numpy().squeeze(), color='blue')
plt.plot(x[1].numpy().squeeze(), y[1].numpy().squeeze(), color='blue')
plt.plot(x[2].numpy().squeeze(), y[2].numpy().squeeze(), color='blue')

plt.plot(sc[0].numpy().squeeze(), new[0].numpy().squeeze(), color='orange')
plt.plot(sc[1].numpy().squeeze(), new[1].numpy().squeeze(), color='red')
plt.plot(sc[2].numpy().squeeze(), new[2].numpy().squeeze(), color='green')
plt.show()

@MoritzLange
Copy link

MoritzLange commented Aug 22, 2024

Here's a version that

  • fixes the problem @logchan has already identified with @yiyuzhuang's code, by using torch.gather()
  • allows arbitrary data shapes¹
  • allows interpolation across any dimension
  • allows to choose the kind of extrapolation to be used beyond the range covered by xp (both functions above assume linear extrapolation, while np.interp defaults to constant extrapolation with its left and right keywords).

¹Of course it still requires xp.shape == fp.shape and x.shape can only differ from the former two in the interpolation dimension.

def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor, dim: int=-1, extrapolate: str='constant') -> torch.Tensor:
    """One-dimensional linear interpolation between monotonically increasing sample
    points, with extrapolation beyond sample points.

    Returns the one-dimensional piecewise linear interpolant to a function with
    given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.

    Args:
        x: The :math:`x`-coordinates at which to evaluate the interpolated
            values.
        xp: The :math:`x`-coordinates of the data points, must be increasing.
        fp: The :math:`y`-coordinates of the data points, same shape as `xp`.
        dim: Dimension across which to interpolate.
        extrapolate: How to handle values outside the range of `xp`. Options are:
            - 'linear': Extrapolate linearly beyond range of xp values.
            - 'constant': Use the boundary value of `fp` for `x` values outside `xp`.

    Returns:
        The interpolated values, same size as `x`.
    """
    # Move the interpolation dimension to the last axis
    x = x.movedim(dim, -1)
    xp = xp.movedim(dim, -1)
    fp = fp.movedim(dim, -1)
    
    m = torch.diff(fp) / torch.diff(xp) # slope
    b = fp[..., :-1] - m * xp[..., :-1] # offset
    indices = torch.searchsorted(xp, x, right=False)
    
    if extrapolate == 'constant':
        # Pad m and b to get constant values outside of xp range
        m = torch.cat([torch.zeros_like(m)[..., :1], m, torch.zeros_like(m)[..., :1]], dim=-1)
        b = torch.cat([fp[..., :1], b, fp[..., -1:]], dim=-1)
    else: # extrapolate == 'linear'
        indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1)

    values = m.gather(-1, indices) * x + b.gather(-1, indices)
    
    return values.movedim(-1, dim)

Note: If there might be duplicates of (xp, fp) points in the data that's ok, but then you need to use some offset in the division, such as m = torch.diff(fp) / (torch.diff(xp) + 1e-10), to avoid nan values in m. These nan values are not a problem in the forward pass, because duplicate (xp, fp) don't factor into the result, but they will kill the gradient in the backward pass.

Usage/test (click to unfold)
>>> import torch
>>> import interp

###########   2d data, dim=1   ###########
>>> fp = torch.tensor([[1, 2, 5, 9], [-1, 0, 2, -1]])
>>> xp = torch.tensor([[0, 2, 4, 10], [3, 5, 9, 11]])
>>> interpolated_vals = interp(torch.stack(2 * [torch.arange(15)]), xp, fp, extrapolate='linear')
>>> print(interpolated_vals)
tensor([[ 1.0000,  1.5000,  2.0000,  3.5000,  5.0000,  5.6667,  6.3333,  7.0000,
          7.6667,  8.3333,  9.0000,  9.6667, 10.3333, 11.0000, 11.6667],
        [-2.5000, -2.0000, -1.5000, -1.0000, -0.5000,  0.0000,  0.5000,  1.0000,
          1.5000,  2.0000,  0.5000, -1.0000, -2.5000, -4.0000, -5.5000]])

>>> interpolated_vals = interp(torch.stack(2 * [torch.arange(15)]), xp, fp, extrapolate='constant')
>>> print(interpolated_vals)
tensor([[ 1.0000,  1.5000,  2.0000,  3.5000,  5.0000,  5.6667,  6.3333,  7.0000,
          7.6667,  8.3333,  9.0000,  9.0000,  9.0000,  9.0000,  9.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -0.5000,  0.0000,  0.5000,  1.0000,
          1.5000,  2.0000,  0.5000, -1.0000, -1.0000, -1.0000, -1.0000]])

###########   3d data, dim=1   ###########
>>> fp = torch.arange(0, 24).reshape((2, 4, 3))
>>> print(fp)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]]])

>>> xp = torch.tile(torch.tensor([1, 3, 4, 8]), (2,3,1)).swapaxes(1, 2) + torch.tensor([0, 1, -1])
>>> print(xp)
tensor([[[1, 2, 0],
         [3, 4, 2],
         [4, 5, 3],
         [8, 9, 7]],

        [[1, 2, 0],
         [3, 4, 2],
         [4, 5, 3],
         [8, 9, 7]]])

>>> x = torch.tile(torch.arange(0, 8), (2, 3, 1)).swapaxes(1, 2)
>>> print(x)
tensor([[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2],
         [3, 3, 3],
         [4, 4, 4],
         [5, 5, 5],
         [6, 6, 6],
         [7, 7, 7]],

        [[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2],
         [3, 3, 3],
         [4, 4, 4],
         [5, 5, 5],
         [6, 6, 6],
         [7, 7, 7]]])

>>> interpolated_vals = interp(x, xp, fp, dim=1, extrapolate='constant')
>>> print(interpolated_vals)
tensor([[[ 0.0000,  1.0000,  2.0000],
         [ 0.0000,  1.0000,  3.5000],
         [ 1.5000,  1.0000,  5.0000],
         [ 3.0000,  2.5000,  8.0000],
         [ 6.0000,  4.0000,  8.7500],
         [ 6.7500,  7.0000,  9.5000],
         [ 7.5000,  7.7500, 10.2500],
         [ 8.2500,  8.5000, 11.0000]],

        [[12.0000, 13.0000, 14.0000],
         [12.0000, 13.0000, 15.5000],
         [13.5000, 13.0000, 17.0000],
         [15.0000, 14.5000, 20.0000],
         [18.0000, 16.0000, 20.7500],
         [18.7500, 19.0000, 21.5000],
         [19.5000, 19.7500, 22.2500],
         [20.2500, 20.5000, 23.0000]]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: interpolation module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants