Skip to content

Commit

Permalink
chore: allow for some duplication to make await cleaner/more obvious
Browse files Browse the repository at this point in the history
Signed-off-by: robhowley <rhowley@seatgeek.com>
  • Loading branch information
robhowley committed May 31, 2024
1 parent c4ba283 commit 8c164c3
Showing 1 changed file with 89 additions and 79 deletions.
168 changes: 89 additions & 79 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,59 +208,6 @@ def online_write_batch(
)
self._write_batch_non_duplicates(table_instance, data, progress, config)

def _read_batches(
self, online_config, entity_ids, table_name, batch_get_item
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

batch_size = online_config.batch_size
entity_ids_iter = iter(entity_ids)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))
batch_result: List[
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
] = []
# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = {
table_name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}
response = batch_get_item(
RequestItems=batch_entity_ids,
)
response = response.get("Responses")
table_responses = response.get(table_name)
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append(
(datetime.fromisoformat(tbl_res["event_ts"]), res)
)
entity_idx += 1

# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
result.extend(batch_result)
return result

def online_read(
self,
config: RepoConfig,
Expand All @@ -278,27 +225,36 @@ def online_read(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)

entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

return self._read_batches(
online_config,
entity_ids,
table_instance.name,
dynamodb_resource.batch_get_item,
)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = self._to_batch_get_payload(
online_config, table_instance.name, batch
)
response = dynamodb_resource.batch_get_item(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_instance.name, response, entity_ids, batch
)
result.extend(batch_result)
return result

async def online_read_async(
self,
Expand All @@ -324,21 +280,30 @@ async def online_read_async(
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
table_name = _get_table_name(online_config, config, table)

async with self._get_aiodynamodb_client(online_config.region) as client:
return self._read_batches(
online_config,
entity_ids,
_get_table_name(online_config, config, table),
lambda **kwargs: (await client(**kwargs)),
)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = self._to_batch_get_payload(
online_config, table_name, batch
)
response = await client.batch_get_item(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_name, response, entity_ids, batch
)
result.extend(batch_result)
return result

def _get_aioboto_session(self):
if self._aioboto_session is None:
Expand Down Expand Up @@ -403,6 +368,51 @@ def _write_batch_non_duplicates(
if progress:
progress(1)

def _process_batch_get_response(self, table_name, response, entity_ids, batch):
response = response.get("Responses")
table_responses = response.get(table_name)

batch_result = []
if table_responses:
table_responses = self._sort_dynamodb_response(table_responses, entity_ids)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
entity_idx += 1
# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
return batch_result

@staticmethod
def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]):
return [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]

@staticmethod
def _to_batch_get_payload(online_config, table_name, batch):
return {
table_name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}


def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
return boto3.client(
Expand Down

0 comments on commit 8c164c3

Please sign in to comment.