From fb01c72626a61310f874664cdb4d7b4c1b327bb3 Mon Sep 17 00:00:00 2001 From: Alexandre Poux Date: Sun, 28 Nov 2021 05:40:06 +0100 Subject: [PATCH] Use complex nan by default when interpolating out of bounds (#6019) * use complex nan by default when interpolating out of bounds * update whats-new.rst * remove unecessary complexity * analyse `dtype.kind` instead of using `np.iscomplexobj` Co-authored-by: Alexandre Poux --- doc/whats-new.rst | 3 ++- xarray/core/missing.py | 12 ++++++++---- xarray/tests/test_interp.py | 13 +++++++++++++ xarray/tests/test_missing.py | 22 ++++++++++++++++++++++ 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eb61cd154cf..8c49a648bd6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,7 +26,8 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ - +- Use complex nan when interpolating complex values out of bounds by default (instead of real nan) (:pull:`6019`). + By `Alexandre Poux `_. Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 8ed9e23f1eb..efaacfa619a 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -82,9 +82,11 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None): self._xi = xi self._yi = yi + nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j + if fill_value is None: - self._left = np.nan - self._right = np.nan + self._left = nan + self._right = nan elif isinstance(fill_value, Sequence) and len(fill_value) == 2: self._left = fill_value[0] self._right = fill_value[1] @@ -143,10 +145,12 @@ def __init__( self.cons_kwargs = kwargs self.call_kwargs = {} + nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j + if fill_value is None and method == "linear": - fill_value = np.nan, np.nan + fill_value = nan, nan elif fill_value is None: - fill_value = np.nan + fill_value = nan self.f = interp1d( xi, diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 2029e6af05b..fd480436889 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -907,3 +907,16 @@ def test_coord_attrs(x, expect_same_attrs): has_same_attrs = ds.interp(x=x).x.attrs == base_attrs assert expect_same_attrs == has_same_attrs + + +@requires_scipy +def test_interp1d_complex_out_of_bounds(): + """Ensure complex nans are used by default""" + da = xr.DataArray( + np.exp(0.3j * np.arange(4)), + [("time", np.arange(4))], + ) + + expected = da.interp(time=3.5, kwargs=dict(fill_value=np.nan + np.nan * 1j)) + actual = da.interp(time=3.5) + assert_identical(actual, expected) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 1ebcd9ac6f7..69b59a7418c 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -687,3 +687,25 @@ def test_interpolate_na_2d(coords): coords=coords, ) assert_equal(actual, expected_x) + + +@requires_scipy +def test_interpolators_complex_out_of_bounds(): + """Ensure complex nans are used for complex data""" + + xi = np.array([-1, 0, 1, 2, 5], dtype=np.float64) + yi = np.exp(1j * xi) + x = np.array([-2, 1, 6], dtype=np.float64) + + expected = np.array( + [np.nan + np.nan * 1j, np.exp(1j), np.nan + np.nan * 1j], dtype=yi.dtype + ) + + for method, interpolator in [ + ("linear", NumpyInterpolator), + ("linear", ScipyInterpolator), + ]: + + f = interpolator(xi, yi, method=method) + actual = f(x) + assert_array_equal(actual, expected)