Skip to content

Commit

Permalink
Refactor value mappers to separate classes
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and hashhar committed Nov 10, 2022
1 parent a975e5d commit aca68a0
Showing 1 changed file with 150 additions and 103 deletions.
253 changes: 150 additions & 103 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
>> query = TrinoQuery(request, sql)
>> rows = list(query.execute())
"""

import abc
import copy
import functools
import os
import random
import re
import threading
import urllib.parse
from datetime import datetime, timedelta, timezone
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from time import sleep
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import pytz
import requests
Expand All @@ -64,9 +64,7 @@

_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')

INF = float("inf")
NEGATIVE_INF = float("-inf")
NAN = float("nan")
T = TypeVar("T")


class ClientSession(object):
Expand Down Expand Up @@ -840,6 +838,128 @@ def decorated(*args, **kwargs):
return wrapper


class ValueMapper(abc.ABC, Generic[T]):
@abc.abstractmethod
def map(self, value: Any) -> T:
pass


class NoOpValueMapper(ValueMapper[Any]):
def map(self, value) -> Optional[Any]:
return value


class DecimalValueMapper(ValueMapper[Decimal]):
def map(self, value) -> Decimal:
return Decimal(value)


class DoubleValueMapper(ValueMapper[float]):
def map(self, value) -> float:
if value == 'Infinity':
return float("inf")
if value == '-Infinity':
return float("-inf")
if value == 'NaN':
return float("nan")
return float(value)


class TemporalValueMapper():
def _get_number_of_digits(self, column):
args = column['arguments']
if len(args) == 0:
return 3, 0
ms_size = args[0]['value']
if ms_size == 0:
return -1, 0
ms_to_trim = ms_size - min(ms_size, 6)
return ms_size, ms_to_trim


class TimeValueMapper(ValueMapper[time], TemporalValueMapper):
def __init__(self, column):
pattern = "%H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
pattern += ".%f"
self.pattern = pattern
self.time_size = 9 + ms_size - ms_to_trim

def map(self, value) -> time:
return datetime.strptime(value[:self.time_size], self.pattern).time()


class TimeWithTimeZoneValueMapper(TimeValueMapper):
PATTERN = r'^(.*)([\+\-])(\d{2}):(\d{2})$'

def map(self, value) -> time:
matches = re.match(TimeWithTimeZoneValueMapper.PATTERN, value)
assert matches is not None
assert len(matches.groups()) == 4
if matches.group(2) == '-':
tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
else:
tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
return datetime.strptime(matches.group(1)[:self.time_size], self.pattern).time().replace(tzinfo=timezone(tz))


class DateValueMapper(ValueMapper[date]):
def map(self, value) -> date:
return datetime.strptime(value, '%Y-%m-%d').date()


class TimestampValueMapper(ValueMapper[datetime], TemporalValueMapper):
def __init__(self, column):
datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
pattern = "%Y-%m-%d %H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
pattern += ".%f"
self.pattern = pattern
self.dt_size = datetime_default_size + ms_size - ms_to_trim
self.dt_tz_offset = datetime_default_size + ms_size

def map(self, value) -> datetime:
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern)


class TimestampWithTimeZoneValueMapper(TimestampValueMapper):
def map(self, value) -> datetime:
dt, tz = value.rsplit(' ', 1)
if tz.startswith('+') or tz.startswith('-'):
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern + ' %z')
date_str = dt[:self.dt_size] + dt[self.dt_tz_offset:]
return datetime.strptime(date_str, self.pattern).replace(tzinfo=pytz.timezone(tz))


class ArrayValueMapper(ValueMapper[List[Any]]):
def __init__(self, mapper: ValueMapper[Any]):
self.mapper = mapper

def map(self, values: List[Any]) -> List[Any]:
return [self.mapper.map(value) for value in values]


class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
def __init__(self, mappers: List[ValueMapper[Any]]):
self.mappers = mappers

def map(self, values: List[Any]) -> Tuple[Any, ...]:
return tuple(self.mappers[index].map(value) for index, value in enumerate(values))


class MapValueMapper(ValueMapper[Dict[Any, Any]]):
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
self.key_mapper = key_mapper
self.value_mapper = value_mapper

def map(self, values: Any) -> Dict[Any, Any]:
return {
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
}


class NoOpRowMapper:
"""
No-op RowMapper which does not perform any transformation
Expand All @@ -856,117 +976,44 @@ class RowMapperFactory:
lambda functions (one for each column) which will process a data value
and returns a RowMapper instance which will process rows of data
"""
no_op_row_mapper = NoOpRowMapper()
NO_OP_ROW_MAPPER = NoOpRowMapper()

def create(self, columns, experimental_python_types):
assert columns is not None

if experimental_python_types:
return RowMapper([self._col_func(column['typeSignature']) for column in columns])
return RowMapperFactory.no_op_row_mapper
return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns])
return RowMapperFactory.NO_OP_ROW_MAPPER

def _col_func(self, column):
def _create_value_mapper(self, column) -> ValueMapper:
col_type = column['rawType']

if col_type == 'array':
return self._array_map_func(column)
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
return ArrayValueMapper(value_mapper)
elif col_type == 'row':
return self._row_map_func(column)
mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']]
return RowValueMapper(mappers)
elif col_type == 'map':
return self._map_map_func(column)
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
return MapValueMapper(key_mapper, value_mapper)
elif col_type.startswith('decimal'):
return lambda val: Decimal(val)
return DecimalValueMapper()
elif col_type.startswith('double') or col_type.startswith('real'):
return self._double_map_func()
return DoubleValueMapper()
elif col_type.startswith('timestamp') and 'with time zone' in col_type:
return TimestampWithTimeZoneValueMapper(column)
elif col_type.startswith('timestamp'):
return self._timestamp_map_func(column, col_type)
return TimestampValueMapper(column)
elif col_type.startswith('time') and 'with time zone' in col_type:
return TimeWithTimeZoneValueMapper(column)
elif col_type.startswith('time'):
return self._time_map_func(column, col_type)
return TimeValueMapper(column)
elif col_type == 'date':
return lambda val: datetime.strptime(val, '%Y-%m-%d').date()
else:
return lambda val: val

def _array_map_func(self, column):
element_mapping_func = self._col_func(column['arguments'][0]['value'])
return lambda values: [element_mapping_func(value) for value in values]

def _row_map_func(self, column):
element_mapping_func = [self._col_func(arg['value']['typeSignature']) for arg in column['arguments']]
return lambda values: tuple(element_mapping_func[idx](value) for idx, value in enumerate(values))

def _map_map_func(self, column):
key_mapping_func = self._col_func(column['arguments'][0]['value'])
value_mapping_func = self._col_func(column['arguments'][1]['value'])
return lambda values: {key_mapping_func(key): value_mapping_func(value) for key, value in values.items()}

def _double_map_func(self):
return lambda val: INF if val == 'Infinity' \
else NEGATIVE_INF if val == '-Infinity' \
else NAN if val == 'NaN' \
else float(val)

def _timestamp_map_func(self, column, col_type):
datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
pattern = "%Y-%m-%d %H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
pattern += ".%f"

dt_size = datetime_default_size + ms_size - ms_to_trim
dt_tz_offset = datetime_default_size + ms_size
if 'with time zone' in col_type:

if ms_to_trim > 0:
return lambda val: \
[datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern + ' %z')
if tz.startswith('+') or tz.startswith('-')
else datetime.strptime(dt[:dt_size] + dt[dt_tz_offset:], pattern)
.replace(tzinfo=pytz.timezone(tz))
for dt, tz in [val.rsplit(' ', 1)]][0]
else:
return lambda val: [datetime.strptime(val, pattern + ' %z')
if tz.startswith('+') or tz.startswith('-')
else datetime.strptime(dt, pattern).replace(tzinfo=pytz.timezone(tz))
for dt, tz in [val.rsplit(' ', 1)]][0]

if ms_to_trim > 0:
return lambda val: datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern)
return DateValueMapper()
else:
return lambda val: datetime.strptime(val, pattern)

def _time_map_func(self, column, col_type):
pattern = "%H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
pattern += ".%f"

time_size = 9 + ms_size - ms_to_trim

if 'with time zone' in col_type:
return lambda val: self._get_time_with_timezome(val, time_size, pattern)
else:
return lambda val: datetime.strptime(val[:time_size], pattern).time()

def _get_time_with_timezome(self, value, time_size, pattern):
matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value)
assert matches is not None
assert len(matches.groups()) == 4
if matches.group(2) == '-':
tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
else:
tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
return datetime.strptime(matches.group(1)[:time_size], pattern).time().replace(tzinfo=timezone(tz))

def _get_number_of_digits(self, column):
args = column['arguments']
if len(args) == 0:
return 3, 0
ms_size = column['arguments'][0]['value']
if ms_size == 0:
return -1, 0
ms_to_trim = ms_size - min(ms_size, 6)
return ms_size, ms_to_trim
return NoOpValueMapper()


class RowMapper:
Expand All @@ -982,14 +1029,14 @@ def map(self, rows):
return [self._map_row(row) for row in rows]

def _map_row(self, row):
return [self._map_value(value, self.columns[idx]) for idx, value in enumerate(row)]
return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)]

def _map_value(self, value, col_mapping_func):
def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]:
if value is None:
return None

try:
return col_mapping_func(value)
return value_mapper.map(value)
except ValueError as e:
error_str = f"Could not convert '{value}' into the associated python type"
raise trino.exceptions.TrinoDataError(error_str) from e

0 comments on commit aca68a0

Please sign in to comment.