From 559efedd24fca28e351f057656ac569f3bf414ed Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Mon, 19 Feb 2024 21:38:15 -0800 Subject: [PATCH] fix: Eliminate Unnecessary Process Group Creation in Worker Initialization (#244) Fix for using new python base image for llama models Creating and destroying process groups for each worker proves unnecessary - instead exit conditions and sigterm signal for shutting down each worker proved reliable after testing. --- presets/inference/llama2-chat/inference-api.py | 17 ++++++++--------- .../llama2-completion/inference-api.py | 15 +++++++-------- presets/models/supported_models.yaml | 12 ++++++------ 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/presets/inference/llama2-chat/inference-api.py b/presets/inference/llama2-chat/inference-api.py index b41018416..11776bf3d 100644 --- a/presets/inference/llama2-chat/inference-api.py +++ b/presets/inference/llama2-chat/inference-api.py @@ -83,7 +83,7 @@ def master_inference(input_string, max_gen_len, temperature, top_p): def shutdown_server(): """Shut down the server.""" - os.killpg(os.getpgrp(), signal.SIGTERM) + os.kill(os.getpid(), signal.SIGTERM) # Default values for the generator gen_params = { @@ -221,10 +221,10 @@ def get_metrics(): return {"error": str(e)} def start_worker_server(): - print(f"Worker {dist.get_rank()} HTTP health server started at port 5000") + print(f"Worker {dist.get_rank()} HTTP health server started at port 5000\n") uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) -def worker_listen_tasks(): +def worker_listen_tasks(): while True: worker_num = dist.get_rank() print(f"Worker {worker_num} ready to recieve next command") @@ -247,15 +247,15 @@ def worker_listen_tasks(): print(f"Error in generation: {str(e)}") elif command == "shutdown": print(f"Worker {worker_num} shutting down") - os.killpg(os.getpgrp(), signal.SIGTERM) + sys.exit() except torch.distributed.DistBackendError as e: print("torch.distributed.DistBackendError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) + sys.exit() except Exception as e: print(f"Error in Worker Listen Task", e) if 'Socket Timeout' in str(e): print("A socket timeout occurred.") - os.killpg(os.getpgrp(), signal.SIGTERM) + sys.exit() if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -276,8 +276,7 @@ def worker_listen_tasks(): # Uncomment to enable worker logs # sys.stdout = sys.__stdout__ - os.setpgrp() - try: + try: # If the current process is the locally ranked 0 (i.e., the primary process) # on its node, then it starts a worker server that exposes a health check endpoint. if local_rank == 0: @@ -294,4 +293,4 @@ def worker_listen_tasks(): worker_listen_tasks() finally: # Additional fail-safe (to ensure no lingering processes) - os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file + os.kill(os.getpid(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/inference/llama2-completion/inference-api.py b/presets/inference/llama2-completion/inference-api.py index ffd8b9d3f..cf500146a 100644 --- a/presets/inference/llama2-completion/inference-api.py +++ b/presets/inference/llama2-completion/inference-api.py @@ -83,7 +83,7 @@ def master_inference(prompts, max_gen_len, temperature, top_p): def shutdown_server(): """Shut down the server.""" - os.killpg(os.getpgrp(), signal.SIGTERM) + os.kill(os.getpid(), signal.SIGTERM) # Default values for the generator gen_params = { @@ -210,10 +210,10 @@ def get_metrics(): return {"error": str(e)} def start_worker_server(): - print(f"Worker {dist.get_rank()} HTTP health server started at port 5000") + print(f"Worker {dist.get_rank()} HTTP health server started at port 5000\n") uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) -def worker_listen_tasks(): +def worker_listen_tasks(): while True: worker_num = dist.get_rank() print(f"Worker {worker_num} ready to recieve next command") @@ -236,15 +236,15 @@ def worker_listen_tasks(): print(f"Error in generation: {str(e)}") elif command == "shutdown": print(f"Worker {worker_num} shutting down") - os.killpg(os.getpgrp(), signal.SIGTERM) + sys.exit() except torch.distributed.DistBackendError as e: print("torch.distributed.DistBackendError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) + sys.exit() except Exception as e: print(f"Error in Worker Listen Task", e) if 'Socket Timeout' in str(e): print("A socket timeout occurred.") - os.killpg(os.getpgrp(), signal.SIGTERM) + sys.exit() if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -265,7 +265,6 @@ def worker_listen_tasks(): # Uncomment to enable worker logs # sys.stdout = sys.__stdout__ - os.setpgrp() try: # If the current process is the locally ranked 0 (i.e., the primary process) # on its node, then it starts a worker server that exposes a health check endpoint. @@ -283,4 +282,4 @@ def worker_listen_tasks(): worker_listen_tasks() finally: # Additional fail-safe (to ensure no lingering processes) - os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file + os.kill(os.getpid(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/models/supported_models.yaml b/presets/models/supported_models.yaml index d5e4fea1f..ee5dffddd 100644 --- a/presets/models/supported_models.yaml +++ b/presets/models/supported_models.yaml @@ -3,27 +3,27 @@ models: - name: llama-2-7b type: llama2-completion runtime: llama-2 - tag: 0.0.1 + tag: 0.0.2 - name: llama-2-7b-chat type: llama2-chat runtime: llama-2 - tag: 0.0.1 + tag: 0.0.2 - name: llama-2-13b type: llama2-completion runtime: llama-2 - tag: 0.0.1 + tag: 0.0.2 - name: llama-2-13b-chat type: llama2-chat runtime: llama-2 - tag: 0.0.1 + tag: 0.0.2 - name: llama-2-70b type: llama2-completion runtime: llama-2 - tag: 0.0.1 + tag: 0.0.2 - name: llama-2-70b-chat type: llama2-chat runtime: llama-2 - tag: 0.0.1 + tag: 0.0.2 # Falcon - name: falcon-7b