From 00d7bdadab606c681bf565922434568649d72ee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20H=C3=B8xbro=20Hansen?= Date: Fri, 24 May 2024 17:17:45 +0200 Subject: [PATCH] Update test_xarray --- datashader/tests/test_xarray.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/datashader/tests/test_xarray.py b/datashader/tests/test_xarray.py index 8397d1bc2..71f3f1f40 100644 --- a/datashader/tests/test_xarray.py +++ b/datashader/tests/test_xarray.py @@ -1,7 +1,7 @@ from __future__ import annotations +from copy import deepcopy import numpy as np from numpy import nan -import os import xarray as xr import datashader as ds @@ -14,9 +14,6 @@ except ImportError: cupy = None -test_gpu = bool(int(os.getenv("DATASHADER_TEST_GPU", 0))) - - xda = xr.DataArray(data=np.array(([1.] * 10 + [10] * 10)), dims=('record'), coords={'x': xr.DataArray(np.array(([0.]*10 + [1]*10)), dims=('record')), @@ -46,9 +43,7 @@ def assert_eq(agg, b): assert agg.equals(b) -@pytest.mark.parametrize("source", [ - (xda), (xdda), (xds), (xdds), -]) +@pytest.mark.parametrize("source", [xda, xdda, xds, xdds]) def test_count(source): out = xr.DataArray(np.array([[5, 5], [5, 5]], dtype='i4'), coords=coords, dims=dims) @@ -97,7 +92,7 @@ def test_count(source): @pytest.mark.parametrize("ds2d", ds2ds) -@pytest.mark.parametrize("cuda", [False, True]) +@pytest.mark.parametrize('on_gpu', [False, pytest.param(True, marks=pytest.mark.gpu)]) @pytest.mark.parametrize("chunksizes", [ None, dict(x=10, channel=10), @@ -105,12 +100,10 @@ def test_count(source): dict(x=3, channel=10), dict(x=3, channel=1), ]) -def test_lines_xarray_common_x(ds2d, cuda, chunksizes): - source = ds2d.copy() - if cuda: - if not (cupy and test_gpu): - pytest.skip("CUDA tests not requested") - elif chunksizes is not None: +def test_lines_xarray_common_x(ds2d, on_gpu, chunksizes): + source = deepcopy(ds2d) + if on_gpu: + if chunksizes is not None: pytest.skip("CUDA-dask for LinesXarrayCommonX not implemented") # CPU -> GPU @@ -162,7 +155,7 @@ def test_lines_xarray_common_x(ds2d, cuda, chunksizes): assert_eq_ndarray(agg.x_range, (0, 4), close=True) assert_eq_ndarray(agg.y_range, (0, 2), close=True) assert_eq_ndarray(agg.data, sol_count) - assert isinstance(agg.data, cupy.ndarray if cuda else np.ndarray) + assert isinstance(agg.data, cupy.ndarray if on_gpu else np.ndarray) # any agg = canvas.line(source, x="x", y="name", agg=ds.any())