Skip to content

Commit

Permalink
Local vllm support
Browse files Browse the repository at this point in the history
  • Loading branch information
Sviatoslav Bilokin committed Mar 3, 2025
1 parent eef2c17 commit a074687
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ mcp = [
mlx-lm = [
"mlx-lm"
]
vllm = [
"vllm"
]
openai = [
"openai>=1.58.1"
]
Expand Down
135 changes: 135 additions & 0 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,141 @@ def __call__(
return self._to_message(text, tools_to_call_from)


class VLLMModel(Model):
"""A class to interact with local vLLM provider.
> [!TIP]
> You must have `vllm` installed on your machine. Please run `pip install smolagents[vllm]` if it's not the case.
Parameters:
model_id (str):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
trust_remote_code (bool):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
init_kwargs (dict, *optional):
Keyword arguments that are passed to vllm.LLM object initialization.
sampling_kwargs (dict, *optional):
Keyword arguments that are passed to vllm.SamplingParams object initialization.
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate().
Example:
```python
>>> engine = VLLMModel(
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
... sampling_kwargs={"max_tokens":10000},
... )
>>> messages = [
... {
... "role": "user",
... "content": [
... {"type": "text", "text": "Explain quantum mechanics in simple terms."}
... ]
... }
... ]
>>> response = engine(messages)
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""

def __init__(
self,
model_id: str,
sampling_kwargs: dict = None,
init_kwargs: dict = None,
trust_remote_code: bool = False,
**kwargs,
):
super().__init__(**kwargs)
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None:
model_id = default_model_id
logger.warning(f"`model_id`not provided, using this default model: '{model_id}'")
self.model_id = model_id
if not _is_package_available("vllm"):
raise ModuleNotFoundError(
"Please install 'vllm' extra to use 'VLLMModel': `pip install 'smolagents[vllm]'`"
)
from vllm import LLM, SamplingParams

if not init_kwargs:
init_kwargs = {}
init_kwargs.update({"trust_remote_code": trust_remote_code})
if not sampling_kwargs:
sampling_kwargs = {}
default_max_tokens = 5000
max_new_tokens = sampling_kwargs.get("max_new_tokens") or sampling_kwargs.get("max_tokens")
if not max_new_tokens:
kwargs["max_new_tokens"] = default_max_tokens
logger.warning(
f"`max_new_tokens` not provided, using this default value for `max_new_tokens`: {default_max_tokens}"
)
self.kwargs = kwargs
self.sampling_params = SamplingParams(**sampling_kwargs)
self.model = LLM(model=model_id, **init_kwargs)
self._is_vlm = False

def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
max_new_tokens = (
kwargs.get("max_new_tokens")
or kwargs.get("max_tokens")
or self.kwargs.get("max_new_tokens")
or self.kwargs.get("max_tokens")
)
completion_kwargs = {}
if max_new_tokens:
completion_kwargs["max_new_tokens"] = max_new_tokens
out = self.model.chat(messages, sampling_params=self.sampling_params, use_tqdm=False)
output = out[-1].outputs[-1].text
if stop_sequences is not None:
output = remove_stop_sequences(output, stop_sequences)
import torch

raw = {"output": torch.tensor(out[-1].outputs[-1].token_ids), "completion_kwargs": completion_kwargs}

if tools_to_call_from is None:
return ChatMessage(
role="assistant",
content=output,
raw=raw,
)
else:
if "Action:" in output:
output = output.split("Action:", 1)[1].strip()
try:
start_index = output.index("{")
end_index = output.rindex("}")
output = output[start_index : end_index + 1]
except Exception as e:
raise Exception("No json blob found in output!") from e

try:
parsed_output = json.loads(output)
except json.JSONDecodeError as e:
raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}")
tool_name = parsed_output.get("name")
tool_arguments = parsed_output.get("arguments")
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
raw=raw,
)


class TransformersModel(Model):
"""A class that uses Hugging Face's Transformers library for language model interaction.
Expand Down

0 comments on commit a074687

Please sign in to comment.