Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 31, 2025
1 parent 3cdd2ce commit c9d72cb
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def group_broadcast(t, shape):


# Quantize assuming once scale per group of elements with shape group_shape,
# Defaults to quantizing to correct platform specific float8 type
# example group shapes:
# * (-1, -1) for per-tensor quantization
# * (1, -1) for per-row quantization
Expand All @@ -64,14 +63,14 @@ def group_broadcast(t, shape):
def scaled_quantize(
x: torch.Tensor,
group_shape: Tuple[int, int],
dtype: torch.dtype,
tgt_dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
assert dtype.is_floating_point, \
assert tgt_dtype.is_floating_point, \
"currently `scaled_quantize` only supports floating point dtypes " \
"but could be extended to support other dtypes"

finfo = torch.finfo(dtype)
finfo = torch.finfo(tgt_dtype)

# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
Expand All @@ -97,7 +96,7 @@ def scaled_quantize(
.permute(0, 2, 1, 3)\
.reshape(x.shape)

return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
return x_scl_sat.to(tgt_dtype).contiguous(), scale.float().reciprocal()


# inverses `scaled_quantize`
Expand Down

0 comments on commit c9d72cb

Please sign in to comment.