Skip to content

Commit

Permalink
Thinking models Part 1: Parsers and OpenAI adapter (#137)
Browse files Browse the repository at this point in the history
Lots here: 

 - OpenAI adapter for better control.
 - R1 support that really works. Saves reasoning into intermediate outputs.
 - Parser infra and R1 parser

Still need some tests, but working.
  • Loading branch information
scosman authored Jan 31, 2025
1 parent 0327032 commit 5d00cf5
Show file tree
Hide file tree
Showing 25 changed files with 984 additions and 221 deletions.
2 changes: 1 addition & 1 deletion app/desktop/studio_server/test_prompt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MockPromptBuilder(BasePromptBuilder):
def prompt_builder_name(cls):
return "MockPromptBuilder"

def build_prompt(self):
def build_base_prompt(self):
return "Mock prompt"

def build_prompt_for_ui(self):
Expand Down
14 changes: 7 additions & 7 deletions libs/core/kiln_ai/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@
Adapters are used to connect Kiln to external systems, or to add new functionality to Kiln.
BaseAdapter is extensible, and used for adding adapters that provide AI functionality. There's currently a LangChain adapter which provides a bridge to LangChain.
Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc.
The ml_model_list submodule contains a list of models that can be used for machine learning tasks. More can easily be added, but we keep a list here of models that are known to work well with Kiln's structured data and tool calling systems.
The prompt_builders submodule contains classes that build prompts for use with the AI agents.
The repair submodule contains an adapter for the repair task.
The parser submodule contains parsers for the output of the AI models.
"""

from . import (
base_adapter,
data_gen,
fine_tune,
langchain_adapters,
ml_model_list,
model_adapters,
prompt_builders,
repair,
)

__all__ = [
"base_adapter",
"langchain_adapters",
"model_adapters",
"data_gen",
"fine_tune",
"ml_model_list",
"prompt_builders",
"repair",
"data_gen",
"fine_tune",
]
35 changes: 31 additions & 4 deletions libs/core/kiln_ai/adapters/adapter_registry.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
from os import getenv

from kiln_ai import datamodel
from kiln_ai.adapters.base_adapter import BaseAdapter
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.ml_model_list import ModelProviderName
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.model_adapters.openai_model_adapter import (
OpenAICompatibleAdapter,
OpenAICompatibleConfig,
)
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
from kiln_ai.utils.config import Config


def adapter_for_task(
kiln_task: datamodel.Task,
model_name: str | None = None,
model_name: str,
provider: str | None = None,
prompt_builder: BasePromptBuilder | None = None,
tags: list[str] | None = None,
) -> BaseAdapter:
# We use langchain for everything right now, but can add any others here
if provider == ModelProviderName.openrouter:
return OpenAICompatibleAdapter(
kiln_task=kiln_task,
config=OpenAICompatibleConfig(
base_url=getenv("OPENROUTER_BASE_URL")
or "https://openrouter.ai/api/v1",
api_key=Config.shared().open_router_api_key,
model_name=model_name,
provider_name=provider,
openrouter_style_reasoning=True,
default_headers={
"HTTP-Referer": "https://getkiln.ai/openrouter",
"X-Title": "KilnAI",
},
),
prompt_builder=prompt_builder,
tags=tags,
)

# We use langchain for all others right now, but moving off it as we touch anything.
return LangchainAdapter(
kiln_task,
model_name=model_name,
Expand Down
47 changes: 32 additions & 15 deletions libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ class ModelName(str, Enum):
deepseek_r1 = "deepseek_r1"


class ModelParserID(str, Enum):
"""
Enumeration of supported model parsers.
"""

r1_thinking = "r1_thinking"


class KilnModelProvider(BaseModel):
"""
Configuration for a specific model provider.
Expand All @@ -94,6 +102,7 @@ class KilnModelProvider(BaseModel):
provider_finetune_id: The finetune ID for the provider, if applicable
provider_options: Additional provider-specific configuration options
structured_output_mode: The mode we should use to call the model for structured output, if it was trained with structured output.
parser: A parser to use for the model, if applicable
"""

name: ModelProviderName
Expand All @@ -103,6 +112,7 @@ class KilnModelProvider(BaseModel):
provider_finetune_id: str | None = None
provider_options: Dict = {}
structured_output_mode: StructuredOutputMode = StructuredOutputMode.default
parser: ModelParserID | None = None


class KilnModel(BaseModel):
Expand Down Expand Up @@ -170,6 +180,7 @@ class KilnModel(BaseModel):
providers=[
KilnModelProvider(
name=ModelProviderName.openrouter,
structured_output_mode=StructuredOutputMode.function_calling,
provider_options={"model": "anthropic/claude-3-5-haiku"},
),
],
Expand All @@ -182,6 +193,7 @@ class KilnModel(BaseModel):
providers=[
KilnModelProvider(
name=ModelProviderName.openrouter,
structured_output_mode=StructuredOutputMode.function_calling,
provider_options={"model": "anthropic/claude-3.5-sonnet"},
),
],
Expand All @@ -195,6 +207,7 @@ class KilnModel(BaseModel):
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "deepseek/deepseek-chat"},
structured_output_mode=StructuredOutputMode.function_calling,
),
],
),
Expand All @@ -207,21 +220,21 @@ class KilnModel(BaseModel):
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "deepseek/deepseek-r1"},
structured_output_mode=StructuredOutputMode.json_schema,
# No custom parser -- openrouter implemented it themselves
structured_output_mode=StructuredOutputMode.json_instructions,
),
KilnModelProvider(
name=ModelProviderName.fireworks_ai,
provider_options={"model": "accounts/fireworks/models/deepseek-r1"},
# Truncates the thinking, but json_mode works
structured_output_mode=StructuredOutputMode.json_mode,
supports_structured_output=False,
supports_data_gen=False,
parser=ModelParserID.r1_thinking,
structured_output_mode=StructuredOutputMode.json_instructions,
),
KilnModelProvider(
# I want your RAM
name=ModelProviderName.ollama,
provider_options={"model": "deepseek-r1:671b"},
structured_output_mode=StructuredOutputMode.json_schema,
parser=ModelParserID.r1_thinking,
structured_output_mode=StructuredOutputMode.json_instructions,
),
],
),
Expand Down Expand Up @@ -306,7 +319,7 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.openrouter,
structured_output_mode=StructuredOutputMode.json_schema,
structured_output_mode=StructuredOutputMode.function_calling,
provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
),
KilnModelProvider(
Expand Down Expand Up @@ -340,7 +353,7 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.openrouter,
structured_output_mode=StructuredOutputMode.json_schema,
structured_output_mode=StructuredOutputMode.function_calling,
provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
),
KilnModelProvider(
Expand Down Expand Up @@ -381,7 +394,7 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.openrouter,
structured_output_mode=StructuredOutputMode.json_schema,
structured_output_mode=StructuredOutputMode.function_calling,
provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
),
KilnModelProvider(
Expand All @@ -403,6 +416,7 @@ class KilnModel(BaseModel):
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "mistralai/mistral-nemo"},
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
),
],
),
Expand Down Expand Up @@ -442,6 +456,7 @@ class KilnModel(BaseModel):
name=ModelProviderName.openrouter,
supports_structured_output=False,
supports_data_gen=False,
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
provider_options={"model": "meta-llama/llama-3.2-1b-instruct"},
),
KilnModelProvider(
Expand All @@ -462,6 +477,7 @@ class KilnModel(BaseModel):
name=ModelProviderName.openrouter,
supports_structured_output=False,
supports_data_gen=False,
structured_output_mode=StructuredOutputMode.json_schema,
provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
),
KilnModelProvider(
Expand Down Expand Up @@ -587,6 +603,7 @@ class KilnModel(BaseModel):
supports_structured_output=False,
supports_data_gen=False,
provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
structured_output_mode=StructuredOutputMode.json_schema,
),
KilnModelProvider(
name=ModelProviderName.fireworks_ai,
Expand All @@ -612,8 +629,7 @@ class KilnModel(BaseModel):
KilnModelProvider(
name=ModelProviderName.openrouter,
# JSON mode not consistent enough to enable in UI
structured_output_mode=StructuredOutputMode.json_mode,
supports_structured_output=False,
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
supports_data_gen=False,
provider_options={"model": "microsoft/phi-4"},
),
Expand Down Expand Up @@ -651,7 +667,7 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.openrouter,
structured_output_mode=StructuredOutputMode.json_schema,
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
supports_data_gen=False,
provider_options={"model": "google/gemma-2-9b-it"},
),
Expand All @@ -673,6 +689,7 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.openrouter,
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
supports_data_gen=False,
provider_options={"model": "google/gemma-2-27b-it"},
),
Expand All @@ -687,7 +704,7 @@ class KilnModel(BaseModel):
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
structured_output_mode=StructuredOutputMode.json_schema,
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
),
KilnModelProvider(
name=ModelProviderName.ollama,
Expand All @@ -705,7 +722,7 @@ class KilnModel(BaseModel):
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "qwen/qwen-2.5-7b-instruct"},
structured_output_mode=StructuredOutputMode.json_schema,
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
),
KilnModelProvider(
name=ModelProviderName.ollama,
Expand All @@ -726,7 +743,7 @@ class KilnModel(BaseModel):
# Not consistent with structure data. Works sometimes but not often
supports_structured_output=False,
supports_data_gen=False,
structured_output_mode=StructuredOutputMode.json_schema,
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
),
KilnModelProvider(
name=ModelProviderName.ollama,
Expand Down
18 changes: 18 additions & 0 deletions libs/core/kiln_ai/adapters/model_adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
# Model Adapters
Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc.
"""

from . import (
base_adapter,
langchain_adapters,
openai_model_adapter,
)

__all__ = [
"base_adapter",
"langchain_adapters",
"openai_model_adapter",
]
Loading

0 comments on commit 5d00cf5

Please sign in to comment.