Skip to content

Commit

Permalink
Add support for empty
Browse files Browse the repository at this point in the history
  • Loading branch information
aandres committed Dec 29, 2022
1 parent 4146c2c commit f033a33
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 110 deletions.
6 changes: 6 additions & 0 deletions docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ pre-commit install
This library relies on property based testing.
Tests convert randomly generated data from protobuf to arrow and back, making sure the end result is the same as the input.

To run tests fast:
```shell
pytest -n auto tests
```

To Get coverage:
```shell
coverage run --branch --include "*/protarrow/*" -m pytest tests
coverage report
Expand Down
9 changes: 9 additions & 0 deletions protarrow/arrow_to_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ def __iter__(self) -> Iterator[Any]:
else:
yield self.field_descriptor.message_type._concrete_class()

def prime(self):
"""This needs to be called if the columns are empty"""
empty = self.field_descriptor.message_type._concrete_class()
for parent, valid in zip(self.parents, self.validity_mask):
if valid.is_valid and valid.as_py():
value = getattr(parent, self.field_descriptor.name)
value.MergeFrom(empty)


@dataclasses.dataclass(frozen=True)
class RepeatedNestedIterable(collections.abc.Iterable):
Expand Down Expand Up @@ -374,6 +382,7 @@ def _extract_struct_field(
messages: Iterable[Message],
) -> None:
nested_list = OptionalNestedIterable(messages, field_descriptor, array.is_valid())
nested_list.prime()
_extract_array_messages(array, field_descriptor.message_type, nested_list)


Expand Down
15 changes: 12 additions & 3 deletions protarrow/cast_to_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,18 @@ def cast_struct_array(
)
arrays.append(array)
fields.append(field)
return pa.StructArray.from_arrays(
arrays=arrays, fields=fields, mask=struct_array.is_null()
)
if len(arrays) == 0:
# TODO: remove when this is fixed
# https://github.com/apache/arrow/issues/15109
return pa.StructArray.from_arrays(
arrays=[pa.nulls(len(struct_array))],
fields=[pa.field("null", pa.null())],
mask=struct_array.is_null(),
).cast(pa.struct([]))
else:
return pa.StructArray.from_arrays(
arrays=arrays, fields=fields, mask=struct_array.is_null()
)


def cast_table(
Expand Down
33 changes: 28 additions & 5 deletions protarrow/proto_to_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def __iter__(self) -> Iterator[Any]:
for child in parent:
yield child

def __len__(self) -> int:
return sum(len(i) for i in self.parents if i)


@dataclasses.dataclass(frozen=True)
class NestedIterable(collections.abc.Iterable):
Expand All @@ -153,6 +156,9 @@ def __iter__(self) -> Iterator[Optional[Any]]:
else:
yield None

def __len__(self) -> int:
return len(self.parents)


@dataclasses.dataclass(frozen=True)
class NestedMessageGetter:
Expand Down Expand Up @@ -186,6 +192,9 @@ def __iter__(self) -> Iterator[Any]:
for value in scalar_map.values():
yield value

def __len__(self) -> int:
return sum(len(i) for i in self.scalar_map if i)


def is_map(field_descriptor: FieldDescriptor) -> bool:
return (
Expand Down Expand Up @@ -502,11 +511,25 @@ def _messages_to_array(
nullable=_proto_field_nullable(field_descriptor, config),
)
)
return pa.StructArray.from_arrays(
arrays=arrays,
fields=fields,
mask=pc.invert(pa.array(validity_mask, pa.bool_())) if validity_mask else None,
)
if len(arrays) == 0:
# TODO: remove when this is fixed
# https://github.com/apache/arrow/issues/15109
size = len(validity_mask) if validity_mask else len(messages)
return pa.StructArray.from_arrays(
arrays=[pa.nulls(size, pa.null())],
fields=[pa.field("null", pa.null())],
mask=pc.invert(pa.array(validity_mask, pa.bool_()))
if validity_mask
else None,
).cast(pa.struct([]))
else:
return pa.StructArray.from_arrays(
arrays=arrays,
fields=fields,
mask=pc.invert(pa.array(validity_mask, pa.bool_()))
if validity_mask
else None,
)


def messages_to_record_batch(
Expand Down
7 changes: 0 additions & 7 deletions protos/bench.proto
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,3 @@ message NestedExampleMessage {
map<string, ExampleMessage> example_message_string_map = 5;

}

message Foo {
google.protobuf.Empty empty_value = 29;
//map<string, google.protobuf.Empty> empty_string_map = 145;
repeated google.protobuf.Empty empty_values = 58;

}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pytest = ">=7.2.0"
pytest-benchmark = ">=4.0.0"
Jinja2 = ">=3.1.2"
inflection = ">=0.5.1"
pytest-xdist = ">=3.1.0"


[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]
Expand Down
63 changes: 20 additions & 43 deletions tests/data/ExampleMessage.jsonl

Large diffs are not rendered by default.

63 changes: 20 additions & 43 deletions tests/data/NestedExampleMessage.jsonl

Large diffs are not rendered by default.

15 changes: 6 additions & 9 deletions tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
messages_to_record_batch,
messages_to_table,
)
from protarrow_protos.bench_pb2 import ExampleMessage, Foo, NestedExampleMessage
from protarrow_protos.bench_pb2 import ExampleMessage, NestedExampleMessage
from tests.random_generator import generate_messages, random_date, truncate_nanos

MESSAGES = [ExampleMessage, NestedExampleMessage]
Expand Down Expand Up @@ -500,22 +500,19 @@ def test_extractor_null_values(message_type: Type[Message], config: ProtarrowCon
assert messages == [message_type()] * len(table)


def test_empty_struct_array():
pa.StructArray.from_arrays([], fields=[], mask=None)


def test_empty():
source_messages = [
Foo(empty_value=Empty()),
Foo(),
ExampleMessage(empty_value=Empty()),
ExampleMessage(),
]

table = messages_to_table(source_messages, Foo, ProtarrowConfig())
messages_back = table_to_messages(table, Foo)
table = messages_to_table(source_messages, ExampleMessage, ProtarrowConfig())
messages_back = table_to_messages(table, ExampleMessage)
_check_messages_same(source_messages, messages_back)


def test_empty_struct_not_possible():
# See https://github.com/apache/arrow/issues/15109
array = pa.StructArray.from_arrays(arrays=[], names=[], mask=pa.array([True, True]))
assert array.type == pa.struct([])
assert len(array) == 0
Expand Down

0 comments on commit f033a33

Please sign in to comment.