Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 Quantization Support #62

Merged
merged 26 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from .base import Compressor
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .int_quantized import IntQuantizationCompressor
from .model_compressor import ModelCompressor
from .naive_quantized import (
FloatQuantizationCompressor,
IntQuantizationCompressor,
QuantizationCompressor,
)
from .pack_quantized import PackedQuantizationCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
20 changes: 20 additions & 0 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
import logging
import operator
import os
import re
from typing import Dict, Optional, Union

import torch
import transformers
from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
QUANTIZATION_CONFIG_NAME,
Expand Down Expand Up @@ -185,6 +188,11 @@ def compress(
compressed_state_dict
)

# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
Satrat marked this conversation as resolved.
Show resolved Hide resolved

return compressed_state_dict

def decompress(self, model_path: str, model: Module):
Expand Down Expand Up @@ -263,3 +271,15 @@ def _get_weight_arg_mappings(model: Module) -> Dict:
quantized_modules_to_args[name] = submodule.quantization_scheme.weights

return quantized_modules_to_args


# HACK: Override the dtype_byte_size function in transformers to support float8 types
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
def new_dtype_byte_size(dtype):
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import torch
from compressed_tensors.compressors import Compressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.quantization import (
FP8_DTYPE,
QuantizationArgs,
QuantizationType,
)
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
from compressed_tensors.quantization.utils import can_quantize
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
Expand All @@ -27,17 +31,21 @@
from tqdm import tqdm


__all__ = ["IntQuantizationCompressor"]
__all__ = [
"QuantizationCompressor",
"IntQuantizationCompressor",
"FloatQuantizationCompressor",
]

_LOGGER: logging.Logger = logging.getLogger(__name__)


@Compressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(Compressor):
@Compressor.register(name=CompressionFormat.naive_quantized.value)
class QuantizationCompressor(Compressor):
"""
Integer compression for quantized models. Weight of each quantized layer is
converted from its original float type to the format specified by the layer's
quantization scheme.
Implements naive compression for quantized models. Weight of each
quantized layer is converted from its original float type to the closest Pytorch
type to the type specified by the layer's QuantizationArgs.
"""

COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
Expand Down Expand Up @@ -76,7 +84,7 @@ def compress(
scale=scale,
zero_point=zp,
args=quant_args,
dtype=torch.int8,
dtype=self._parse_compression_dtype(quant_args),
)
elif name.endswith("zero_point"):
if torch.all(value == 0):
Expand Down Expand Up @@ -123,3 +131,32 @@ def decompress(
zero_point=zero_point,
)
yield merge_names(weight_name, "weight"), decompressed

def _parse_compression_dtype(self, args: QuantizationArgs) -> torch.dtype:
if args.type is QuantizationType.FLOAT:
return FP8_DTYPE
else: # QuantizationType.INT
if args.num_bits <= 8:
return torch.int8
elif args.num_bits <= 16:
return torch.int16
else:
return torch.int32


@Compressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(QuantizationCompressor):
"""
Alias for integer quantized models
"""

pass


@Compressor.register(name=CompressionFormat.float_quantized.value)
class FloatQuantizationCompressor(QuantizationCompressor):
"""
Alias for fp quantized models
"""

pass
2 changes: 2 additions & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
int_quantized = "int-quantized"
float_quantized = "float-quantized"
naive_quantized = "naive-quantized"
pack_quantized = "pack-quantized"


Expand Down
66 changes: 54 additions & 12 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
from typing import Optional

import torch
from compressed_tensors.quantization.observers.helpers import calculate_range
from compressed_tensors.quantization.quant_args import (
FP8_DTYPE,
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
round_fp8,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import is_float_quantization
from torch.nn import Module


Expand Down Expand Up @@ -145,18 +150,26 @@ def _process_quantization(
do_quantize: bool = True,
do_dequantize: bool = True,
) -> torch.Tensor:
bit_range = 2**args.num_bits
q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
q_min = torch.tensor(-bit_range / 2, device=x.device)

q_min, q_max = calculate_range(args, x.device)
group_size = args.group_size

if args.strategy == QuantizationStrategy.GROUP:

if do_dequantize: # if dequantizing the output should be a fp type
if do_dequantize:
# if dequantizing the output should match the original weight dtype,
# which is the same as the scale's
output = torch.zeros_like(x, dtype=scale.dtype)
else:
# outputting a quantized output, use the dtype passed in as a kwarg if its
# specified, otherwise default to the input type
output_dtype = dtype if dtype is not None else x.dtype
output = torch.zeros_like(x, dtype=output_dtype)
if output_dtype is FP8_DTYPE:
# zeros_like doesn't support fp8 types directly, workaround
output = torch.zeros_like(x)
output = output.to(FP8_DTYPE)
else:
output = torch.zeros_like(x, dtype=output_dtype)

# TODO: vectorize the for loop
# TODO: fix genetric assumption about the tensor size for computing group
Expand Down Expand Up @@ -184,7 +197,13 @@ def _process_quantization(
idx = i * group_size
if do_quantize:
output[:, idx : (idx + group_size)] = _quantize(
x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
x[:, idx : (idx + group_size)],
sc,
zp,
q_min,
q_max,
quantization_type=args.type,
dtype=dtype,
)
if do_dequantize:
input = (
Expand All @@ -196,7 +215,15 @@ def _process_quantization(

else: # covers channel, token and tensor strategies
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
output = _quantize(
x,
scale,
zero_point,
q_min,
q_max,
quantization_type=args.type,
dtype=dtype,
)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

Expand Down Expand Up @@ -290,13 +317,25 @@ def _quantize(
zero_point: torch.Tensor,
q_min: torch.Tensor,
q_max: torch.Tensor,
quantization_type: Optional[QuantizationType] = QuantizationType.INT,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
quantized_value = torch.clamp(
torch.round(x / scale + zero_point),
q_min,
q_max,
)

scaled = x / scale + zero_point
if quantization_type is QuantizationType.FLOAT:
# clamp first because cast isn't saturated
quantized_value = torch.clamp(
scaled,
q_min,
q_max,
)
quantized_value = round_fp8(quantized_value, FP8_DTYPE)
else:
quantized_value = torch.clamp(
torch.round(scaled),
q_min,
q_max,
)

if dtype is not None:
quantized_value = quantized_value.to(dtype)
Expand All @@ -310,4 +349,7 @@ def _dequantize(
scale: torch.Tensor,
zero_point: torch.Tensor,
) -> torch.Tensor:
if is_float_quantization(x_q):
# can't perform arithmetic in fp8 types, need to convert first
return (x_q.to(scale.dtype) - zero_point.to(scale.dtype)) * scale
return (x_q - zero_point) * scale
12 changes: 10 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationType,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter
Expand Down Expand Up @@ -95,7 +98,12 @@ def _initialize_scale_zero_point_observer(
)
module.register_parameter(f"{base_name}_scale", init_scale)

zp_dtype = (
torch.int8
if quantization_args.type is QuantizationType.INT
else module.weight.dtype
)
init_zero_point = Parameter(
torch.empty(0, device=device, dtype=int), requires_grad=False
torch.empty(0, device=device, dtype=zp_dtype), requires_grad=False
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
61 changes: 52 additions & 9 deletions src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
from typing import Tuple

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
FP8_DTYPE,
QuantizationArgs,
QuantizationType,
round_fp8,
)
from torch import FloatTensor, IntTensor, Tensor


__all__ = ["calculate_qparams"]
__all__ = ["calculate_qparams", "calculate_range"]


def calculate_qparams(
Expand All @@ -37,18 +42,56 @@ def calculate_qparams(
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
device = min_vals.device

bit_range = 2**quantization_args.num_bits - 1
bit_min = -(bit_range + 1) / 2
bit_max = bit_min + bit_range
bit_min, bit_max = calculate_range(quantization_args, device)
bit_range = bit_max - bit_min

if quantization_args.symmetric:
max_val_pos = torch.max(-min_vals, max_vals)
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
scales = max_val_pos / (float(bit_range) / 2)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)

# set zero_points to correct types
if quantization_args.type == QuantizationType.FLOAT:
zero_points = round_fp8(zero_points, FP8_DTYPE)
else: # QuantizationType.INT
zero_points = zero_points.to(torch.int8)
else:
scales = (max_vals - min_vals) / float(bit_range)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
zero_points = bit_min - torch.round(min_vals / scales)
zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)

if quantization_args.type == QuantizationType.FLOAT:
zero_points = bit_min - (min_vals / scales)
zero_points = round_fp8(
torch.clamp(zero_points, bit_min, bit_max), FP8_DTYPE
)
else: # QuantizationType.INT
zero_points = bit_min - torch.round(min_vals / scales)
zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)

return scales, zero_points


def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
"""
Calculated the effective quantization range for the given Quantization Args

:param quantization_args: quantization args to get range of
:param device: device to store the range to
:return: tuple endpoints for the given quantization range
"""
if quantization_args.type == QuantizationType.INT:
bit_range = 2**quantization_args.num_bits
q_max = torch.tensor(bit_range / 2 - 1, device=device)
q_min = torch.tensor(-bit_range / 2, device=device)
else: # QuantizationType.FLOAT
if quantization_args.num_bits != 8:
raise ValueError(
"Floating point quantization is only supported for 8 bits,"
f"got {quantization_args.num_bits}"
)
fp_range_info = torch.finfo(FP8_DTYPE)
q_max = torch.tensor(fp_range_info.max, device=device)
q_min = torch.tensor(fp_range_info.min, device=device)

return q_min, q_max
Loading
Loading