Skip to content

Commit

Permalink
Test HfApiModel call with custom_role_conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova committed Feb 6, 2025
1 parent 6cd5b6a commit b547f50
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import unittest
from pathlib import Path
from typing import Optional
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from transformers.testing_utils import get_tests_dir

from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
from smolagents.models import get_clean_message_list, parse_json_if_needed
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed


class ModelTests(unittest.TestCase):
Expand Down Expand Up @@ -103,6 +103,19 @@ def test_parse_json_if_needed(self):
assert parsed_args == 3


class TestHfApiModel:
def test_call_with_custom_role_conversions(self):
custom_role_conversions = {MessageRole.USER: MessageRole.SYSTEM}
model = HfApiModel(model_id="test-model", custom_role_conversions=custom_role_conversions)
model.client = MagicMock()
messages = [{"role": "user", "content": "Test message"}]
_ = model(messages)
# Verify that the role conversion was applied
assert model.client.chat_completion.call_args.kwargs["messages"][0]["role"] == "system", (
"role conversion should be applied"
)


def test_get_clean_message_list_basic():
messages = [
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
Expand Down

0 comments on commit b547f50

Please sign in to comment.