diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 73321c0d0aa..ded28b4123b 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -103,6 +103,14 @@ def fasterprune( W = W.t() W = W.float() + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + tick = time.time() dead = torch.diag(self.H) == 0 @@ -119,17 +127,6 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H - sparsity = tensor_sparsity(W) - mask = ( - torch.where( - W == 0, - torch.tensor(1, dtype=torch.bool), - torch.tensor(0, dtype=torch.bool), - ) - if sparsity >= SPARSITY_THRESHOLD - else None - ) - # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -141,21 +138,13 @@ def fasterprune( Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] - if sparsity >= SPARSITY_THRESHOLD: - tmp = ( - (~mask[:, i1:i2]) - * W1**2 - / (torch.diag(Hinv1).reshape((1, -1))) ** 2 - ) - thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] - mask1 = tmp <= thresh + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): w = W1[:, i] d = Hinv1[i, i] q = w.clone() - if sparsity >= SPARSITY_THRESHOLD: - q[mask1[:, i]] = 0 if hasattr(self.layer, "weight_fake_quant"): scale = self.layer.weight_fake_quant.scale @@ -216,13 +205,21 @@ def fasterprune( Losses1[:, i] = (w - q) ** 2 / d**2 err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if preserve_zeros: + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] + else: + W1[:, i:] -= w1_err Err1[:, i] = err1 W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) diff --git a/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py index a068c391431..eca6f5d2379 100644 --- a/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py +++ b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py @@ -19,6 +19,7 @@ import pytest import sparseml +from compressed_tensors.compressors.utils import tensor_follows_mask_structure from parameterized import parameterized_class from tests.testing_utils import parse_params, requires_torch @@ -28,29 +29,6 @@ ) -def tensor_follows_mask_structure(tensor, mask: str = "2:4"): - """ - :param tensor: tensor to check - :param mask: mask structure to check for, in the format "n:m" - :return: True if the tensor follows the mask structure, False otherwise. - Note, some weights can incidentally be zero, so we check for - atleast n zeros in each chunk of size m - """ - import torch - - n, m = tuple(map(int, mask.split(":"))) - # Reshape the tensor into chunks of size m - tensor = tensor.view(-1, m) - - # Count the number of zeros in each chunk - zero_counts = (tensor == 0).sum(dim=1) - - # Check if the number of zeros in each chunk atleast n - # Greater than sign is needed as some weights can incidentally - # be zero - return torch.all(zero_counts >= n) - - @requires_torch @pytest.mark.integration @parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY))