Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 10, 2024
1 parent e80ac8c commit 37983fe
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 53 deletions.
11 changes: 6 additions & 5 deletions edspdf/lazy_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def process_start_method(self):
def set_processing(
self,
batch_size: int = INFER,
batch_unit: Literal["docs", "pages", "lines"] = INFER,
batch_unit: Literal["doc", "page", "content_boxe"] = INFER,
chunk_size: int = INFER,
num_cpu_workers: int = INFER,
num_gpu_workers: int = INFER,
Expand All @@ -122,9 +122,10 @@ def set_processing(
batch_size: int
Number of documents to process at a time in a GPU worker (or in the
main process if no workers are used).
batch_unit: Literal["docs", "pages", "lines"]
The unit of the batch size. Can be "docs" or "words". If "words", the
batch size is total number of words in the documents.
batch_unit: Literal["doc", "page", "content_box"]
The unit of the batch size. Can be "doc", "page" or "content_box". If
"content_box", the batch size is total number of content_box in the
documents.
chunk_size: int
Number of documents to build before splitting into batches. Only used
with "simple" and "multiprocessing" backends. This is also the number of
Expand Down Expand Up @@ -211,7 +212,7 @@ def map_pipeline(self, model: Pipeline) -> "LazyCollection":
for name, pipe, kwargs in self.pipeline:
new_steps.append((name, pipe, kwargs))

new_steps.append(("_ensure_doc", model.ensure_doc, {}))
new_steps.append((None, model.ensure_doc, {}))

for name, pipe in model.pipeline:
if name not in model._disabled:
Expand Down
137 changes: 89 additions & 48 deletions edspdf/processing/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def dump(*args, **kwargs):
old = dill.Pickler.dispatch.get(AlignDevicesHook)
dill.Pickler.dispatch[AlignDevicesHook] = save_align_devices_hook
dill.settings["recurse"] = True
return torch_save(*args, pickle_module=dill, **kwargs)
return torch_save(
*args,
pickle_module=dill,
**kwargs,
)
finally:
dill.settings["recurse"] = False
if AlignDevicesHook is not None:
Expand All @@ -196,12 +200,22 @@ def load(*args, map_location=None, **kwargs):
MAP_LOCATION = map_location
if torch.__version__ >= "2.1" and isinstance(args[0], str):
kwargs["mmap"] = True
result = torch_load(
*args,
pickle_module=dill,
map_location=map_location,
**kwargs,
)
# with open(args[0], "rb") as f:
# result = dill.load(f, **kwargs)
try:
if torch.__version__ < "2.0.0":
pickle = torch_load.__globals__["pickle"]
torch_load.__globals__["pickle"] = dill
result = torch_load(
*args,
pickle_module=dill,
map_location=map_location,
**kwargs,
)
finally:
import pickle

torch_load.__globals__["pickle"] = pickle
MAP_LOCATION = None
return result

Expand Down Expand Up @@ -312,21 +326,21 @@ def _run(self):
def read_tasks():
nonlocal next_batch_id, expect_new_tasks, had_error

try:
if lc.batch_unit == "lines":
if lc.batch_unit == "content_box":

def formula(batch):
return sum(len(doc.content_boxes) for doc in batch)
def formula(batch):
return sum(len(doc.content_boxes) for doc in batch)

elif lc.batch_unit == "pages":
elif lc.batch_unit == "page":

def formula(batch):
return sum(len(doc.pages) for doc in batch)
def formula(batch):
return sum(len(doc.pages) for doc in batch)

else:
formula = len
else:
formula = len

while expect_new_tasks or len(active_batches) > 0:
while expect_new_tasks or len(active_batches) > 0:
try:
stage, task = self.exchanger.get_cpu_task(
idx=self.cpu_idx,
)
Expand All @@ -350,16 +364,16 @@ def formula(batch):
batchify(lc.reader.read_worker(fragments), lc.chunk_size)
)
for chunk_idx, docs in enumerate(chunks):
print("CHUNK", chunk_idx, "LEN", len(docs), flush=True)
# If we sort by size, we must first create the documents
# to have features against which we will sort
if do_preprocess_chunk:
for pipe, kwargs in stages[0]["cpu_components"]:
if hasattr(pipe, "batch_process"):
docs = pipe.batch_process(docs)
else:
docs: List[PDFDoc] = [
pipe(doc, **kwargs) for doc in docs
]
for pipe, kwargs in preprocess_pipes:
if hasattr(pipe, "batch_process"):
docs = pipe.batch_process(docs)
else:
docs: List[PDFDoc] = [
pipe(doc, **kwargs) for doc in docs
]

batches = [
batch
Expand All @@ -373,6 +387,7 @@ def formula(batch):
for batch_idx, batch in enumerate(batches):
assert len(batch) > 0
batch_id = next_batch_id
print("BATCH", batch_id, "LEN", len(batch), flush=True)

# We mark the task id only for the last batch of a task
# since the purpose of storing the task id is to know
Expand All @@ -393,23 +408,38 @@ def formula(batch):
yield stage, (None, batch_id, None)
else:
yield stage, task
except BaseException as e:
had_error = True
import traceback
except BaseException as e:
had_error = True
import traceback

print(traceback.format_exc(), flush=True)
self.exchanger.put_results((e, 0, self.cpu_idx, None))
print(traceback.format_exc(), flush=True)
self.exchanger.put_results((e, 0, self.cpu_idx, None))

lc: LazyCollection = load(self.lazy_collection_path, map_location=self.device)
do_preprocess_chunk = lc.batch_unit != "docs"
preprocess_pipes = []
is_before_split = True
split_into_batches_after = None
if lc.batch_unit != "docs":
split_into_batches_after = next(
(p[0] for p in lc.pipeline if p[0] is not None), None
)
is_before_split = split_into_batches_after is not None

print("split_into_batches_after", split_into_batches_after)

stages: List[Stage] = [{"cpu_components": [], "gpu_component": None}]
for name, pipe, *rest in lc.pipeline:
if name in self.gpu_pipe_names:
is_before_split = False
stages[-1]["gpu_component"] = pipe
stages.append({"cpu_components": [], "gpu_component": None})
else:
stages[-1]["cpu_components"].append((pipe, *rest))
if is_before_split:
preprocess_pipes.append((pipe, *rest))
else:
stages[-1]["cpu_components"].append((pipe, *rest))
if name is split_into_batches_after:
is_before_split = False

# Start at cpu_idx to avoid having all workers sending their
# first batch (0 % num_device, cf below) to the same gpu
Expand All @@ -430,12 +460,11 @@ def formula(batch):
gpu_pipe = stages[stage - 1]["gpu_component"]
docs = gpu_pipe.postprocess(docs, result) # type: ignore

if not do_preprocess_chunk or stage > 0:
for pipe, kwargs in stages[stage]["cpu_components"]:
if hasattr(pipe, "batch_process"):
docs = pipe.batch_process(docs)
else:
docs = [pipe(doc, **kwargs) for doc in docs]
for pipe, kwargs in stages[stage]["cpu_components"]:
if hasattr(pipe, "batch_process"):
docs = pipe.batch_process(docs)
else:
docs = [pipe(doc, **kwargs) for doc in docs]

gpu_pipe: "TrainablePipe" = stages[stage]["gpu_component"]
if gpu_pipe is not None:
Expand Down Expand Up @@ -535,13 +564,13 @@ def _run(self):
self.exchanger.outputs_queue.put(None)
with torch.no_grad():
while True:
stage, task = self.exchanger.get_gpu_task(self.gpu_idx)
if task is None:
break
if had_error:
continue # pragma: no cover

try:
stage, task = self.exchanger.get_gpu_task(self.gpu_idx)
if task is None:
break
if had_error:
continue # pragma: no cover

cpu_idx, batch_id, batch = task
pipe = stage_components[stage]
pipe.enable_cache(batch_id)
Expand All @@ -562,11 +591,12 @@ def _run(self):
pipe.disable_cache(batch_id)
del batch, task
except BaseException as e:
had_error = True
import traceback
if not had_error:
had_error = True
import traceback

print(traceback.format_exc(), flush=True)
self.exchanger.put_results((e, 0, None, None))
print(traceback.format_exc(), flush=True)
self.exchanger.put_results((e, 0, None, None))
task = batch = res = None # noqa
# We need to drain the queues of CPUWorker fed inputs (pre-moved to GPU)
# to ensure no tensor allocated on producer processes (CPUWorker via
Expand Down Expand Up @@ -745,6 +775,8 @@ def execute_multiprocessing_backend(

revert_pickler = replace_pickler()

print("FP", fp.name, flush=True)

for gpu_idx in range(num_gpu_workers):
gpu_workers.append(
mp.Process(
Expand Down Expand Up @@ -892,5 +924,14 @@ def process():
logging.error(f"Killing cpu worker {i}")
worker.kill()

for queue_group in (
*exchanger.cpu_inputs_queues,
*exchanger.gpu_inputs_queues,
[exchanger.outputs_queue],
):
for queue in queue_group:
if hasattr(queue, "cancel_join_thread"):
queue.cancel_join_thread()

gen = process()
return lc.writer.write_main(gen) if lc.writer is not None else flatten(gen)

0 comments on commit 37983fe

Please sign in to comment.