Skip to content

Commit

Permalink
fix: process group error
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Feb 19, 2024
1 parent 1ecb339 commit 16b9a1d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
15 changes: 6 additions & 9 deletions presets/inference/llama2-chat/inference-api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def master_inference(input_string, max_gen_len, temperature, top_p):

def shutdown_server():
"""Shut down the server."""
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

# Default values for the generator
gen_params = {
Expand Down Expand Up @@ -247,15 +247,15 @@ def worker_listen_tasks():
print(f"Error in generation: {str(e)}")
elif command == "shutdown":
print(f"Worker {worker_num} shutting down")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except torch.distributed.DistBackendError as e:
print("torch.distributed.DistBackendError", e)
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except Exception as e:
print(f"Error in Worker Listen Task", e)
if 'Socket Timeout' in str(e):
print("A socket timeout occurred.")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()

if __name__ == "__main__":
# Fetch the LOCAL_RANK environment variable to determine the rank of this process
Expand All @@ -276,10 +276,7 @@ def worker_listen_tasks():
# Uncomment to enable worker logs
# sys.stdout = sys.__stdout__

pid = os.getpid()
if pid != os.getpgid(pid):
os.setpgrp()
try:
try:
# If the current process is the locally ranked 0 (i.e., the primary process)
# on its node, then it starts a worker server that exposes a health check endpoint.
if local_rank == 0:
Expand All @@ -296,4 +293,4 @@ def worker_listen_tasks():
worker_listen_tasks()
finally:
# Additional fail-safe (to ensure no lingering processes)
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)
13 changes: 5 additions & 8 deletions presets/inference/llama2-completion/inference-api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def master_inference(prompts, max_gen_len, temperature, top_p):

def shutdown_server():
"""Shut down the server."""
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

# Default values for the generator
gen_params = {
Expand Down Expand Up @@ -236,15 +236,15 @@ def worker_listen_tasks():
print(f"Error in generation: {str(e)}")
elif command == "shutdown":
print(f"Worker {worker_num} shutting down")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except torch.distributed.DistBackendError as e:
print("torch.distributed.DistBackendError", e)
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except Exception as e:
print(f"Error in Worker Listen Task", e)
if 'Socket Timeout' in str(e):
print("A socket timeout occurred.")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()

if __name__ == "__main__":
# Fetch the LOCAL_RANK environment variable to determine the rank of this process
Expand All @@ -265,9 +265,6 @@ def worker_listen_tasks():
# Uncomment to enable worker logs
# sys.stdout = sys.__stdout__

pid = os.getpid()
if pid != os.getpgid(pid):
os.setpgrp()
try:
# If the current process is the locally ranked 0 (i.e., the primary process)
# on its node, then it starts a worker server that exposes a health check endpoint.
Expand All @@ -285,4 +282,4 @@ def worker_listen_tasks():
worker_listen_tasks()
finally:
# Additional fail-safe (to ensure no lingering processes)
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

0 comments on commit 16b9a1d

Please sign in to comment.