Skip to content

Commit

Permalink
Fix handling of default values for null nested messages
Browse files Browse the repository at this point in the history
  • Loading branch information
aandres committed Jul 13, 2023
1 parent d092d3d commit 3a1d70a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
25 changes: 15 additions & 10 deletions protarrow/arrow_to_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def __iter__(self) -> Iterator[Callable[[pa.Scalar], None]]:

def __call__(self, scalar: pa.Scalar) -> None:
value = self.converter(scalar) if scalar.is_valid else None
if value is not None:
# `self.message` can be null for nested messages.
# We'd expect the scalar to be either null or the default value in this case
if self.message is not None and value is not None:
if self.nullable:
getattr(self.message, self.field_descriptor.name).value = value
else:
Expand All @@ -284,9 +286,10 @@ def __init__(
def __iter__(self) -> Iterator[Callable[[pa.Scalar], None]]:
assert self.attribute is None
for message, size in zip(self.messages, self.sizes):
self.attribute = getattr(message, self.field_descriptor.name)
for _ in range(size):
yield self
if message is not None:
self.attribute = getattr(message, self.field_descriptor.name)
for _ in range(size):
yield self
self.attribute = None

def __call__(self, scalar: pa.Scalar) -> None:
Expand All @@ -312,9 +315,10 @@ def __post_init__(self, key_arrow_type: pa.DataType):
def __iter__(self) -> Iterator[Callable[[pa.Scalar], Message]]:
assert self.attribute is None
for message, offset in zip(self.messages, self.sizes):
self.attribute = getattr(message, self.field_descriptor.name)
for _ in range(offset):
yield self
if message is not None:
self.attribute = getattr(message, self.field_descriptor.name)
for _ in range(offset):
yield self
self.attribute = None

def __call__(self, scalar: pa.Scalar) -> Message:
Expand Down Expand Up @@ -363,9 +367,10 @@ def __post_init__(self, key_arrow_type: pa.DataType, value_arrow_type: pa.DataTy
def __iter__(self) -> Iterator[Callable[[pa.Scalar, pa.Scalar], Message]]:
assert self.attribute is None
for message, size in zip(self.messages, self.sizes):
self.attribute = getattr(message, self.field_descriptor.name)
for _ in range(size):
yield self
if message is not None:
self.attribute = getattr(message, self.field_descriptor.name)
for _ in range(size):
yield self
self.attribute = None

def __call__(self, key: pa.Scalar, value: pa.Scalar):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,21 @@ def test_missing_enum_from_proto():
assert protarrow.messages_to_table(
[message], ExampleMessage, protarrow.ProtarrowConfig(enum_type=pa.binary())
)["example_enum_value"].to_pylist() == [b"UNKNOWN_EXAMPLE_ENUM"]


def test_nested_missing_values():
"""
Check for cases where nested messages are null.
When a nested message is null, the arrow arrays of their underlying fields
"""
source_messages = [
NestedExampleMessage(),
NestedExampleMessage(example_message=ExampleMessage(double_value=1.0)),
]
table = protarrow.messages_to_table(
source_messages,
NestedExampleMessage,
)
messages_back = protarrow.table_to_messages(table, NestedExampleMessage)
assert messages_back == source_messages

0 comments on commit 3a1d70a

Please sign in to comment.