Skip to content

Commit

Permalink
Add a __call__ method to models
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Jan 20, 2025
1 parent b352c62 commit 685aa01
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Union

from .anthropic import Anthropic
from .base import Model
from .exllamav2 import ExLlamaV2Model, exl2
from .gemini import Gemini
from .llamacpp import LlamaCpp
Expand Down
3 changes: 2 additions & 1 deletion outlines/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import singledispatchmethod
from typing import Union

from outlines.models import Model
from outlines.prompts import Vision

__all__ = ["Anthropic"]
Expand Down Expand Up @@ -63,7 +64,7 @@ def format_vision_input(self, model_input: Vision):
}


class Anthropic(AnthropicBase):
class Anthropic(Model, AnthropicBase):
def __init__(self, model_name: str, *args, **kwargs):
from anthropic import Anthropic

Expand Down
10 changes: 10 additions & 0 deletions outlines/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC


class Model(ABC):
"""Base class for all models."""

def __call__(self, model_input, output_type=None, **inference_kwargs):
from outlines.generate import Generator

return Generator(self, output_type)(model_input, **inference_kwargs)
3 changes: 2 additions & 1 deletion outlines/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel
from typing_extensions import _TypedDictMeta # type: ignore

from outlines.models import Model
from outlines.prompts import Vision
from outlines.types import Choice, Json, List

Expand Down Expand Up @@ -91,7 +92,7 @@ def format_enum_output_type(self, output_type):
}


class Gemini(GeminiBase):
class Gemini(Model, GeminiBase):
def __init__(self, model_name: str, *args, **kwargs):
import google.generativeai as genai

Expand Down
3 changes: 2 additions & 1 deletion outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Union

from outlines.models import Model
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -93,7 +94,7 @@ def __setstate__(self, state):
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")


class LlamaCpp:
class LlamaCpp(Model):
"""Wraps a model provided by the `llama-cpp-python` library."""

def __init__(self, model_path: Union[str, "Llama"], **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pydantic import BaseModel

from outlines.models import Model
from outlines.prompts import Vision
from outlines.types import Json

Expand Down Expand Up @@ -114,7 +115,7 @@ def format_json_output_type(self, output_type: Json):
}


class OpenAI(OpenAIBase):
class OpenAI(Model, OpenAIBase):
"""Thin wrapper around the `openai.OpenAI` client.
This wrapper is used to convert the input and output types specified by the
Expand Down

0 comments on commit 685aa01

Please sign in to comment.