diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index d481783d..067b2c62 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -9,6 +9,7 @@ import random import threading import time +from dataclasses import dataclass from collections import deque, defaultdict from functools import cached_property from typing import Dict, Tuple, List, Any, Union, DefaultDict @@ -33,6 +34,7 @@ from mii.config import GenerateParamsConfig from mii.constants import GenerationFinishReason, ZMQ_RECV_TIMEOUT from mii.logging import logger +from mii.modeling.tokenizers import MIITokenizerWrapper class RaggedBatchBase: @@ -199,12 +201,14 @@ def _process_logits( def _generate_output(self, r: Request) -> bool: outputs = [] if r.stream: - outputs.append((r.uid, - [r.next_token], - r.prompt_length, - r.num_generated_tokens, - GenerationFinishReason.NONE, - r.stream)) + outputs.append(( + r.uid, + [r.next_token], + r.prompt_length, + r.num_generated_tokens, + GenerationFinishReason.NONE, + r.stream, + )) if r.finish_reason != GenerationFinishReason.NONE: if r.stream or not r.generated_tokens: output_tokens = [] @@ -214,12 +218,14 @@ def _generate_output(self, r: Request) -> bool: if r.return_full_text: # Avoid returning bos token, refactor this later output_tokens = torch.cat((r.prompt_tokens[1:], output_tokens)) - outputs.append((r.uid, - output_tokens, - r.prompt_length, - r.num_generated_tokens, - r.finish_reason, - r.stream)) + outputs.append(( + r.uid, + output_tokens, + r.prompt_length, + r.num_generated_tokens, + r.finish_reason, + r.stream, + )) for output in outputs: self.result_queues[r.tid].put_nowait(output) @@ -452,40 +458,47 @@ def flush(self, uids: List[int]) -> None: @dataclass class StreamState: - prev_tok: str - tids: List[int] + prev_token_size: int + token_ids: List[int] -class ReadableStream(): - def __init__(self, tokenizer): +class ReadableStream: + def __init__(self, tokenizer: MIITokenizerWrapper) -> None: self.tokenizer = tokenizer - self.stream_state = {} + self.stream_state: Dict[int, StreamState] = {} - def init_state(self, thread_id): + def init_state(self, thread_id: int) -> StreamState: if thread_id not in self.stream_state: - self.stream_state[thread_id] = StreamState(prev_tok=None, tids=[]) + self.stream_state[thread_id] = StreamState(token_ids=[], prev_token_size=0) + return self.stream_state[thread_id] return self.stream_state[thread_id] - def flush_state(self, thread_id): + def flush_state(self, thread_id: int) -> None: if thread_id in self.stream_state: del self.stream_state[thread_id] - def decode(self, thread_id, token_ids): - sstate = self.init_state(thread_id) - final = [] - for tid in token_ids: - sstate.tids.append(tid) - r = self.tokenizer.decode(sstate.tids) - if " " in r: - if sstate.prev_tok is not None: - r = r.replace(sstate.prev_tok, "") - sstate.tids = [sstate.tids[-1]] - elif len(sstate.tids) > 1: - sstate.tids.pop(0) - r = r.replace(sstate.prev_tok, "") - sstate.prev_tok = self.tokenizer.decode(tid) - final.append(r) - return "".join(final) + def decode(self, thread_id: int, token_ids: List[int]) -> str: + state = self.init_state(thread_id) + output = [] + + for token_id in token_ids: + state.token_ids.append(token_id) + decoded = self.tokenizer.decode(state.token_ids) + + # We don't have enough token_ids in the buffer and + # tokenizer returned unicode 'U+FFFD REPLACEMENT CHARACTER' + if "\ufffd" in decoded: + continue + + if state.prev_token_size > 0: + prev_token = state.token_ids[:state.prev_token_size] + state.token_ids = state.token_ids[state.prev_token_size:] + decoded = decoded.replace(self.tokenizer.decode(prev_token), "", 1) + + output.append(decoded) + state.prev_token_size = len(state.token_ids) + + return "".join(output) class MIIPipeline(RaggedBatchBase): @@ -652,10 +665,12 @@ def get_response(self) -> Tuple[int, Response]: else: generated_text = self.tokenizer.decode(generated_token_ids) - response = self.make_response(generated_text=generated_text, - prompt_length=prompt_length, - generated_length=generated_length, - finish_reason=finish_reason) + response = self.make_response( + generated_text=generated_text, + prompt_length=prompt_length, + generated_length=generated_length, + finish_reason=finish_reason, + ) return uid, response def start(self) -> None: diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 9f321210..03a18c87 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -28,6 +28,19 @@ def callback(response): assert outputs, "output is empty" +def test_streaming_consistency(deployment, query): + expected_output = deployment(query, do_sample=False) + streaming_parts = [] + + def callback(response): + streaming_parts.append(response[0].generated_text) + + deployment(query, do_sample=False, streaming_fn=callback) + streaming_output = "".join(streaming_parts) + + assert streaming_output == expected_output[0].generated_text, "outputs w and w/o streaming are not equal" + + def test_multi_prompt(deployment, query): outputs = deployment([query] * 4) for r in outputs: diff --git a/tests/test_ragged_batching.py b/tests/test_ragged_batching.py new file mode 100644 index 00000000..583b92f2 --- /dev/null +++ b/tests/test_ragged_batching.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import pytest + +from mii.batching.ragged_batching import ReadableStream +from mii.config import ModelConfig +from mii.modeling.tokenizers import load_tokenizer + + +@pytest.mark.parametrize( + "model_name", + [ + "tiiuae/falcon-7b", + "NousResearch/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", + "cloudyu/Mixtral_11Bx2_MoE_19B", + "facebook/opt-125m", + ], + ids=["falcon", + "llama", + "mistral", + "mixtral", + "opt"], +) +@pytest.mark.parametrize( + "query", + [ + "It’s a region that includes Washington, Oregon, and Idaho.", + "# Heading\n\ntitle redundant spaces, #id — an anchor", + "例如", + ], + ids=[ + "apostrophe", + "markdown", + "chinese", + ]) +def test_readable_stream(model_config, query): + tokenizer = load_tokenizer(ModelConfig(**model_config)) + thread_id = 42 + + token_ids = tokenizer.encode(query) + expected = tokenizer.decode(token_ids) + decoded = [] + + stream = ReadableStream(tokenizer) + for token_id in token_ids: + decoded.append(stream.decode(thread_id, [token_id])) + + assert "".join(decoded) == expected