Skip to content

Commit

Permalink
Initial approach to fixing parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
skrydal committed Jan 28, 2025
1 parent 0f538d8 commit 908f4f0
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import concurrent.futures
import logging
import multiprocessing as mp
from dataclasses import dataclass, field
from functools import partial
from math import ceil
from queue import Empty
from typing import Callable, Dict, Iterable, List, Optional, Union

from datahub_classify.helper_classes import ColumnInfo, Metadata
Expand Down Expand Up @@ -170,36 +171,75 @@ def update_field_terms(
if term:
field_terms[col_info.metadata.name] = term

@staticmethod
def _worker_process(task_queue, result_queue, classifier):
logger.debug("Starting process to handle classification")
while True:
try:
columns_batch = task_queue.get(timeout=1)
if columns_batch is None:
logger.debug("From the task queue retrieved empty batch - finishing process execution")
break
logger.debug(f"Processing batch of columns: {columns_batch}")
result = classifier.classify(columns_batch)
result_queue.put(result)
except Empty:
continue
except Exception as e:
result_queue.put(e)
break

def async_classify(
self, classifier: Classifier, columns: List[ColumnInfo]
) -> Iterable[ColumnInfo]:
num_columns = len(columns)
BATCH_SIZE = 5 # Number of columns passed to classify api at a time

task_queue = mp.Queue()
result_queue = mp.Queue()
batches_count = ceil(num_columns / BATCH_SIZE)
# if batches_count > self.config.classification.max_workers:
# return []
logger.debug(
f"Will Classify {num_columns} column(s) with {self.config.classification.max_workers} worker(s) with batch size {BATCH_SIZE}."
f"Will Classify {num_columns} column(s) with max {self.config.classification.max_workers} worker(s) with batch size {BATCH_SIZE}. Count of batches: {batches_count}"
)
# mp.set_start_method('fork')
processes = []
for _ in range(min(self.config.classification.max_workers, batches_count)):
p = mp.Process(
target=ClassificationHandler._worker_process,
args=(task_queue, result_queue, classifier),
daemon=True,
)
p.start()
processes.append(p)

with concurrent.futures.ProcessPoolExecutor(
max_workers=self.config.classification.max_workers,
) as executor:
column_info_proposal_futures = [
executor.submit(
classifier.classify,
columns[
(i * BATCH_SIZE) : min(i * BATCH_SIZE + BATCH_SIZE, num_columns)
],
)
for i in range(ceil(num_columns / BATCH_SIZE))
]
logger.debug(f"Started {len(processes)} processes to classify columns")
for i in range(batches_count):
batch = columns[(i * BATCH_SIZE):min(i * BATCH_SIZE + BATCH_SIZE, num_columns)]
task_queue.put(batch)

return [
column_with_proposal
for proposal_future in concurrent.futures.as_completed(
column_info_proposal_futures
)
for column_with_proposal in proposal_future.result()
]
logger.debug("Loaded all column batches to the tasks queue")
for _ in processes:
task_queue.put(None)

logger.debug("Loaded termination signals to the task queue")

results = []
for i in range(batches_count):
batch_result = result_queue.get()
logger.debug(f"Received batch results {i}/{batches_count}")
if isinstance(batch_result, Exception):
raise batch_result
results.extend(batch_result)

logger.debug(f"Received {len(results)} results, joining processes")

for p in processes:
p.join()

logger.debug("Processes joined, returning results")

return results

def populate_terms_in_schema_metadata(
self,
Expand Down
13 changes: 7 additions & 6 deletions metadata-ingestion/src/datahub/utilities/partition_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import atexit
import collections
import functools
import logging
Expand Down Expand Up @@ -220,10 +219,10 @@ def _shutdown_executors() -> None:
# This entire shutdown hook is largely a backstop mechanism to protect against
# improper usage of the BatchPartitionExecutor. In proper usage that uses
# a context manager or calls shutdown() explicitly, this will be a no-op.
if hasattr(threading, "_register_atexit"):
threading._register_atexit(_shutdown_executors)
else:
atexit.register(_shutdown_executors)
# if hasattr(threading, "_register_atexit"):
# threading._register_atexit(_shutdown_executors)
# else:
# atexit.register(_shutdown_executors)


class BatchPartitionExecutor(Closeable):
Expand Down Expand Up @@ -491,7 +490,9 @@ def shutdown(self) -> None:

# By acquiring all the permits, we ensure that no more tasks will be scheduled
# and automatically wait until all existing tasks have completed.
for _ in range(self.max_pending):
# logger.debug(f"Trying to acquire all the tasks before shutdown, max_pending: {self.max_pending}")
for i in range(self.max_pending):
# logger.debug(f"Acquiring _pending_count no {i}")
self._pending_count.acquire()

# We must wait for the clearinghouse worker to exit before calling shutdown
Expand Down

0 comments on commit 908f4f0

Please sign in to comment.