Skip to content

Commit

Permalink
🙈 check for contiguous array (#64)
Browse files Browse the repository at this point in the history
* 🙈 check for contiguous array

* 🎄 fix mypy warning
  • Loading branch information
jvdd authored Dec 24, 2023
1 parent 18b9fdf commit e5e4866
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
22 changes: 22 additions & 0 deletions tests/test_tsdownsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,25 @@ def test_error_invalid_args():
with pytest.raises(ValueError) as e_msg:
MinMaxDownsampler().downsample(arr, arr[:-1], n_out=100, n_threads=2)
assert "x and y must have the same length" in str(e_msg.value)


@pytest.mark.parametrize("downsampler", generate_rust_downsamplers())
def test_non_contiguous_array(downsampler: AbstractDownsampler):
"""Test non contiguous array."""
arr = np.random.randint(0, 100, size=10_000)
arr = arr[::2]
assert not arr.flags["C_CONTIGUOUS"]
with pytest.raises(ValueError) as e_msg:
downsampler.downsample(arr, n_out=100)
assert "must be contiguous" in str(e_msg.value)


def test_everynth_non_contiguous_array():
"""Test non contiguous array."""
arr = np.random.randint(0, 100, size=10_000)
arr = arr[::2]
assert not arr.flags["C_CONTIGUOUS"]
downsampler = EveryNthDownsampler()
s_downsampled = downsampler.downsample(arr, n_out=100)
assert s_downsampled[0] == 0
assert s_downsampled[-1] == 4950
3 changes: 3 additions & 0 deletions tsdownsample/downsamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def downsample(


class EveryNthDownsampler(AbstractDownsampler):
def __init__(self, **kwargs):
super().__init__(check_contiguous=False, **kwargs)

def _downsample(
self, x: Union[np.ndarray, None], y: np.ndarray, n_out: int, **_
) -> np.ndarray:
Expand Down
19 changes: 18 additions & 1 deletion tsdownsample/downsampling_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@ class AbstractDownsampler(ABC):

def __init__(
self,
check_contiguous: bool = True,
x_dtype_regex_list: Optional[List[str]] = None,
y_dtype_regex_list: Optional[List[str]] = None,
):
self.check_contiguous = check_contiguous
self.x_dtype_regex_list = x_dtype_regex_list
self.y_dtype_regex_list = y_dtype_regex_list

def _check_contiguous(self, arr: np.ndarray, y: bool = True):
# necessary for rust downsamplers as they don't support non-contiguous arrays
# (we call .as_slice().unwrap() on the array) in the lib.rs file
# which will panic if the array is not contiguous
if not self.check_contiguous:
return

if arr.flags["C_CONTIGUOUS"]:
return

raise ValueError(f"{'y' if y else 'x'} array must be contiguous.")

def _supports_dtype(self, arr: np.ndarray, y: bool = True):
dtype_regex_list = self.y_dtype_regex_list if y else self.x_dtype_regex_list
# base case
Expand Down Expand Up @@ -66,6 +80,7 @@ def _check_valid_downsample_args(
raise ValueError("x must be 1D array")
if len(x) != len(y):
raise ValueError("x and y must have the same length")

return x, y

@staticmethod
Expand Down Expand Up @@ -113,8 +128,10 @@ def downsample(self, *args, n_out: int, **kwargs): # x and y are optional
self._check_valid_n_out(n_out)
x, y = self._check_valid_downsample_args(*args)
self._supports_dtype(y, y=True)
self._check_contiguous(y, y=True)
if x is not None:
self._supports_dtype(x, y=False)
self._check_contiguous(x, y=False)
return self._downsample(x, y, n_out, **kwargs)


Expand Down Expand Up @@ -144,7 +161,7 @@ class AbstractRustDownsampler(AbstractDownsampler, ABC):
"""RustDownsampler interface-class, subclassed by concrete downsamplers."""

def __init__(self):
super().__init__(_rust_dtypes, _y_rust_dtypes) # same for x and y
super().__init__(True, _rust_dtypes, _y_rust_dtypes) # same for x and y

@property
def rust_mod(self) -> ModuleType:
Expand Down

0 comments on commit e5e4866

Please sign in to comment.