Skip to content

Commit

Permalink
Use complex nan by default when interpolating out of bounds (#6019)
Browse files Browse the repository at this point in the history
* use complex nan by default when interpolating out of bounds

* update whats-new.rst

* remove unecessary complexity

* analyse `dtype.kind` instead of using `np.iscomplexobj`

Co-authored-by: Alexandre Poux <work@alexandrepoux.fr>
  • Loading branch information
pums974 and Alexandre Poux authored Nov 28, 2021
1 parent 23d345b commit fb01c72
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ New Features

Breaking changes
~~~~~~~~~~~~~~~~

- Use complex nan when interpolating complex values out of bounds by default (instead of real nan) (:pull:`6019`).
By `Alexandre Poux <https://github.com/pums974>`_.

Deprecations
~~~~~~~~~~~~
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None):
self._xi = xi
self._yi = yi

nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j

if fill_value is None:
self._left = np.nan
self._right = np.nan
self._left = nan
self._right = nan
elif isinstance(fill_value, Sequence) and len(fill_value) == 2:
self._left = fill_value[0]
self._right = fill_value[1]
Expand Down Expand Up @@ -143,10 +145,12 @@ def __init__(
self.cons_kwargs = kwargs
self.call_kwargs = {}

nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j

if fill_value is None and method == "linear":
fill_value = np.nan, np.nan
fill_value = nan, nan
elif fill_value is None:
fill_value = np.nan
fill_value = nan

self.f = interp1d(
xi,
Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,16 @@ def test_coord_attrs(x, expect_same_attrs):

has_same_attrs = ds.interp(x=x).x.attrs == base_attrs
assert expect_same_attrs == has_same_attrs


@requires_scipy
def test_interp1d_complex_out_of_bounds():
"""Ensure complex nans are used by default"""
da = xr.DataArray(
np.exp(0.3j * np.arange(4)),
[("time", np.arange(4))],
)

expected = da.interp(time=3.5, kwargs=dict(fill_value=np.nan + np.nan * 1j))
actual = da.interp(time=3.5)
assert_identical(actual, expected)
22 changes: 22 additions & 0 deletions xarray/tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,25 @@ def test_interpolate_na_2d(coords):
coords=coords,
)
assert_equal(actual, expected_x)


@requires_scipy
def test_interpolators_complex_out_of_bounds():
"""Ensure complex nans are used for complex data"""

xi = np.array([-1, 0, 1, 2, 5], dtype=np.float64)
yi = np.exp(1j * xi)
x = np.array([-2, 1, 6], dtype=np.float64)

expected = np.array(
[np.nan + np.nan * 1j, np.exp(1j), np.nan + np.nan * 1j], dtype=yi.dtype
)

for method, interpolator in [
("linear", NumpyInterpolator),
("linear", ScipyInterpolator),
]:

f = interpolator(xi, yi, method=method)
actual = f(x)
assert_array_equal(actual, expected)

0 comments on commit fb01c72

Please sign in to comment.