diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 21d76251..d481783d 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -199,13 +199,12 @@ def _process_logits( def _generate_output(self, r: Request) -> bool: outputs = [] if r.stream: - outputs.append(( - r.uid, - [r.next_token], - r.prompt_length, - r.num_generated_tokens, - GenerationFinishReason.NONE, - )) + outputs.append((r.uid, + [r.next_token], + r.prompt_length, + r.num_generated_tokens, + GenerationFinishReason.NONE, + r.stream)) if r.finish_reason != GenerationFinishReason.NONE: if r.stream or not r.generated_tokens: output_tokens = [] @@ -215,13 +214,12 @@ def _generate_output(self, r: Request) -> bool: if r.return_full_text: # Avoid returning bos token, refactor this later output_tokens = torch.cat((r.prompt_tokens[1:], output_tokens)) - outputs.append(( - r.uid, - output_tokens, - r.prompt_length, - r.num_generated_tokens, - r.finish_reason, - )) + outputs.append((r.uid, + output_tokens, + r.prompt_length, + r.num_generated_tokens, + r.finish_reason, + r.stream)) for output in outputs: self.result_queues[r.tid].put_nowait(output) @@ -452,6 +450,44 @@ def flush(self, uids: List[int]) -> None: self.inference_engine.flush(uid) +@dataclass +class StreamState: + prev_tok: str + tids: List[int] + + +class ReadableStream(): + def __init__(self, tokenizer): + self.tokenizer = tokenizer + self.stream_state = {} + + def init_state(self, thread_id): + if thread_id not in self.stream_state: + self.stream_state[thread_id] = StreamState(prev_tok=None, tids=[]) + return self.stream_state[thread_id] + + def flush_state(self, thread_id): + if thread_id in self.stream_state: + del self.stream_state[thread_id] + + def decode(self, thread_id, token_ids): + sstate = self.init_state(thread_id) + final = [] + for tid in token_ids: + sstate.tids.append(tid) + r = self.tokenizer.decode(sstate.tids) + if " " in r: + if sstate.prev_tok is not None: + r = r.replace(sstate.prev_tok, "") + sstate.tids = [sstate.tids[-1]] + elif len(sstate.tids) > 1: + sstate.tids.pop(0) + r = r.replace(sstate.prev_tok, "") + sstate.prev_tok = self.tokenizer.decode(tid) + final.append(r) + return "".join(final) + + class MIIPipeline(RaggedBatchBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -548,6 +584,7 @@ def __init__(self, *args, **kwargs): self._is_shutdown = False self.UID_RANGE_LB = 1 self.UID_RANGE_UB = 10000 + self.readable_stream = ReadableStream(self.tokenizer) def __call__(self) -> None: # CUDA device gets reset, must set it again to avoid problems @@ -605,14 +642,20 @@ def get_response(self) -> Tuple[int, Response]: generated_length=None, finish_reason=None) tid = threading.get_ident() - result = self.result_queues[tid].get() - uid = result[0] - generated_token_ids = result[1] + uid, generated_token_ids, prompt_length, generated_length, finish_reason, streaming = self.result_queues[tid].get() + if len(generated_token_ids) == 0: generated_text = "" + self.readable_stream.flush_state(tid) + elif streaming: + generated_text = self.readable_stream.decode(tid, generated_token_ids) else: generated_text = self.tokenizer.decode(generated_token_ids) - response = self.make_response(generated_text, result[2], result[3], result[4]) + + response = self.make_response(generated_text=generated_text, + prompt_length=prompt_length, + generated_length=generated_length, + finish_reason=finish_reason) return uid, response def start(self) -> None: