Skip to content

Commit

Permalink
Readable token streaming support (#397)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
greshilov and jeffra authored Feb 1, 2024
1 parent 690645e commit e917dae
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 4 deletions.
66 changes: 62 additions & 4 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random
import threading
import time
from dataclasses import dataclass
from collections import deque, defaultdict
from functools import cached_property
from typing import Dict, Tuple, List, Any, Union, DefaultDict
Expand All @@ -33,6 +34,7 @@
from mii.config import GenerateParamsConfig
from mii.constants import GenerationFinishReason, ZMQ_RECV_TIMEOUT
from mii.logging import logger
from mii.modeling.tokenizers import MIITokenizerWrapper


class RaggedBatchBase:
Expand Down Expand Up @@ -205,6 +207,7 @@ def _generate_output(self, r: Request) -> bool:
r.prompt_length,
r.num_generated_tokens,
GenerationFinishReason.NONE,
r.stream,
))
if r.finish_reason != GenerationFinishReason.NONE:
if r.stream or not r.generated_tokens:
Expand All @@ -221,6 +224,7 @@ def _generate_output(self, r: Request) -> bool:
r.prompt_length,
r.num_generated_tokens,
r.finish_reason,
r.stream,
))
for output in outputs:
self.result_queues[r.tid].put_nowait(output)
Expand Down Expand Up @@ -452,6 +456,51 @@ def flush(self, uids: List[int]) -> None:
self.inference_engine.flush(uid)


@dataclass
class StreamState:
prev_token_size: int
token_ids: List[int]


class ReadableStream:
def __init__(self, tokenizer: MIITokenizerWrapper) -> None:
self.tokenizer = tokenizer
self.stream_state: Dict[int, StreamState] = {}

def init_state(self, thread_id: int) -> StreamState:
if thread_id not in self.stream_state:
self.stream_state[thread_id] = StreamState(token_ids=[], prev_token_size=0)
return self.stream_state[thread_id]
return self.stream_state[thread_id]

def flush_state(self, thread_id: int) -> None:
if thread_id in self.stream_state:
del self.stream_state[thread_id]

def decode(self, thread_id: int, token_ids: List[int]) -> str:
state = self.init_state(thread_id)
output = []

for token_id in token_ids:
state.token_ids.append(token_id)
decoded = self.tokenizer.decode(state.token_ids)

# We don't have enough token_ids in the buffer and
# tokenizer returned unicode 'U+FFFD REPLACEMENT CHARACTER'
if "\ufffd" in decoded:
continue

if state.prev_token_size > 0:
prev_token = state.token_ids[:state.prev_token_size]
state.token_ids = state.token_ids[state.prev_token_size:]
decoded = decoded.replace(self.tokenizer.decode(prev_token), "", 1)

output.append(decoded)
state.prev_token_size = len(state.token_ids)

return "".join(output)


class MIIPipeline(RaggedBatchBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -548,6 +597,7 @@ def __init__(self, *args, **kwargs):
self._is_shutdown = False
self.UID_RANGE_LB = 1
self.UID_RANGE_UB = 10000
self.readable_stream = ReadableStream(self.tokenizer)

def __call__(self) -> None:
# CUDA device gets reset, must set it again to avoid problems
Expand Down Expand Up @@ -605,14 +655,22 @@ def get_response(self) -> Tuple[int, Response]:
generated_length=None,
finish_reason=None)
tid = threading.get_ident()
result = self.result_queues[tid].get()
uid = result[0]
generated_token_ids = result[1]
uid, generated_token_ids, prompt_length, generated_length, finish_reason, streaming = self.result_queues[tid].get()

if len(generated_token_ids) == 0:
generated_text = ""
self.readable_stream.flush_state(tid)
elif streaming:
generated_text = self.readable_stream.decode(tid, generated_token_ids)
else:
generated_text = self.tokenizer.decode(generated_token_ids)
response = self.make_response(generated_text, result[2], result[3], result[4])

response = self.make_response(
generated_text=generated_text,
prompt_length=prompt_length,
generated_length=generated_length,
finish_reason=finish_reason,
)
return uid, response

def start(self) -> None:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def callback(response):
assert outputs, "output is empty"


def test_streaming_consistency(deployment, query):
expected_output = deployment(query, do_sample=False)
streaming_parts = []

def callback(response):
streaming_parts.append(response[0].generated_text)

deployment(query, do_sample=False, streaming_fn=callback)
streaming_output = "".join(streaming_parts)

assert streaming_output == expected_output[0].generated_text, "outputs w and w/o streaming are not equal"


def test_multi_prompt(deployment, query):
outputs = deployment([query] * 4)
for r in outputs:
Expand Down
51 changes: 51 additions & 0 deletions tests/test_ragged_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import pytest

from mii.batching.ragged_batching import ReadableStream
from mii.config import ModelConfig
from mii.modeling.tokenizers import load_tokenizer


@pytest.mark.parametrize(
"model_name",
[
"tiiuae/falcon-7b",
"NousResearch/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"cloudyu/Mixtral_11Bx2_MoE_19B",
"facebook/opt-125m",
],
ids=["falcon",
"llama",
"mistral",
"mixtral",
"opt"],
)
@pytest.mark.parametrize(
"query",
[
"It’s a region that includes Washington, Oregon, and Idaho.",
"# Heading\n\n<s>title</s> redundant spaces, #id — an anchor",
"例如",
],
ids=[
"apostrophe",
"markdown",
"chinese",
])
def test_readable_stream(model_config, query):
tokenizer = load_tokenizer(ModelConfig(**model_config))
thread_id = 42

token_ids = tokenizer.encode(query)
expected = tokenizer.decode(token_ids)
decoded = []

stream = ReadableStream(tokenizer)
for token_id in token_ids:
decoded.append(stream.decode(thread_id, [token_id]))

assert "".join(decoded) == expected

0 comments on commit e917dae

Please sign in to comment.