Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/cle…
Browse files Browse the repository at this point in the history
…an-api
  • Loading branch information
ishaansehgal99 committed Feb 20, 2024
2 parents f6ec3d3 + 559efed commit b1638c3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
17 changes: 8 additions & 9 deletions presets/inference/llama2-chat/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def master_inference(input_string, max_gen_len, temperature, top_p):

def shutdown_server():
"""Shut down the server."""
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

# Default values for the generator
gen_params = {
Expand Down Expand Up @@ -221,10 +221,10 @@ def get_metrics():
return {"error": str(e)}

def start_worker_server():
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000")
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000\n")
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)

def worker_listen_tasks():
def worker_listen_tasks():
while True:
worker_num = dist.get_rank()
print(f"Worker {worker_num} ready to recieve next command")
Expand All @@ -247,15 +247,15 @@ def worker_listen_tasks():
print(f"Error in generation: {str(e)}")
elif command == "shutdown":
print(f"Worker {worker_num} shutting down")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except torch.distributed.DistBackendError as e:
print("torch.distributed.DistBackendError", e)
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except Exception as e:
print(f"Error in Worker Listen Task", e)
if 'Socket Timeout' in str(e):
print("A socket timeout occurred.")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()

if __name__ == "__main__":
# Fetch the LOCAL_RANK environment variable to determine the rank of this process
Expand All @@ -276,8 +276,7 @@ def worker_listen_tasks():
# Uncomment to enable worker logs
# sys.stdout = sys.__stdout__

os.setpgrp()
try:
try:
# If the current process is the locally ranked 0 (i.e., the primary process)
# on its node, then it starts a worker server that exposes a health check endpoint.
if local_rank == 0:
Expand All @@ -294,4 +293,4 @@ def worker_listen_tasks():
worker_listen_tasks()
finally:
# Additional fail-safe (to ensure no lingering processes)
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)
15 changes: 7 additions & 8 deletions presets/inference/llama2-completion/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def master_inference(prompts, max_gen_len, temperature, top_p):

def shutdown_server():
"""Shut down the server."""
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

# Default values for the generator
gen_params = {
Expand Down Expand Up @@ -210,10 +210,10 @@ def get_metrics():
return {"error": str(e)}

def start_worker_server():
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000")
print(f"Worker {dist.get_rank()} HTTP health server started at port 5000\n")
uvicorn.run(app=app_worker, host='0.0.0.0', port=5000)

def worker_listen_tasks():
def worker_listen_tasks():
while True:
worker_num = dist.get_rank()
print(f"Worker {worker_num} ready to recieve next command")
Expand All @@ -236,15 +236,15 @@ def worker_listen_tasks():
print(f"Error in generation: {str(e)}")
elif command == "shutdown":
print(f"Worker {worker_num} shutting down")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except torch.distributed.DistBackendError as e:
print("torch.distributed.DistBackendError", e)
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()
except Exception as e:
print(f"Error in Worker Listen Task", e)
if 'Socket Timeout' in str(e):
print("A socket timeout occurred.")
os.killpg(os.getpgrp(), signal.SIGTERM)
sys.exit()

if __name__ == "__main__":
# Fetch the LOCAL_RANK environment variable to determine the rank of this process
Expand All @@ -265,7 +265,6 @@ def worker_listen_tasks():
# Uncomment to enable worker logs
# sys.stdout = sys.__stdout__

os.setpgrp()
try:
# If the current process is the locally ranked 0 (i.e., the primary process)
# on its node, then it starts a worker server that exposes a health check endpoint.
Expand All @@ -283,4 +282,4 @@ def worker_listen_tasks():
worker_listen_tasks()
finally:
# Additional fail-safe (to ensure no lingering processes)
os.killpg(os.getpgrp(), signal.SIGTERM)
os.kill(os.getpid(), signal.SIGTERM)

0 comments on commit b1638c3

Please sign in to comment.