From 58ca6632247cb738d069a585e1ec9a9d5e66da68 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:39:12 -0400 Subject: [PATCH] [ Misc ] Improve Min Capability Checking in `compressed-tensors` (#6522) --- .../compressed_tensors/compressed_tensors.py | 22 ++++++++++++------- .../schemes/compressed_tensors_scheme.py | 7 ++++++ .../schemes/compressed_tensors_unquantized.py | 4 ++++ .../schemes/compressed_tensors_w4a16_24.py | 4 ++++ .../schemes/compressed_tensors_w8a8_fp8.py | 4 ++++ .../schemes/compressed_tensors_w8a8_int8.py | 4 ++++ .../schemes/compressed_tensors_wNa16.py | 4 ++++ 7 files changed, 41 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 1424c620ae675..659f5a599dc14 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -37,7 +37,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 75 + return 70 def get_name(self) -> str: return "compressed_tensors" @@ -85,13 +85,14 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def get_config_filenames(cls) -> List[str]: return [] - def _check_gptq_and_marlin_can_run(self): + def _check_scheme_supported(self, min_capability: int): capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] - if capability < 80: - raise RuntimeError("The quantization config is not supported for ", - "the current GPU. Minimum capability: 80. ", - f"Current capability: {capability}.") + if capability < min_capability: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: @@ -171,7 +172,6 @@ def _get_schema(self, weight_quant: BaseModel, # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): - self._check_gptq_and_marlin_can_run() if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): return CompressedTensorsW4A16Sparse24( @@ -222,10 +222,16 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": raise ValueError( f"Could not find quantization details for {layer}.") - return self._get_schema( + scheme = self._get_schema( weight_quant=layer_quant_details["weights"], input_quant=layer_quant_details["input_activations"]) + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + class CompressedTensorsLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 3aa9130782039..d5f37b47bb87e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC): of different quantization schemes supported by CompressedTensors. """ + @abstractmethod + def get_min_capability(self) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + @abstractmethod def create_weights(self, *args, **kwargs): """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index 2c7fe3e0e4114..4350ff4e90ae8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): in a linear transformation. """ + def get_min_capability(self) -> int: + # volta and up + return 70 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 54bf85c096f2e..eec523d00372c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -29,6 +29,10 @@ def __init__(self, raise ValueError( "group_size must be given when using strategy group") + def get_min_capability(self) -> int: + # ampere + up + return 80 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index f1ca9510d92aa..e842475e4f34b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -33,6 +33,10 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): "Consider quantizing with per tensor scales or upgrading " "to Hopper.") + def get_min_capability(self) -> int: + # lovelace and up + return 89 + def process_weights_after_loading(self, layer) -> None: # If per tensor, when we have a fused module (e.g. QKV) with per # tensor scales (thus N scales being passed to the kernel), diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6fec5d01056d8..e81496c89ac7f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -19,6 +19,10 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme + def get_min_capability(self) -> int: + # turing and up + return 75 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # Cutlass kernels need transposed weight. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 187a3f9877ccf..3f3febcad4f85 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -42,6 +42,10 @@ def __init__(self, group_size=self.group_size, is_sym=True) + def get_min_capability(self) -> int: + # ampere and up + return 80 + def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], input_size_per_partition: int,