Skip to content

Commit

Permalink
fix mypy error
Browse files Browse the repository at this point in the history
Signed-off-by: kewang2 <kewang2@amd.com>
  • Loading branch information
kewang2 committed Dec 17, 2024
1 parent e556196 commit 6c45013
Show file tree
Hide file tree
Showing 16 changed files with 37 additions and 35 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ def get_quant_method(self, layer: torch.nn.Module,
"""
raise NotImplementedError

def get_cache_scale(self, name: str) -> Optional[List[str]]:
def get_cache_scale(self, name: str) -> Optional[str]:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def get_scheme(
# utils.py to instance method of CompressedTensorsConfig
# class. By doing this, different QuantizationConfig
# classes can implement their own get_cache_scale method.
def get_cache_scale(self, name: str) -> Optional[List[str]]:
def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method"))

# In the export model of quark, the quantization configuration
# of kv_cache is stored in layer_quant_config. First, it is
# judged whether kv_cache_group exists, and then it is judged
# whether layer_quant_config has a quantization configuration
# In the export model of quark, the quantization configuration
# of kv_cache is stored in layer_quant_config. First, it is
# judged whether kv_cache_group exists, and then it is judged
# whether layer_quant_config has a quantization configuration
# that matches kv_cache.
if len(kv_cache_group) == 0:
kv_cache_config = None
Expand Down Expand Up @@ -277,7 +277,7 @@ def get_scheme(self, layer: torch.nn.Module,

return scheme

def get_cache_scale(self, name: str) -> Optional[List[str]]:
def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in quark. If this is the case, return its equivalent param name
Expand All @@ -300,9 +300,11 @@ def get_cache_scale(self, name: str) -> Optional[List[str]]:
elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale", ".attn.k_scale")
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale", ".attn.v_scale")
return name.replace(".v_proj.output_scale",
".attn.v_scale")

# If no matches, return None
return None
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/gpt_j.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,8 +1438,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0
else loaded_weight[0])
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
Expand Down

0 comments on commit 6c45013

Please sign in to comment.