Skip to content

Commit

Permalink
CR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Jan 31, 2025
1 parent f91a594 commit 1965f89
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 39 deletions.
23 changes: 13 additions & 10 deletions libs/core/kiln_ai/adapters/adapter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
from kiln_ai.adapters.model_adapters.openai_model_adapter import (
OpenAICompatibleAdapter,
OpenAICompatibleConfig,
)
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
from kiln_ai.utils.config import Config
Expand All @@ -19,20 +20,22 @@ def adapter_for_task(
tags: list[str] | None = None,
) -> BaseAdapter:
if provider == ModelProviderName.openrouter:
api_key = Config.shared().open_router_api_key
base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
return OpenAICompatibleAdapter(
base_url=base_url,
api_key=api_key,
kiln_task=kiln_task,
model_name=model_name,
provider_name=provider,
config=OpenAICompatibleConfig(
base_url=getenv("OPENROUTER_BASE_URL")
or "https://openrouter.ai/api/v1",
api_key=Config.shared().open_router_api_key,
model_name=model_name,
provider_name=provider,
openrouter_style_reasoning=True,
default_headers={
"HTTP-Referer": "https://getkiln.ai/openrouter",
"X-Title": "KilnAI",
},
),
prompt_builder=prompt_builder,
tags=tags,
default_headers={
"HTTP-Referer": "https://getkiln.ai/openrouter",
"X-Title": "KilnAI",
},
)

# We use langchain for all others right now, but moving off it as we touch anything.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from os import getenv
from typing import Any, Dict, NoReturn

from langchain_aws import ChatBedrockConverse
Expand Down
58 changes: 30 additions & 28 deletions libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Any, Dict, NoReturn

from openai import AsyncOpenAI
Expand All @@ -7,7 +8,6 @@
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion import Choice

import kiln_ai.datamodel as datamodel
from kiln_ai.adapters.ml_model_list import StructuredOutputMode
Expand All @@ -20,32 +20,35 @@
from kiln_ai.adapters.parsers.json_parser import parse_json_string


@dataclass
class OpenAICompatibleConfig:
api_key: str
model_name: str
provider_name: str
base_url: str | None = None
default_headers: dict[str, str] | None = None
openrouter_style_reasoning: bool = False


class OpenAICompatibleAdapter(BaseAdapter):
def __init__(
self,
api_key: str,
config: OpenAICompatibleConfig,
kiln_task: datamodel.Task,
model_name: str,
provider_name: str,
base_url: str | None = None, # Client will default to OpenAI
default_headers: dict[str, str] | None = None,
prompt_builder: BasePromptBuilder | None = None,
tags: list[str] | None = None,
):
if not model_name or not provider_name:
raise ValueError(
"model_name and provider_name must be provided for OpenAI compatible adapter"
)

# Create an async OpenAI client instead
self.config = config
self.client = AsyncOpenAI(
api_key=api_key, base_url=base_url, default_headers=default_headers
api_key=config.api_key,
base_url=config.base_url,
default_headers=config.default_headers,
)

super().__init__(
kiln_task,
model_name=model_name,
model_provider_name=provider_name,
model_name=config.model_name,
model_provider_name=config.provider_name,
prompt_builder=prompt_builder,
tags=tags,
)
Expand Down Expand Up @@ -95,15 +98,15 @@ async def _run(self, input: Dict | str) -> RunOutput:

# Main completion call
response_format_options = await self.response_format_options()
print(f"response_format_options: {response_format_options}")
response = await self.client.chat.completions.create(
model=provider.provider_options["model"],
messages=messages,
# TODO P0: remove this
extra_body={"include_reasoning": True},
extra_body={"include_reasoning": True}
if self.config.openrouter_style_reasoning
else {},
**response_format_options,
)
print(f"response: {response}")

if not isinstance(response, ChatCompletion):
raise RuntimeError(
f"Expected ChatCompletion response, got {type(response)}."
Expand All @@ -121,7 +124,11 @@ async def _run(self, input: Dict | str) -> RunOutput:
message = response.choices[0].message

# Save reasoning if it exists
if hasattr(message, "reasoning") and message.reasoning: # pyright: ignore
if (
self.config.openrouter_style_reasoning
and hasattr(message, "reasoning")
and message.reasoning # pyright: ignore
):
intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore

# the string content of the response
Expand Down Expand Up @@ -170,14 +177,11 @@ async def response_format_options(self) -> dict[str, Any]:
return {}

provider = await self.model_provider()
# TODO check these
match provider.structured_output_mode:
case StructuredOutputMode.json_mode:
return {"response_format": {"type": "json_object"}}
case StructuredOutputMode.json_schema:
# TODO P0: use json_schema
output_schema = self.kiln_task.output_schema()
print(f"output_schema: {output_schema}")
return {
"response_format": {
"type": "json_schema",
Expand All @@ -188,17 +192,15 @@ async def response_format_options(self) -> dict[str, Any]:
}
}
case StructuredOutputMode.function_calling:
# TODO P0
return self.tool_call_params()
case StructuredOutputMode.json_instructions:
# JSON done via instructions in prompt, not the API response format
# TODO try json_object on stable API
# JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below).
return {}
case StructuredOutputMode.json_instruction_and_object:
# We set response_format to json_object, but we also need to set the instructions in the prompt
# We set response_format to json_object and also set json instructions in the prompt
return {"response_format": {"type": "json_object"}}
case StructuredOutputMode.default:
# Default to function calling -- it's older than the other modes
# Default to function calling -- it's older than the other modes. Higher compatibility.
return self.tool_call_params()
case _:
raise ValueError(
Expand Down

0 comments on commit 1965f89

Please sign in to comment.