Skip to content

Commit

Permalink
fix: avoid torch freeze by disabling torch implicit parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 15, 2024
1 parent 1969189 commit 84710b8
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions edspdf/processing/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def revert():
try: # pragma: no cover
import torch

if "TORCH_SET_NUM_THREADS" in os.environ:
torch.set_num_threads(int(os.environ["TORCH_SET_NUM_THREADS"]))

# Torch may still be imported as a namespace package, so we can access the
# torch.save and torch.load functions
torch_save = torch.save
Expand Down Expand Up @@ -677,6 +680,7 @@ def execute_multiprocessing_backend(
if lc.disable_implicit_parallelism:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCH_SET_NUM_THREADS"] = "1"

requires_gpu = (
num_gpu_workers is None
Expand Down Expand Up @@ -879,9 +883,15 @@ def process():
exchanger.cpu_inputs_queues[i][0].put(None)

while any(active_chunks):
print(
"active_chunks", active_chunks, "non_finalized", non_finalized
)
yield from get_and_process_output()

while len(non_finalized):
print(
"non_finalized", non_finalized, "active_chunks", active_chunks
)
yield from get_and_process_output()

finally:
Expand Down

0 comments on commit 84710b8

Please sign in to comment.