Skip to content

Commit

Permalink
Update vllm/model_executor/layers/quantization/compressed_tensors/tri…
Browse files Browse the repository at this point in the history
…ton_scaled_mm.py


Great suggestion to use reshape operator. I think the change should be like this
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b

Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
kewang-xlnx and mgoin authored Dec 17, 2024
1 parent fab2244 commit 4ca3bb8
Showing 1 changed file with 2 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,8 @@ def triton_scaled_mm(input: torch.Tensor,
assert weight.shape[0] == K
assert input.dtype == weight.dtype

if scale_a.dim() == 0:
scale_a.data = scale_a.data.unsqueeze(dim=0)
if scale_a.dim() == 1:
scale_a.data = scale_a.data.unsqueeze(dim=1)

if scale_b.dim() == 0:
scale_b.data = scale_b.data.unsqueeze(dim=0)
if scale_b.dim() == 1:
scale_b.data = scale_b.data.unsqueeze(dim=1)
scale_a = scale_a.reshape(1, -1) if scale_a.dim() <= 1 else scale_a
scale_b = scale_b.reshape(1, -1) if scale_b.dim() <= 1 else scale_b

assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
Expand Down

0 comments on commit 4ca3bb8

Please sign in to comment.