From 55af1f378f5c07c7b0ab64ab85b06113d635f78f Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Wed, 5 Feb 2025 17:22:31 -0800 Subject: [PATCH] [python] Update reasoning integration test --- engines/python/setup/djl_python/output_formatter.py | 2 +- tests/integration/llm/client.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/engines/python/setup/djl_python/output_formatter.py b/engines/python/setup/djl_python/output_formatter.py index c86709559..e0cb76a84 100644 --- a/engines/python/setup/djl_python/output_formatter.py +++ b/engines/python/setup/djl_python/output_formatter.py @@ -313,9 +313,9 @@ def _json_chat_output_formatter(request_output: TextGenerationOutput): "index": 0, "message": { "role": "assistant", + "reasoning_content": reasoning_content, "content": content, }, - "reasoning_content": reasoning_content, "logprobs": None, "finish_reason": best_sequence.finish_reason, } diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index d1a6f7eb2..d64a4b6cd 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -610,6 +610,7 @@ def get_model_name(): "deepseek-r1-distill-qwen-1-5b": { "batch_size": [1, 4], "seq_length": [256], + "enable_reasoning": True, "tokenizer": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" }, } @@ -1587,7 +1588,7 @@ def test_handler_rolling_batch(model, model_spec): stream_values = spec.get("stream", [False, True]) # dryrun phase req = {"inputs": batch_generation(1)[0]} - seq_length = 100 + seq_length = spec["seq_length"][0] params = {"do_sample": True, "max_new_tokens": seq_length, "details": True} req["parameters"] = params if "parameters" in spec: @@ -1626,7 +1627,7 @@ def test_handler_adapters(model, model_spec): inputs = batch_generation(len(spec.get("adapters"))) for i, adapter in enumerate(spec.get("adapters")): req = {"inputs": inputs[i]} - seq_length = 100 + seq_length = spec["seq_length"][0] params = { "do_sample": True, "max_new_tokens": seq_length, @@ -1694,8 +1695,7 @@ def test_handler_rolling_batch_chat(model, model_spec): req = {"messages": batch_generation_reasoning(1)[0]} else: req = {"messages": batch_generation_chat(1)[0]} - seq_length = 100 - req["max_tokens"] = seq_length + req["max_tokens"] = spec["seq_length"][0] req["logprobs"] = True req["top_logprobs"] = 1 if "adapters" in spec: @@ -1724,8 +1724,7 @@ def test_handler_rolling_batch_tool(model, model_spec): stream_values = spec.get("stream", [False, True]) # dryrun phase req = batch_generation_tool(1)[0] - seq_length = 100 - req["max_tokens"] = seq_length + req["max_tokens"] = spec["seq_length"][0] req["logprobs"] = True req["top_logprobs"] = 1 if "adapters" in spec: