Skip to content

Commit

Permalink
introduce outlines.models.mlxlm
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 11, 2024
1 parent ed44a47 commit e2d8a5c
Show file tree
Hide file tree
Showing 13 changed files with 645 additions and 4 deletions.
32 changes: 32 additions & 0 deletions docs/reference/models/mlxlm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# mlx-lm

Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx-examples/tree/main/llms), allowing models to be run quickly on Apple Silicon via the [mlx](https://ml-explore.github.io/mlx/build/html/index.html) library.

## Installation

In addition to `outlines`, you must install `mlx-lm` and `mlx` libraries. You must use a device which [supports Metal](https://support.apple.com/en-us/102894).

## Using `models.mlxlm`

```python
from outlines import models

model = models.mlxlm("mlx-community/mlx-community/Meta-Llama-3-8B-Instruct-8bit")
```

With the loaded model, you can generate text or perform structured generation, e.g.

```python3
from outlines import models, generate

model = models.mlxlm("mlx-community/Meta-Llama-3-8B-Instruct-8bit")

phone_number_pattern = "\\+?[1-9][0-9]{7,14}"
generator = generate.regex(model, phone_number_pattern)

model_output = generator("What's Jennys Number?\n")
print(model_output)
# '8675309'
```

For more examples, see the [cookbook](cookbook/index.md).
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ nav:
- vLLM: reference/models/vllm.md
- Llama.cpp: reference/models/llamacpp.md
- Transformers: reference/models/transformers.md
- MLX: reference/models/mlxlm.md
- ExllamaV2: reference/models/exllamav2.md
- Mamba: reference/models/mamba.md
- OpenAI: reference/models/openai.md
Expand Down
8 changes: 5 additions & 3 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial

Expand Down Expand Up @@ -33,14 +34,15 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
return generator


@cfg.register(MLXLM)
@cfg.register(VLLM)
def cfg_vllm(
model: VLLM,
def cfg_unimplemented(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
"The CFG Logits processor is not available for the vLLM integration."
f"The CFG Logits processor is not available for {type(model)}."
)


Expand Down
13 changes: 13 additions & 0 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.mlxlm import MLXLM
from outlines.models.vllm import VLLM
from outlines.samplers import Sampler, multinomial

Expand Down Expand Up @@ -37,6 +38,18 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
return generator


@regex.register(MLXLM)
def regex_mlxlm(
model: MLXLM,
regex_str: str,
sampler: Sampler = multinomial(),
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)


@regex.register(LlamaCpp)
def regex_llamacpp(
model: LlamaCpp,
Expand Down
7 changes: 6 additions & 1 deletion outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from outlines.fsm.guide import StopAtEOSGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import VLLM, LlamaCpp, OpenAI
from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -36,6 +36,11 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
return generator


@text.register(MLXLM)
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)


@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
Expand Down
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .exllamav2 import ExLlamaV2Model, exl2
from .llamacpp import LlamaCpp, llamacpp
from .mamba import Mamba, mamba
from .mlxlm import MLXLM, mlxlm
from .openai import OpenAI, azure_openai, openai
from .transformers import Transformers, transformers
from .vllm import VLLM, vllm
Expand Down
240 changes: 240 additions & 0 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import dataclasses
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
import mlx.nn as nn
from transformers import PreTrainedTokenizer

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.processors import BaseLogitsProcessor

try:
import mlx.core as mx
import mlx_lm
except ImportError:
pass


class MLXLM:
"""
Represents an `mlx_lm` model
"""

def __init__(
self,
model: "nn.Module",
tokenizer: "PreTrainedTokenizer",
):
self.model = model
self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode()
self.tokenizer = TransformerTokenizer(
tokenizer._tokenizer
) # _tokenizer is HF Tokenizer

def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> str:
streamer = self.stream(
prompts, generation_parameters, logits_processor, sampling_parameters
)
return "".join(list(streamer))

def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> Iterator[str]:
"""Generate text using `mlx_lm`.
Arguments
---------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
Returns
-------
The generated text.
"""
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
)
if max_tokens is None:
max_tokens = int(1e9)

if not isinstance(prompts, str):
raise NotImplementedError(
"The `mlx-lm` library does not support batch inference."
)
if sampler == "beam_search":
raise NotImplementedError(
"The `mlx-lm` library does not support Beam Search."
)
if num_samples != 1:
raise NotImplementedError(
"The `mlx-lm` library does not allow to take several samples."
)
if top_k is not None:
raise NotImplementedError("The `mlx-lm` library does not support top_k.")
if seed is not None:
raise NotImplementedError("The `mlx-lm` library does not support seed.")
if stop_at is not None:
raise NotImplementedError("The `mlx-lm` library does not support stop_at.")

generate_kwargs = {
"temp": temperature,
"top_p": top_p,
"sampler": sampler,
"logits_processor": logits_processor,
}

# Adapted from
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts))

for (token, prob), n in zip(
self.generate_step(prompt_tokens, **generate_kwargs),
range(max_tokens),
):
if token == self.tokenizer.eos_token_id:
break
yield self.tokenizer.decode([token])[0]

def generate_step(
self,
prompt: "mx.array",
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "BaseLogitsProcessor",
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
sampler (str): The sampler string defined by SequenceGeneratorAdapter
logits_processor (BaseLogitsProcessor): Augment logits before sampling.
"""
temperature: float = temp or 1.0

def sample(logits: "mx.array") -> Tuple["mx.array", float]:
softmax_logits = mx.softmax(logits)

if temperature == 0.0 or sampler == "greedy":
token = mx.argmax(logits, axis=-1)
elif sampler == "multinomial":
if top_p is not None and top_p > 0 and top_p < 1.0:
token = mlx_lm.sample_utils.top_p_sampling(
logits, top_p, temperature
)
else:
token = mx.random.categorical(logits * (1 / temperature))
else:
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`")

prob = softmax_logits[0, token]
return token, prob

kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]

# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
unprocessed_input_ids = prompt
generated_ids: List[int] = []

while True:
logits = self.model(unprocessed_input_ids[None], cache=cache)
logits = logits[:, -1, :]

if logits_processor is not None:
# convert to logits_processor 1d expectation, apply, then convert back
logits_1d = logits.reshape(-1)
logits_1d = logits_processor(generated_ids, logits_1d)
logits = logits_1d.reshape(1, -1)

new_token_single, prob = sample(logits)
new_token = new_token_single.item()
yield new_token, prob

generated_ids.append(new_token)
unprocessed_input_ids = new_token_single


def mlxlm(
model_name: str,
tokenizer_config: dict = {},
model_config: dict = {},
adapter_path: Optional[str] = None,
lazy: bool = False,
):
"""Instantiate a model from the `mlx_lm` library and its tokenizer.
Signature adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422
Parameters
----------
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns
-------
A `MLXLM` model instance.
"""
try:
import mlx.core as mx
import mlx_lm
except ImportError:
raise ImportError(
"The `mlx_lm` library needs to be installed in order to use `mlx_lm` models."
)
if not mx.metal.is_available():
raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)")

model, tokenizer = mlx_lm.load(
model_name,
tokenizer_config=tokenizer_config,
model_config=model_config,
adapter_path=adapter_path,
lazy=lazy,
)
return MLXLM(model, tokenizer)
7 changes: 7 additions & 0 deletions outlines/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .structured import (
BaseLogitsProcessor,
CFGLogitsProcessor,
FSMLogitsProcessor,
JSONLogitsProcessor,
RegexLogitsProcessor,
)
Loading

0 comments on commit e2d8a5c

Please sign in to comment.