Skip to content

Commit

Permalink
[python] Support reasoning content (#2722)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Feb 5, 2025
1 parent 4ad8ce3 commit 40588c3
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def parse_chat_completions_request_vllm(
):

tool_parser = rolling_batch.get_tool_parser()
reasoning_parser = rolling_batch.get_reasoning_parser()
model = input_map.pop("model", "lmi")
chat_params = ChatCompletionRequest(**input_map, model=model)

Expand Down Expand Up @@ -90,6 +91,7 @@ def parse_chat_completions_request_vllm(
"request_prompts": request_prompt,
"engine_prompt": engine_prompt,
"tool_parser": tool_parser,
"reasoning_parser": reasoning_parser,
"chat_params": chat_params,
}
return input_text, params
Expand Down
39 changes: 37 additions & 2 deletions engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
reasoning_parser = parameters.get("reasoning_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
generated_text = get_generated_text(best_sequence, request_output)
Expand All @@ -301,7 +302,24 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
if chat_params and chat_params.tool_choice and type(

if reasoning_parser:
reasoning_content, content = (
reasoning_parser.extract_reasoning_content(generated_text,
request=chat_params))

if reasoning_content:
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"reasoning_content": reasoning_content,
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
elif chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
Expand Down Expand Up @@ -386,6 +404,7 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):
parameters = request_output.input.parameters
chat_params = parameters.get("chat_params")
tool_parser = parameters.get("tool_parser")
reasoning_parser = parameters.get("reasoning_parser")
best_sequence = request_output.sequences[
request_output.best_sequence_index]
next_token, index, first_token, last_token = best_sequence.get_next_token()
Expand All @@ -396,7 +415,23 @@ def _jsonlines_chat_output_formatter(request_output: TextGenerationOutput):

created = int(time.time())

if chat_params and chat_params.tool_choice and type(
if reasoning_parser:
current_text = get_generated_text(best_sequence, request_output)
previous_text = current_text[0:-len(next_token.text)]
current_token_ids = [t.id for t in best_sequence.tokens]
previous_token_ids = current_token_ids[:-1]
delta = reasoning_parser.extract_reasoning_content_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=next_token.text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=[next_token.id],
)
if delta is None:
return ""
delta = delta.model_dump(exclude_unset=True)
elif chat_params and chat_params.tool_choice and type(
chat_params.tool_choice
).__name__ == "ChatCompletionNamedToolChoiceParam":
tool_calls = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class VllmRbProperties(Properties):
enable_auto_tool_choice: bool = False
tool_call_parser: Optional[str] = None

# Reasoning properties
enable_reasoning: bool = False
reasoning_parser: Optional[str] = None

# Neuron vLLM properties
device: str = 'auto'
preloaded_model: Optional[Any] = None
Expand Down Expand Up @@ -129,6 +133,18 @@ def validate_tool_call_parser(self):
f"(chose from {{ {','.join(valid_tool_parses)} }})")
return self

@model_validator(mode='after')
def validate_reasoning_parser(self):
if self.enable_reasoning:
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys(
)
if self.reasoning_parser not in valid_reasoning_parses:
raise ValueError(
f"Invalid reasoning parser: {self.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
return self

@field_validator('override_neuron_config', mode="before")
def validate_override_neuron_config(cls, val):
if isinstance(val, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def get_tool_parser(self):
"""
return None

def get_reasoning_parser(self):
"""
:return: the reasoning parser if available
"""
return None

@abstractmethod
def inference(self, new_requests: List[Request]) -> List:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.lora_requests = {}
self.is_mistral_tokenizer = args.tokenizer_mode == 'mistral'
self.tool_parser = None
self.reasoning_parser = None
if self.vllm_configs.enable_auto_tool_choice:
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
try:
Expand All @@ -61,6 +62,15 @@ def __init__(self, model_id_or_path: str, properties: dict,
self.engine.tokenizer.tokenizer)
except Exception as e:
raise TypeError("Error in tool parser creation.") from e
if self.vllm_configs.enable_reasoning:
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
try:
self.reasoning_parser = ReasoningParserManager.get_reasoning_parser(
self.vllm_configs.reasoning_parser)
self.reasoning_parser = self.reasoning_parser(
self.engine.tokenizer.tokenizer)
except Exception as e:
raise TypeError("Error in reasoning parser creation.") from e

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
Expand All @@ -77,6 +87,9 @@ def use_vllm_chat_completions(self):
def get_tool_parser(self):
return self.tool_parser

def get_reasoning_parser(self):
return self.reasoning_parser

def get_chat_template(self):
if self.is_mistral_tokenizer:
# Mistral tokenizer chat template cannot be overridden
Expand Down
30 changes: 28 additions & 2 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,12 @@ def get_model_name():
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "TheBloke/Llama-2-7B-Chat-fp16",
}
},
"deepseek-r1-distill-qwen-1-5b": {
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
},
}

vllm_tool_model_spec = {
Expand Down Expand Up @@ -1423,6 +1428,24 @@ def batch_generation_tool(batch_size):
return data[:batch_size]


def batch_generation_reasoning(batch_size):
messages = [
[{
"role": "user",
"content": "9.11 and 9.8, which is greater?"
}],
[{
"role": "user",
"content": "How many Rs are there in the word 'strawberry'?"
}],
]

if batch_size > len(messages):
# dynamically extend to support larger bs by repetition
messages *= math.ceil(batch_size / len(messages))
return messages[:batch_size]


def t5_batch_generation(batch_size):
input_sentences = [
"translate English to German: The house is wonderful.",
Expand Down Expand Up @@ -1667,7 +1690,10 @@ def test_handler_rolling_batch_chat(model, model_spec):
check_worker_number(spec["worker"])
stream_values = spec.get("stream", [False, True])
# dryrun phase
req = {"messages": batch_generation_chat(1)[0]}
if spec.get("enable_reasoning", False):
req = {"messages": batch_generation_reasoning(1)[0]}
else:
req = {"messages": batch_generation_chat(1)[0]}
seq_length = 100
req["max_tokens"] = seq_length
req["logprobs"] = True
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,13 @@
"option.enable_auto_tool_choice": True,
"option.tool_call_parser": "mistral",
},
"deepseek-r1-distill-qwen-1-5b": {
"option.model_id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"option.tensor_parallel_degree": 1,
"option.max_rolling_batch_size": 4,
"option.enable_reasoning": True,
"option.reasoning_parser": "deepseek_r1",
},
}

vllm_neo_model_list = {
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,12 @@ def test_mistral_7b_instruct_v03_tool(self):
r.launch()
client.run("vllm_tool mistral-7b-instruct-v03-tool".split())

def test_deepseek_r1_distill_qwen_1_5b(self):
with Runner('lmi', 'deepseek-r1-distill-qwen-1-5b') as r:
prepare.build_vllm_model("deepseek-r1-distill-qwen-1-5b")
r.launch()
client.run("vllm_chat deepseek-r1-distill-qwen-1-5b".split())


@pytest.mark.vllm
@pytest.mark.lora
Expand Down

0 comments on commit 40588c3

Please sign in to comment.