From dae4c41614eb9295451e22eccb6b02427dadab7b Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Wed, 9 Nov 2022 10:31:08 +0100 Subject: [PATCH] Fix `None` support in structural types --- tests/integration/test_types_integration.py | 7 ++- trino/client.py | 49 ++++++++++++++------- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 101c347b..cdb2534c 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -156,6 +156,8 @@ def test_array(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \ .add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \ + .add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \ + .add_field(sql="ARRAY[CAST(4.9E-324 AS DOUBLE), null]", python=[5e-324, None]) \ .execute() @@ -163,13 +165,16 @@ def test_map(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) \ .add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[1, null])", python={'a': 1, 'b': None}) \ + .add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[2.4, null])", python={'a': Decimal("2.4"), 'b': None}) \ + .add_field(sql="MAP(ARRAY[2.4, 4.8], ARRAY[CAST(4.9E-324 AS DOUBLE), null])", + python={Decimal("2.4"): 5e-324, Decimal("4.8"): None}) \ .execute() def test_row(trino_connection): SqlTest(trino_connection) \ .add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None) \ - .add_field(sql="CAST(ROW(1, 2e0) AS ROW(x BIGINT, y DOUBLE))", python=(1, 2.0)) \ + .add_field(sql="CAST(ROW(1, 2e0, null) AS ROW(x BIGINT, y DOUBLE, z DOUBLE))", python=(1, 2.0, None)) \ .execute() diff --git a/trino/client.py b/trino/client.py index 079f7f87..b1c85e63 100644 --- a/trino/client.py +++ b/trino/client.py @@ -840,7 +840,7 @@ def decorated(*args, **kwargs): class ValueMapper(abc.ABC, Generic[T]): @abc.abstractmethod - def map(self, value: Any) -> T: + def map(self, value: Any) -> Optional[T]: pass @@ -850,12 +850,16 @@ def map(self, value) -> Optional[Any]: class DecimalValueMapper(ValueMapper[Decimal]): - def map(self, value) -> Decimal: + def map(self, value) -> Optional[Decimal]: + if value is None: + return None return Decimal(value) class DoubleValueMapper(ValueMapper[float]): - def map(self, value) -> float: + def map(self, value) -> Optional[float]: + if value is None: + return None if value == 'Infinity': return float("inf") if value == '-Infinity': @@ -886,14 +890,18 @@ def __init__(self, column): self.pattern = pattern self.time_size = 9 + ms_size - ms_to_trim - def map(self, value) -> time: + def map(self, value) -> Optional[time]: + if value is None: + return None return datetime.strptime(value[:self.time_size], self.pattern).time() class TimeWithTimeZoneValueMapper(TimeValueMapper): PATTERN = r'^(.*)([\+\-])(\d{2}):(\d{2})$' - def map(self, value) -> time: + def map(self, value) -> Optional[time]: + if value is None: + return None matches = re.match(TimeWithTimeZoneValueMapper.PATTERN, value) assert matches is not None assert len(matches.groups()) == 4 @@ -905,7 +913,9 @@ def map(self, value) -> time: class DateValueMapper(ValueMapper[date]): - def map(self, value) -> date: + def map(self, value) -> Optional[date]: + if value is None: + return None return datetime.strptime(value, '%Y-%m-%d').date() @@ -920,12 +930,16 @@ def __init__(self, column): self.dt_size = datetime_default_size + ms_size - ms_to_trim self.dt_tz_offset = datetime_default_size + ms_size - def map(self, value) -> datetime: + def map(self, value) -> Optional[datetime]: + if value is None: + return None return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern) class TimestampWithTimeZoneValueMapper(TimestampValueMapper): - def map(self, value) -> datetime: + def map(self, value) -> Optional[datetime]: + if value is None: + return None dt, tz = value.rsplit(' ', 1) if tz.startswith('+') or tz.startswith('-'): return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern + ' %z') @@ -933,11 +947,13 @@ def map(self, value) -> datetime: return datetime.strptime(date_str, self.pattern).replace(tzinfo=pytz.timezone(tz)) -class ArrayValueMapper(ValueMapper[List[Any]]): +class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): def __init__(self, mapper: ValueMapper[Any]): self.mapper = mapper - def map(self, values: List[Any]) -> List[Any]: + def map(self, values: List[Any]) -> Optional[List[Any]]: + if values is None: + return None return [self.mapper.map(value) for value in values] @@ -945,16 +961,20 @@ class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): def __init__(self, mappers: List[ValueMapper[Any]]): self.mappers = mappers - def map(self, values: List[Any]) -> Tuple[Any, ...]: + def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: + if values is None: + return None return tuple(self.mappers[index].map(value) for index, value in enumerate(values)) -class MapValueMapper(ValueMapper[Dict[Any, Any]]): +class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): self.key_mapper = key_mapper self.value_mapper = value_mapper - def map(self, values: Any) -> Dict[Any, Any]: + def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]: + if values is None: + return None return { self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items() } @@ -1032,9 +1052,6 @@ def _map_row(self, row): return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)] def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]: - if value is None: - return None - try: return value_mapper.map(value) except ValueError as e: