diff --git a/.gitignore b/.gitignore index cafd598..2ea307f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ __pycache__/ -.venv/ \ No newline at end of file +.venv/ +.DS_Store diff --git a/mlx_engine/cache_wrapper.py b/mlx_engine/cache_wrapper.py index 09684a0..9c0cde3 100644 --- a/mlx_engine/cache_wrapper.py +++ b/mlx_engine/cache_wrapper.py @@ -6,7 +6,7 @@ trim_prompt_cache, can_trim_prompt_cache, ) -from mlx_lm.utils import generation_stream +from mlx_lm.utils import generation_stream, maybe_quantize_kv_cache import mlx.core as mx import mlx.nn as nn import sys @@ -21,7 +21,11 @@ def __init__( self, model: nn.Module, max_kv_size: Optional[int], + *, verbose: bool = False, + kv_bits: Optional[int] = None, + kv_group_size: Optional[int] = None, + quantized_kv_start: Optional[int] = None, ): """ Initialize the CacheWrapper. @@ -37,6 +41,11 @@ def __init__( self.draft_model: Optional[nn.Module] = None self.max_kv_size = max_kv_size self.verbose = verbose + self.kv_cache_qtn_params = dict( + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + ) @staticmethod def _find_common_prefix( @@ -151,6 +160,7 @@ def _prefill( current_chunk = remaining_tokens[:current_chunk_size] model(current_chunk[None], cache=cache) + maybe_quantize_kv_cache(prompt_cache=cache, **self.kv_cache_qtn_params) mx.eval([c.state for c in cache]) remaining_tokens = remaining_tokens[current_chunk_size:] diff --git a/mlx_engine/model_kit.py b/mlx_engine/model_kit.py index 5bf115f..d556c1f 100644 --- a/mlx_engine/model_kit.py +++ b/mlx_engine/model_kit.py @@ -1,5 +1,5 @@ import sys -from typing import List, Optional +from typing import List, Optional, Tuple from mlx_engine.logging import log_info, log_warn import mlx_lm @@ -58,8 +58,10 @@ def _full_model_init( kv_group_size: Optional[int] = None, quantized_kv_start: Optional[int] = None, ): - self._validate_kv_cache_quantization_params( - kv_bits, kv_group_size, quantized_kv_start + kv_bits, kv_group_size, quantized_kv_start = ( + self._get_kv_cache_quantization_params( + kv_bits, kv_group_size, quantized_kv_start + ) ) if kv_bits and max_kv_size is not None: # Quantized KV cache is only supported for non-rotating KV cache @@ -73,7 +75,13 @@ def _full_model_init( log_info(prefix="ModelKit", message=f"Loading model from {model_path}...") self.model, self.tokenizer = mlx_lm.utils.load(self.model_path) self.detokenizer = self.tokenizer.detokenizer - self.cache_wrapper = CacheWrapper(self.model, max_kv_size) + self.cache_wrapper = CacheWrapper( + self.model, + max_kv_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + ) self.kv_bits = kv_bits self.kv_group_size = kv_group_size self.quantized_kv_start = quantized_kv_start @@ -100,23 +108,49 @@ def __init__( ) @staticmethod - def _validate_kv_cache_quantization_params( + def _get_kv_cache_quantization_params( kv_bits: Optional[int], kv_group_size: Optional[int], quantized_kv_start: Optional[int], - ): + ) -> Tuple[Optional[int], Optional[int], Optional[int]]: + """ + Validates and processes KV cache quantization parameters. + + Args: + kv_bits: Number of bits for quantization. If None, disables quantization. + kv_group_size: Group size for quantization. Defaults to 64 if quantization enabled. + quantized_kv_start: Step to begin quantization. Defaults to 0 if quantization enabled. + + Returns: + Tuple of (kv_bits, kv_group_size, quantized_kv_start), all None if quantization disabled. + + Raises: + ValueError: If kv_bits is invalid or missing when other params are set. + """ if any([kv_group_size, quantized_kv_start]) and kv_bits is None: raise ValueError( "Enabling KV Cache Quantization requires kv_bits to be set" ) - if kv_bits and kv_bits not in VALID_KV_BITS: + if kv_bits is None: + return None, None, None + + # defaults taken from here: + # https://github.com/ml-explore/mlx-examples/blob/3d793ec/llms/mlx_lm/utils.py#L352-L353 + if kv_group_size is None: + kv_group_size = 64 + if quantized_kv_start is None: + quantized_kv_start = 0 + + if kv_bits not in VALID_KV_BITS: raise ValueError(f"Invalid kv_bits value. Must be one of {VALID_KV_BITS}") - if kv_group_size and kv_group_size not in VALID_KV_GROUP_SIZE: + if kv_group_size not in VALID_KV_GROUP_SIZE: raise ValueError( f"Invalid kv_group_size value. Must be one of {VALID_KV_GROUP_SIZE}" ) + return kv_bits, kv_group_size, quantized_kv_start + def tokenize(self, prompt: str) -> List[int]: ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt)) if type(ids) == int: