Skip to content

Commit

Permalink
LiteLLMModel - detect message flatenning based on model information (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sysradium authored Feb 13, 2025
1 parent 41a388d commit 392fc5a
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from huggingface_hub import InferenceClient
Expand Down Expand Up @@ -799,6 +800,16 @@ def __init__(
self.api_key = api_key
self.custom_role_conversions = custom_role_conversions

@cached_property
def _flatten_messages_as_text(self):
import litellm

model_info: dict = litellm.get_model_info(self.model_id)
if model_info["litellm_provider"] == "ollama":
return model_info["key"] != "llava"

return False

def __call__(
self,
messages: List[Dict[str, str]],
Expand All @@ -818,7 +829,7 @@ def __call__(
api_base=self.api_base,
api_key=self.api_key,
convert_images_to_image_urls=True,
flatten_messages_as_text=self.model_id.startswith("ollama"),
flatten_messages_as_text=self._flatten_messages_as_text,
custom_role_conversions=self.custom_role_conversions,
**kwargs,
)
Expand Down

0 comments on commit 392fc5a

Please sign in to comment.