Skip to content

Commit

Permalink
Merge pull request #141 from Kiln-AI/thinking_wip
Browse files Browse the repository at this point in the history
Thinking+Chain of thought
  • Loading branch information
scosman authored Jan 31, 2025
2 parents a21273d + 8b6e9ed commit f9c319f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
5 changes: 5 additions & 0 deletions libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class KilnModelProvider(BaseModel):
provider_options: Additional provider-specific configuration options
structured_output_mode: The mode we should use to call the model for structured output, if it was trained with structured output.
parser: A parser to use for the model, if applicable
reasoning_capable: Whether the model is designed to output thinking in a structured format (eg <think></think>). If so we don't use COT across 2 calls, and ask for thinking and final response in the same call.
"""

name: ModelProviderName
Expand All @@ -113,6 +114,7 @@ class KilnModelProvider(BaseModel):
provider_options: Dict = {}
structured_output_mode: StructuredOutputMode = StructuredOutputMode.default
parser: ModelParserID | None = None
reasoning_capable: bool = False


class KilnModel(BaseModel):
Expand Down Expand Up @@ -222,19 +224,22 @@ class KilnModel(BaseModel):
provider_options={"model": "deepseek/deepseek-r1"},
# No custom parser -- openrouter implemented it themselves
structured_output_mode=StructuredOutputMode.json_instructions,
reasoning_capable=True,
),
KilnModelProvider(
name=ModelProviderName.fireworks_ai,
provider_options={"model": "accounts/fireworks/models/deepseek-r1"},
parser=ModelParserID.r1_thinking,
structured_output_mode=StructuredOutputMode.json_instructions,
reasoning_capable=True,
),
KilnModelProvider(
# I want your RAM
name=ModelProviderName.ollama,
provider_options={"model": "deepseek-r1:671b"},
parser=ModelParserID.r1_thinking,
structured_output_mode=StructuredOutputMode.json_instructions,
reasoning_capable=True,
),
],
),
Expand Down
17 changes: 12 additions & 5 deletions libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def model(self) -> LangChainModelType:
return self._model

async def _run(self, input: Dict | str) -> RunOutput:
provider = await self.model_provider()
model = await self.model()
chain = model
intermediate_outputs = {}
Expand All @@ -139,10 +140,18 @@ async def _run(self, input: Dict | str) -> RunOutput:
HumanMessage(content=user_msg),
]

# TODO: make this compatible with thinking models
# COT with structured output
# Handle chain of thought if enabled. 3 Modes:
# 1. Unstructured output: just call the LLM, with prompting for thinking
# 2. "Thinking" LLM designed to output thinking in a structured format: we make 1 call to the LLM, which outputs thinking in a structured format.
# 3. Normal LLM with structured output: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call.
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
if cot_prompt and self.has_structured_output():
thinking_llm = provider.reasoning_capable

if cot_prompt and (not self.has_structured_output() or thinking_llm):
# Case 1 or 2: Unstructured output, or "Thinking" LLM designed to output thinking in a structured format
messages.append({"role": "system", "content": cot_prompt})
elif not thinking_llm and cot_prompt and self.has_structured_output():
# Case 3: Normal LLM with structured output
# Base model (without structured output) used for COT message
base_model = await self.langchain_model_from()
messages.append(
Expand All @@ -156,8 +165,6 @@ async def _run(self, input: Dict | str) -> RunOutput:
messages.append(
SystemMessage(content="Considering the above, return a final result.")
)
elif cot_prompt:
messages.append(SystemMessage(content=cot_prompt))

response = await chain.ainvoke(messages)

Expand Down
34 changes: 21 additions & 13 deletions libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,33 @@ def __init__(

async def _run(self, input: Dict | str) -> RunOutput:
provider = await self.model_provider()

intermediate_outputs: dict[str, str] = {}

prompt = await self.build_prompt()
user_msg = self.prompt_builder.build_user_message(input)
messages = [
ChatCompletionSystemMessageParam(role="system", content=prompt),
ChatCompletionUserMessageParam(role="user", content=user_msg),
]

# Handle chain of thought if enabled
# Handle chain of thought if enabled. 3 Modes:
# 1. Unstructured output: just call the LLM, with prompting for thinking
# 2. "Thinking" LLM designed to output thinking in a structured format: we make 1 call to the LLM, which outputs thinking in a structured format.
# 3. Normal LLM with structured output: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call.
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
if cot_prompt and self.has_structured_output():
# TODO P0: Fix COT
thinking_llm = provider.reasoning_capable

if cot_prompt and (not self.has_structured_output() or thinking_llm):
# Case 1 or 2: Unstructured output or "Thinking" LLM designed to output thinking in a structured format
messages.append({"role": "system", "content": cot_prompt})
elif not thinking_llm and cot_prompt and self.has_structured_output():
# Case 3: Normal LLM with structured output, requires 2 calls
messages.append(
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
)

# First call for chain of thought
cot_response = await self.client.chat.completions.create(
model=self.model_name,
model=provider.provider_options["model"],
messages=messages,
)
cot_content = cot_response.choices[0].message.content
Expand All @@ -91,19 +99,19 @@ async def _run(self, input: Dict | str) -> RunOutput:
),
]
)
elif cot_prompt:
messages.append({"role": "system", "content": cot_prompt})
else:
intermediate_outputs = {}

extra_body = {}
if self.config.openrouter_style_reasoning and thinking_llm:
extra_body["include_reasoning"] = True
# Filter to providers that support the reasoning parameter
extra_body["provider"] = {"require_parameters": True}

# Main completion call
response_format_options = await self.response_format_options()
response = await self.client.chat.completions.create(
model=provider.provider_options["model"],
messages=messages,
extra_body={"include_reasoning": True}
if self.config.openrouter_style_reasoning
else {},
extra_body=extra_body,
**response_format_options,
)

Expand Down
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/parsers/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def parse_json_string(json_string: str) -> Dict[str, Any]:
try:
return json.loads(cleaned_string)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON: {str(e)}") from e
raise ValueError(f"Failed to parse JSON: {str(e)}\n\n{cleaned_string}") from e

0 comments on commit f9c319f

Please sign in to comment.