Skip to content

Commit

Permalink
readable token streaming support
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored and greshilov committed Jan 30, 2024
1 parent 690645e commit 44b09d7
Showing 1 changed file with 61 additions and 18 deletions.
79 changes: 61 additions & 18 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,12 @@ def _process_logits(
def _generate_output(self, r: Request) -> bool:
outputs = []
if r.stream:
outputs.append((
r.uid,
[r.next_token],
r.prompt_length,
r.num_generated_tokens,
GenerationFinishReason.NONE,
))
outputs.append((r.uid,
[r.next_token],
r.prompt_length,
r.num_generated_tokens,
GenerationFinishReason.NONE,
r.stream))
if r.finish_reason != GenerationFinishReason.NONE:
if r.stream or not r.generated_tokens:
output_tokens = []
Expand All @@ -215,13 +214,12 @@ def _generate_output(self, r: Request) -> bool:
if r.return_full_text:
# Avoid returning bos token, refactor this later
output_tokens = torch.cat((r.prompt_tokens[1:], output_tokens))
outputs.append((
r.uid,
output_tokens,
r.prompt_length,
r.num_generated_tokens,
r.finish_reason,
))
outputs.append((r.uid,
output_tokens,
r.prompt_length,
r.num_generated_tokens,
r.finish_reason,
r.stream))
for output in outputs:
self.result_queues[r.tid].put_nowait(output)

Expand Down Expand Up @@ -452,6 +450,44 @@ def flush(self, uids: List[int]) -> None:
self.inference_engine.flush(uid)


@dataclass
class StreamState:
prev_tok: str
tids: List[int]


class ReadableStream():
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.stream_state = {}

def init_state(self, thread_id):
if thread_id not in self.stream_state:
self.stream_state[thread_id] = StreamState(prev_tok=None, tids=[])
return self.stream_state[thread_id]

def flush_state(self, thread_id):
if thread_id in self.stream_state:
del self.stream_state[thread_id]

def decode(self, thread_id, token_ids):
sstate = self.init_state(thread_id)
final = []
for tid in token_ids:
sstate.tids.append(tid)
r = self.tokenizer.decode(sstate.tids)
if " " in r:
if sstate.prev_tok is not None:
r = r.replace(sstate.prev_tok, "")
sstate.tids = [sstate.tids[-1]]
elif len(sstate.tids) > 1:
sstate.tids.pop(0)
r = r.replace(sstate.prev_tok, "")
sstate.prev_tok = self.tokenizer.decode(tid)
final.append(r)
return "".join(final)


class MIIPipeline(RaggedBatchBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -548,6 +584,7 @@ def __init__(self, *args, **kwargs):
self._is_shutdown = False
self.UID_RANGE_LB = 1
self.UID_RANGE_UB = 10000
self.readable_stream = ReadableStream(self.tokenizer)

def __call__(self) -> None:
# CUDA device gets reset, must set it again to avoid problems
Expand Down Expand Up @@ -605,14 +642,20 @@ def get_response(self) -> Tuple[int, Response]:
generated_length=None,
finish_reason=None)
tid = threading.get_ident()
result = self.result_queues[tid].get()
uid = result[0]
generated_token_ids = result[1]
uid, generated_token_ids, prompt_length, generated_length, finish_reason, streaming = self.result_queues[tid].get()

if len(generated_token_ids) == 0:
generated_text = ""
self.readable_stream.flush_state(tid)
elif streaming:
generated_text = self.readable_stream.decode(tid, generated_token_ids)
else:
generated_text = self.tokenizer.decode(generated_token_ids)
response = self.make_response(generated_text, result[2], result[3], result[4])

response = self.make_response(generated_text=generated_text,
prompt_length=prompt_length,
generated_length=generated_length,
finish_reason=finish_reason)
return uid, response

def start(self) -> None:
Expand Down

0 comments on commit 44b09d7

Please sign in to comment.