From 235052f61efbd5e15cfb55ddf699bd4a0715f6f5 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 6 Nov 2024 08:53:35 -0800 Subject: [PATCH] Refactored files --- .../floatx/floatx_tensor_core_layout.py | 176 ++++-------------- 1 file changed, 35 insertions(+), 141 deletions(-) diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index e801182559..cfdb566279 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from functools import reduce -from typing import Optional, Tuple +from typing import Tuple, Optional import torch from torch import Tensor @@ -25,23 +24,11 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce( - torch.bitwise_or, - [ - x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) - for i in range(8 // n_bits) - ], - ) + return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack( - [ - (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) - for i in range(8 // n_bits) - ], - dim=-1, - ).flatten(-2) + return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -55,40 +42,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [ - 1, - 5, - 9, - 13, - 17, - 21, - 25, - 29, - 3, - 7, - 11, - 15, - 19, - 23, - 27, - 31, - 0, - 4, - 8, - 12, - 16, - 20, - 24, - 28, - 2, - 6, - 10, - 14, - 18, - 22, - 26, - 30, - ], + 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, + 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -97,40 +52,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [ - 16, - 0, - 24, - 8, - 17, - 1, - 25, - 9, - 18, - 2, - 26, - 10, - 19, - 3, - 27, - 11, - 20, - 4, - 28, - 12, - 21, - 5, - 29, - 13, - 22, - 6, - 30, - 14, - 23, - 7, - 31, - 15, - ], + 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, + 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -166,12 +89,8 @@ def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = ( - tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) - ) # Pass 2 from original code - tensor_ybit = _bit_interleave( - tensor_ybit.flatten(), y - ) # Pass 3 from original code + tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code + tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -205,9 +124,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: return _pack_tc_floatx(tensor, nbits) -def to_scaled_tc_floatx( - tensor: Tensor, ebits: int, mbits: int -) -> Tuple[Tensor, Tensor]: +def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -215,9 +132,7 @@ def to_scaled_tc_floatx( # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( - _ONES_TABLE[mbits + 1] / (2**mbits) - ) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) dtype = tensor.dtype tensor = tensor.float() @@ -244,10 +159,8 @@ def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = ( - tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) - ) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -318,7 +231,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 7, + 57344: 7 }, { # tokens: [65:128] 3072: 9, @@ -329,7 +242,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 6, + 57344: 6 }, { # tokens: [129:192] 3072: 6, @@ -340,7 +253,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 5, 28672: 5, - 57344: 4, + 57344: 4 }, { # tokens: [193:256] 3072: 9, @@ -351,7 +264,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 4, 14336: 8, 28672: 6, - 57344: 4, + 57344: 4 }, { # tokens: [257:320] 3072: 7, @@ -362,7 +275,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 3, 28672: 3, - 57344: 4, + 57344: 4 }, { # tokens: [321:384] 3072: 3, @@ -373,7 +286,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 8, 14336: 3, 28672: 4, - 57344: 3, + 57344: 3 }, { # tokens: [385:448] 3072: 5, @@ -384,7 +297,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 1, 28672: 1, - 57344: 3, + 57344: 3 }, { # tokens: [449:512] 3072: 2, @@ -395,7 +308,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 6, 28672: 4, - 57344: 1, + 57344: 1 }, { # tokens: [513:576] 3072: 2, @@ -406,7 +319,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 3, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [577:640] 3072: 5, @@ -417,7 +330,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [641:704] 3072: 3, @@ -428,7 +341,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 1, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [705:768] 3072: 3, @@ -439,18 +352,17 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1, - }, + 57344: 1 + } ] # quantization api integrations - @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl""" - + """Layout type for FloatxTensorCoreAQTTensorImpl + """ ebits: int mbits: int @@ -478,7 +390,6 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ - def __new__( cls, packed_floatx_data: torch.Tensor, @@ -487,16 +398,11 @@ def __new__( ): assert packed_floatx_data.ndim == 2 assert packed_floatx_data.dtype == torch.uint8 - shape = ( - packed_floatx_data.shape[0], - packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, - ) + shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8) kwargs = {} kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_floatx_data.layout + kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout ) kwargs["dtype"] = packed_floatx_data.dtype kwargs["requires_grad"] = False @@ -519,17 +425,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_floatx_data, scale = ( - tensor_data_dict["packed_floatx_data"], - tensor_data_dict["scale"], - ) - (_layout,) = tensor_attributes + packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] + _layout, = tensor_attributes return cls(packed_floatx_data, scale, _layout) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx( - self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits - ) + unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits) return unpacked_floatx_data, self.scale @classmethod @@ -548,9 +449,7 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx( - unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits - ) + packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits) return cls(packed_floatx_data, scale, _layout) def __repr__(self): @@ -588,12 +487,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: x.to(device=kwargs.pop("device", None)) - ), + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), ) raise NotImplementedError(