Skip to content

Commit

Permalink
Move model adapters into their own folder
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Jan 31, 2025
1 parent eb3b75d commit f38450e
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 50 deletions.
14 changes: 7 additions & 7 deletions libs/core/kiln_ai/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@
Adapters are used to connect Kiln to external systems, or to add new functionality to Kiln.
BaseAdapter is extensible, and used for adding adapters that provide AI functionality. There's currently a LangChain adapter which provides a bridge to LangChain.
Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc.
The ml_model_list submodule contains a list of models that can be used for machine learning tasks. More can easily be added, but we keep a list here of models that are known to work well with Kiln's structured data and tool calling systems.
The prompt_builders submodule contains classes that build prompts for use with the AI agents.
The repair submodule contains an adapter for the repair task.
The parser submodule contains parsers for the output of the AI models.
"""

from . import (
base_adapter,
data_gen,
fine_tune,
langchain_adapters,
ml_model_list,
model_adapters,
prompt_builders,
repair,
)

__all__ = [
"base_adapter",
"langchain_adapters",
"model_adapters",
"data_gen",
"fine_tune",
"ml_model_list",
"prompt_builders",
"repair",
"data_gen",
"fine_tune",
]
6 changes: 3 additions & 3 deletions libs/core/kiln_ai/adapters/adapter_registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from os import getenv

from kiln_ai import datamodel
from kiln_ai.adapters.base_adapter import BaseAdapter
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.ml_model_list import ModelProviderName
from kiln_ai.adapters.model_adapters.open_ai_model_adapter import (
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.model_adapters.openai_model_adapter import (
OpenAICompatibleAdapter,
)
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
Expand Down
18 changes: 18 additions & 0 deletions libs/core/kiln_ai/adapters/model_adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
# Model Adapters
Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc.
"""

from . import (
base_adapter,
langchain_adapters,
openai_model_adapter,
)

__all__ = [
"base_adapter",
"langchain_adapters",
"openai_model_adapter",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from dataclasses import dataclass
from typing import Dict

from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
from kiln_ai.adapters.prompt_builders import BasePromptBuilder, SimplePromptBuilder
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
from kiln_ai.adapters.run_output import RunOutput
from kiln_ai.datamodel import (
Expand All @@ -15,10 +18,6 @@
from kiln_ai.datamodel.json_schema import validate_schema
from kiln_ai.utils.config import Config

from .ml_model_list import KilnModelProvider, StructuredOutputMode
from .parsers.parser_registry import model_parser_from_id
from .prompt_builders import BasePromptBuilder, SimplePromptBuilder


@dataclass
class AdapterInfo:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,25 @@
from pydantic import BaseModel

import kiln_ai.datamodel as datamodel
from kiln_ai.adapters.ml_model_list import (
KilnModelProvider,
ModelProviderName,
StructuredOutputMode,
)
from kiln_ai.adapters.model_adapters.base_adapter import (
AdapterInfo,
BaseAdapter,
BasePromptBuilder,
RunOutput,
)
from kiln_ai.adapters.ollama_tools import (
get_ollama_connection,
ollama_base_url,
ollama_model_installed,
)
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
from kiln_ai.utils.config import Config

from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
from .ml_model_list import (
KilnModelProvider,
ModelProviderName,
StructuredOutputMode,
)
from .provider_tools import kiln_model_provider_from

LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Any, Dict, NoReturn

import kiln_ai.datamodel as datamodel
from kiln_ai.adapters.base_adapter import (
from kiln_ai.adapters.ml_model_list import StructuredOutputMode
from kiln_ai.adapters.model_adapters.base_adapter import (
AdapterInfo,
BaseAdapter,
BasePromptBuilder,
RunOutput,
)
from kiln_ai.adapters.ml_model_list import StructuredOutputMode
from kiln_ai.adapters.parsers.json_parser import parse_json_string
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI

from kiln_ai.adapters.langchain_adapters import (
LangchainAdapter,
get_structured_output_options,
langchain_model_from_provider,
)
from kiln_ai.adapters.ml_model_list import (
KilnModelProvider,
ModelProviderName,
StructuredOutputMode,
)
from kiln_ai.adapters.model_adapters.langchain_adapters import (
LangchainAdapter,
get_structured_output_options,
langchain_model_from_provider,
)
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
from kiln_ai.adapters.test_prompt_adaptors import build_test_task

Expand Down Expand Up @@ -94,7 +94,8 @@ async def test_langchain_adapter_with_cot(tmp_path):
# Patch both the langchain_model_from function and self.model()
with (
patch(
"kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
"kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from",
mock_model_from,
),
patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
):
Expand Down Expand Up @@ -151,7 +152,7 @@ async def test_get_structured_output_options(structured_output_mode, expected_me

# Test with provider that has options
with patch(
"kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
"kiln_ai.adapters.model_adapters.langchain_adapters.kiln_model_provider_from",
AsyncMock(return_value=mock_provider),
):
options = await get_structured_output_options("model_name", "provider")
Expand All @@ -164,7 +165,9 @@ async def test_langchain_model_from_provider_openai():
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
with patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
) as mock_config:
mock_config.return_value.open_ai_api_key = "test_key"
model = await langchain_model_from_provider(provider, "gpt-4")
assert isinstance(model, ChatOpenAI)
Expand All @@ -177,7 +180,9 @@ async def test_langchain_model_from_provider_groq():
name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
with patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
) as mock_config:
mock_config.return_value.groq_api_key = "test_key"
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
assert isinstance(model, ChatGroq)
Expand All @@ -191,7 +196,9 @@ async def test_langchain_model_from_provider_bedrock():
provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
with patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
) as mock_config:
mock_config.return_value.bedrock_access_key = "test_access"
mock_config.return_value.bedrock_secret_key = "test_secret"
model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
Expand All @@ -206,7 +213,9 @@ async def test_langchain_model_from_provider_fireworks():
name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
)

with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
with patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
) as mock_config:
mock_config.return_value.fireworks_api_key = "test_key"
model = await langchain_model_from_provider(provider, "mixtral-8x7b")
assert isinstance(model, ChatFireworks)
Expand All @@ -222,15 +231,15 @@ async def test_langchain_model_from_provider_ollama():
mock_connection = MagicMock()
with (
patch(
"kiln_ai.adapters.langchain_adapters.get_ollama_connection",
"kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection",
return_value=AsyncMock(return_value=mock_connection),
),
patch(
"kiln_ai.adapters.langchain_adapters.ollama_model_installed",
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed",
return_value=True,
),
patch(
"kiln_ai.adapters.langchain_adapters.ollama_base_url",
"kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url",
return_value="http://localhost:11434",
),
):
Expand Down Expand Up @@ -281,16 +290,16 @@ async def test_langchain_adapter_model_structured_output(tmp_path):
mock_model.with_structured_output = MagicMock(return_value="structured_model")

adapter = LangchainAdapter(
kiln_task=task, model_name="test_model", provider="test_provider"
kiln_task=task, model_name="test_model", provider="ollama"
)

with (
patch(
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
"kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from",
AsyncMock(return_value=mock_model),
),
patch(
"kiln_ai.adapters.langchain_adapters.get_structured_output_options",
"kiln_ai.adapters.model_adapters.langchain_adapters.get_structured_output_options",
AsyncMock(return_value={"option1": "value1"}),
),
):
Expand Down Expand Up @@ -322,11 +331,11 @@ async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
del mock_model.with_structured_output

adapter = LangchainAdapter(
kiln_task=task, model_name="test_model", provider="test_provider"
kiln_task=task, model_name="test_model", provider="ollama"
)

with patch(
"kiln_ai.adapters.langchain_adapters.langchain_model_from",
"kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from",
AsyncMock(return_value=mock_model),
):
with pytest.raises(ValueError, match="does not support structured output"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import pytest

from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
from kiln_ai.adapters.model_adapters.base_adapter import (
AdapterInfo,
BaseAdapter,
RunOutput,
)
from kiln_ai.datamodel import (
DataSource,
DataSourceType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

import kiln_ai.datamodel as datamodel
from kiln_ai.adapters.adapter_registry import adapter_for_task
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput
from kiln_ai.adapters.ml_model_list import (
built_in_models,
)
from kiln_ai.adapters.model_adapters.base_adapter import (
AdapterInfo,
BaseAdapter,
RunOutput,
)
from kiln_ai.adapters.ollama_tools import ollama_online
from kiln_ai.adapters.prompt_builders import (
BasePromptBuilder,
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions libs/core/kiln_ai/adapters/repair/test_repair_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pydantic import ValidationError

from kiln_ai.adapters.adapter_registry import adapter_for_task
from kiln_ai.adapters.base_adapter import RunOutput
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.model_adapters.base_adapter import RunOutput
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.repair.repair_task import (
RepairTaskInput,
RepairTaskRun,
Expand Down
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/test_prompt_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import kiln_ai.datamodel as datamodel
from kiln_ai.adapters.adapter_registry import adapter_for_task
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.ml_model_list import built_in_models
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.ollama_tools import ollama_online
from kiln_ai.adapters.prompt_builders import (
BasePromptBuilder,
Expand Down
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/test_prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
from kiln_ai.adapters.prompt_builders import (
FewShotChainOfThoughtPromptBuilder,
FewShotPromptBuilder,
Expand Down
2 changes: 1 addition & 1 deletion libs/server/kiln_server/test_run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
from kiln_ai.datamodel import (
DataSource,
DataSourceType,
Expand Down

0 comments on commit f38450e

Please sign in to comment.