Skip to content

Commit

Permalink
improve readable token streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
greshilov committed Jan 30, 2024
1 parent 44b09d7 commit 1ae8852
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 40 deletions.
95 changes: 55 additions & 40 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random
import threading
import time
from dataclasses import dataclass
from collections import deque, defaultdict
from functools import cached_property
from typing import Dict, Tuple, List, Any, Union, DefaultDict
Expand All @@ -33,6 +34,7 @@
from mii.config import GenerateParamsConfig
from mii.constants import GenerationFinishReason, ZMQ_RECV_TIMEOUT
from mii.logging import logger
from mii.modeling.tokenizers import MIITokenizerWrapper


class RaggedBatchBase:
Expand Down Expand Up @@ -199,12 +201,14 @@ def _process_logits(
def _generate_output(self, r: Request) -> bool:
outputs = []
if r.stream:
outputs.append((r.uid,
[r.next_token],
r.prompt_length,
r.num_generated_tokens,
GenerationFinishReason.NONE,
r.stream))
outputs.append((
r.uid,
[r.next_token],
r.prompt_length,
r.num_generated_tokens,
GenerationFinishReason.NONE,
r.stream,
))
if r.finish_reason != GenerationFinishReason.NONE:
if r.stream or not r.generated_tokens:
output_tokens = []
Expand All @@ -214,12 +218,14 @@ def _generate_output(self, r: Request) -> bool:
if r.return_full_text:
# Avoid returning bos token, refactor this later
output_tokens = torch.cat((r.prompt_tokens[1:], output_tokens))
outputs.append((r.uid,
output_tokens,
r.prompt_length,
r.num_generated_tokens,
r.finish_reason,
r.stream))
outputs.append((
r.uid,
output_tokens,
r.prompt_length,
r.num_generated_tokens,
r.finish_reason,
r.stream,
))
for output in outputs:
self.result_queues[r.tid].put_nowait(output)

Expand Down Expand Up @@ -452,40 +458,47 @@ def flush(self, uids: List[int]) -> None:

@dataclass
class StreamState:
prev_tok: str
tids: List[int]
prev_token_size: int
token_ids: List[int]


class ReadableStream():
def __init__(self, tokenizer):
class ReadableStream:
def __init__(self, tokenizer: MIITokenizerWrapper) -> None:
self.tokenizer = tokenizer
self.stream_state = {}
self.stream_state: Dict[int, StreamState] = {}

def init_state(self, thread_id):
def init_state(self, thread_id: int) -> StreamState:
if thread_id not in self.stream_state:
self.stream_state[thread_id] = StreamState(prev_tok=None, tids=[])
self.stream_state[thread_id] = StreamState(token_ids=[], prev_token_size=0)
return self.stream_state[thread_id]
return self.stream_state[thread_id]

def flush_state(self, thread_id):
def flush_state(self, thread_id: int) -> None:
if thread_id in self.stream_state:
del self.stream_state[thread_id]

def decode(self, thread_id, token_ids):
sstate = self.init_state(thread_id)
final = []
for tid in token_ids:
sstate.tids.append(tid)
r = self.tokenizer.decode(sstate.tids)
if " " in r:
if sstate.prev_tok is not None:
r = r.replace(sstate.prev_tok, "")
sstate.tids = [sstate.tids[-1]]
elif len(sstate.tids) > 1:
sstate.tids.pop(0)
r = r.replace(sstate.prev_tok, "")
sstate.prev_tok = self.tokenizer.decode(tid)
final.append(r)
return "".join(final)
def decode(self, thread_id: int, token_ids: List[int]) -> str:
state = self.init_state(thread_id)
output = []

for token_id in token_ids:
state.token_ids.append(token_id)
decoded = self.tokenizer.decode(state.token_ids)

# We don't have enough token_ids in the buffer and
# tokenizer returned unicode 'U+FFFD REPLACEMENT CHARACTER'
if "\ufffd" in decoded:
continue

if state.prev_token_size > 0:
prev_token = state.token_ids[:state.prev_token_size]
state.token_ids = state.token_ids[state.prev_token_size:]
decoded = decoded.replace(self.tokenizer.decode(prev_token), "", 1)

output.append(decoded)
state.prev_token_size = len(state.token_ids)

return "".join(output)


class MIIPipeline(RaggedBatchBase):
Expand Down Expand Up @@ -652,10 +665,12 @@ def get_response(self) -> Tuple[int, Response]:
else:
generated_text = self.tokenizer.decode(generated_token_ids)

response = self.make_response(generated_text=generated_text,
prompt_length=prompt_length,
generated_length=generated_length,
finish_reason=finish_reason)
response = self.make_response(
generated_text=generated_text,
prompt_length=prompt_length,
generated_length=generated_length,
finish_reason=finish_reason,
)
return uid, response

def start(self) -> None:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def callback(response):
assert outputs, "output is empty"


def test_streaming_consistency(deployment, query):
expected_output = deployment(query, do_sample=False)
streaming_parts = []

def callback(response):
streaming_parts.append(response[0].generated_text)

deployment(query, do_sample=False, streaming_fn=callback)
streaming_output = "".join(streaming_parts)

assert streaming_output == expected_output[0].generated_text, "outputs w and w/o streaming are not equal"


def test_multi_prompt(deployment, query):
outputs = deployment([query] * 4)
for r in outputs:
Expand Down
51 changes: 51 additions & 0 deletions tests/test_ragged_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import pytest

from mii.batching.ragged_batching import ReadableStream
from mii.config import ModelConfig
from mii.modeling.tokenizers import load_tokenizer


@pytest.mark.parametrize(
"model_name",
[
"tiiuae/falcon-7b",
"NousResearch/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"cloudyu/Mixtral_11Bx2_MoE_19B",
"facebook/opt-125m",
],
ids=["falcon",
"llama",
"mistral",
"mixtral",
"opt"],
)
@pytest.mark.parametrize(
"query",
[
"It’s a region that includes Washington, Oregon, and Idaho.",
"# Heading\n\n<s>title</s> redundant spaces, #id — an anchor",
"例如",
],
ids=[
"apostrophe",
"markdown",
"chinese",
])
def test_readable_stream(model_config, query):
tokenizer = load_tokenizer(ModelConfig(**model_config))
thread_id = 42

token_ids = tokenizer.encode(query)
expected = tokenizer.decode(token_ids)
decoded = []

stream = ReadableStream(tokenizer)
for token_id in token_ids:
decoded.append(stream.decode(thread_id, [token_id]))

assert "".join(decoded) == expected

0 comments on commit 1ae8852

Please sign in to comment.