Skip to content

Commit

Permalink
fix avro and growable
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Oct 8, 2024
1 parent cf5c63b commit 3beca20
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 28 deletions.
1 change: 1 addition & 0 deletions crates/polars-arrow/src/array/growable/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl<'a> Growable<'a> for GrowableStruct<'a> {
if let Some(validity) = &mut self.validity {
validity.extend_constant(additional, false);
}
self.length += additional;
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn make_mutable(
.iter()
.map(|field| make_mutable(field.dtype(), None, capacity))
.collect::<PolarsResult<Vec<_>>>()?;
Box::new(DynMutableStructArray::new(values, 0, dtype.clone()))
Box::new(DynMutableStructArray::new(values, dtype.clone()))
as Box<dyn MutableArray>
},
other => {
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-arrow/src/io/avro/read/nested.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ pub struct DynMutableStructArray {
}

impl DynMutableStructArray {
pub fn new(values: Vec<Box<dyn MutableArray>>, length: usize, dtype: ArrowDataType) -> Self {
pub fn new(values: Vec<Box<dyn MutableArray>>, dtype: ArrowDataType) -> Self {
Self {
dtype,
length,
length: 0,
values,
validity: None,
}
Expand All @@ -243,11 +243,11 @@ impl DynMutableStructArray {
#[inline]
fn push_null(&mut self) {
self.values.iter_mut().for_each(|x| x.push_null());
self.length += 1;
match &mut self.validity {
Some(validity) => validity.push(false),
None => self.init_validity(),
}
self.length += 1;
}

fn init_validity(&mut self) {
Expand Down
26 changes: 11 additions & 15 deletions py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import io
from dataclasses import dataclass
from datetime import datetime, time
from typing import TYPE_CHECKING

import pandas as pd
import pyarrow as pa
import pytest
import io

import polars as pl
import polars.selectors as cs
Expand Down Expand Up @@ -1084,12 +1084,14 @@ def test_zfs_nullable_when_otherwise() -> None:
df = pl.DataFrame([a, b])

df = df.select(
x = pl.when(pl.col.a.is_not_null()).then(pl.col.a).otherwise(pl.col.b),
y = pl.when(pl.col.a.is_null()).then(pl.col.a).otherwise(pl.col.b),
x=pl.when(pl.col.a.is_not_null()).then(pl.col.a).otherwise(pl.col.b),
y=pl.when(pl.col.a.is_null()).then(pl.col.a).otherwise(pl.col.b),
)

assert_series_equal(df['x'], pl.Series('x', [{}, {}, {}, {}, None], pl.Struct([])))
assert_series_equal(df['y'], pl.Series('y', [None, None, None, {}, None], pl.Struct([])))
assert_series_equal(df["x"], pl.Series("x", [{}, {}, {}, {}, None], pl.Struct([])))
assert_series_equal(
df["y"], pl.Series("y", [None, None, None, {}, None], pl.Struct([]))
)


def test_zfs_struct_fns() -> None:
Expand All @@ -1098,13 +1100,13 @@ def test_zfs_struct_fns() -> None:
assert a.struct.fields == []

# @TODO: This should really throw an error as per #19132
assert a.struct.rename_fields(['a']).struct.unnest().shape == (1, 0)
assert a.struct.rename_fields(["a"]).struct.unnest().shape == (1, 0)
assert a.struct.rename_fields([]).struct.unnest().shape == (1, 0)

assert_series_equal(a.struct.json_encode(), pl.Series('a', ["{}"], pl.String))
assert_series_equal(a.struct.json_encode(), pl.Series("a", ["{}"], pl.String))


@pytest.mark.parametrize("format", ['binary', 'json'])
@pytest.mark.parametrize("format", ["binary", "json"])
@pytest.mark.parametrize("size", [0, 1, 2, 13])
def test_zfs_serialization_roundtrip(format: pl.SerializationFormat, size: int) -> None:
a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame()
Expand All @@ -1125,13 +1127,7 @@ def test_zfs_row_encoding(size: int) -> None:

df = pl.DataFrame([a, pl.Series("x", list(range(size)), pl.Int8)])

gb = (
df
.lazy()
.group_by(["a", "x"])
.agg(pl.all().min())
.collect(streaming=True)
)
gb = df.lazy().group_by(["a", "x"]).agg(pl.all().min()).collect(streaming=True)

# We need to ignore the order because the group_by is undeterministic
assert_frame_equal(gb, df, check_row_order=False)
18 changes: 10 additions & 8 deletions py-polars/tests/unit/io/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import io
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any

import pandas as pd
import pytest
Expand Down Expand Up @@ -358,9 +358,9 @@ def test_ipc_variadic_buffers_categorical_binview_18636() -> None:

@pytest.mark.parametrize("size", [0, 1, 2, 13])
def test_ipc_chunked_roundtrip(size: int) -> None:
a = pl.Series("a", [{ 'x': 1 }] * size, pl.Struct({ 'x': pl.Int8 })).to_frame()
a = pl.Series("a", [{"x": 1}] * size, pl.Struct({"x": pl.Int8})).to_frame()

c = pl.concat([a] * 2, how='vertical')
c = pl.concat([a] * 2, how="vertical")

f = io.BytesIO()
c.write_ipc(f)
Expand All @@ -384,7 +384,7 @@ def test_zfs_ipc_roundtrip(size: int) -> None:
def test_zfs_ipc_chunked_roundtrip(size: int) -> None:
a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame()

c = pl.concat([a] * 2, how='vertical')
c = pl.concat([a] * 2, how="vertical")

f = io.BytesIO()
c.write_ipc(f)
Expand All @@ -394,13 +394,15 @@ def test_zfs_ipc_chunked_roundtrip(size: int) -> None:


@pytest.mark.parametrize("size", [0, 1, 2, 13])
@pytest.mark.parametrize("value", [{}, { 'x': 1 }])
@pytest.mark.parametrize("value", [{}, {"x": 1}])
@pytest.mark.write_disk
def test_memmap_ipc_chunked_structs(size: int, value: Dict[str, int], tmp_path: Path) -> None:
def test_memmap_ipc_chunked_structs(
size: int, value: dict[str, int], tmp_path: Path
) -> None:
a = pl.Series("a", [value] * size, pl.Struct).to_frame()

c = pl.concat([a] * 2, how='vertical')
c = pl.concat([a] * 2, how="vertical")

f = tmp_path / 'f.ipc'
f = tmp_path / "f.ipc"
c.write_ipc(f)
assert_frame_equal(c, pl.read_ipc(f))
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def test_json_infer_3_dtypes() -> None:
assert out.dtypes[0] == pl.List(pl.String)


#NOTE: This doesn't work for 0, but that is normal
# NOTE: This doesn't work for 0, but that is normal
@pytest.mark.parametrize("size", [1, 2, 13])
def test_zfs_json_roundtrip(size: int) -> None:
a = pl.Series("a", [{}] * size, pl.Struct([])).to_frame()
Expand Down

0 comments on commit 3beca20

Please sign in to comment.