Skip to content

Commit

Permalink
Fix load_model in webagent CLI (#855)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova authored Mar 4, 2025
1 parent e2a4690 commit 3a14d0d
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ jobs:
# run: |
# uv run pytest ./tests/test_all_docs.py

- name: CLI tests
run: |
uv run pytest ./tests/test_cli.py
if: ${{ success() || failure() }}

- name: Final answer tests
run: |
uv run pytest ./tests/test_final_answer.py
Expand Down
2 changes: 1 addition & 1 deletion src/smolagents/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def parse_arguments(description):
return parser.parse_args()


def load_model(model_type: str, model_id: str, api_base: str | None, api_key: str | None) -> Model:
def load_model(model_type: str, model_id: str, api_base: str | None = None, api_key: str | None = None) -> Model:
if model_type == "OpenAIServerModel":
return OpenAIServerModel(
api_key=api_key or os.getenv("FIREWORKS_API_KEY"),
Expand Down
3 changes: 2 additions & 1 deletion src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from huggingface_hub import InferenceClient
from huggingface_hub.utils import is_torch_available
from PIL import Image

Expand Down Expand Up @@ -429,6 +428,8 @@ def __init__(
custom_role_conversions: Optional[Dict[str, str]] = None,
**kwargs,
):
from huggingface_hub import InferenceClient

super().__init__(**kwargs)
self.model_id = model_id
self.provider = provider
Expand Down
54 changes: 54 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from unittest.mock import patch

import pytest

from smolagents.cli import load_model
from smolagents.models import HfApiModel, LiteLLMModel, OpenAIServerModel, TransformersModel


@pytest.fixture
def set_env_vars(monkeypatch):
monkeypatch.setenv("FIREWORKS_API_KEY", "test_fireworks_api_key")
monkeypatch.setenv("HF_TOKEN", "test_hf_api_key")


def test_load_model_openai_server_model(set_env_vars):
with patch("openai.OpenAI") as MockOpenAI:
model = load_model("OpenAIServerModel", "test_model_id")
assert isinstance(model, OpenAIServerModel)
assert model.model_id == "test_model_id"
assert MockOpenAI.call_count == 1
assert MockOpenAI.call_args.kwargs["base_url"] == "https://api.fireworks.ai/inference/v1"
assert MockOpenAI.call_args.kwargs["api_key"] == "test_fireworks_api_key"


def test_load_model_litellm_model():
model = load_model("LiteLLMModel", "test_model_id", api_key="test_api_key", api_base="https://api.test.com")
assert isinstance(model, LiteLLMModel)
assert model.api_key == "test_api_key"
assert model.api_base == "https://api.test.com"
assert model.model_id == "test_model_id"


def test_load_model_transformers_model():
with (
patch("transformers.AutoModelForCausalLM.from_pretrained"),
patch("transformers.AutoTokenizer.from_pretrained"),
):
model = load_model("TransformersModel", "test_model_id")
assert isinstance(model, TransformersModel)
assert model.model_id == "test_model_id"


def test_load_model_hf_api_model(set_env_vars):
with patch("huggingface_hub.InferenceClient") as huggingface_hub_InferenceClient:
model = load_model("HfApiModel", "test_model_id")
assert isinstance(model, HfApiModel)
assert model.model_id == "test_model_id"
assert huggingface_hub_InferenceClient.call_count == 1
assert huggingface_hub_InferenceClient.call_args.kwargs["token"] == "test_hf_api_key"


def test_load_model_invalid_model_type():
with pytest.raises(ValueError, match="Unsupported model type: InvalidModel"):
load_model("InvalidModel", "test_model_id")

0 comments on commit 3a14d0d

Please sign in to comment.