Skip to content

Commit

Permalink
Merge pull request #147 from Kiln-AI/thinking_wip
Browse files Browse the repository at this point in the history
Adding tests to adapter registry and adapters
  • Loading branch information
scosman authored Feb 1, 2025
2 parents 42d86e2 + a659625 commit 8e69878
Show file tree
Hide file tree
Showing 18 changed files with 651 additions and 80 deletions.
4 changes: 3 additions & 1 deletion app/desktop/studio_server/finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,9 @@ def system_message_from_request(
)
try:
prompt_builder = prompt_builder_from_ui_name(system_message_generator, task)
system_message = prompt_builder.build_prompt()
system_message = prompt_builder.build_prompt(
include_json_instructions=False
)
except Exception as e:
raise HTTPException(
status_code=400,
Expand Down
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Adapters are used to connect Kiln to external systems, or to add new functionality to Kiln.
Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc.
Model adapters are used to call AI models, like Ollama, OpenAI, etc.
The ml_model_list submodule contains a list of models that can be used for machine learning tasks. More can easily be added, but we keep a list here of models that are known to work well with Kiln's structured data and tool calling systems.
Expand Down
9 changes: 7 additions & 2 deletions libs/core/kiln_ai/adapters/adapter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,15 @@ def adapter_for_task(
pass
case ModelProviderName.fireworks_ai:
pass
# These are virtual providers that should have mapped to an actual provider in core_provider
case ModelProviderName.kiln_fine_tune:
pass
raise ValueError(
"Fine tune is not a supported core provider. It should map to an actual provider."
)
case ModelProviderName.kiln_custom_registry:
pass
raise ValueError(
"Custom openai compatible provider is not a supported core provider. It should map to an actual provider."
)
case _:
raise ValueError(f"Unsupported provider: {provider}")
# Triggers typechecking if I miss a case
Expand Down
4 changes: 2 additions & 2 deletions libs/core/kiln_ai/adapters/data_gen/data_gen_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def from_task(
num_subtopics=num_subtopics,
human_guidance=human_guidance,
existing_topics=existing_topics,
system_prompt=prompt_builder.build_prompt(),
system_prompt=prompt_builder.build_prompt(include_json_instructions=False),
)


Expand Down Expand Up @@ -132,7 +132,7 @@ def from_task(
topic=topic,
num_samples=num_samples,
human_guidance=human_guidance,
system_prompt=prompt_builder.build_prompt(),
system_prompt=prompt_builder.build_prompt(include_json_instructions=False),
)


Expand Down
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/model_adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
# Model Adapters
Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc.
Model adapters are used to call AI models, like Ollama, OpenAI, etc.
"""

Expand Down
33 changes: 27 additions & 6 deletions libs/core/kiln_ai/adapters/model_adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Dict
from typing import Dict, Literal, Tuple

from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
Expand Down Expand Up @@ -59,15 +59,15 @@ def __init__(
self.model_provider_name = model_provider_name
self._model_provider: KilnModelProvider | None = None

async def model_provider(self) -> KilnModelProvider:
def model_provider(self) -> KilnModelProvider:
"""
Lazy load the model provider for this adapter.
"""
if self._model_provider is not None:
return self._model_provider
if not self.model_name or not self.model_provider_name:
raise ValueError("model_name and model_provider_name must be provided")
self._model_provider = await kiln_model_provider_from(
self._model_provider = kiln_model_provider_from(
self.model_name, self.model_provider_name
)
if not self._model_provider:
Expand Down Expand Up @@ -102,7 +102,7 @@ async def invoke(
run_output = await self._run(input)

# Parse
provider = await self.model_provider()
provider = self.model_provider()
parser = model_parser_from_id(provider.parser)(
structured_output=self.has_structured_output()
)
Expand Down Expand Up @@ -144,9 +144,9 @@ def adapter_info(self) -> AdapterInfo:
async def _run(self, input: Dict | str) -> RunOutput:
pass

async def build_prompt(self) -> str:
def build_prompt(self) -> str:
# The prompt builder needs to know if we want to inject formatting instructions
provider = await self.model_provider()
provider = self.model_provider()
add_json_instructions = self.has_structured_output() and (
provider.structured_output_mode == StructuredOutputMode.json_instructions
or provider.structured_output_mode
Expand All @@ -157,6 +157,27 @@ async def build_prompt(self) -> str:
include_json_instructions=add_json_instructions
)

def run_strategy(
self,
) -> Tuple[Literal["cot_as_message", "cot_two_call", "basic"], str | None]:
# Determine the run strategy for COT prompting. 3 options:
# 1. "Thinking" LLM designed to output thinking in a structured format plus a COT prompt: we make 1 call to the LLM, which outputs thinking in a structured format. We include the thinking instuctions as a message.
# 2. Normal LLM with COT prompt: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call. It also separates the thinking from the final response.
# 3. Non chain of thought: we make 1 call to the LLM, with no COT prompt.
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
reasoning_capable = self.model_provider().reasoning_capable

if cot_prompt and reasoning_capable:
# 1: "Thinking" LLM designed to output thinking in a structured format
# A simple message with the COT prompt appended to the message list is sufficient
return "cot_as_message", cot_prompt
elif cot_prompt:
# 2: Unstructured output with COT
# Two calls to separate the thinking from the final response
return "cot_two_call", cot_prompt
else:
return "basic", None

# create a run and task output
def generate_run(
self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
Expand Down
40 changes: 19 additions & 21 deletions libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def model(self) -> LangChainModelType:
# Decide if we want to use Langchain's structured output:
# 1. Only for structured tasks
# 2. Only if the provider's mode isn't json_instructions (only mode that doesn't use an API option for structured output capabilities)
provider = await self.model_provider()
provider = self.model_provider()
use_lc_structured_output = (
self.has_structured_output()
and provider.structured_output_mode
Expand All @@ -116,7 +116,7 @@ async def model(self) -> LangChainModelType:
)
output_schema["title"] = "task_response"
output_schema["description"] = "A response from the task"
with_structured_output_options = await self.get_structured_output_options(
with_structured_output_options = self.get_structured_output_options(
self.model_name, self.model_provider_name
)
self._model = self._model.with_structured_output(
Expand All @@ -127,36 +127,34 @@ async def model(self) -> LangChainModelType:
return self._model

async def _run(self, input: Dict | str) -> RunOutput:
provider = await self.model_provider()
provider = self.model_provider()
model = await self.model()
chain = model
intermediate_outputs = {}

prompt = await self.build_prompt()
prompt = self.build_prompt()
user_msg = self.prompt_builder.build_user_message(input)
messages = [
SystemMessage(content=prompt),
HumanMessage(content=user_msg),
]

# Handle chain of thought if enabled. 3 Modes:
# 1. Unstructured output: just call the LLM, with prompting for thinking
# 2. "Thinking" LLM designed to output thinking in a structured format: we make 1 call to the LLM, which outputs thinking in a structured format.
# 3. Normal LLM with structured output: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call.
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
thinking_llm = provider.reasoning_capable

if cot_prompt and (not self.has_structured_output() or thinking_llm):
# Case 1 or 2: Unstructured output, or "Thinking" LLM designed to output thinking in a structured format
messages.append({"role": "system", "content": cot_prompt})
elif not thinking_llm and cot_prompt and self.has_structured_output():
# Case 3: Normal LLM with structured output
# Base model (without structured output) used for COT message
base_model = await self.langchain_model_from()
run_strategy, cot_prompt = self.run_strategy()

if run_strategy == "cot_as_message":
if not cot_prompt:
raise ValueError("cot_prompt is required for cot_as_message strategy")
messages.append(SystemMessage(content=cot_prompt))
elif run_strategy == "cot_two_call":
if not cot_prompt:
raise ValueError("cot_prompt is required for cot_two_call strategy")
messages.append(
SystemMessage(content=cot_prompt),
)

# Base model (without structured output) used for COT message
base_model = await self.langchain_model_from()

cot_messages = [*messages]
cot_response = await base_model.ainvoke(cot_messages)
intermediate_outputs["chain_of_thought"] = cot_response.content
Expand Down Expand Up @@ -212,10 +210,10 @@ def _munge_response(self, response: Dict) -> Dict:
return response["arguments"]
return response

async def get_structured_output_options(
def get_structured_output_options(
self, model_name: str, model_provider_name: str
) -> Dict[str, Any]:
provider = await self.model_provider()
provider = self.model_provider()
if not provider:
return {}

Expand Down Expand Up @@ -244,7 +242,7 @@ async def get_structured_output_options(
return options

async def langchain_model_from(self) -> BaseChatModel:
provider = await self.model_provider()
provider = self.model_provider()
return await langchain_model_from_provider(provider, self.model_name)


Expand Down
31 changes: 15 additions & 16 deletions libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,26 @@ def __init__(
)

async def _run(self, input: Dict | str) -> RunOutput:
provider = await self.model_provider()
provider = self.model_provider()
intermediate_outputs: dict[str, str] = {}
prompt = await self.build_prompt()
prompt = self.build_prompt()
user_msg = self.prompt_builder.build_user_message(input)
messages = [
ChatCompletionSystemMessageParam(role="system", content=prompt),
ChatCompletionUserMessageParam(role="user", content=user_msg),
]

# Handle chain of thought if enabled. 3 Modes:
# 1. Unstructured output: just call the LLM, with prompting for thinking
# 2. "Thinking" LLM designed to output thinking in a structured format: we make 1 call to the LLM, which outputs thinking in a structured format.
# 3. Normal LLM with structured output: we make 2 calls to the LLM - one for thinking and one for the final response. This helps us use the LLM's structured output modes (json_schema, tools, etc), which can't be used in a single call.
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
thinking_llm = provider.reasoning_capable

if cot_prompt and (not self.has_structured_output() or thinking_llm):
# Case 1 or 2: Unstructured output or "Thinking" LLM designed to output thinking in a structured format
messages.append({"role": "system", "content": cot_prompt})
elif not thinking_llm and cot_prompt and self.has_structured_output():
# Case 3: Normal LLM with structured output, requires 2 calls
run_strategy, cot_prompt = self.run_strategy()

if run_strategy == "cot_as_message":
if not cot_prompt:
raise ValueError("cot_prompt is required for cot_as_message strategy")
messages.append(
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
)
elif run_strategy == "cot_two_call":
if not cot_prompt:
raise ValueError("cot_prompt is required for cot_two_call strategy")
messages.append(
ChatCompletionSystemMessageParam(role="system", content=cot_prompt)
)
Expand All @@ -93,7 +92,7 @@ async def _run(self, input: Dict | str) -> RunOutput:
)

extra_body = {}
if self.config.openrouter_style_reasoning and thinking_llm:
if self.config.openrouter_style_reasoning and provider.reasoning_capable:
extra_body["include_reasoning"] = True
# Filter to providers that support the reasoning parameter
extra_body["provider"] = {"require_parameters": True}
Expand Down Expand Up @@ -176,7 +175,7 @@ async def response_format_options(self) -> dict[str, Any]:
if not self.has_structured_output():
return {}

provider = await self.model_provider()
provider = self.model_provider()
match provider.structured_output_mode:
case StructuredOutputMode.json_mode:
return {"response_format": {"type": "json_object"}}
Expand Down
Loading

0 comments on commit 8e69878

Please sign in to comment.