Skip to content

Commit

Permalink
feat: add from_pretrained function for model loading; enhance LlamaCp…
Browse files Browse the repository at this point in the history
…pChatCompletionClient initialization and create methods
  • Loading branch information
xhabit committed Feb 24, 2025
1 parent c29ec8a commit 35f35ca
Showing 1 changed file with 63 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import logging # added import
import re
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, cast
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import EVENT_LOGGER_NAME, FunctionCall, MessageHandlerContext
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, MessageHandlerContext
from autogen_core.logging import LLMCallEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelInfo,
RequestUsage,
SystemMessage,
Expand All @@ -31,6 +32,36 @@
logger = logging.getLogger(EVENT_LOGGER_NAME) # initialize logger


def from_pretrained(
model_path: str,
repo_id: str | None = None,
model_info: ModelInfo | None = None,
additional_files: List[str] | None = None,
local_dir: str | None = None,
local_dir_use_symlinks: str = "auto",
cache_dir: str | None = None,
**kwargs: Any,
) -> Llama:
"""
Load a model from the Hugging Face Hub or a local directory.
:param repo_id: The repository ID of the model.
:param filename: The filename of the model.
:param model_info: The model info.
:param additional_files: Additional files to download.
:param local_dir: The local directory to load the model from.
:param local_dir_use_symlinks: Whether to use symlinks for the local directory.
:param cache_dir: The cache directory.
:param kwargs: Additional keyword arguments.
:return: The loaded model.
"""
if repo_id:
return Llama.from_pretrained(repo_id=repo_id, filename=model_path, **kwargs) # pyright: ignore[reportUnknownMemberType]
# The partially unknown type is in the `llama_cpp` package
else:
return Llama(model_path=model_path, **kwargs)


def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
if stop_reason is None:
return "unknown"
Expand Down Expand Up @@ -107,24 +138,26 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
def __init__(
self,
filename: str,
repo_id: str | None = None,
model_info: ModelInfo | None = None,
**kwargs: Any,
):
) -> None:
"""
Initialize the LlamaCpp client.
"""
self.llm: Llama = (
Llama.from_pretrained(filename=filename, repo_id=kwargs.pop("repo_id"), **kwargs) # pyright: ignore[reportUnknownMemberType]
# The partially unknown type is in the `llama_cpp` package
if "repo_id" in kwargs
else Llama(model_path=filename, **kwargs)
)
self.llm: Llama = from_pretrained(repo_id=repo_id, model_path=filename, model_info=model_info, kwargs=kwargs)
self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0}

async def create(
self,
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
tools: Optional[Sequence[Tool | ToolSchema]] = None,
**kwargs: Any,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
tools = tools or []

Expand Down Expand Up @@ -192,7 +225,7 @@ async def create(

# Parse the response
response_tool_calls: ChatCompletionTool | None = None
response_text:str| None = None
response_text: str | None = None
if "choices" in response and len(response["choices"]) > 0:
if "message" in response["choices"][0]:
response_text = response["choices"][0]["message"]["content"]
Expand Down Expand Up @@ -230,28 +263,19 @@ async def create(
)

# Create a CreateResult object
# breakpoint()
if "finish_reason" in response["choices"][0]:
finish_reason = response["choices"][0]["finish_reason"]
else:
finish_reason = "unknown"
if finish_reason not in ("stop", "length", "function_calls", "content_filter", "unknown"):
finish_reason = "unknown"
if thought:
create_result = CreateResult(
content=content,
thought=thought,
usage=cast(RequestUsage, response["usage"]),
finish_reason=normalize_stop_reason(finish_reason), # type: ignore
cached=False,
)
else:
create_result = CreateResult(
content=content,
usage=cast(RequestUsage, response["usage"]),
finish_reason=normalize_stop_reason(finish_reason), # type: ignore
cached=False,
)
create_result = CreateResult(
content=content,
thought=thought,
usage=cast(RequestUsage, response["usage"]),
finish_reason=normalize_stop_reason(finish_reason), # type: ignore
cached=False,
)

# If we are running in the context of a handler we can get the agent_id
try:
Expand Down Expand Up @@ -312,13 +336,17 @@ def _extract_tool_arguments(self, response_text: str) -> str:

async def create_stream(
self,
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
tools: Optional[Sequence[Tool | ToolSchema]] = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")
if False: # Unreachable code to satisfy the return type.
yield ""
yield ""

# Implement abstract methods
def actual_usage(self) -> RequestUsage:
Expand Down

0 comments on commit 35f35ca

Please sign in to comment.