diff --git a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py index c0e921d861f..a76b4bb50f0 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py @@ -1,8 +1,8 @@ import logging # added import import re -from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, cast +from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union, cast -from autogen_core import EVENT_LOGGER_NAME, FunctionCall, MessageHandlerContext +from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, MessageHandlerContext from autogen_core.logging import LLMCallEvent from autogen_core.models import ( AssistantMessage, @@ -10,6 +10,7 @@ CreateResult, FinishReasons, FunctionExecutionResultMessage, + LLMMessage, ModelInfo, RequestUsage, SystemMessage, @@ -31,6 +32,36 @@ logger = logging.getLogger(EVENT_LOGGER_NAME) # initialize logger +def from_pretrained( + model_path: str, + repo_id: str | None = None, + model_info: ModelInfo | None = None, + additional_files: List[str] | None = None, + local_dir: str | None = None, + local_dir_use_symlinks: str = "auto", + cache_dir: str | None = None, + **kwargs: Any, +) -> Llama: + """ + Load a model from the Hugging Face Hub or a local directory. + + :param repo_id: The repository ID of the model. + :param filename: The filename of the model. + :param model_info: The model info. + :param additional_files: Additional files to download. + :param local_dir: The local directory to load the model from. + :param local_dir_use_symlinks: Whether to use symlinks for the local directory. + :param cache_dir: The cache directory. + :param kwargs: Additional keyword arguments. + :return: The loaded model. + """ + if repo_id: + return Llama.from_pretrained(repo_id=repo_id, filename=model_path, **kwargs) # pyright: ignore[reportUnknownMemberType] + # The partially unknown type is in the `llama_cpp` package + else: + return Llama(model_path=model_path, **kwargs) + + def normalize_stop_reason(stop_reason: str | None) -> FinishReasons: if stop_reason is None: return "unknown" @@ -107,24 +138,26 @@ class LlamaCppChatCompletionClient(ChatCompletionClient): def __init__( self, filename: str, + repo_id: str | None = None, + model_info: ModelInfo | None = None, **kwargs: Any, - ): + ) -> None: """ Initialize the LlamaCpp client. """ - self.llm: Llama = ( - Llama.from_pretrained(filename=filename, repo_id=kwargs.pop("repo_id"), **kwargs) # pyright: ignore[reportUnknownMemberType] - # The partially unknown type is in the `llama_cpp` package - if "repo_id" in kwargs - else Llama(model_path=filename, **kwargs) - ) + self.llm: Llama = from_pretrained(repo_id=repo_id, model_path=filename, model_info=model_info, kwargs=kwargs) self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0} async def create( self, - messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], - tools: Optional[Sequence[Tool | ToolSchema]] = None, - **kwargs: Any, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + # None means do not override the default + # A value means to override the client default - often specified in the constructor + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, ) -> CreateResult: tools = tools or [] @@ -192,7 +225,7 @@ async def create( # Parse the response response_tool_calls: ChatCompletionTool | None = None - response_text:str| None = None + response_text: str | None = None if "choices" in response and len(response["choices"]) > 0: if "message" in response["choices"][0]: response_text = response["choices"][0]["message"]["content"] @@ -230,28 +263,19 @@ async def create( ) # Create a CreateResult object - # breakpoint() if "finish_reason" in response["choices"][0]: finish_reason = response["choices"][0]["finish_reason"] else: finish_reason = "unknown" if finish_reason not in ("stop", "length", "function_calls", "content_filter", "unknown"): finish_reason = "unknown" - if thought: - create_result = CreateResult( - content=content, - thought=thought, - usage=cast(RequestUsage, response["usage"]), - finish_reason=normalize_stop_reason(finish_reason), # type: ignore - cached=False, - ) - else: - create_result = CreateResult( - content=content, - usage=cast(RequestUsage, response["usage"]), - finish_reason=normalize_stop_reason(finish_reason), # type: ignore - cached=False, - ) + create_result = CreateResult( + content=content, + thought=thought, + usage=cast(RequestUsage, response["usage"]), + finish_reason=normalize_stop_reason(finish_reason), # type: ignore + cached=False, + ) # If we are running in the context of a handler we can get the agent_id try: @@ -312,13 +336,17 @@ def _extract_tool_arguments(self, response_text: str) -> str: async def create_stream( self, - messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], - tools: Optional[Sequence[Tool | ToolSchema]] = None, - **kwargs: Any, - ) -> AsyncGenerator[str, None]: + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + # None means do not override the default + # A value means to override the client default - often specified in the constructor + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> AsyncGenerator[Union[str, CreateResult], None]: raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient") - if False: # Unreachable code to satisfy the return type. - yield "" + yield "" # Implement abstract methods def actual_usage(self) -> RequestUsage: