Skip to content

Commit

Permalink
Ensure categorical column order is the same across dask partitions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ianthomas23 authored Jun 23, 2023
1 parent 6dce648 commit 5a89820
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
2 changes: 1 addition & 1 deletion datashader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,7 @@ def _bypixel_sanitise(source, glyph, agg):
source[glyph.geometry].array._sindex = sindex
dshape = dshape_from_pandas(source)
elif isinstance(source, dd.DataFrame):
dshape = dshape_from_dask(source)
dshape, source = dshape_from_dask(source)
elif isinstance(source, Dataset):
# Multi-dimensional Dataset
dshape = dshape_from_xarray_dataset(source)
Expand Down
33 changes: 33 additions & 0 deletions datashader/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,3 +2128,36 @@ def test_dataframe_dtypes(ddf, npartitions):
ddf = ddf.repartition(npartitions)
assert ddf.npartitions == npartitions
ds.Canvas(2, 2).points(ddf, 'x', 'y', ds.count())


@pytest.mark.parametrize('on_gpu', [False, True])
def test_dask_categorical_counts(on_gpu):
# Issue 1202
if on_gpu and not test_gpu:
pytest.skip('gpu tests not enabled')

df = pd.DataFrame(
data=dict(
x = [0, 1, 2, 0, 1, 2, 1, 1, 1, 1, 1, 1],
y = [0]*12,
cat = ['a', 'b', 'c', 'a', 'b', 'c', 'b', 'b', 'b', 'b', 'b', 'c'],
)
)
ddf = dd.from_pandas(df, npartitions=2)
assert ddf.npartitions == 2
ddf.cat = ddf.cat.astype('category')

# Categorical counts at the dataframe level to confirm test is reasonable.
cat_totals = ddf.cat.value_counts().compute()
assert cat_totals['a'] == 2
assert cat_totals['b'] == 7
assert cat_totals['c'] == 3

canvas = ds.Canvas(3, 1, x_range=(0, 2), y_range=(-1, 1))
agg = canvas.points(ddf, 'x', 'y', ds.by("cat", ds.count()))
assert all(agg.cat == ['a', 'b', 'c'])

# Prior to fix, this gives [7, 3, 2]
sum_cat = agg.sum(dim=['x', 'y'])
assert all(sum_cat.cat == ['a', 'b', 'c'])
assert all(sum_cat.values == [2, 7, 3])
2 changes: 1 addition & 1 deletion datashader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def dshape_from_dask(df):
# for dask-cudf DataFrames with multiple partitions
return datashape.var * datashape.Record([
(k, dshape_from_pandas_helper(df[k].get_partition(0))) for k in df.columns
])
]), df


def dshape_from_xarray_dataset(xr_ds):
Expand Down

0 comments on commit 5a89820

Please sign in to comment.