Skip to content

Commit

Permalink
fix code format
Browse files Browse the repository at this point in the history
  • Loading branch information
kewang-xlnx committed Dec 4, 2024
1 parent ea90add commit dce4cec
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}")

# lazy import to avoid triggering `torch.compile` too early
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig

from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
Expand All @@ -56,7 +58,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .neuron_quant import NeuronQuantConfig
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig

method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
Expand Down
32 changes: 16 additions & 16 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import re
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast

import torch

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)

from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkMoEMethod)
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform

__all__ = ["QuarkLinearMethod"]
Expand Down Expand Up @@ -104,16 +103,17 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
deep_compare(q_config, q_configs[0])
for q_config in q_configs):
raise ValueError(
"The quantization method used for kv_cache should be the same, "
"but the quantization method for the kv_cache layer in the "
"config is different.")
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
"kv_cache layer in the config is different.")
kv_cache_config = q_configs[0].get("output_tensors")
if kv_cache_config is None:
raise ValueError(
"The kv_cache quantization configuration is empty.")

# Since we have already set kv_cache quantization configurations, we will remove
# the quantization configuration for the output_tensors corresponding to the kv_cache layer.
# Since we have already set kv_cache quantization configurations,
# we will remove the quantization configuration for the
# output_tensors corresponding to the kv_cache layer.
for q_config in q_configs:
q_config["output_tensors"] = None

Expand Down Expand Up @@ -217,7 +217,7 @@ def _find_matched_config(self, layer_name: str,
else:
layer_quant_config = cast(
Dict[str, Any], self.quant_config.get("layer_quant_config"))
for name_pattern in layer_quant_config.keys():
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Callable, Optional, Dict, Any
from typing import Any, Callable, Dict, Optional

import torch

import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops

from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/quark/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from typing import Any, Optional, Iterable
from typing import Any, Iterable, Optional

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)

Expand Down

0 comments on commit dce4cec

Please sign in to comment.