Skip to content

Commit

Permalink
[python] Update reasoning integration test (#2725)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Feb 6, 2025
1 parent 40588c3 commit 560f223
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/output_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput):
"index": 0,
"message": {
"role": "assistant",
"reasoning_content": reasoning_content,
"content": content,
},
"reasoning_content": reasoning_content,
"logprobs": None,
"finish_reason": best_sequence.finish_reason,
}
Expand Down
11 changes: 5 additions & 6 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def get_model_name():
"deepseek-r1-distill-qwen-1-5b": {
"batch_size": [1, 4],
"seq_length": [256],
"enable_reasoning": True,
"tokenizer": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
},
}
Expand Down Expand Up @@ -1587,7 +1588,7 @@ def test_handler_rolling_batch(model, model_spec):
stream_values = spec.get("stream", [False, True])
# dryrun phase
req = {"inputs": batch_generation(1)[0]}
seq_length = 100
seq_length = spec["seq_length"][0]
params = {"do_sample": True, "max_new_tokens": seq_length, "details": True}
req["parameters"] = params
if "parameters" in spec:
Expand Down Expand Up @@ -1626,7 +1627,7 @@ def test_handler_adapters(model, model_spec):
inputs = batch_generation(len(spec.get("adapters")))
for i, adapter in enumerate(spec.get("adapters")):
req = {"inputs": inputs[i]}
seq_length = 100
seq_length = spec["seq_length"][0]
params = {
"do_sample": True,
"max_new_tokens": seq_length,
Expand Down Expand Up @@ -1694,8 +1695,7 @@ def test_handler_rolling_batch_chat(model, model_spec):
req = {"messages": batch_generation_reasoning(1)[0]}
else:
req = {"messages": batch_generation_chat(1)[0]}
seq_length = 100
req["max_tokens"] = seq_length
req["max_tokens"] = spec["seq_length"][0]
req["logprobs"] = True
req["top_logprobs"] = 1
if "adapters" in spec:
Expand Down Expand Up @@ -1724,8 +1724,7 @@ def test_handler_rolling_batch_tool(model, model_spec):
stream_values = spec.get("stream", [False, True])
# dryrun phase
req = batch_generation_tool(1)[0]
seq_length = 100
req["max_tokens"] = seq_length
req["max_tokens"] = spec["seq_length"][0]
req["logprobs"] = True
req["top_logprobs"] = 1
if "adapters" in spec:
Expand Down

0 comments on commit 560f223

Please sign in to comment.