Skip to content

Commit

Permalink
Fix None support in structural types
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and hashhar committed Nov 10, 2022
1 parent aca68a0 commit dae4c41
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
7 changes: 6 additions & 1 deletion tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,25 @@ def test_array(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \
.add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \
.add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \
.add_field(sql="ARRAY[CAST(4.9E-324 AS DOUBLE), null]", python=[5e-324, None]) \
.execute()


def test_map(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) \
.add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[1, null])", python={'a': 1, 'b': None}) \
.add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[2.4, null])", python={'a': Decimal("2.4"), 'b': None}) \
.add_field(sql="MAP(ARRAY[2.4, 4.8], ARRAY[CAST(4.9E-324 AS DOUBLE), null])",
python={Decimal("2.4"): 5e-324, Decimal("4.8"): None}) \
.execute()


def test_row(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None) \
.add_field(sql="CAST(ROW(1, 2e0) AS ROW(x BIGINT, y DOUBLE))", python=(1, 2.0)) \
.add_field(sql="CAST(ROW(1, 2e0, null) AS ROW(x BIGINT, y DOUBLE, z DOUBLE))", python=(1, 2.0, None)) \
.execute()


Expand Down
49 changes: 33 additions & 16 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def decorated(*args, **kwargs):

class ValueMapper(abc.ABC, Generic[T]):
@abc.abstractmethod
def map(self, value: Any) -> T:
def map(self, value: Any) -> Optional[T]:
pass


Expand All @@ -850,12 +850,16 @@ def map(self, value) -> Optional[Any]:


class DecimalValueMapper(ValueMapper[Decimal]):
def map(self, value) -> Decimal:
def map(self, value) -> Optional[Decimal]:
if value is None:
return None
return Decimal(value)


class DoubleValueMapper(ValueMapper[float]):
def map(self, value) -> float:
def map(self, value) -> Optional[float]:
if value is None:
return None
if value == 'Infinity':
return float("inf")
if value == '-Infinity':
Expand Down Expand Up @@ -886,14 +890,18 @@ def __init__(self, column):
self.pattern = pattern
self.time_size = 9 + ms_size - ms_to_trim

def map(self, value) -> time:
def map(self, value) -> Optional[time]:
if value is None:
return None
return datetime.strptime(value[:self.time_size], self.pattern).time()


class TimeWithTimeZoneValueMapper(TimeValueMapper):
PATTERN = r'^(.*)([\+\-])(\d{2}):(\d{2})$'

def map(self, value) -> time:
def map(self, value) -> Optional[time]:
if value is None:
return None
matches = re.match(TimeWithTimeZoneValueMapper.PATTERN, value)
assert matches is not None
assert len(matches.groups()) == 4
Expand All @@ -905,7 +913,9 @@ def map(self, value) -> time:


class DateValueMapper(ValueMapper[date]):
def map(self, value) -> date:
def map(self, value) -> Optional[date]:
if value is None:
return None
return datetime.strptime(value, '%Y-%m-%d').date()


Expand All @@ -920,41 +930,51 @@ def __init__(self, column):
self.dt_size = datetime_default_size + ms_size - ms_to_trim
self.dt_tz_offset = datetime_default_size + ms_size

def map(self, value) -> datetime:
def map(self, value) -> Optional[datetime]:
if value is None:
return None
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern)


class TimestampWithTimeZoneValueMapper(TimestampValueMapper):
def map(self, value) -> datetime:
def map(self, value) -> Optional[datetime]:
if value is None:
return None
dt, tz = value.rsplit(' ', 1)
if tz.startswith('+') or tz.startswith('-'):
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern + ' %z')
date_str = dt[:self.dt_size] + dt[self.dt_tz_offset:]
return datetime.strptime(date_str, self.pattern).replace(tzinfo=pytz.timezone(tz))


class ArrayValueMapper(ValueMapper[List[Any]]):
class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
def __init__(self, mapper: ValueMapper[Any]):
self.mapper = mapper

def map(self, values: List[Any]) -> List[Any]:
def map(self, values: List[Any]) -> Optional[List[Any]]:
if values is None:
return None
return [self.mapper.map(value) for value in values]


class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
def __init__(self, mappers: List[ValueMapper[Any]]):
self.mappers = mappers

def map(self, values: List[Any]) -> Tuple[Any, ...]:
def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
if values is None:
return None
return tuple(self.mappers[index].map(value) for index, value in enumerate(values))


class MapValueMapper(ValueMapper[Dict[Any, Any]]):
class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
self.key_mapper = key_mapper
self.value_mapper = value_mapper

def map(self, values: Any) -> Dict[Any, Any]:
def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
if values is None:
return None
return {
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
}
Expand Down Expand Up @@ -1032,9 +1052,6 @@ def _map_row(self, row):
return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)]

def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]:
if value is None:
return None

try:
return value_mapper.map(value)
except ValueError as e:
Expand Down

0 comments on commit dae4c41

Please sign in to comment.