Skip to content

Commit

Permalink
chore: use primitives instead of typing imports and fixes completion … (
Browse files Browse the repository at this point in the history
#149)

Signed-off-by: Adrian Cole <adrian.cole@elastic.co>
  • Loading branch information
codefromthecrypt authored Oct 15, 2024
1 parent e687b0b commit c247c8e
Show file tree
Hide file tree
Showing 53 changed files with 235 additions and 257 deletions.
3 changes: 1 addition & 2 deletions packages/exchange/src/exchange/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from copy import deepcopy
from typing import List
from attrs import define, field


Expand Down Expand Up @@ -31,7 +30,7 @@ class CheckpointData:
total_token_count: int = field(default=0)

# in order list of individual checkpoints in the exchange
checkpoints: List[Checkpoint] = field(factory=list)
checkpoints: list[Checkpoint] = field(factory=list)

# the offset to apply to the message index when calculating the last message index
# this is useful because messages on the exchange behave like a queue, where you can only
Expand Down
8 changes: 4 additions & 4 deletions packages/exchange/src/exchange/content.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Optional

from attrs import define, asdict

Expand All @@ -7,11 +7,11 @@


class Content:
def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None:
def __init_subclass__(cls, **kwargs: dict[str, any]) -> None:
super().__init_subclass__(**kwargs)
CONTENT_TYPES[cls.__name__] = cls

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, any]:
data = asdict(self, recurse=True)
data["type"] = self.__class__.__name__
return data
Expand All @@ -26,7 +26,7 @@ class Text(Content):
class ToolUse(Content):
id: str
name: str
parameters: Any
parameters: any
is_error: bool = False
error_message: Optional[str] = None

Expand Down
16 changes: 8 additions & 8 deletions packages/exchange/src/exchange/exchange.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
import traceback
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Tuple

from typing import Mapping
from attrs import define, evolve, field, Factory
from tiktoken import get_encoding

Expand Down Expand Up @@ -41,16 +40,16 @@ class Exchange:
model: str
system: str
moderator: Moderator = field(default=ContextTruncate())
tools: Tuple[Tool] = field(factory=tuple, converter=tuple)
messages: List[Message] = field(factory=list)
tools: tuple[Tool, ...] = field(factory=tuple, converter=tuple)
messages: list[Message] = field(factory=list)
checkpoint_data: CheckpointData = field(factory=CheckpointData)
generation_args: dict = field(default=Factory(dict))

@property
def _toolmap(self) -> Mapping[str, Tool]:
return {tool.name: tool for tool in self.tools}

def replace(self, **kwargs: Dict[str, Any]) -> "Exchange":
def replace(self, **kwargs: dict[str, any]) -> "Exchange":
"""Make a copy of the exchange, replacing any passed arguments"""
# TODO: ensure that the checkpoint data is updated correctly. aka,
# if we replace the messages, we need to update the checkpoint data
Expand Down Expand Up @@ -264,7 +263,7 @@ def pop_first_message(self) -> Message:
# we've removed all the checkpoints, so we need to reset the message index offset
self.checkpoint_data.message_index_offset = 0

def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
def pop_last_checkpoint(self) -> tuple[Checkpoint, list[Message]]:
"""
Reverts the exchange back to the last checkpoint, removing associated messages
"""
Expand All @@ -275,7 +274,7 @@ def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
messages.append(self.messages.pop())
return removed_checkpoint, messages

def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
def pop_first_checkpoint(self) -> tuple[Checkpoint, list[Message]]:
"""
Pop the first checkpoint from the exchange, removing associated messages
"""
Expand Down Expand Up @@ -332,5 +331,6 @@ def is_allowed_to_call_llm(self) -> bool:
# this to be a required method of the provider instead.
return len(self.messages) > 0 and self.messages[-1].role == "user"

def get_token_usage(self) -> Dict[str, Usage]:
@staticmethod
def get_token_usage() -> dict[str, Usage]:
return _token_usage_collector.get_token_usage_group_by_model()
5 changes: 1 addition & 4 deletions packages/exchange/src/exchange/invalid_choice_error.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import List


class InvalidChoiceError(Exception):
def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None:
def __init__(self, attribute_name: str, attribute_value: str, available_values: list[str]) -> None:
self.attribute_name = attribute_name
self.attribute_value = attribute_value
self.available_values = available_values
Expand Down
22 changes: 11 additions & 11 deletions packages/exchange/src/exchange/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import time
from pathlib import Path
from typing import Any, Dict, List, Literal, Type
from typing import Literal

from attrs import define, field
from jinja2 import Environment, FileSystemLoader
Expand All @@ -12,7 +12,7 @@
Role = Literal["user", "assistant"]


def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401
def validate_role_and_content(instance: "Message", *_: any) -> None: # noqa: ANN401
if instance.role == "user":
if not (instance.text or instance.tool_result):
raise ValueError("User message must include a Text or ToolResult")
Expand All @@ -25,7 +25,7 @@ def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: AN
raise ValueError("Assistant message does not support ToolResult")


def content_converter(contents: List[Dict[str, Any]]) -> List[Content]:
def content_converter(contents: list[dict[str, any]]) -> list[Content]:
return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents]


Expand All @@ -48,9 +48,9 @@ class Message:
role: Role = field(default="user")
id: str = field(factory=lambda: str(create_object_id(prefix="msg")))
created: int = field(factory=lambda: int(time.time()))
content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter)
content: list[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter)

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, any]:
return {
"role": self.role,
"id": self.id,
Expand All @@ -68,7 +68,7 @@ def text(self) -> str:
return "\n".join(result)

@property
def tool_use(self) -> List[ToolUse]:
def tool_use(self) -> list[ToolUse]:
"""All tool use content of this message."""
result = []
for content in self.content:
Expand All @@ -77,7 +77,7 @@ def tool_use(self) -> List[ToolUse]:
return result

@property
def tool_result(self) -> List[ToolResult]:
def tool_result(self) -> list[ToolResult]:
"""All tool result content of this message."""
result = []
for content in self.content:
Expand All @@ -87,10 +87,10 @@ def tool_result(self) -> List[ToolResult]:

@classmethod
def load(
cls: Type["Message"],
cls: type["Message"],
filename: str,
role: Role = "user",
**kwargs: Dict[str, Any],
**kwargs: dict[str, any],
) -> "Message":
"""Load the message from filename relative to where the load is called.
Expand All @@ -113,9 +113,9 @@ def load(
return cls(role=role, content=[Text(text=rendered_content)])

@classmethod
def user(cls: Type["Message"], text: str) -> "Message":
def user(cls: type["Message"], text: str) -> "Message":
return cls(role="user", content=[Text(text)])

@classmethod
def assistant(cls: Type["Message"], text: str) -> "Message":
def assistant(cls: type["Message"], text: str) -> "Message":
return cls(role="assistant", content=[Text(text)])
3 changes: 1 addition & 2 deletions packages/exchange/src/exchange/moderators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import cache
from typing import Type

from exchange.invalid_choice_error import InvalidChoiceError
from exchange.moderators.base import Moderator
Expand All @@ -10,7 +9,7 @@


@cache
def get_moderator(name: str) -> Type[Moderator]:
def get_moderator(name: str) -> type[Moderator]:
moderators = load_plugins(group="exchange.moderator")
if name not in moderators:
raise InvalidChoiceError("moderator", name, moderators.keys())
Expand Down
3 changes: 1 addition & 2 deletions packages/exchange/src/exchange/moderators/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Type


class Moderator(ABC):
@abstractmethod
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
pass
3 changes: 1 addition & 2 deletions packages/exchange/src/exchange/moderators/passive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Type
from exchange.moderators.base import Moderator


class PassiveModerator(Moderator):
def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
def rewrite(self, _: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
pass
4 changes: 1 addition & 3 deletions packages/exchange/src/exchange/moderators/summarizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Type

from exchange import Message
from exchange.checkpoint import CheckpointData
from exchange.moderators import ContextTruncate, PassiveModerator


class ContextSummarizer(ContextTruncate):
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
"""Summarize the context history up to the last few messages in the exchange"""

self._update_system_prompt_token_count(exchange)
Expand Down
4 changes: 2 additions & 2 deletions packages/exchange/src/exchange/moderators/truncate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional

from exchange.checkpoint import CheckpointData
from exchange.message import Message
Expand Down Expand Up @@ -62,7 +62,7 @@ def _update_system_prompt_token_count(self, exchange: Exchange) -> None:
exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count
exchange.checkpoint_data.total_token_count += self.system_prompt_token_count

def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]:
def _get_messages_to_remove(self, exchange: Exchange) -> list[Message]:
# this keeps all the messages/checkpoints
throwaway_exchange = exchange.replace(
moderator=PassiveModerator(),
Expand Down
3 changes: 1 addition & 2 deletions packages/exchange/src/exchange/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import cache
from typing import Type

from exchange.invalid_choice_error import InvalidChoiceError
from exchange.providers.anthropic import AnthropicProvider # noqa
Expand All @@ -15,7 +14,7 @@


@cache
def get_provider(name: str) -> Type[Provider]:
def get_provider(name: str) -> type[Provider]:
providers = load_plugins(group="exchange.provider")
if name not in providers:
raise InvalidChoiceError("provider", name, providers.keys())
Expand Down
21 changes: 11 additions & 10 deletions packages/exchange/src/exchange/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from typing import Any, Dict, List, Tuple, Type

import httpx

Expand Down Expand Up @@ -29,7 +28,7 @@ def __init__(self, client: httpx.Client) -> None:
self.client = client

@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider":
cls.check_env_vars()
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
key = os.environ.get("ANTHROPIC_API_KEY")
Expand All @@ -45,7 +44,7 @@ def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
return cls(client)

@staticmethod
def get_usage(data: Dict) -> Usage: # noqa: ANN401
def get_usage(data: dict) -> Usage: # noqa: ANN401
usage = data.get("usage")
input_tokens = usage.get("input_tokens")
output_tokens = usage.get("output_tokens")
Expand All @@ -61,7 +60,7 @@ def get_usage(data: Dict) -> Usage: # noqa: ANN401
)

@staticmethod
def anthropic_response_to_message(response: Dict) -> Message:
def anthropic_response_to_message(response: dict) -> Message:
content_blocks = response.get("content", [])
content = []
for block in content_blocks:
Expand All @@ -78,7 +77,7 @@ def anthropic_response_to_message(response: Dict) -> Message:
return Message(role="assistant", content=content)

@staticmethod
def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]:
def tools_to_anthropic_spec(tools: tuple[Tool, ...]) -> list[dict[str, any]]:
return [
{
"name": tool.name,
Expand All @@ -89,7 +88,7 @@ def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]:
]

@staticmethod
def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]:
def messages_to_anthropic_spec(messages: list[Message]) -> list[dict[str, any]]:
messages_spec = []
# if messages is empty - just make a default
for message in messages:
Expand Down Expand Up @@ -127,10 +126,12 @@ def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: List[Tool] = [],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: list[Tool] = None,
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
if tools is None:
tools = []
tools_set = set()
unique_tools = []
for tool in tools:
Expand Down
4 changes: 1 addition & 3 deletions packages/exchange/src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Type

import httpx
import os

Expand All @@ -21,7 +19,7 @@ def __init__(self, client: httpx.Client) -> None:
super().__init__(client)

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
def from_env(cls: type["AzureProvider"]) -> "AzureProvider":
cls.check_env_vars()
url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME")
deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")
Expand Down
13 changes: 7 additions & 6 deletions packages/exchange/src/exchange/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Optional, Tuple, Type
from typing import Optional

from exchange.message import Message
from exchange.tool import Tool
Expand All @@ -19,11 +19,11 @@ class Provider(ABC):
REQUIRED_ENV_VARS: list[str] = []

@classmethod
def from_env(cls: Type["Provider"]) -> "Provider":
def from_env(cls: type["Provider"]) -> "Provider":
return cls()

@classmethod
def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None:
def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None:
for env_var in cls.REQUIRED_ENV_VARS:
if env_var not in os.environ:
raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url)
Expand All @@ -33,9 +33,10 @@ def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: tuple[Tool, ...],
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
"""Generate the next message using the specified model"""
pass

Expand Down
Loading

0 comments on commit c247c8e

Please sign in to comment.