Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantize the prompt when it's longer than quantized_kv_start #105

Merged
merged 2 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__pycache__/
.venv/
.venv/
.DS_Store
12 changes: 11 additions & 1 deletion mlx_engine/cache_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
trim_prompt_cache,
can_trim_prompt_cache,
)
from mlx_lm.utils import generation_stream
from mlx_lm.utils import generation_stream, maybe_quantize_kv_cache
import mlx.core as mx
import mlx.nn as nn
import sys
Expand All @@ -21,7 +21,11 @@ def __init__(
self,
model: nn.Module,
max_kv_size: Optional[int],
*,
verbose: bool = False,
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
):
"""
Initialize the CacheWrapper.
Expand All @@ -37,6 +41,11 @@ def __init__(
self.draft_model: Optional[nn.Module] = None
self.max_kv_size = max_kv_size
self.verbose = verbose
self.kv_cache_qtn_params = dict(
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
)

@staticmethod
def _find_common_prefix(
Expand Down Expand Up @@ -151,6 +160,7 @@ def _prefill(
current_chunk = remaining_tokens[:current_chunk_size]

model(current_chunk[None], cache=cache)
maybe_quantize_kv_cache(prompt_cache=cache, **self.kv_cache_qtn_params)
mx.eval([c.state for c in cache])

remaining_tokens = remaining_tokens[current_chunk_size:]
Expand Down
50 changes: 42 additions & 8 deletions mlx_engine/model_kit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, Optional
from typing import List, Optional, Tuple

from mlx_engine.logging import log_info, log_warn
import mlx_lm
Expand Down Expand Up @@ -58,8 +58,10 @@ def _full_model_init(
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
):
self._validate_kv_cache_quantization_params(
kv_bits, kv_group_size, quantized_kv_start
kv_bits, kv_group_size, quantized_kv_start = (
self._get_kv_cache_quantization_params(
kv_bits, kv_group_size, quantized_kv_start
)
)
if kv_bits and max_kv_size is not None:
# Quantized KV cache is only supported for non-rotating KV cache
Expand All @@ -73,7 +75,13 @@ def _full_model_init(
log_info(prefix="ModelKit", message=f"Loading model from {model_path}...")
self.model, self.tokenizer = mlx_lm.utils.load(self.model_path)
self.detokenizer = self.tokenizer.detokenizer
self.cache_wrapper = CacheWrapper(self.model, max_kv_size)
self.cache_wrapper = CacheWrapper(
self.model,
max_kv_size,
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
)
self.kv_bits = kv_bits
self.kv_group_size = kv_group_size
self.quantized_kv_start = quantized_kv_start
Expand All @@ -100,23 +108,49 @@ def __init__(
)

@staticmethod
def _validate_kv_cache_quantization_params(
def _get_kv_cache_quantization_params(
kv_bits: Optional[int],
kv_group_size: Optional[int],
quantized_kv_start: Optional[int],
):
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""
Validates and processes KV cache quantization parameters.

Args:
kv_bits: Number of bits for quantization. If None, disables quantization.
kv_group_size: Group size for quantization. Defaults to 64 if quantization enabled.
quantized_kv_start: Step to begin quantization. Defaults to 0 if quantization enabled.

Returns:
Tuple of (kv_bits, kv_group_size, quantized_kv_start), all None if quantization disabled.

Raises:
ValueError: If kv_bits is invalid or missing when other params are set.
"""
if any([kv_group_size, quantized_kv_start]) and kv_bits is None:
raise ValueError(
"Enabling KV Cache Quantization requires kv_bits to be set"
)

if kv_bits and kv_bits not in VALID_KV_BITS:
if kv_bits is None:
return None, None, None

# defaults taken from here:
# https://github.com/ml-explore/mlx-examples/blob/3d793ec/llms/mlx_lm/utils.py#L352-L353
if kv_group_size is None:
kv_group_size = 64
if quantized_kv_start is None:
quantized_kv_start = 0

if kv_bits not in VALID_KV_BITS:
raise ValueError(f"Invalid kv_bits value. Must be one of {VALID_KV_BITS}")
if kv_group_size and kv_group_size not in VALID_KV_GROUP_SIZE:
if kv_group_size not in VALID_KV_GROUP_SIZE:
raise ValueError(
f"Invalid kv_group_size value. Must be one of {VALID_KV_GROUP_SIZE}"
)

return kv_bits, kv_group_size, quantized_kv_start

def tokenize(self, prompt: str) -> List[int]:
ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt))
if type(ids) == int:
Expand Down