Skip to content

Commit

Permalink
address review comment (1)
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 15, 2025
1 parent 4437108 commit d7671a7
Showing 1 changed file with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,17 @@ def get_tuple_type(
"""
Compute the `ts.TupleType` corresponding to the tuple structure of input data expressions.
"""
return ts.TupleType(
types=[
get_tuple_type(d)
if isinstance(d, tuple)
else d.get_field_type()
if isinstance(d, IteratorExpr)
else d.gt_dtype
for d in data
]
)
data_types: list[ts.DataType] = []
for dataitem in data:
if isinstance(dataitem, tuple):
data_types.append(get_tuple_type(dataitem))
elif isinstance(dataitem, IteratorExpr):
data_types.append(dataitem.get_field_type())
elif isinstance(dataitem, MemletExpr):
data_types.append(dataitem.gt_dtype)
else:
data_types.append(dataitem.gt_dtype)
return ts.TupleType(data_types)


@dataclasses.dataclass(frozen=True)
Expand Down

0 comments on commit d7671a7

Please sign in to comment.