Skip to content

Commit

Permalink
Refactored files
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 7, 2024
1 parent 4921bb3 commit 235052f
Showing 1 changed file with 35 additions and 141 deletions.
176 changes: 35 additions & 141 deletions torchao/dtypes/floatx/floatx_tensor_core_layout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import reduce
from typing import Optional, Tuple
from typing import Tuple, Optional

import torch
from torch import Tensor
Expand All @@ -25,23 +24,11 @@


def _pack(x: Tensor, n_bits: int) -> Tensor:
return reduce(
torch.bitwise_or,
[
x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits)
for i in range(8 // n_bits)
],
)
return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)])


def _unpack(x: Tensor, n_bits: int) -> Tensor:
return torch.stack(
[
(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1)
for i in range(8 // n_bits)
],
dim=-1,
).flatten(-2)
return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2)


# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
Expand All @@ -55,40 +42,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:

if not undo:
bit_order = {
1: [
1,
5,
9,
13,
17,
21,
25,
29,
3,
7,
11,
15,
19,
23,
27,
31,
0,
4,
8,
12,
16,
20,
24,
28,
2,
6,
10,
14,
18,
22,
26,
30,
],
1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31,
0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30],
2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14],
4: [1, 5, 3, 7, 0, 4, 2, 6],
}[n_bits]
Expand All @@ -97,40 +52,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
# this is inverse of the above, obtained by running
# [v.index(i) for i in range(len(v))]
bit_order = {
1: [
16,
0,
24,
8,
17,
1,
25,
9,
18,
2,
26,
10,
19,
3,
27,
11,
20,
4,
28,
12,
21,
5,
29,
13,
22,
6,
30,
14,
23,
7,
31,
15,
],
1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11,
20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15],
2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7],
4: [4, 0, 6, 2, 5, 1, 7, 3],
}[n_bits]
Expand Down Expand Up @@ -166,12 +89,8 @@ def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask
tensor_ybit = _pack(tensor_ybit, y)

tensor_ybit = (
tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2)
) # Pass 2 from original code
tensor_ybit = _bit_interleave(
tensor_ybit.flatten(), y
) # Pass 3 from original code
tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code
tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code
fragments.append(tensor_ybit)
used_bits += y

Expand Down Expand Up @@ -205,19 +124,15 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
return _pack_tc_floatx(tensor, nbits)


def to_scaled_tc_floatx(
tensor: Tensor, ebits: int, mbits: int
) -> Tuple[Tensor, Tensor]:
def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]:
# _n_ones() is not compatible with torch.compile() due to << operator
# https://github.com/pytorch/pytorch/issues/119152
# exp_bias = _n_ones(ebits - 1)
# max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))

# workaround: global lookup table
exp_bias = _ONES_TABLE[ebits - 1]
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (
_ONES_TABLE[mbits + 1] / (2**mbits)
)
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))

dtype = tensor.dtype
tensor = tensor.float()
Expand All @@ -244,10 +159,8 @@ def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor:
tensor_ybit = tensor[offset : offset + size_ybit]
offset += size_ybit

tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
tensor_ybit = (
tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2)
) # undo Pass 2
tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3
tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2

tensor_ybit = _unpack(tensor_ybit.flatten(), y)
tensor_ybit = tensor_ybit << (nbits - used_bits - y)
Expand Down Expand Up @@ -318,7 +231,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 5,
14336: 7,
28672: 7,
57344: 7,
57344: 7
},
{ # tokens: [65:128]
3072: 9,
Expand All @@ -329,7 +242,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 5,
14336: 7,
28672: 7,
57344: 6,
57344: 6
},
{ # tokens: [129:192]
3072: 6,
Expand All @@ -340,7 +253,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 5,
14336: 5,
28672: 5,
57344: 4,
57344: 4
},
{ # tokens: [193:256]
3072: 9,
Expand All @@ -351,7 +264,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 4,
14336: 8,
28672: 6,
57344: 4,
57344: 4
},
{ # tokens: [257:320]
3072: 7,
Expand All @@ -362,7 +275,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 1,
14336: 3,
28672: 3,
57344: 4,
57344: 4
},
{ # tokens: [321:384]
3072: 3,
Expand All @@ -373,7 +286,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 8,
14336: 3,
28672: 4,
57344: 3,
57344: 3
},
{ # tokens: [385:448]
3072: 5,
Expand All @@ -384,7 +297,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 3,
14336: 1,
28672: 1,
57344: 3,
57344: 3
},
{ # tokens: [449:512]
3072: 2,
Expand All @@ -395,7 +308,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 2,
14336: 6,
28672: 4,
57344: 1,
57344: 1
},
{ # tokens: [513:576]
3072: 2,
Expand All @@ -406,7 +319,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 3,
14336: 3,
28672: 1,
57344: 1,
57344: 1
},
{ # tokens: [577:640]
3072: 5,
Expand All @@ -417,7 +330,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 1,
14336: 1,
28672: 1,
57344: 1,
57344: 1
},
{ # tokens: [641:704]
3072: 3,
Expand All @@ -428,7 +341,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 2,
14336: 1,
28672: 1,
57344: 1,
57344: 1
},
{ # tokens: [705:768]
3072: 3,
Expand All @@ -439,18 +352,17 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) ->
10240: 1,
14336: 1,
28672: 1,
57344: 1,
},
57344: 1
}
]


# quantization api integrations


@dataclass(frozen=True)
class FloatxTensorCoreLayout(Layout):
"""Layout type for FloatxTensorCoreAQTTensorImpl"""

"""Layout type for FloatxTensorCoreAQTTensorImpl
"""
ebits: int
mbits: int

Expand Down Expand Up @@ -478,7 +390,6 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl):
it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor
FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit)
"""

def __new__(
cls,
packed_floatx_data: torch.Tensor,
Expand All @@ -487,16 +398,11 @@ def __new__(
):
assert packed_floatx_data.ndim == 2
assert packed_floatx_data.dtype == torch.uint8
shape = (
packed_floatx_data.shape[0],
packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8,
)
shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8)
kwargs = {}
kwargs["device"] = packed_floatx_data.device
kwargs["layout"] = (
kwargs.get("layout")
if kwargs.get("layout", False)
else packed_floatx_data.layout
kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout
)
kwargs["dtype"] = packed_floatx_data.dtype
kwargs["requires_grad"] = False
Expand All @@ -519,17 +425,12 @@ def __tensor_flatten__(self):
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_floatx_data, scale = (
tensor_data_dict["packed_floatx_data"],
tensor_data_dict["scale"],
)
(_layout,) = tensor_attributes
packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"]
_layout, = tensor_attributes
return cls(packed_floatx_data, scale, _layout)

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]:
unpacked_floatx_data = unpack_tc_floatx(
self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits
)
unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits)
return unpacked_floatx_data, self.scale

@classmethod
Expand All @@ -548,9 +449,7 @@ def from_plain(
bit, M is mantissa bit
"""
assert isinstance(_layout, FloatxTensorCoreLayout)
packed_floatx_data = pack_tc_floatx(
unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits
)
packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits)
return cls(packed_floatx_data, scale, _layout)

def __repr__(self):
Expand Down Expand Up @@ -588,12 +487,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)
elif func is aten._to_copy.default:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(
lambda x: x.to(device=kwargs.pop("device", None))
),
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))),
)

raise NotImplementedError(
Expand Down

0 comments on commit 235052f

Please sign in to comment.