Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 30, 2025
1 parent d27826d commit 8bdc14a
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 23 deletions.
4 changes: 2 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
use_mla: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
**kwargs,
**extra_impl_args,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**kwargs)
**extra_impl_args)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
Expand Down
10 changes: 3 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class ModelConfig:
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
disable_mla: Whether to disable MLA for DeepSeek models.
override_generation_config: Override the generation config with the
given config.
"""
Expand Down Expand Up @@ -227,7 +226,6 @@ def __init__(
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
disable_mla: bool = False,
enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
) -> None:
Expand Down Expand Up @@ -278,7 +276,6 @@ def __init__(
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.disable_mla = disable_mla
self.enable_sleep_mode = enable_sleep_mode

from vllm.platforms import current_platform
Expand Down Expand Up @@ -748,7 +745,7 @@ def is_deepseek_mla(self) -> bool:
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
if self.should_use_mla:
if self.use_mla:
return self.hf_text_config.kv_lora_rank
else:
qk_rope_head_dim = getattr(self.hf_text_config,
Expand Down Expand Up @@ -815,7 +812,7 @@ def get_total_num_kv_heads(self) -> int:

def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if self.should_use_mla:
if self.use_mla:
# When using MLA during decode it becomes MQA
return 1

Expand Down Expand Up @@ -971,8 +968,7 @@ def is_cross_encoder(self) -> bool:

@property
def use_mla(self) -> bool:
use_mla = (self.is_deepseek_mla and not self.disable_mla
and not envs.VLLM_MLA_DISABLE)
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
return use_mla

def supported_runner_types(self) -> Set[RunnerType]:
Expand Down
5 changes: 0 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class EngineArgs:
kv_cache_dtype: str = 'auto'
seed: int = 0
max_model_len: Optional[int] = None
disable_mla: bool = False
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
Expand Down Expand Up @@ -932,9 +931,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument('--disable-mla',
action='store_true',
help='Disable MLA for DeepSeek models.')
parser.add_argument(
"--generation-config",
type=nullable_str,
Expand Down Expand Up @@ -1015,7 +1011,6 @@ def create_model_config(self) -> ModelConfig:
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
disable_mla=self.disable_mla,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def __init__(
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
if model_config.should_use_mla:
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
Expand Down
13 changes: 6 additions & 7 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,12 @@ def __init__(
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

# Get attention backend.
self.attn_backend = get_attn_backend(
self.head_size,
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
model_config.is_attention_free,
use_mla=model_config.should_use_mla)
self.attn_backend = get_attn_backend(self.head_size,
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
model_config.is_attention_free,
use_mla=model_config.use_mla)

# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ def __init__(
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.should_use_mla,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
Expand Down

0 comments on commit 8bdc14a

Please sign in to comment.