From c247c8eb307addb1803e2a579055ac1799b7dac9 Mon Sep 17 00:00:00 2001 From: Adrian Cole <64215+codefromthecrypt@users.noreply.github.com> Date: Wed, 16 Oct 2024 09:41:37 +1100 Subject: [PATCH] =?UTF-8?q?chore:=20use=20primitives=20instead=20of=20typi?= =?UTF-8?q?ng=20imports=20and=20fixes=20completion=20=E2=80=A6=20(#149)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrian Cole --- packages/exchange/src/exchange/checkpoint.py | 3 +- packages/exchange/src/exchange/content.py | 8 ++--- packages/exchange/src/exchange/exchange.py | 16 +++++----- .../src/exchange/invalid_choice_error.py | 5 +-- packages/exchange/src/exchange/message.py | 22 ++++++------- .../src/exchange/moderators/__init__.py | 3 +- .../exchange/src/exchange/moderators/base.py | 3 +- .../src/exchange/moderators/passive.py | 3 +- .../src/exchange/moderators/summarizer.py | 4 +-- .../src/exchange/moderators/truncate.py | 4 +-- .../src/exchange/providers/__init__.py | 3 +- .../src/exchange/providers/anthropic.py | 21 ++++++------ .../exchange/src/exchange/providers/azure.py | 4 +-- .../exchange/src/exchange/providers/base.py | 13 ++++---- .../src/exchange/providers/bedrock.py | 32 +++++++++---------- .../src/exchange/providers/databricks.py | 12 +++---- .../exchange/src/exchange/providers/google.py | 21 ++++++------ .../exchange/src/exchange/providers/groq.py | 11 +++---- .../exchange/src/exchange/providers/ollama.py | 3 +- .../exchange/src/exchange/providers/openai.py | 11 +++---- .../exchange/src/exchange/providers/utils.py | 10 +++--- .../src/exchange/token_usage_collector.py | 3 +- packages/exchange/src/exchange/tool.py | 9 +++--- packages/exchange/src/exchange/utils.py | 14 ++++---- packages/exchange/tests/providers/conftest.py | 7 ++-- .../exchange/tests/providers/test_azure.py | 8 ++--- .../exchange/tests/providers/test_bedrock.py | 6 ++-- .../tests/providers/test_databricks.py | 4 +-- .../tests/providers/test_provider_utils.py | 10 +++--- packages/exchange/tests/test_exchange.py | 31 ++++++++++-------- .../exchange/tests/test_exchange_frozen.py | 2 +- packages/exchange/tests/test_summarizer.py | 10 +++--- packages/exchange/tests/test_truncate.py | 2 +- src/goose/cli/config.py | 12 +++---- src/goose/cli/prompt/completer.py | 9 +++--- src/goose/cli/prompt/lexer.py | 6 ++-- .../cli/prompt/overwrite_session_prompt.py | 4 +-- src/goose/cli/session.py | 4 +-- src/goose/command/__init__.py | 3 +- src/goose/command/base.py | 4 +-- src/goose/command/file.py | 3 +- src/goose/profile.py | 10 +++--- src/goose/toolkit/base.py | 6 ++-- src/goose/toolkit/developer.py | 9 +++--- .../toolkit/repo_context/repo_context.py | 7 ++-- src/goose/toolkit/repo_context/utils.py | 7 ++-- .../summarization/summarize_project.py | 6 ++-- .../toolkit/summarization/summarize_repo.py | 6 ++-- src/goose/toolkit/summarization/utils.py | 30 ++++++++--------- src/goose/toolkit/utils.py | 4 +-- src/goose/utils/__init__.py | 10 +++--- src/goose/utils/file_utils.py | 20 ++++++------ src/goose/utils/session_file.py | 14 ++++---- 53 files changed, 235 insertions(+), 257 deletions(-) diff --git a/packages/exchange/src/exchange/checkpoint.py b/packages/exchange/src/exchange/checkpoint.py index f355dd0a2..063ef35d1 100644 --- a/packages/exchange/src/exchange/checkpoint.py +++ b/packages/exchange/src/exchange/checkpoint.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import List from attrs import define, field @@ -31,7 +30,7 @@ class CheckpointData: total_token_count: int = field(default=0) # in order list of individual checkpoints in the exchange - checkpoints: List[Checkpoint] = field(factory=list) + checkpoints: list[Checkpoint] = field(factory=list) # the offset to apply to the message index when calculating the last message index # this is useful because messages on the exchange behave like a queue, where you can only diff --git a/packages/exchange/src/exchange/content.py b/packages/exchange/src/exchange/content.py index b9cc986fc..66957b7c6 100644 --- a/packages/exchange/src/exchange/content.py +++ b/packages/exchange/src/exchange/content.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Optional from attrs import define, asdict @@ -7,11 +7,11 @@ class Content: - def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None: + def __init_subclass__(cls, **kwargs: dict[str, any]) -> None: super().__init_subclass__(**kwargs) CONTENT_TYPES[cls.__name__] = cls - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, any]: data = asdict(self, recurse=True) data["type"] = self.__class__.__name__ return data @@ -26,7 +26,7 @@ class Text(Content): class ToolUse(Content): id: str name: str - parameters: Any + parameters: any is_error: bool = False error_message: Optional[str] = None diff --git a/packages/exchange/src/exchange/exchange.py b/packages/exchange/src/exchange/exchange.py index b2fdbc5ec..05eda0df1 100644 --- a/packages/exchange/src/exchange/exchange.py +++ b/packages/exchange/src/exchange/exchange.py @@ -1,8 +1,7 @@ import json import traceback from copy import deepcopy -from typing import Any, Dict, List, Mapping, Tuple - +from typing import Mapping from attrs import define, evolve, field, Factory from tiktoken import get_encoding @@ -41,8 +40,8 @@ class Exchange: model: str system: str moderator: Moderator = field(default=ContextTruncate()) - tools: Tuple[Tool] = field(factory=tuple, converter=tuple) - messages: List[Message] = field(factory=list) + tools: tuple[Tool, ...] = field(factory=tuple, converter=tuple) + messages: list[Message] = field(factory=list) checkpoint_data: CheckpointData = field(factory=CheckpointData) generation_args: dict = field(default=Factory(dict)) @@ -50,7 +49,7 @@ class Exchange: def _toolmap(self) -> Mapping[str, Tool]: return {tool.name: tool for tool in self.tools} - def replace(self, **kwargs: Dict[str, Any]) -> "Exchange": + def replace(self, **kwargs: dict[str, any]) -> "Exchange": """Make a copy of the exchange, replacing any passed arguments""" # TODO: ensure that the checkpoint data is updated correctly. aka, # if we replace the messages, we need to update the checkpoint data @@ -264,7 +263,7 @@ def pop_first_message(self) -> Message: # we've removed all the checkpoints, so we need to reset the message index offset self.checkpoint_data.message_index_offset = 0 - def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + def pop_last_checkpoint(self) -> tuple[Checkpoint, list[Message]]: """ Reverts the exchange back to the last checkpoint, removing associated messages """ @@ -275,7 +274,7 @@ def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: messages.append(self.messages.pop()) return removed_checkpoint, messages - def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + def pop_first_checkpoint(self) -> tuple[Checkpoint, list[Message]]: """ Pop the first checkpoint from the exchange, removing associated messages """ @@ -332,5 +331,6 @@ def is_allowed_to_call_llm(self) -> bool: # this to be a required method of the provider instead. return len(self.messages) > 0 and self.messages[-1].role == "user" - def get_token_usage(self) -> Dict[str, Usage]: + @staticmethod + def get_token_usage() -> dict[str, Usage]: return _token_usage_collector.get_token_usage_group_by_model() diff --git a/packages/exchange/src/exchange/invalid_choice_error.py b/packages/exchange/src/exchange/invalid_choice_error.py index ffbb9899f..def35bbc0 100644 --- a/packages/exchange/src/exchange/invalid_choice_error.py +++ b/packages/exchange/src/exchange/invalid_choice_error.py @@ -1,8 +1,5 @@ -from typing import List - - class InvalidChoiceError(Exception): - def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: + def __init__(self, attribute_name: str, attribute_value: str, available_values: list[str]) -> None: self.attribute_name = attribute_name self.attribute_value = attribute_value self.available_values = available_values diff --git a/packages/exchange/src/exchange/message.py b/packages/exchange/src/exchange/message.py index 035c60345..5edff692c 100644 --- a/packages/exchange/src/exchange/message.py +++ b/packages/exchange/src/exchange/message.py @@ -1,7 +1,7 @@ import inspect import time from pathlib import Path -from typing import Any, Dict, List, Literal, Type +from typing import Literal from attrs import define, field from jinja2 import Environment, FileSystemLoader @@ -12,7 +12,7 @@ Role = Literal["user", "assistant"] -def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401 +def validate_role_and_content(instance: "Message", *_: any) -> None: # noqa: ANN401 if instance.role == "user": if not (instance.text or instance.tool_result): raise ValueError("User message must include a Text or ToolResult") @@ -25,7 +25,7 @@ def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: AN raise ValueError("Assistant message does not support ToolResult") -def content_converter(contents: List[Dict[str, Any]]) -> List[Content]: +def content_converter(contents: list[dict[str, any]]) -> list[Content]: return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents] @@ -48,9 +48,9 @@ class Message: role: Role = field(default="user") id: str = field(factory=lambda: str(create_object_id(prefix="msg"))) created: int = field(factory=lambda: int(time.time())) - content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter) + content: list[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, any]: return { "role": self.role, "id": self.id, @@ -68,7 +68,7 @@ def text(self) -> str: return "\n".join(result) @property - def tool_use(self) -> List[ToolUse]: + def tool_use(self) -> list[ToolUse]: """All tool use content of this message.""" result = [] for content in self.content: @@ -77,7 +77,7 @@ def tool_use(self) -> List[ToolUse]: return result @property - def tool_result(self) -> List[ToolResult]: + def tool_result(self) -> list[ToolResult]: """All tool result content of this message.""" result = [] for content in self.content: @@ -87,10 +87,10 @@ def tool_result(self) -> List[ToolResult]: @classmethod def load( - cls: Type["Message"], + cls: type["Message"], filename: str, role: Role = "user", - **kwargs: Dict[str, Any], + **kwargs: dict[str, any], ) -> "Message": """Load the message from filename relative to where the load is called. @@ -113,9 +113,9 @@ def load( return cls(role=role, content=[Text(text=rendered_content)]) @classmethod - def user(cls: Type["Message"], text: str) -> "Message": + def user(cls: type["Message"], text: str) -> "Message": return cls(role="user", content=[Text(text)]) @classmethod - def assistant(cls: Type["Message"], text: str) -> "Message": + def assistant(cls: type["Message"], text: str) -> "Message": return cls(role="assistant", content=[Text(text)]) diff --git a/packages/exchange/src/exchange/moderators/__init__.py b/packages/exchange/src/exchange/moderators/__init__.py index 82d032e42..925473e98 100644 --- a/packages/exchange/src/exchange/moderators/__init__.py +++ b/packages/exchange/src/exchange/moderators/__init__.py @@ -1,5 +1,4 @@ from functools import cache -from typing import Type from exchange.invalid_choice_error import InvalidChoiceError from exchange.moderators.base import Moderator @@ -10,7 +9,7 @@ @cache -def get_moderator(name: str) -> Type[Moderator]: +def get_moderator(name: str) -> type[Moderator]: moderators = load_plugins(group="exchange.moderator") if name not in moderators: raise InvalidChoiceError("moderator", name, moderators.keys()) diff --git a/packages/exchange/src/exchange/moderators/base.py b/packages/exchange/src/exchange/moderators/base.py index d7c630c6a..98a6ad663 100644 --- a/packages/exchange/src/exchange/moderators/base.py +++ b/packages/exchange/src/exchange/moderators/base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod -from typing import Type class Moderator(ABC): @abstractmethod - def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821 pass diff --git a/packages/exchange/src/exchange/moderators/passive.py b/packages/exchange/src/exchange/moderators/passive.py index e3a24efbd..30e6f2c66 100644 --- a/packages/exchange/src/exchange/moderators/passive.py +++ b/packages/exchange/src/exchange/moderators/passive.py @@ -1,7 +1,6 @@ -from typing import Type from exchange.moderators.base import Moderator class PassiveModerator(Moderator): - def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + def rewrite(self, _: type["exchange.exchange.Exchange"]) -> None: # noqa: F821 pass diff --git a/packages/exchange/src/exchange/moderators/summarizer.py b/packages/exchange/src/exchange/moderators/summarizer.py index 7e2dd5588..a7bb1b0f5 100644 --- a/packages/exchange/src/exchange/moderators/summarizer.py +++ b/packages/exchange/src/exchange/moderators/summarizer.py @@ -1,12 +1,10 @@ -from typing import Type - from exchange import Message from exchange.checkpoint import CheckpointData from exchange.moderators import ContextTruncate, PassiveModerator class ContextSummarizer(ContextTruncate): - def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821 """Summarize the context history up to the last few messages in the exchange""" self._update_system_prompt_token_count(exchange) diff --git a/packages/exchange/src/exchange/moderators/truncate.py b/packages/exchange/src/exchange/moderators/truncate.py index 41115f663..a9c08b650 100644 --- a/packages/exchange/src/exchange/moderators/truncate.py +++ b/packages/exchange/src/exchange/moderators/truncate.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from exchange.checkpoint import CheckpointData from exchange.message import Message @@ -62,7 +62,7 @@ def _update_system_prompt_token_count(self, exchange: Exchange) -> None: exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count exchange.checkpoint_data.total_token_count += self.system_prompt_token_count - def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]: + def _get_messages_to_remove(self, exchange: Exchange) -> list[Message]: # this keeps all the messages/checkpoints throwaway_exchange = exchange.replace( moderator=PassiveModerator(), diff --git a/packages/exchange/src/exchange/providers/__init__.py b/packages/exchange/src/exchange/providers/__init__.py index 088fa7738..56418df47 100644 --- a/packages/exchange/src/exchange/providers/__init__.py +++ b/packages/exchange/src/exchange/providers/__init__.py @@ -1,5 +1,4 @@ from functools import cache -from typing import Type from exchange.invalid_choice_error import InvalidChoiceError from exchange.providers.anthropic import AnthropicProvider # noqa @@ -15,7 +14,7 @@ @cache -def get_provider(name: str) -> Type[Provider]: +def get_provider(name: str) -> type[Provider]: providers = load_plugins(group="exchange.provider") if name not in providers: raise InvalidChoiceError("provider", name, providers.keys()) diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py index 84ecd12fb..a6a0c2262 100644 --- a/packages/exchange/src/exchange/providers/anthropic.py +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -1,5 +1,4 @@ import os -from typing import Any, Dict, List, Tuple, Type import httpx @@ -29,7 +28,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": + def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider": cls.check_env_vars() url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) key = os.environ.get("ANTHROPIC_API_KEY") @@ -45,7 +44,7 @@ def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": return cls(client) @staticmethod - def get_usage(data: Dict) -> Usage: # noqa: ANN401 + def get_usage(data: dict) -> Usage: # noqa: ANN401 usage = data.get("usage") input_tokens = usage.get("input_tokens") output_tokens = usage.get("output_tokens") @@ -61,7 +60,7 @@ def get_usage(data: Dict) -> Usage: # noqa: ANN401 ) @staticmethod - def anthropic_response_to_message(response: Dict) -> Message: + def anthropic_response_to_message(response: dict) -> Message: content_blocks = response.get("content", []) content = [] for block in content_blocks: @@ -78,7 +77,7 @@ def anthropic_response_to_message(response: Dict) -> Message: return Message(role="assistant", content=content) @staticmethod - def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]: + def tools_to_anthropic_spec(tools: tuple[Tool, ...]) -> list[dict[str, any]]: return [ { "name": tool.name, @@ -89,7 +88,7 @@ def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]: ] @staticmethod - def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]: + def messages_to_anthropic_spec(messages: list[Message]) -> list[dict[str, any]]: messages_spec = [] # if messages is empty - just make a default for message in messages: @@ -127,10 +126,12 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: List[Tool] = [], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: list[Tool] = None, + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: + if tools is None: + tools = [] tools_set = set() unique_tools = [] for tool in tools: diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py index 4d470f978..fa8814f39 100644 --- a/packages/exchange/src/exchange/providers/azure.py +++ b/packages/exchange/src/exchange/providers/azure.py @@ -1,5 +1,3 @@ -from typing import Type - import httpx import os @@ -21,7 +19,7 @@ def __init__(self, client: httpx.Client) -> None: super().__init__(client) @classmethod - def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": + def from_env(cls: type["AzureProvider"]) -> "AzureProvider": cls.check_env_vars() url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME") deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index c8d860ecc..76b1c3391 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -1,7 +1,7 @@ import os from abc import ABC, abstractmethod from attrs import define, field -from typing import List, Optional, Tuple, Type +from typing import Optional from exchange.message import Message from exchange.tool import Tool @@ -19,11 +19,11 @@ class Provider(ABC): REQUIRED_ENV_VARS: list[str] = [] @classmethod - def from_env(cls: Type["Provider"]) -> "Provider": + def from_env(cls: type["Provider"]) -> "Provider": return cls() @classmethod - def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None: + def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None: for env_var in cls.REQUIRED_ENV_VARS: if env_var not in os.environ: raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url) @@ -33,9 +33,10 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: """Generate the next message using the specified model""" pass diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py index 6c32d7cb3..1dd0ebadf 100644 --- a/packages/exchange/src/exchange/providers/bedrock.py +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -4,7 +4,7 @@ import logging import os from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Optional from urllib.parse import quote, urlparse import httpx @@ -36,7 +36,7 @@ def __init__( aws_access_key: str, aws_secret_key: str, aws_session_token: Optional[str] = None, - **kwargs: Dict[str, Any], + **kwargs: dict[str, any], ) -> None: self.region = aws_region self.host = f"https://{SERVICE}.{aws_region}.amazonaws.com/" @@ -45,7 +45,7 @@ def __init__( self.session_token = aws_session_token super().__init__(base_url=self.host, timeout=600, **kwargs) - def post(self, path: str, json: Dict, **kwargs: Dict[str, Any]) -> httpx.Response: + def post(self, path: str, json: dict, **kwargs: dict[str, any]) -> httpx.Response: signed_headers = self.sign_and_get_headers( method="POST", url=path, @@ -60,7 +60,7 @@ def sign_and_get_headers( url: str, payload: dict, service: str, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Sign the request and generate the necessary headers for AWS authentication. @@ -72,10 +72,10 @@ def sign_and_get_headers( region (str): The AWS region. access_key (str): The AWS access key. secret_key (str): The AWS secret key. - session_token (Optional[str]): The AWS session token, if any. + session_token (optional[str]): The AWS session token, if any. Returns: - Dict[str, str]: The headers required for the request. + dict[str, str]: The headers required for the request. """ def sign(key: bytes, msg: str) -> bytes: @@ -160,7 +160,7 @@ def __init__(self, client: AwsClient) -> None: self.client = client @classmethod - def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": + def from_env(cls: type["BedrockProvider"]) -> "BedrockProvider": cls.check_env_vars() aws_region = os.environ.get("AWS_REGION", "us-east-1") aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID") @@ -179,22 +179,22 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: """ Generate a completion response from the Bedrock gateway. Args: model (str): The model identifier. system (str): The system prompt or configuration. - messages (List[Message]): A list of messages to be processed by the model. - tools (Tuple[Tool]): A tuple of tools to be used in the completion process. + messages (list[Message]): A list of messages to be processed by the model. + tools (tuple[Tool]): A tuple of tools to be used in the completion process. **kwargs: Additional keyword arguments for inference configuration. Returns: - Tuple[Message, Usage]: A tuple containing the response message and usage data. + tuple[Message, Usage]: A tuple containing the response message and usage data. """ inference_config = dict( @@ -231,7 +231,7 @@ def complete( return self.response_to_message(response_message), usage @retry_procedure - def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401 + def _post(self, payload: any, path: str) -> dict: # noqa: ANN401 response = self.client.post(path, json=payload) return raise_for_status(response).json() @@ -311,7 +311,7 @@ def response_to_message(response_message: dict) -> Message: raise Exception("Invalid response") @staticmethod - def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: + def tools_to_bedrock_spec(tools: tuple[Tool, ...]) -> Optional[dict]: if len(tools) == 0: return None # API requires a non-empty tool config or None tools_added = set() diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index 9bd582dc5..052d78a67 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, List, Tuple, Type - import httpx import os @@ -43,7 +41,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": + def from_env(cls: type["DatabricksProvider"]) -> "DatabricksProvider": cls.check_env_vars(cls.instructions_url) url = os.environ.get("DATABRICKS_HOST") key = os.environ.get("DATABRICKS_TOKEN") @@ -73,10 +71,10 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: payload = dict( messages=[ {"role": "system", "content": system}, diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index 1bcac3205..e5f9312d1 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -1,5 +1,4 @@ import os -from typing import Any, Dict, List, Tuple, Type import httpx @@ -30,7 +29,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": + def from_env(cls: type["GoogleProvider"]) -> "GoogleProvider": cls.check_env_vars(cls.instructions_url) url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) key = os.environ.get("GOOGLE_API_KEY") @@ -45,7 +44,7 @@ def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": return cls(client) @staticmethod - def get_usage(data: Dict) -> Usage: # noqa: ANN401 + def get_usage(data: dict) -> Usage: # noqa: ANN401 usage = data.get("usageMetadata") input_tokens = usage.get("promptTokenCount") output_tokens = usage.get("candidatesTokenCount") @@ -61,7 +60,7 @@ def get_usage(data: Dict) -> Usage: # noqa: ANN401 ) @staticmethod - def google_response_to_message(response: Dict) -> Message: + def google_response_to_message(response: dict) -> Message: candidates = response.get("candidates", []) if candidates: # Only use first candidate for now @@ -85,12 +84,12 @@ def google_response_to_message(response: Dict) -> Message: return Message(role="assistant", content=[]) @staticmethod - def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: + def tools_to_google_spec(tools: tuple[Tool, ...]) -> dict[str, list[dict[str, any]]]: if not tools: return {} converted_tools = [] for tool in tools: - converted_tool: Dict[str, Any] = { + converted_tool: dict[str, any] = { "name": tool.name, "description": tool.description or "", } @@ -100,7 +99,7 @@ def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: return {"functionDeclarations": converted_tools} @staticmethod - def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: + def messages_to_google_spec(messages: list[Message]) -> list[dict[str, any]]: messages_spec = [] for message in messages: role = "user" if message.role == "user" else "model" @@ -136,10 +135,10 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: List[Tool] = [], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: list[Tool] = None, + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: tools_set = set() unique_tools = [] for tool in tools: diff --git a/packages/exchange/src/exchange/providers/groq.py b/packages/exchange/src/exchange/providers/groq.py index aeac0d376..edd0945a1 100644 --- a/packages/exchange/src/exchange/providers/groq.py +++ b/packages/exchange/src/exchange/providers/groq.py @@ -1,5 +1,4 @@ import os -from typing import Any, Dict, List, Tuple, Type import httpx @@ -37,7 +36,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["GroqProvider"]) -> "GroqProvider": + def from_env(cls: type["GroqProvider"]) -> "GroqProvider": cls.check_env_vars(cls.instructions_url) url = os.environ.get("GROQ_HOST", GROQ_HOST) key = os.environ.get("GROQ_API_KEY") @@ -69,10 +68,10 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: system_message = [{"role": "system", "content": system}] payload = dict( messages=system_message + messages_to_openai_spec(messages), diff --git a/packages/exchange/src/exchange/providers/ollama.py b/packages/exchange/src/exchange/providers/ollama.py index f05ea426c..51fef5105 100644 --- a/packages/exchange/src/exchange/providers/ollama.py +++ b/packages/exchange/src/exchange/providers/ollama.py @@ -1,5 +1,4 @@ import os -from typing import Type import httpx @@ -31,7 +30,7 @@ def __init__(self, client: httpx.Client) -> None: super().__init__(client) @classmethod - def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider": + def from_env(cls: type["OllamaProvider"]) -> "OllamaProvider": ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST) timeout = httpx.Timeout(60 * 10) diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py index b25c5a70a..02bdcf4a3 100644 --- a/packages/exchange/src/exchange/providers/openai.py +++ b/packages/exchange/src/exchange/providers/openai.py @@ -1,5 +1,4 @@ import os -from typing import Any, Dict, List, Tuple, Type import httpx @@ -37,7 +36,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": + def from_env(cls: type["OpenAiProvider"]) -> "OpenAiProvider": cls.check_env_vars(cls.instructions_url) url = os.environ.get("OPENAI_HOST", OPENAI_HOST) key = os.environ.get("OPENAI_API_KEY") @@ -69,10 +68,10 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}] payload = dict( messages=system_message + messages_to_openai_spec(messages), diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py index 4be7ac31e..9af7287ef 100644 --- a/packages/exchange/src/exchange/providers/utils.py +++ b/packages/exchange/src/exchange/providers/utils.py @@ -1,7 +1,7 @@ import base64 import json import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Optional import httpx from exchange.content import Text, ToolResult, ToolUse @@ -10,10 +10,10 @@ from tenacity import retry_if_exception -def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable: +def retry_if_status(codes: Optional[list[int]] = None, above: Optional[int] = None) -> callable: codes = codes or [] - def predicate(exc: Exception) -> bool: + def predicate(exc: BaseException) -> bool: if isinstance(exc, httpx.HTTPStatusError): if exc.response.status_code in codes: return True @@ -42,7 +42,7 @@ def encode_image(image_path: str) -> str: return base64.b64encode(image_file.read()).decode("utf-8") -def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]: +def messages_to_openai_spec(messages: list[Message]) -> list[dict[str, any]]: messages_spec = [] for message in messages: converted = {"role": message.role} @@ -106,7 +106,7 @@ def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]: return messages_spec -def tools_to_openai_spec(tools: Tuple[Tool]) -> Dict[str, Any]: +def tools_to_openai_spec(tools: tuple[Tool, ...]) -> dict[str, any]: tools_names = set() result = [] for tool in tools: diff --git a/packages/exchange/src/exchange/token_usage_collector.py b/packages/exchange/src/exchange/token_usage_collector.py index 8f0801062..c99110c29 100644 --- a/packages/exchange/src/exchange/token_usage_collector.py +++ b/packages/exchange/src/exchange/token_usage_collector.py @@ -1,5 +1,4 @@ from collections import defaultdict -from typing import Dict from exchange.providers.base import Usage @@ -11,7 +10,7 @@ def __init__(self) -> None: def collect(self, model: str, usage: Usage) -> None: self.usage_data.append((model, usage)) - def get_token_usage_group_by_model(self) -> Dict[str, Usage]: + def get_token_usage_group_by_model(self) -> dict[str, Usage]: usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0)) for model, usage in self.usage_data: usage_by_model = usage_group_by_model[model] diff --git a/packages/exchange/src/exchange/tool.py b/packages/exchange/src/exchange/tool.py index 4ce9e7c50..1ca1f4358 100644 --- a/packages/exchange/src/exchange/tool.py +++ b/packages/exchange/src/exchange/tool.py @@ -1,5 +1,4 @@ import inspect -from typing import Any, Callable, Type from attrs import define @@ -13,17 +12,17 @@ class Tool: Attributes: name (str): The name of the tool description (str): A description of what the tool does - parameters dict[str, Any]: A json schema of the function signature + parameters dict[str, any]: A json schema of the function signature function (Callable): The python function that powers the tool """ name: str description: str - parameters: dict[str, Any] - function: Callable + parameters: dict[str, any] + function: callable @classmethod - def from_function(cls: Type["Tool"], func: Any) -> "Tool": # noqa: ANN401 + def from_function(cls: type["Tool"], func: any) -> "Tool": # noqa: ANN401 """Create a tool instance from a function and its docstring The function must have a docstring - we require it to load the description diff --git a/packages/exchange/src/exchange/utils.py b/packages/exchange/src/exchange/utils.py index 04d5ffa18..b95f1c485 100644 --- a/packages/exchange/src/exchange/utils.py +++ b/packages/exchange/src/exchange/utils.py @@ -1,7 +1,7 @@ import inspect import uuid from importlib.metadata import entry_points -from typing import Any, Callable, Dict, List, Type, get_args, get_origin +from typing import get_args, get_origin from griffe import ( Docstring, @@ -20,7 +20,7 @@ def compact(content: str) -> str: return " ".join(content.split()) -def parse_docstring(func: Callable) -> tuple[str, List[Dict]]: +def parse_docstring(func: callable) -> tuple[str, list[dict]]: """Get description and parameters from function docstring""" function_args = list(inspect.signature(func).parameters.keys()) text = str(func.__doc__) @@ -71,7 +71,7 @@ def parse_docstring(func: Callable) -> tuple[str, List[Dict]]: def _check_section_is_present( - parsed_docstring: List[DocstringSection], section_type: Type[DocstringSectionText] + parsed_docstring: list[DocstringSection], section_type: type[DocstringSectionText] ) -> bool: for section in parsed_docstring: if isinstance(section, section_type): @@ -79,7 +79,7 @@ def _check_section_is_present( return False -def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401 +def json_schema(func: any) -> dict[str, any]: # noqa: ANN401 """Get the json schema for a function""" signature = inspect.signature(func) parameters = signature.parameters @@ -107,16 +107,16 @@ def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401 return schema -def _map_type_to_schema(py_type: Type) -> Dict[str, Any]: # noqa: ANN401 +def _map_type_to_schema(py_type: type) -> dict[str, any]: # noqa: ANN401 origin = get_origin(py_type) args = get_args(py_type) if origin is list or origin is tuple: - return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)} + return {"type": "array", "items": _map_type_to_schema(args[0] if args else any)} elif origin is dict: return { "type": "object", - "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any), + "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else any), } elif py_type is int: return {"type": "integer"} diff --git a/packages/exchange/tests/providers/conftest.py b/packages/exchange/tests/providers/conftest.py index 1aafb9af7..a747e9ce6 100644 --- a/packages/exchange/tests/providers/conftest.py +++ b/packages/exchange/tests/providers/conftest.py @@ -1,7 +1,6 @@ import json import os import re -from typing import Type, Tuple import pytest import yaml @@ -189,14 +188,14 @@ def scrub_response_headers(response): return response -def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: +def complete(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]: provider = provider_cls.from_env() system = "You are a helpful assistant." messages = [Message.user("Hello")] return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs) -def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: +def tools(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]: provider = provider_cls.from_env() system = "You are a helpful assistant. Expect to need to read a file using read_file." messages = [Message.user("What are the contents of this file? test.txt")] @@ -205,7 +204,7 @@ def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, ) -def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: +def vision(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]: provider = provider_cls.from_env() system = "You are a helpful assistant." messages = [ diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py index 4f040ed3c..44b75d380 100644 --- a/packages/exchange/tests/providers/test_azure.py +++ b/packages/exchange/tests/providers/test_azure.py @@ -14,10 +14,10 @@ @pytest.mark.parametrize( "env_var_name", [ - ("AZURE_CHAT_COMPLETIONS_HOST_NAME"), - ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), - ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), - ("AZURE_CHAT_COMPLETIONS_KEY"), + "AZURE_CHAT_COMPLETIONS_HOST_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", + "AZURE_CHAT_COMPLETIONS_KEY", ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py index f8fcaa4b8..f7b68c034 100644 --- a/packages/exchange/tests/providers/test_bedrock.py +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -15,9 +15,9 @@ @pytest.mark.parametrize( "env_var_name", [ - ("AWS_ACCESS_KEY_ID"), - ("AWS_SECRET_ACCESS_KEY"), - ("AWS_SESSION_TOKEN"), + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py index fdbaba474..cd01335a7 100644 --- a/packages/exchange/tests/providers/test_databricks.py +++ b/packages/exchange/tests/providers/test_databricks.py @@ -10,8 +10,8 @@ @pytest.mark.parametrize( "env_var_name", [ - ("DATABRICKS_HOST"), - ("DATABRICKS_TOKEN"), + "DATABRICKS_HOST", + "DATABRICKS_TOKEN", ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): diff --git a/packages/exchange/tests/providers/test_provider_utils.py b/packages/exchange/tests/providers/test_provider_utils.py index 5ad0135ea..2a6ab729b 100644 --- a/packages/exchange/tests/providers/test_provider_utils.py +++ b/packages/exchange/tests/providers/test_provider_utils.py @@ -107,9 +107,9 @@ def test_messages_to_openai_spec() -> None: Message(role="user", content=[Text("How are you?")]), Message( role="assistant", - content=[ToolUse(id=1, name="tool1", parameters={"param1": "value1"})], + content=[ToolUse(id="1", name="tool1", parameters={"param1": "value1"})], ), - Message(role="user", content=[ToolResult(tool_use_id=1, output="Result")]), + Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), ] spec = messages_to_openai_spec(messages) @@ -121,7 +121,7 @@ def test_messages_to_openai_spec() -> None: "role": "assistant", "tool_calls": [ { - "id": 1, + "id": "1", "type": "function", "function": { "name": "tool1", @@ -133,7 +133,7 @@ def test_messages_to_openai_spec() -> None: { "role": "tool", "content": "Result", - "tool_call_id": 1, + "tool_call_id": "1", }, ] @@ -216,7 +216,7 @@ def test_openai_response_to_message_valid_tooluse() -> None: expect = asdict( Message( role="assistant", - content=[ToolUse(id=1, name="example_fn", parameters={"param": "value"})], + content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], ) ) actual.pop("id") diff --git a/packages/exchange/tests/test_exchange.py b/packages/exchange/tests/test_exchange.py index f01ef4694..34937630c 100644 --- a/packages/exchange/tests/test_exchange.py +++ b/packages/exchange/tests/test_exchange.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import pytest from exchange.checkpoint import Checkpoint, CheckpointData @@ -29,12 +27,12 @@ def no_overlapping_checkpoints(exchange: Exchange) -> bool: return True -def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]: +def checkpoint_to_index_pairs(checkpoints: list[Checkpoint]) -> list[tuple[int, int]]: return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints] class MockProvider(Provider): - def __init__(self, sequence: List[Message], usage_dicts: List[dict]): + def __init__(self, sequence: list[Message], usage_dicts: list[dict]): # We'll use init to provide a preplanned reply sequence self.sequence = sequence self.call_count = 0 @@ -56,11 +54,18 @@ def get_usage(data: dict) -> Usage: total_tokens=total_tokens, ) - def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message: + def complete( + self, + model: str, + system: str, + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: output = self.sequence[self.call_count] usage = self.get_usage(self.usage_dicts[self.call_count]) self.call_count += 1 - return (output, usage) + return output, usage def test_reply_with_unsupported_tool(): @@ -116,7 +121,7 @@ def test_invalid_tool_parameters(): ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) @@ -154,7 +159,7 @@ def test_max_tool_use_when_limit_reached(): ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) @@ -195,7 +200,7 @@ def long_output_tool_char() -> str: ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(long_output_tool_char)], + tools=(Tool.from_function(long_output_tool_char),), moderator=PassiveModerator(), ) @@ -236,7 +241,7 @@ def long_output_tool_token() -> str: ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(long_output_tool_token)], + tools=(Tool.from_function(long_output_tool_token),), moderator=PassiveModerator(), ) @@ -301,7 +306,7 @@ def resumed_exchange() -> Exchange: ex = Exchange( provider=provider, messages=messages, - tools=[], + tools=(), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", checkpoint_data=CheckpointData(), @@ -399,7 +404,7 @@ def test_pop_first_message_no_messages(): provider=MockProvider(sequence=[], usage_dicts=[]), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) @@ -741,7 +746,7 @@ def test_rewind_with_tool_usage(): ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) ex.add(Message(role="user", content=[Text(text="test")])) diff --git a/packages/exchange/tests/test_exchange_frozen.py b/packages/exchange/tests/test_exchange_frozen.py index a3095b3a3..9d227afab 100644 --- a/packages/exchange/tests/test_exchange_frozen.py +++ b/packages/exchange/tests/test_exchange_frozen.py @@ -9,7 +9,7 @@ class MockProvider(Provider): - def complete(self, model, system, messages, tools=None): + def complete(self, model, system, messages, tools, **kwargs): return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict( {"total_tokens": 35} ) diff --git a/packages/exchange/tests/test_summarizer.py b/packages/exchange/tests/test_summarizer.py index fa7281920..7920fe317 100644 --- a/packages/exchange/tests/test_summarizer.py +++ b/packages/exchange/tests/test_summarizer.py @@ -3,11 +3,11 @@ from exchange.content import ToolResult, ToolUse from exchange.moderators.passive import PassiveModerator from exchange.moderators.summarizer import ContextSummarizer -from exchange.providers import Usage +from exchange.providers import Usage, Provider -class MockProvider: - def complete(self, model, system, messages, tools): +class MockProvider(Provider): + def complete(self, model, system, messages, tools, **kwargs): assistant_message_text = "Summarized content here." output_tokens = len(assistant_message_text) total_input_tokens = sum(len(msg.text) for msg in messages) @@ -138,14 +138,14 @@ def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_inst ] -class AnotherMockProvider: +class AnotherMockProvider(Provider): def __init__(self): self.sequence = MESSAGE_SEQUENCE self.current_index = 1 self.summarize_next = False self.summarized_count = 0 - def complete(self, model, system, messages, tools): + def complete(self, model, system, messages, tools, **kwargs): system_prompt_tokens = 100 input_token_count = system_prompt_tokens diff --git a/packages/exchange/tests/test_truncate.py b/packages/exchange/tests/test_truncate.py index 3875303e7..eeb993ff1 100644 --- a/packages/exchange/tests/test_truncate.py +++ b/packages/exchange/tests/test_truncate.py @@ -73,7 +73,7 @@ def __init__(self): self.summarize_next = False self.summarized_count = 0 - def complete(self, model, system, messages, tools): + def complete(self, model, system, messages, tools, **kwargs): input_token_count = SYSTEM_PROMPT_TOKENS message = self.sequence[self.current_index] diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index 109706b79..9b613ee1c 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -1,6 +1,6 @@ from functools import cache from pathlib import Path -from typing import Callable, Dict, Mapping, Optional, Tuple +from typing import Mapping, Optional from rich import print from rich.panel import Panel @@ -20,7 +20,7 @@ @cache -def default_profiles() -> Mapping[str, Callable]: +def default_profiles() -> Mapping[str, callable]: return load_plugins(group="goose.profile") @@ -29,7 +29,7 @@ def session_path(name: str) -> Path: return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}") -def write_config(profiles: Dict[str, Profile]) -> None: +def write_config(profiles: dict[str, Profile]) -> None: """Overwrite the config with the passed profiles""" PROFILES_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True) converted = {name: profile.to_dict() for name, profile in profiles.items()} @@ -38,7 +38,7 @@ def write_config(profiles: Dict[str, Profile]) -> None: yaml.dump(converted, f) -def ensure_config(name: Optional[str]) -> Tuple[str, Profile]: +def ensure_config(name: Optional[str]) -> tuple[str, Profile]: """Ensure that the config exists and has the default section""" # TODO we should copy a templated default config in to better document # but this is complicated a bit by autodetecting the provider @@ -70,7 +70,7 @@ def ensure_config(name: Optional[str]) -> Tuple[str, Profile]: return (name, default_profile) -def read_config() -> Dict[str, Profile]: +def read_config() -> dict[str, Profile]: """Return config from the configuration file and validates its contents""" yaml = YAML() @@ -80,7 +80,7 @@ def read_config() -> Dict[str, Profile]: return {name: Profile(**profile) for name, profile in data.items()} -def default_model_configuration() -> Tuple[str, str, str]: +def default_model_configuration() -> tuple[str, str, str]: providers = load_plugins(group="exchange.provider") for provider, cls in providers.items(): try: diff --git a/src/goose/cli/prompt/completer.py b/src/goose/cli/prompt/completer.py index 6739d1530..fb453fd38 100644 --- a/src/goose/cli/prompt/completer.py +++ b/src/goose/cli/prompt/completer.py @@ -1,5 +1,4 @@ import re -from typing import List from prompt_toolkit.completion import CompleteEvent, Completer, Completion from prompt_toolkit.document import Document @@ -8,10 +7,10 @@ class GoosePromptCompleter(Completer): - def __init__(self, commands: List[Command]) -> None: + def __init__(self, commands: list[Command]) -> None: self.commands = commands - def get_command_completions(self, document: Document) -> List[Completion]: + def get_command_completions(self, document: Document) -> list[Completion]: all_completions = [] for command_name, command_instance in self.commands.items(): pattern = rf"(? List[Completion]: all_completions.extend(completions) return all_completions - def get_command_name_completions(self, document: Document) -> List[Completion]: + def get_command_name_completions(self, document: Document) -> list[Completion]: pattern = r"(? List[Completion]: completions.append(Completion(command_name, start_position=-len(query), display=command_name)) return completions - def get_completions(self, document: Document, _: CompleteEvent) -> List[Completion]: + def get_completions(self, document: Document, _: CompleteEvent) -> list[Completion]: command_completions = self.get_command_completions(document) command_name_completions = self.get_command_name_completions(document) return command_name_completions + command_completions diff --git a/src/goose/cli/prompt/lexer.py b/src/goose/cli/prompt/lexer.py index 0e2bb0c91..e00fd207a 100644 --- a/src/goose/cli/prompt/lexer.py +++ b/src/goose/cli/prompt/lexer.py @@ -1,5 +1,5 @@ import re -from typing import Callable, List, Tuple +from typing import Callable from prompt_toolkit.document import Document from prompt_toolkit.lexers import Lexer @@ -27,7 +27,7 @@ def value_for_command(command_string: str) -> re.Pattern[str]: class PromptLexer(Lexer): - def __init__(self, command_names: List[str]) -> None: + def __init__(self, command_names: list[str]) -> None: self.patterns = [] for command_name in command_names: self.patterns.append((completion_for_command(command_name), "class:command")) @@ -35,7 +35,7 @@ def __init__(self, command_names: List[str]) -> None: self.patterns.append((command_itself(command_name), "class:command")) def lex_document(self, document: Document) -> Callable[[int], list]: - def get_line_tokens(line_number: int) -> Tuple[str, str]: + def get_line_tokens(line_number: int) -> tuple[str, str]: line = document.lines[line_number] tokens = [] diff --git a/src/goose/cli/prompt/overwrite_session_prompt.py b/src/goose/cli/prompt/overwrite_session_prompt.py index 1d90cbb1f..64bbeed61 100644 --- a/src/goose/cli/prompt/overwrite_session_prompt.py +++ b/src/goose/cli/prompt/overwrite_session_prompt.py @@ -1,10 +1,8 @@ -from typing import Any - from rich.prompt import Prompt class OverwriteSessionPrompt(Prompt): - def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None: + def __init__(self, *args: tuple[any], **kwargs: dict[str, any]) -> None: super().__init__(*args, **kwargs) self.choices = { "yes": "Overwrite the existing session", diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index d39e66eea..666901b6a 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,6 +1,6 @@ import traceback from pathlib import Path -from typing import Any, Optional +from typing import Optional from exchange import Message, Text, ToolResult, ToolUse from rich import print @@ -62,7 +62,7 @@ def __init__( profile: Optional[str] = None, plan: Optional[dict] = None, log_level: Optional[str] = "INFO", - **kwargs: dict[str, Any], + **kwargs: dict[str, any], ) -> None: if name is None: self.name = droid() diff --git a/src/goose/command/__init__.py b/src/goose/command/__init__.py index d9fd674a4..cef47fec9 100644 --- a/src/goose/command/__init__.py +++ b/src/goose/command/__init__.py @@ -1,5 +1,4 @@ from functools import cache -from typing import Dict from goose.command.base import Command from goose.utils import load_plugins @@ -11,5 +10,5 @@ def get_command(name: str) -> type[Command]: @cache -def get_commands() -> Dict[str, type[Command]]: +def get_commands() -> dict[str, type[Command]]: return load_plugins(group="goose.command") diff --git a/src/goose/command/base.py b/src/goose/command/base.py index 5a8c346ff..081453de6 100644 --- a/src/goose/command/base.py +++ b/src/goose/command/base.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List, Optional +from typing import Optional from prompt_toolkit.completion import Completion @@ -7,7 +7,7 @@ class Command(ABC): """A command that can be executed by the CLI.""" - def get_completions(self, query: str) -> List[Completion]: + def get_completions(self, query: str) -> list[Completion]: """ Get completions for the command. diff --git a/src/goose/command/file.py b/src/goose/command/file.py index cb8bdfd67..786785cf8 100644 --- a/src/goose/command/file.py +++ b/src/goose/command/file.py @@ -1,5 +1,4 @@ import os -from typing import List from prompt_toolkit.completion import Completion @@ -7,7 +6,7 @@ class FileCommand(Command): - def get_completions(self, query: str) -> List[Completion]: + def get_completions(self, query: str) -> list[Completion]: if query.startswith("/"): directory = os.path.dirname(query) search_term = os.path.basename(query) diff --git a/src/goose/profile.py b/src/goose/profile.py index b1fb2ad8c..a4de3409e 100644 --- a/src/goose/profile.py +++ b/src/goose/profile.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Mapping, Type +from typing import Mapping from attrs import asdict, define, field @@ -21,10 +21,10 @@ class Profile: processor: str accelerator: str moderator: str - toolkits: List[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec)) + toolkits: list[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec)) @toolkits.validator - def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None: + def check_toolkit_requirements(self, _: type["ToolkitSpec"], toolkits: list[ToolkitSpec]) -> None: # checks that the list of toolkits in the profile have their requirements installed_toolkits = set([toolkit.name for toolkit in toolkits]) @@ -36,7 +36,7 @@ def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[Tool msg = f"Toolkit {toolkit_name} requires {req} but it is not present" raise ValueError(msg) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, any]: return asdict(self) def profile_info(self) -> str: @@ -44,7 +44,7 @@ def profile_info(self) -> str: return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}" -def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile: +def default_profile(provider: str, processor: str, accelerator: str, **kwargs: dict[str, any]) -> Profile: """Get the default profile""" # TODO consider if the providers should have recommended models diff --git a/src/goose/toolkit/base.py b/src/goose/toolkit/base.py index d26630ca4..d6c232fb5 100644 --- a/src/goose/toolkit/base.py +++ b/src/goose/toolkit/base.py @@ -1,6 +1,6 @@ import inspect from abc import ABC -from typing import Callable, Mapping, Optional, Tuple, TypeVar +from typing import Mapping, Optional, TypeVar from attrs import define, field from exchange import Tool @@ -8,7 +8,7 @@ from goose.notifier import Notifier # Create a type variable that can represent any function signature -F = TypeVar("F", bound=Callable) +F = TypeVar("F", bound=callable) def tool(func: F) -> F: @@ -55,7 +55,7 @@ def system(self) -> str: """Get the addition to the system prompt for this toolkit.""" return "" - def tools(self) -> Tuple[Tool, ...]: + def tools(self) -> tuple[Tool, ...]: """Get the tools for this toolkit This default method looks for functions on the toolkit annotated diff --git a/src/goose/toolkit/developer.py b/src/goose/toolkit/developer.py index ba600d921..b48a18069 100644 --- a/src/goose/toolkit/developer.py +++ b/src/goose/toolkit/developer.py @@ -3,7 +3,6 @@ import subprocess import time from pathlib import Path -from typing import Dict, List from exchange import Message from goose.toolkit.base import Toolkit, tool @@ -35,9 +34,9 @@ class Developer(Toolkit): We also include some default shell strategies in the prompt, such as using ripgrep """ - def __init__(self, *args: object, **kwargs: Dict[str, object]) -> None: + def __init__(self, *args: object, **kwargs: dict[str, object]) -> None: super().__init__(*args, **kwargs) - self.timestamps: Dict[str, float] = {} + self.timestamps: dict[str, float] = {} def system(self) -> str: """Retrieve system configuration details for developer""" @@ -55,7 +54,7 @@ def system(self) -> str: return system_prompt @tool - def update_plan(self, tasks: List[dict]) -> List[dict]: + def update_plan(self, tasks: list[dict]) -> list[dict]: """ Update the plan by overwriting all current tasks @@ -63,7 +62,7 @@ def update_plan(self, tasks: List[dict]) -> List[dict]: shown to the user directly, you do not need to reiterate it Args: - tasks (List(dict)): The list of tasks, where each task is a dictionary + tasks (list(dict)): The list of tasks, where each task is a dictionary with a key for the task "description" and the task "status". The status MUST be one of "planned", "complete", "failed", "in-progress". diff --git a/src/goose/toolkit/repo_context/repo_context.py b/src/goose/toolkit/repo_context/repo_context.py index 8be8794f6..0f443930e 100644 --- a/src/goose/toolkit/repo_context/repo_context.py +++ b/src/goose/toolkit/repo_context/repo_context.py @@ -1,7 +1,6 @@ import os from functools import cache from subprocess import CompletedProcess, run -from typing import Dict, Tuple from exchange import Message @@ -21,7 +20,7 @@ def __init__(self, notifier: Notifier, requires: Requirements) -> None: self.repo_project_root, self.is_git_repo, self.goose_session_root = self.determine_git_proj() - def determine_git_proj(self) -> Tuple[str, bool, str]: + def determine_git_proj(self) -> tuple[str, bool, str]: """Determines the root as well as where Goose is currently running If the project is not part of a Github repo, the root of the project will be defined as the current working @@ -72,11 +71,11 @@ def is_mono_repo(self) -> bool: return self.repo_size > 2000 @tool - def summarize_current_project(self) -> Dict[str, str]: + def summarize_current_project(self) -> dict[str, str]: """Summarizes the current project based on repo root (if git repo) or current project_directory (if not) Returns: - summary (Dict[str, str]): Keys are file paths and values are the summaries + summary (dict[str, str]): Keys are file paths and values are the summaries """ self.notifier.log("Summarizing the most relevant files in the current project. This may take a while...") diff --git a/src/goose/toolkit/repo_context/utils.py b/src/goose/toolkit/repo_context/utils.py index dca7f04b0..e69cea936 100644 --- a/src/goose/toolkit/repo_context/utils.py +++ b/src/goose/toolkit/repo_context/utils.py @@ -2,7 +2,6 @@ import concurrent.futures import os from collections import deque -from typing import Dict, List, Tuple from exchange import Exchange @@ -26,7 +25,7 @@ def get_repo_size(repo_path: str) -> int: return get_directory_size(git_dir) / (1024**2) -def get_files_and_directories(root_dir: str) -> Dict[str, list]: +def get_files_and_directories(root_dir: str) -> dict[str, list]: """Gets file names and directory names. Checks that goose has correctly typed the file and directory names and that the files actually exist (to avoid downstream file read errors). @@ -61,7 +60,7 @@ def get_files_and_directories(root_dir: str) -> Dict[str, list]: return {"files": files, "directories": dirs} -def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> List[str]: +def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> list[str]: """Lets goose pick files in a BFS manner""" queue = deque([root]) @@ -80,7 +79,7 @@ def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> Li return all_files -def process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]: +def process_directory(current_dir: str, exchange: Exchange) -> tuple[list[str], list[str]]: """Allows goose to pick files and subdirectories contained in a given directory (current_dir). Get the list of file and directory names in the current folder, then ask Goose to pick which ones to keep. diff --git a/src/goose/toolkit/summarization/summarize_project.py b/src/goose/toolkit/summarization/summarize_project.py index d910fbc47..f5e22562c 100644 --- a/src/goose/toolkit/summarization/summarize_project.py +++ b/src/goose/toolkit/summarization/summarize_project.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional +from typing import Optional from goose.toolkit import Toolkit from goose.toolkit.base import tool @@ -11,7 +11,7 @@ class SummarizeProject(Toolkit): def get_project_summary( self, project_dir_path: Optional[str] = os.getcwd(), - extensions: Optional[List[str]] = None, + extensions: Optional[list[str]] = None, summary_instructions_prompt: Optional[str] = None, ) -> dict: """Generates or retrieves a project summary based on specified file extensions. @@ -19,7 +19,7 @@ def get_project_summary( Args: project_dir_path (Optional[Path]): Path to the project directory. Defaults to the current working directory if None - extensions (Optional[List[str]]): Specific file extensions to summarize. + extensions (Optional[list[str]]): Specific file extensions to summarize. summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g. "Summarize the file in two sentences.". The default instruction is "Please summarize this file." diff --git a/src/goose/toolkit/summarization/summarize_repo.py b/src/goose/toolkit/summarization/summarize_repo.py index 18c7da428..58765dd9a 100644 --- a/src/goose/toolkit/summarization/summarize_repo.py +++ b/src/goose/toolkit/summarization/summarize_repo.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from goose.toolkit import Toolkit from goose.toolkit.base import tool @@ -10,7 +10,7 @@ class SummarizeRepo(Toolkit): def summarize_repo( self, repo_url: str, - specified_extensions: Optional[List[str]] = None, + specified_extensions: Optional[list[str]] = None, summary_instructions_prompt: Optional[str] = None, ) -> dict: """ @@ -19,7 +19,7 @@ def summarize_repo( Args: repo_url (str): The URL of the repository to summarize. - specified_extensions (Optional[List[str]]): List of file extensions to summarize, e.g., ["tf", "md"]. If + specified_extensions (Optional[list[str]]): list of file extensions to summarize, e.g., ["tf", "md"]. If this list is empty, then all files in the repo are summarized summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g. "Summarize the file in two sentences.". The default instruction is "Please summarize this file." diff --git a/src/goose/toolkit/summarization/utils.py b/src/goose/toolkit/summarization/utils.py index d398713cc..96e5d363d 100644 --- a/src/goose/toolkit/summarization/utils.py +++ b/src/goose/toolkit/summarization/utils.py @@ -2,7 +2,7 @@ import subprocess from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Optional from exchange import Exchange from exchange.providers.utils import InitialMessageTooLargeError @@ -15,7 +15,7 @@ # TODO: move git stuff -def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]: +def run_git_command(command: list[str]) -> subprocess.CompletedProcess[str]: result = subprocess.run(["git"] + command, capture_output=True, text=True, check=False) if result.returncode != 0: @@ -28,7 +28,7 @@ def clone_repo(repo_url: str, target_directory: str) -> None: run_git_command(["clone", repo_url, target_directory]) -def load_summary_file_if_exists(project_name: str) -> Optional[Dict]: +def load_summary_file_if_exists(project_name: str) -> Optional[dict]: """Checks if a summary file exists at '.goose/summaries/projectname-summary.json. Returns contents of the file if it exists, otherwise returns None @@ -36,7 +36,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]: project_name (str): name of the project or repo Returns: - Optional[Dict]: File contents, else None + Optional[dict]: File contents, else None """ summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json" if Path(summary_file_path).exists(): @@ -44,7 +44,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]: return json.load(f) -def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> Tuple[str, str]: +def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> tuple[str, str]: """Summarizes a single file Args: @@ -74,15 +74,15 @@ def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = No def summarize_repo( repo_url: str, exchange: Exchange, - extensions: List[str], + extensions: list[str], summary_instructions_prompt: Optional[str] = None, -) -> Dict[str, str]: +) -> dict[str, str]: """Clones (if needed) and summarizes a repo Args: repo_url (str): Repository url exchange (Exchange): Exchange for summarizing the repo. - extensions (List[str]): List of file-types to summarize. + extensions (list[str]): list of file-types to summarize. summary_instructions_prompt (Optional[str]): Optional parameter to customize summarization results. Defaults to "Please summarize this file" """ @@ -110,15 +110,15 @@ def summarize_repo( def summarize_directory( - directory: str, exchange: Exchange, extensions: List[str], summary_instructions_prompt: Optional[str] = None -) -> Dict[str, str]: + directory: str, exchange: Exchange, extensions: list[str], summary_instructions_prompt: Optional[str] = None +) -> dict[str, str]: """Summarize files in a given directory based on extensions. Will also recursively find files in subdirectories and summarize them. Args: directory (str): path to the top-level directory to summarize exchange (Exchange): Exchange to use to summarize - extensions (List[str]): List of file-type extensions to summarize (and ignore all other extensions). + extensions (list[str]): list of file-type extensions to summarize (and ignore all other extensions). summary_instructions_prompt (Optional[str]): Optional instructions to give to the exchange regarding summarization. Returns: @@ -158,19 +158,19 @@ def summarize_directory( def summarize_files_concurrent( - exchange: Exchange, file_list: List[str], project_name: str, summary_instructions_prompt: Optional[str] = None -) -> Dict[str, str]: + exchange: Exchange, file_list: list[str], project_name: str, summary_instructions_prompt: Optional[str] = None +) -> dict[str, str]: """Takes in a list of files and summarizes them. Exchange does not keep history of the summarized files. Args: exchange (Exchange): Underlying exchange - file_list (List[str]): List of paths to files to summarize + file_list (list[str]): list of paths to files to summarize project_name (str): Used to save the summary of the files to .goose/summaries/-summary.json summary_instructions_prompt (Optional[str]): Summary instructions for the LLM. Defaults to "Please summarize this file." Returns: - file_summaries (Dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange + file_summaries (dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange """ summary_file = load_summary_file_if_exists(project_name) if summary_file: diff --git a/src/goose/toolkit/utils.py b/src/goose/toolkit/utils.py index ad97360f2..d6f0335b1 100644 --- a/src/goose/toolkit/utils.py +++ b/src/goose/toolkit/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Dict +from typing import Optional from pygments.lexers import get_lexer_for_filename from pygments.util import ClassNotFound @@ -67,7 +67,7 @@ def find_last_task_group_index(input_str: str) -> int: return last_group_start_index -def parse_plan(input_plan_str: str) -> Dict: +def parse_plan(input_plan_str: str) -> dict: last_group_start_index = find_last_task_group_index(input_plan_str) if last_group_start_index == -1: return {"kickoff_message": input_plan_str, "tasks": []} diff --git a/src/goose/utils/__init__.py b/src/goose/utils/__init__.py index 69887e7f0..d9535b187 100644 --- a/src/goose/utils/__init__.py +++ b/src/goose/utils/__init__.py @@ -1,7 +1,7 @@ import random import string from importlib.metadata import entry_points -from typing import Any, Callable, Dict, List, Type, TypeVar +from typing import TypeVar, Callable T = TypeVar("T") @@ -31,10 +31,10 @@ def load_plugins(group: str) -> dict: return plugins -def ensure(cls: Type[T]) -> Callable[[Any], T]: +def ensure(cls: type[T]) -> Callable[[any], T]: """Convert dictionary to a class instance""" - def converter(val: Any) -> T: # noqa: ANN401 + def converter(val: any) -> T: # noqa: ANN401 if isinstance(val, cls): return val elif isinstance(val, dict): @@ -47,10 +47,10 @@ def converter(val: Any) -> T: # noqa: ANN401 return converter -def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]: +def ensure_list(cls: type[T]) -> Callable[[list[dict[str, any]]], type[T]]: """Convert a list of dictionaries to class instances""" - def converter(val: List[Dict[str, Any]]) -> List[T]: + def converter(val: list[dict[str, any]]) -> list[T]: output = [] for entry in val: output.append(ensure(cls)(entry)) diff --git a/src/goose/utils/file_utils.py b/src/goose/utils/file_utils.py index eabc50f73..1531ad651 100644 --- a/src/goose/utils/file_utils.py +++ b/src/goose/utils/file_utils.py @@ -2,7 +2,7 @@ import os from collections import Counter from pathlib import Path -from typing import Dict, List, Optional +from typing import Optional def create_extensions_list(project_root: str, max_n: int) -> list: @@ -11,7 +11,7 @@ def create_extensions_list(project_root: str, max_n: int) -> list: project_root (str): Root of the project to analyze max_n (int): The number of file extensions to return Returns: - extensions (List[str]): A list of the top N file extensions + extensions (list[str]): A list of the top N file extensions """ if max_n == 0: raise (ValueError("Number of file extensions must be greater than 0")) @@ -31,14 +31,14 @@ def create_extensions_list(project_root: str, max_n: int) -> list: return extensions -def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]: +def create_language_weighting(files_in_directory: list[str]) -> dict[str, float]: """Calculate language weighting by file size to match GitHub's methodology. Args: - files_in_directory (List[str]): Paths to files in the project directory + files_in_directory (list[str]): Paths to files in the project directory Returns: - Dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values + dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values """ # Initialize counters for sizes @@ -59,7 +59,7 @@ def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float] return dict(sorted(language_percentages.items(), key=lambda item: item[1], reverse=True)) -def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> List[str]: +def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> list[str]: """List all files in a directory with a given extension. Set extension to '' to return all files. Args: @@ -67,7 +67,7 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L extension (Optional[str]): extension to lookup. Defaults to '' which will return all files. Returns: - files (List[str]): List of file paths + files (list[str]): list of file paths """ # add a leading '.' to extension if needed if extension and not extension.startswith("."): @@ -77,15 +77,15 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L return files -def create_file_list(dir_path: str, extensions: List[str]) -> List[str]: +def create_file_list(dir_path: str, extensions: list[str]) -> list[str]: """Creates a list of files with certain extensions Args: dir_path (str): Directory to list files of. Will include files recursively in sub-directories. - extensions (List[str]): List of file extensions to select for. If empty list, return all files + extensions (list[str]): list of file extensions to select for. If empty list, return all files Returns: - final_file_list (List[str]): List of file paths with specified extensions. + final_file_list (list[str]): list of file paths with specified extensions. """ # if extensions is empty list, return all files if not extensions: diff --git a/src/goose/utils/session_file.py b/src/goose/utils/session_file.py index e367dcf1f..510b4aa19 100644 --- a/src/goose/utils/session_file.py +++ b/src/goose/utils/session_file.py @@ -2,7 +2,7 @@ import os import tempfile from pathlib import Path -from typing import Dict, Iterator, List +from typing import Iterator from exchange import Message @@ -17,12 +17,12 @@ def is_empty_session(path: Path) -> bool: return path.is_file() and path.stat().st_size == 0 -def write_to_file(file_path: Path, messages: List[Message]) -> None: +def write_to_file(file_path: Path, messages: list[Message]) -> None: with open(file_path, "w") as f: _write_messages_to_file(f, messages) -def read_or_create_file(file_path: Path) -> List[Message]: +def read_or_create_file(file_path: Path) -> list[Message]: if file_path.exists(): return read_from_file(file_path) with open(file_path, "w"): @@ -30,7 +30,7 @@ def read_or_create_file(file_path: Path) -> List[Message]: return [] -def read_from_file(file_path: Path) -> List[Message]: +def read_from_file(file_path: Path) -> list[Message]: try: with open(file_path, "r") as f: messages = [json.loads(m) for m in list(f) if m.strip()] @@ -40,7 +40,7 @@ def read_from_file(file_path: Path) -> List[Message]: return [Message(**m) for m in messages] -def list_sorted_session_files(session_files_directory: Path) -> Dict[str, Path]: +def list_sorted_session_files(session_files_directory: Path) -> dict[str, Path]: logs = list_session_files(session_files_directory) return {log.stem: log for log in sorted(logs, key=lambda x: x.stat().st_mtime, reverse=True)} @@ -55,7 +55,7 @@ def session_file_exists(session_files_directory: Path) -> bool: return any(list_session_files(session_files_directory)) -def save_latest_session(file_path: Path, messages: List[Message]) -> None: +def save_latest_session(file_path: Path, messages: list[Message]) -> None: with tempfile.NamedTemporaryFile("w", delete=False) as temp_file: _write_messages_to_file(temp_file, messages) temp_file_path = temp_file.name @@ -63,7 +63,7 @@ def save_latest_session(file_path: Path, messages: List[Message]) -> None: os.replace(temp_file_path, file_path) -def _write_messages_to_file(file: any, messages: List[Message]) -> None: +def _write_messages_to_file(file: any, messages: list[Message]) -> None: for m in messages: json.dump(m.to_dict(), file) file.write("\n")