Skip to content

Commit

Permalink
Make custom models and fine-tunes work in our new adapter regsitry. N…
Browse files Browse the repository at this point in the history
…ow they use the correct adapter (OpenAI for OpenAI fine-tunes or custom models, etc)

Improve fine-tune cache.
  • Loading branch information
scosman committed Feb 1, 2025
1 parent 8b6e9ed commit 55c6c2f
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 27 deletions.
6 changes: 5 additions & 1 deletion libs/core/kiln_ai/adapters/adapter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
OpenAICompatibleConfig,
)
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
from kiln_ai.adapters.provider_tools import core_provider
from kiln_ai.utils.config import Config


Expand All @@ -20,7 +21,10 @@ def adapter_for_task(
prompt_builder: BasePromptBuilder | None = None,
tags: list[str] | None = None,
) -> BaseAdapter:
match provider:
# Get the provider to run. For things like the fine-tune provider, we want to run the underlying provider
core_provider_name = core_provider(model_name, provider)

match core_provider_name:
case ModelProviderName.openrouter:
return OpenAICompatibleAdapter(
kiln_task=kiln_task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,21 @@ async def response_format_options(self) -> dict[str, Any]:
return NoReturn

def tool_call_params(self) -> dict[str, Any]:
# Add additional_properties: false to the schema (OpenAI requires this for some models)
output_schema = self.kiln_task.output_schema()
if not isinstance(output_schema, dict):
raise ValueError(
"Invalid output schema for this task. Can not use tool calls."
)
output_schema["additionalProperties"] = False

return {
"tools": [
{
"type": "function",
"function": {
"name": "task_response",
"parameters": self.kiln_task.output_schema(),
"parameters": output_schema,
"strict": True,
},
}
Expand Down
77 changes: 65 additions & 12 deletions libs/core/kiln_ai/adapters/provider_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
KilnModelProvider,
ModelName,
ModelProviderName,
StructuredOutputMode,
built_in_models,
)
from kiln_ai.adapters.ollama_tools import (
get_ollama_connection,
)
from kiln_ai.datamodel import Finetune, Task
from kiln_ai.datamodel.registry import project_from_id

from ..utils.config import Config
from kiln_ai.utils.config import Config


async def provider_enabled(provider_name: ModelProviderName) -> bool:
Expand Down Expand Up @@ -102,6 +102,46 @@ async def builtin_model_from(
return provider


def core_provider(model_id: str, provider_name: ModelProviderName) -> ModelProviderName:
"""
Get the provider that should be run.
Some provider IDs are wrappers (fine-tunes, custom models). This maps these to runnable providers (openai, ollama, etc)
"""

# Custom models map to the underlying provider
if provider_name is ModelProviderName.kiln_custom_registry:
provider_name, _ = parse_custom_model_id(model_id)
return provider_name

# Fine-tune provider maps to an underlying provider
if provider_name is ModelProviderName.kiln_fine_tune:
finetune = finetune_from_id(model_id)
if finetune.provider not in ModelProviderName.__members__:
raise ValueError(
f"Finetune {model_id} has no underlying provider {finetune.provider}"
)
return ModelProviderName(finetune.provider)

return provider_name


def parse_custom_model_id(
model_id: str,
) -> tuple[ModelProviderName, str]:
if "::" not in model_id:
raise ValueError(f"Invalid custom model ID: {model_id}")

# For custom registry, get the provider name and model name from the model id
provider_name = model_id.split("::", 1)[0]
model_name = model_id.split("::", 1)[1]

if provider_name not in ModelProviderName.__members__:
raise ValueError(f"Invalid provider name: {provider_name}")

return ModelProviderName(provider_name), model_name


async def kiln_model_provider_from(
name: str, provider_name: str | None = None
) -> KilnModelProvider:
Expand All @@ -117,8 +157,7 @@ async def kiln_model_provider_from(

# For custom registry, get the provider name and model name from the model id
if provider_name == ModelProviderName.kiln_custom_registry:
provider_name = name.split("::", 1)[0]
name = name.split("::", 1)[1]
provider_name, name = parse_custom_model_id(name)

# Custom/untested model. Set untested, and build a ModelProvider at runtime
if provider_name is None:
Expand All @@ -136,9 +175,6 @@ async def kiln_model_provider_from(
)


finetune_cache: dict[str, KilnModelProvider] = {}


def openai_compatible_provider_model(
model_id: str,
) -> KilnModelProvider:
Expand Down Expand Up @@ -178,9 +214,10 @@ def openai_compatible_provider_model(
)


def finetune_provider_model(
model_id: str,
) -> KilnModelProvider:
finetune_cache: dict[str, Finetune] = {}


def finetune_from_id(model_id: str) -> Finetune:
if model_id in finetune_cache:
return finetune_cache[model_id]

Expand All @@ -202,6 +239,15 @@ def finetune_provider_model(
f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
)

finetune_cache[model_id] = fine_tune
return fine_tune


def finetune_provider_model(
model_id: str,
) -> KilnModelProvider:
fine_tune = finetune_from_id(model_id)

provider = ModelProviderName[fine_tune.provider]
model_provider = KilnModelProvider(
name=provider,
Expand All @@ -210,11 +256,18 @@ def finetune_provider_model(
},
)

# If we know the model was trained with specific output mode, set it
if fine_tune.structured_output_mode is not None:
# If we know the model was trained with specific output mode, set it
model_provider.structured_output_mode = fine_tune.structured_output_mode
else:
# Some early adopters won't have structured_output_mode set on their fine-tunes.
# We know that OpenAI uses json_schema, and Fireworks (only other provider) use json_mode.
# This can be removed in the future
if provider == ModelProviderName.openai:
model_provider.structured_output_mode = StructuredOutputMode.json_schema
else:
model_provider.structured_output_mode = StructuredOutputMode.json_mode

finetune_cache[model_id] = model_provider
return model_provider


Expand Down
Loading

0 comments on commit 55c6c2f

Please sign in to comment.