diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 28de2eefa..c22c89359 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -21,6 +21,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from enum import Enum +from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from huggingface_hub import InferenceClient @@ -676,6 +677,16 @@ def __init__( self.api_key = api_key self.custom_role_conversions = custom_role_conversions + @cached_property + def _flatten_messages_as_text(self): + import litellm + + model_info: dict = litellm.get_model_info(self.model_id) + if model_info["litellm_provider"] == "ollama": + return model_info["key"] != "llava" + + return False + def __call__( self, messages: List[Dict[str, str]], @@ -695,7 +706,7 @@ def __call__( api_base=self.api_base, api_key=self.api_key, convert_images_to_image_urls=True, - flatten_messages_as_text=self.model_id.startswith("ollama"), + flatten_messages_as_text=self._flatten_messages_as_text, custom_role_conversions=self.custom_role_conversions, **kwargs, )