From 45bb5e78837dce1b031244888d6194f7b1f32a03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20H=C3=B8xbro=20Hansen?= Date: Thu, 23 May 2024 15:16:55 +0200 Subject: [PATCH] Improve compatibility with dask-expr (#1335) --- .github/workflows/test.yaml | 2 +- datashader/data_libraries/dask.py | 19 +++++++++++++++++-- datashader/tests/test_dask.py | 15 +++++++++++++++ pyproject.toml | 2 +- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 56d4ef42a..7a036d5fd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -146,7 +146,7 @@ jobs: echo "ENVS=$envs" >> $GITHUB_ENV - uses: holoviz-dev/holoviz_tasks/install@v0 with: - name: unit_test_suite + name: unit_test_suite_np${{ matrix.numpy-version }} python-version: ${{ matrix.python-version }} channel-priority: flexible channels: ${{ env.CHANNELS }} diff --git a/datashader/data_libraries/dask.py b/datashader/data_libraries/dask.py index 8aa5954d1..dc0a9f125 100644 --- a/datashader/data_libraries/dask.py +++ b/datashader/data_libraries/dask.py @@ -16,6 +16,20 @@ __all__ = () +def _dask_compat(df): + """ + Places where this is done, are to be compatible with both + `dask-expr` and classic `dask.dataframe` (where `optimize` does not exist). + With dask-expr calling df.__dask_graph__() or df.__dask_keys__() will + make the graph no longer match the df._name, so we preemptively call it + to make it match. + + For more information, see the following comment: + https://github.com/holoviz/datashader/pull/1317#issuecomment-2039986852 + """ + return getattr(df, 'optimize', lambda: df)() + + @bypixel.pipeline.register(dd.DataFrame) def dask_pipeline(df, schema, canvas, glyph, summary, *, antialias=False, cuda=False): dsk, name = glyph_dispatch(glyph, df, schema, canvas, summary, antialias=antialias, cuda=cuda) @@ -27,6 +41,7 @@ def dask_pipeline(df, schema, canvas, glyph, summary, *, antialias=False, cuda=F if isinstance(dsk, da.Array): return da.compute(dsk, scheduler=scheduler)[0] + df = _dask_compat(df) keys = df.__dask_keys__() optimize = df.__dask_optimize__ graph = df.__dask_graph__() @@ -100,7 +115,7 @@ def func(partition: pd.DataFrame, cumulative_lens, partition_info=None): # Here be dragons # Get the dataframe graph - df = getattr(df, 'optimize', lambda: df)() # Work with new dask_expr + df = _dask_compat(df) graph = df.__dask_graph__() # Guess a reasonable output dtype from combination of dataframe dtypes @@ -211,7 +226,7 @@ def line(glyph, df, schema, canvas, summary, *, antialias=False, cuda=False): shape, bounds, st, axis = shape_bounds_st_and_axis(df, canvas, glyph) # Compile functions - df = getattr(df, 'optimize', lambda: df)() # Work with new dask_expr + df = _dask_compat(df) partitioned = isinstance(df, dd.DataFrame) and df.npartitions > 1 create, info, append, combine, finalize, antialias_stage_2, antialias_stage_2_funcs, _ = \ compile_components(summary, schema, glyph, antialias=antialias, cuda=cuda, diff --git a/datashader/tests/test_dask.py b/datashader/tests/test_dask.py index 413f98dc3..dc4af5c5e 100644 --- a/datashader/tests/test_dask.py +++ b/datashader/tests/test_dask.py @@ -2636,3 +2636,18 @@ def test_categorical_where_last_n(ddf, npartitions): assert_eq_ndarray(agg[:, :, :, 0].data, c.points(ddf, 'x', 'y', ds.by('cat2', ds.where(ds.last('plusminus'), 'reverse'))).data) + +@pytest.mark.parametrize('ddf', ddfs) +@pytest.mark.parametrize('npartitions', [1, 2, 3, 4]) +def test_series_reset_index(ddf, npartitions): + # Test for: https://github.com/holoviz/datashader/issues/1331 + ser = ddf['i32'].reset_index() + cvs = ds.Canvas(plot_width=2, plot_height=2) + out = cvs.line(ser, x='index', y='i32') + + expected = xr.DataArray( + data=[[True, False], [False, True]], + coords={"index": [4.75, 14.25], "i32": [4.75, 14.25]}, + dims=['i32', 'index'], + ) + assert_eq_xr(out, expected) diff --git a/pyproject.toml b/pyproject.toml index 83954127c..46c37fa53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ [tool.codespell] -ignore-words-list = "trough,thi" +ignore-words-list = "trough,thi,ser" [tool.ruff]