From cedc2c24d96cd4d28e13271fa9c299d1855017c7 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 1 Feb 2024 09:59:04 -0800 Subject: [PATCH] Fix generate output order (#401) --- mii/grpc_related/modelresponse_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index e6da3097..0b6c99c3 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -65,12 +65,13 @@ def GeneratorReply(self, request, context): task_methods = self._get_task_methods("GeneratorReply") prompts, kwargs = task_methods.unpack_request_from_proto(request) - uids_running, uids_complete_order, responses = [], [], [] + uids_put_order, uids_running, uids_complete_order, responses = [], [], [], [] # Put requests for all prompts into the pipeline for p in prompts: request_kwargs = kwargs.copy() uid = self.inference_pipeline.put_request(p, request_kwargs) + uids_put_order.append(uid) uids_running.append(uid) # Get responses from the pipeline as they are ready, flush finished uids @@ -82,7 +83,7 @@ def GeneratorReply(self, request, context): uid = uids_running[0] responses.append(response) self.inference_pipeline.flush_uid(uid) - uids_complete_order.append(uids_running.index(uid)) + uids_complete_order.append(uids_put_order.index(uid)) uids_running.remove(uid) # Sort responses in the order of prompts