Skip to content

Commit

Permalink
fix: Eliminate Unnecessary Process Group Creation in Worker Initializ…
Browse files Browse the repository at this point in the history
…ation (#244)

Fix for using new python base image for llama models

Creating and destroying process groups for each worker proves
unnecessary - instead exit conditions and sigterm signal for shutting
down each worker proved reliable after testing.
  • Loading branch information
ishaansehgal99 authored Feb 20, 2024
1 parent b41e027 commit 559efed
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 23 deletions.
17 changes: 8 additions & 9 deletions presets/inference/llama2-chat/inference-api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def master_inference(input_string, max_gen_len, temperature, top_p):

def shutdown_server():
"""Shut down the server."""
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

# Default values for the generator
gen_params = {
Expand Down Expand Up @@ -221,10 +221,10 @@ def get_metrics():
return {"error": str(e)}

def start_worker_server():
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000")
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000\n")
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)

def worker_listen_tasks():
def worker_listen_tasks():
while True:
worker_num = dist.get_rank()
print(f"Worker {worker_num} ready to recieve next command")
Expand All @@ -247,15 +247,15 @@ def worker_listen_tasks():
print(f"Error in generation: {str(e)}")
elif command == "shutdown":
print(f"Worker {worker_num} shutting down")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except torch.distributed.DistBackendError as e:
print("torch.distributed.DistBackendError", e)
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except Exception as e:
print(f"Error in Worker Listen Task", e)
if 'Socket Timeout' in str(e):
print("A socket timeout occurred.")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()

if __name__ == "__main__":
# Fetch the LOCAL_RANK environment variable to determine the rank of this process
Expand All @@ -276,8 +276,7 @@ def worker_listen_tasks():
# Uncomment to enable worker logs
# sys.stdout = sys.__stdout__

os.setpgrp()
try:
try:
# If the current process is the locally ranked 0 (i.e., the primary process)
# on its node, then it starts a worker server that exposes a health check endpoint.
if local_rank == 0:
Expand All @@ -294,4 +293,4 @@ def worker_listen_tasks():
worker_listen_tasks()
finally:
# Additional fail-safe (to ensure no lingering processes)
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)
15 changes: 7 additions & 8 deletions presets/inference/llama2-completion/inference-api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def master_inference(prompts, max_gen_len, temperature, top_p):

def shutdown_server():
"""Shut down the server."""
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

# Default values for the generator
gen_params = {
Expand Down Expand Up @@ -210,10 +210,10 @@ def get_metrics():
return {"error": str(e)}

def start_worker_server():
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000")
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000\n")
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)

def worker_listen_tasks():
def worker_listen_tasks():
while True:
worker_num = dist.get_rank()
print(f"Worker {worker_num} ready to recieve next command")
Expand All @@ -236,15 +236,15 @@ def worker_listen_tasks():
print(f"Error in generation: {str(e)}")
elif command == "shutdown":
print(f"Worker {worker_num} shutting down")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except torch.distributed.DistBackendError as e:
print("torch.distributed.DistBackendError", e)
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except Exception as e:
print(f"Error in Worker Listen Task", e)
if 'Socket Timeout' in str(e):
print("A socket timeout occurred.")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()

if __name__ == "__main__":
# Fetch the LOCAL_RANK environment variable to determine the rank of this process
Expand All @@ -265,7 +265,6 @@ def worker_listen_tasks():
# Uncomment to enable worker logs
# sys.stdout = sys.__stdout__

os.setpgrp()
try:
# If the current process is the locally ranked 0 (i.e., the primary process)
# on its node, then it starts a worker server that exposes a health check endpoint.
Expand All @@ -283,4 +282,4 @@ def worker_listen_tasks():
worker_listen_tasks()
finally:
# Additional fail-safe (to ensure no lingering processes)
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)
12 changes: 6 additions & 6 deletions presets/models/supported_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@ models:
- name: llama-2-7b
type: llama2-completion
runtime: llama-2
tag: 0.0.1
tag: 0.0.2
- name: llama-2-7b-chat
type: llama2-chat
runtime: llama-2
tag: 0.0.1
tag: 0.0.2
- name: llama-2-13b
type: llama2-completion
runtime: llama-2
tag: 0.0.1
tag: 0.0.2
- name: llama-2-13b-chat
type: llama2-chat
runtime: llama-2
tag: 0.0.1
tag: 0.0.2
- name: llama-2-70b
type: llama2-completion
runtime: llama-2
tag: 0.0.1
tag: 0.0.2
- name: llama-2-70b-chat
type: llama2-chat
runtime: llama-2
tag: 0.0.1
tag: 0.0.2

# Falcon
- name: falcon-7b
Expand Down

0 comments on commit 559efed

Please sign in to comment.