Skip to content

Commit

Permalink
feat(ingest): key-partitioning for rest emitter (#9613)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Jan 11, 2024
1 parent 7a78824 commit f05056a
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 52 deletions.
86 changes: 34 additions & 52 deletions metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
import contextlib
import functools
import logging
from concurrent.futures import ThreadPoolExecutor
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import auto
from threading import BoundedSemaphore
from typing import Tuple, Union
from typing import Optional, Union

from datahub.cli.cli_utils import set_env_variables_override_config
from datahub.configuration.common import (
Expand All @@ -25,6 +23,7 @@
MetadataChangeEvent,
MetadataChangeProposal,
)
from datahub.utilities.advanced_thread_executor import PartitionExecutor
from datahub.utilities.server_config_util import set_gms_config

logger = logging.getLogger(__name__)
Expand All @@ -40,7 +39,7 @@ class DatahubRestSinkConfig(DatahubClientConfig):

# These only apply in async mode.
max_threads: int = 15
max_pending_requests: int = 1000
max_pending_requests: int = 500


@dataclass
Expand All @@ -51,39 +50,26 @@ class DataHubRestSinkReport(SinkReport):
def compute_stats(self) -> None:
super().compute_stats()

def report_write_latency(self, delta: timedelta) -> None:
pass

def _get_urn(record_envelope: RecordEnvelope) -> Optional[str]:
metadata = record_envelope.record

class BoundedExecutor:
"""BoundedExecutor behaves as a ThreadPoolExecutor which will block on
calls to submit() once the limit given as "bound" work items are queued for
execution.
:param bound: Integer - the maximum number of items in the work queue
:param max_workers: Integer - the size of the thread pool
"""
if isinstance(metadata, MetadataChangeEvent):
return metadata.proposedSnapshot.urn
elif isinstance(metadata, (MetadataChangeProposalWrapper, MetadataChangeProposal)):
return metadata.entityUrn

def __init__(self, bound, max_workers):
self.executor = ThreadPoolExecutor(max_workers)
self.semaphore = BoundedSemaphore(bound + max_workers)
return None

"""See concurrent.futures.Executor#submit"""

def submit(self, fn, *args, **kwargs):
self.semaphore.acquire()
try:
future = self.executor.submit(fn, *args, **kwargs)
except Exception:
self.semaphore.release()
raise
else:
future.add_done_callback(lambda x: self.semaphore.release())
return future
def _get_partition_key(record_envelope: RecordEnvelope) -> str:
urn = _get_urn(record_envelope)
if urn:
return urn

"""See concurrent.futures.Executor#shutdown"""

def shutdown(self, wait=True):
self.executor.shutdown(wait)
# This shouldn't happen super frequently, but just adding a fallback of generating
# a UUID so that we don't do any partitioning.
return str(uuid.uuid4())


class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
Expand Down Expand Up @@ -120,9 +106,9 @@ def __post_init__(self) -> None:
set_env_variables_override_config(self.config.server, self.config.token)
logger.debug("Setting gms config")
set_gms_config(gms_config)
self.executor = BoundedExecutor(
self.executor = PartitionExecutor(
max_workers=self.config.max_threads,
bound=self.config.max_pending_requests,
max_pending=self.config.max_pending_requests,
)

def handle_work_unit_start(self, workunit: WorkUnit) -> None:
Expand All @@ -147,9 +133,7 @@ def _write_done_callback(
elif future.done():
e = future.exception()
if not e:
start_time, end_time = future.result()
self.report.report_record_written(record_envelope)
self.report.report_write_latency(end_time - start_time)
write_callback.on_success(record_envelope, {})
elif isinstance(e, OperationalError):
# only OperationalErrors should be ignored
Expand All @@ -164,13 +148,9 @@ def _write_done_callback(
]

# Include information about the entity that failed.
record = record_envelope.record
if isinstance(record, MetadataChangeProposalWrapper):
entity_id = record.entityUrn
e.info["id"] = entity_id
elif isinstance(record, MetadataChangeEvent):
entity_id = record.proposedSnapshot.urn
e.info["id"] = entity_id
record_urn = _get_urn(record_envelope)
if record_urn:
e.info["urn"] = record_urn

if not self.treat_errors_as_warnings:
self.report.report_failure({"error": e.message, "info": e.info})
Expand All @@ -188,10 +168,9 @@ def _emit_wrapper(
MetadataChangeProposal,
MetadataChangeProposalWrapper,
],
) -> Tuple[datetime, datetime]:
start_time = datetime.now()
) -> None:
# TODO: Add timing metrics
self.emitter.emit(record)
return start_time, datetime.now()

def write_record_async(
self,
Expand All @@ -206,23 +185,26 @@ def write_record_async(
) -> None:
record = record_envelope.record
if self.config.mode == SyncOrAsync.ASYNC:
write_future = self.executor.submit(self._emit_wrapper, record)
write_future.add_done_callback(
functools.partial(
partition_key = _get_partition_key(record_envelope)
self.executor.submit(
partition_key,
self._emit_wrapper,
record,
done_callback=functools.partial(
self._write_done_callback, record_envelope, write_callback
)
),
)
self.report.pending_requests += 1
else:
# execute synchronously
try:
(start, end) = self._emit_wrapper(record)
self._emit_wrapper(record)
write_callback.on_success(record_envelope, success_metadata={})
except Exception as e:
write_callback.on_failure(record_envelope, e, failure_metadata={})

def close(self):
self.executor.shutdown(wait=True)
self.executor.shutdown()

def __repr__(self) -> str:
return self.emitter.__repr__()
Expand Down
132 changes: 132 additions & 0 deletions metadata-ingestion/src/datahub/utilities/advanced_thread_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import collections
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import BoundedSemaphore
from typing import Any, Callable, Deque, Dict, Optional, Tuple, TypeVar

from datahub.ingestion.api.closeable import Closeable

_R = TypeVar("_R")
_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL = 0.05


class PartitionExecutor(Closeable):
def __init__(self, max_workers: int, max_pending: int) -> None:
"""A thread pool executor with partitioning and a pending request bound.
It works similarly to a ThreadPoolExecutor, with the following changes:
- At most one request per partition key will be executing at a time.
- If the number of pending requests exceeds the threshold, the submit call
will block until the number of pending requests drops below the threshold.
Due to the interaction between max_workers and max_pending, it is possible
for execution to effectively be serialized when there's a large influx of
requests with the same key. This can be mitigated by setting a reasonably
large max_pending value.
Args:
max_workers: The maximum number of threads to use for executing requests.
max_pending: The maximum number of pending (e.g. non-executing) requests to allow.
"""
self.max_workers = max_workers
self.max_pending = max_pending

self._executor = ThreadPoolExecutor(max_workers=max_workers)

# Each pending or executing request will acquire a permit from this semaphore.
self._semaphore = BoundedSemaphore(max_pending + max_workers)

# A key existing in this dict means that there is a submitted request for that key.
# Any entries in the key's value e.g. the deque are requests that are waiting
# to be submitted once the current request for that key completes.
self._pending_by_key: Dict[
str, Deque[Tuple[Callable, tuple, dict, Optional[Callable[[Future], None]]]]
] = {}

def submit(
self,
key: str,
fn: Callable[..., _R],
*args: Any,
# Ideally, we would've used ParamSpec to annotate this method. However,
# due to the limitations of PEP 612, we can't add a keyword argument here.
# See https://peps.python.org/pep-0612/#concatenating-keyword-parameters
# As such, we're using Any here, and won't validate the args to this method.
# We might be able to work around it by moving the done_callback arg to be before
# the *args, but that would mean making done_callback a required arg instead of
# optional as it is now.
done_callback: Optional[Callable[[Future], None]] = None,
**kwargs: Any,
) -> None:
"""See concurrent.futures.Executor#submit"""

self._semaphore.acquire()

if key in self._pending_by_key:
self._pending_by_key[key].append((fn, args, kwargs, done_callback))

else:
self._pending_by_key[key] = collections.deque()
self._submit_nowait(key, fn, args, kwargs, done_callback=done_callback)

def _submit_nowait(
self,
key: str,
fn: Callable[..., _R],
args: tuple,
kwargs: dict,
done_callback: Optional[Callable[[Future], None]],
) -> Future:
future = self._executor.submit(fn, *args, **kwargs)

def _system_done_callback(future: Future) -> None:
self._semaphore.release()

# If there is another pending request for this key, submit it now.
# The key must exist in the map.
if self._pending_by_key[key]:
fn, args, kwargs, user_done_callback = self._pending_by_key[
key
].popleft()
self._submit_nowait(key, fn, args, kwargs, user_done_callback)

else:
# If there are no pending requests for this key, mark the key
# as no longer in progress.
del self._pending_by_key[key]

if done_callback:
future.add_done_callback(done_callback)
future.add_done_callback(_system_done_callback)
return future

def flush(self) -> None:
"""Wait for all pending requests to complete."""

# Acquire all the semaphore permits so that no more requests can be submitted.
for _i in range(self.max_pending):
self._semaphore.acquire()

# Now, wait for all the pending requests to complete.
while len(self._pending_by_key) > 0:
# TODO: There should be a better way to wait for all executor threads to be idle.
# One option would be to just shutdown the existing executor and create a new one.
time.sleep(_PARTITION_EXECUTOR_FLUSH_SLEEP_INTERVAL)

# Now allow new requests to be submitted.
# TODO: With Python 3.9, release() can take a count argument.
for _i in range(self.max_pending):
self._semaphore.release()

def shutdown(self) -> None:
"""See concurrent.futures.Executor#shutdown. Behaves as if wait=True."""

self.flush()
assert len(self._pending_by_key) == 0

# Technically, the wait=True here is redundant, since all the threads should
# be idle now.
self._executor.shutdown(wait=True)

def close(self) -> None:
self.shutdown()
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import time
from concurrent.futures import Future

from datahub.utilities.advanced_thread_executor import PartitionExecutor
from datahub.utilities.perf_timer import PerfTimer


def test_partitioned_executor():
executing_tasks = set()
done_tasks = set()

def task(key: str, id: str) -> None:
executing_tasks.add((key, id))
time.sleep(0.8)
done_tasks.add(id)
executing_tasks.remove((key, id))

with PartitionExecutor(max_workers=2, max_pending=10) as executor:
# Submit tasks with the same key. They should be executed sequentially.
executor.submit("key1", task, "key1", "task1")
executor.submit("key1", task, "key1", "task2")
executor.submit("key1", task, "key1", "task3")

# Submit a task with a different key. It should be executed in parallel.
executor.submit("key2", task, "key2", "task4")

saw_keys_in_parallel = False
while executing_tasks or not done_tasks:
keys_executing = [key for key, _ in executing_tasks]
assert list(sorted(keys_executing)) == list(
sorted(set(keys_executing))
), "partitioning not working"

if len(keys_executing) == 2:
saw_keys_in_parallel = True

time.sleep(0.1)

executor.flush()
assert saw_keys_in_parallel
assert not executing_tasks
assert done_tasks == {"task1", "task2", "task3", "task4"}


def test_partitioned_executor_bounding():
task_duration = 0.5
done_tasks = set()

def on_done(future: Future) -> None:
done_tasks.add(future.result())

def task(id: str) -> str:
time.sleep(task_duration)
return id

with PartitionExecutor(
max_workers=5, max_pending=10
) as executor, PerfTimer() as timer:
# The first 15 submits should be non-blocking.
for i in range(15):
executor.submit(f"key{i}", task, f"task{i}", done_callback=on_done)
assert timer.elapsed_seconds() < task_duration

# This submit should block.
executor.submit("key-blocking", task, "task-blocking", done_callback=on_done)
assert timer.elapsed_seconds() > task_duration

# Wait for everything to finish.
executor.flush()
assert len(done_tasks) == 16

0 comments on commit f05056a

Please sign in to comment.