diff --git a/ultravox/inference/infer.py b/ultravox/inference/infer.py index 4752dbfe..a06a4e68 100644 --- a/ultravox/inference/infer.py +++ b/ultravox/inference/infer.py @@ -1,3 +1,4 @@ +import threading from typing import Optional import librosa @@ -29,7 +30,6 @@ def __init__( self.processor = processor self.dtype = dtype - @torch.inference_mode() def infer( self, sample: datasets.VoiceSample, @@ -38,27 +38,36 @@ def infer( ) -> base.VoiceOutput: inputs = self._dataproc(sample) input_len = inputs["input_ids"].shape[1] - temperature = temperature or None - do_sample = temperature is not None - - terminators = [self.tokenizer.eos_token_id] - if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: - terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) - - output = self.model.generate( - **inputs, - do_sample=do_sample, - max_new_tokens=max_tokens or MAX_TOKENS, - temperature=temperature, - repetition_penalty=REPETITION_PENALTY, - pad_token_id=self.tokenizer.eos_token_id, - eos_token_id=terminators, - ) + output = self._generate(inputs, max_tokens, temperature) output_tokens = output[0][input_len:] output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True) output_len = len(output_tokens) return base.VoiceOutput(output_text, input_len, output_len) + def infer_stream( + self, + sample: datasets.VoiceSample, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> base.InferenceGenerator: + inputs = self._dataproc(sample) + input_tokens = inputs["input_ids"].shape[1] + decode_kwargs = {"skip_special_tokens": True} + streamer = transformers.TextIteratorStreamer( + self.tokenizer, skip_prompt=True, decode_kwargs=decode_kwargs + ) + + thread_args = (inputs, max_tokens, temperature, streamer) + thread = threading.Thread(target=self._generate, args=thread_args) + thread.start() + output_tokens = 0 + for chunk in streamer: + if chunk: + yield base.InferenceChunk(chunk) + output_tokens += 1 + yield base.InferenceStats(input_tokens, output_tokens) + thread.join() + def _dataproc(self, sample: datasets.VoiceSample): text_input = self.tokenizer.apply_chat_template( sample.messages, add_generation_prompt=True, tokenize=False @@ -94,3 +103,29 @@ def _dataproc(self, sample: datasets.VoiceSample): if "audio_values" in inputs: inputs["audio_values"] = inputs["audio_values"].to(dtype=self.dtype) return inputs + + @torch.inference_mode() + def _generate( + self, + inputs: torch.Tensor, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + streamer: Optional[transformers.TextStreamer] = None, + ): + temperature = temperature or None + do_sample = temperature is not None + + terminators = [self.tokenizer.eos_token_id] + if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: + terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) + + return self.model.generate( + **inputs, + do_sample=do_sample, + max_new_tokens=max_tokens or MAX_TOKENS, + temperature=temperature, + repetition_penalty=REPETITION_PENALTY, + pad_token_id=self.tokenizer.eos_token_id, + eos_token_id=terminators, + streamer=streamer, + ) diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index ec165dad..12c84e09 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -33,6 +33,16 @@ def __init__( tokenizer: transformers.PreTrainedTokenizer, audio_processor: transformers.ProcessorMixin, ): + def fake_generate(**kwargs): + input = kwargs.get("input_ids") + output = [range(25)] + streamer = kwargs.get("streamer", None) + if streamer: + for token in output[0][input.shape[1] :]: + streamer.on_finalized_text(tokenizer.decode(token)) + streamer.on_finalized_text("", stream_end=True) + return output + processor = ultravox_processing.UltravoxProcessor( audio_processor, tokenizer=tokenizer ) @@ -44,7 +54,7 @@ def __init__( dtype=torch.float32, ) self.model.device = "cpu" - self.model.generate = mock.MagicMock(return_value=[range(25)]) + self.model.generate = mock.MagicMock(side_effect=fake_generate) EXPECTED_TOKEN_IDS_START = [128000, 128006, 882, 128007]