diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 21d4355b36ab0..57dd6e310297d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -245,20 +245,24 @@ def create_weights( layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: - # Block quant doesn't need to process weights after loading + # TODO(rob): refactor block quant into separate class. if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" if current_platform.is_rocm(): - weight, weight_scale, _ = \ + weight, weight_scale_inv, _ = \ normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale) - layer.weight = Parameter(weight, requires_grad=False) - layer.weight_scale_inv = Parameter(weight_scale, - requires_grad=False) + weight_scale=layer.weight_scale_inv) + else: + weight = layer.weight.data + weight_scale_inv = layer.weight_scale_inv.data + + # Torch.compile cannot use Parameter subclasses. + layer.weight = Parameter(weight, requires_grad=False) + layer.weight_scale_inv = Parameter(weight_scale_inv, + requires_grad=False) return - layer.weight = torch.nn.Parameter(layer.weight.data, - requires_grad=False) + # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, @@ -507,8 +511,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - # Block quant doesn't need to process weights after loading + # TODO (rob): refactor block quant into separate class. if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" if current_platform.is_rocm(): w13_weight, w13_weight_scale_inv, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( @@ -518,22 +523,21 @@ def process_weights_after_loading(self, layer: Module) -> None: normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w13_weight_scale_inv = torch.nn.Parameter( - w13_weight_scale_inv, requires_grad=False) - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale_inv = torch.nn.Parameter( - w2_weight_scale_inv, requires_grad=False) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False) + else: + w13_weight = layer.w13_weight.data + w13_weight_scale_inv = layer.w13_weight_scale_inv.data + w2_weight = layer.w2_weight + w2_weight_scale_inv = layer.w2_weight_scale_inv + + # torch.compile() cannot use Parameter subclasses. + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, + requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, + requires_grad=False) return + # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If rocm, use float8_e4m3fnuz as dtype