Skip to content

Commit

Permalink
enable kernel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bigPYJ1151 committed Aug 17, 2024
1 parent 99739e6 commit 9eeddea
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 17 deletions.
8 changes: 8 additions & 0 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ docker exec cpu-test bash -c "
pip install pytest
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported

# Run Compressed-Tensor test
docker exec cpu-test bash -c "
pytest -v \
tests/kernels/test_int8_quant.py::test_dynamic_scaled_int8_quant \
tests/kernels/test_int8_quant.py::test_static_scaled_int8_quant \
tests/kernels/test_cutlass.py::test_cutlass_int8_gemm \
tests/kernels/test_cutlass.py::test_cutlass_int8_gemm_output_dtype"

# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
Expand Down
5 changes: 4 additions & 1 deletion tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
return torch.as_tensor(
x,
dtype=torch.float32,
device='cuda' if torch.cuda.is_available() else 'cpu')

def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype,
Expand Down
28 changes: 18 additions & 10 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@

from vllm import _custom_ops as ops
from vllm.platforms import current_platform

CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
from vllm.utils import is_cpu

if not is_cpu():
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
OUTPUT_DTYPES = [torch.bfloat16, torch.float16]
DEFAULT_DEVICE = "cuda"
else:
CUDA_DEVICES = ["cpu"]
OUTPUT_DTYPES = [torch.bfloat16]
DEFAULT_DEVICE = "cpu"

capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
capability = capability[0] * 10 + capability[1] if capability is not None else 0


def to_fp8(tensor: torch.Tensor):
Expand Down Expand Up @@ -84,7 +92,7 @@ def cutlass_int8_gemm_helper(m: int,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
device: str = DEFAULT_DEVICE):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
a = to_int8(torch.randn((m, k), device=device) * 5)
Expand Down Expand Up @@ -135,7 +143,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,

@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("out_dtype", OUTPUT_DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype],
Expand All @@ -151,7 +159,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,

@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("out_dtype", OUTPUT_DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
Expand Down Expand Up @@ -227,7 +235,7 @@ def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("out_dtype", OUTPUT_DTYPES)
@pytest.mark.skip
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
out_dtype: torch.dtype):
Expand Down Expand Up @@ -278,7 +286,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
@pytest.mark.parametrize("m", [32, 64, 128])
@pytest.mark.parametrize("n", [16, 32, 64])
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("out_dtype", OUTPUT_DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("azp_per_token", [True, False])
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
Expand Down
20 changes: 14 additions & 6 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@

from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from vllm._custom_ops import scaled_int8_quant
from vllm.utils import is_cpu

DTYPES = [torch.half, torch.bfloat16, torch.float]
if not is_cpu():
DTYPES = [torch.half, torch.bfloat16, torch.float]
DEVICE = "cuda"
else:
DTYPES = [torch.bfloat16, torch.float]
DEVICE = "cpu"
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
Expand All @@ -20,9 +26,10 @@
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device=DEVICE) * 1000

# reference
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
Expand All @@ -45,11 +52,12 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device=DEVICE) * 1000
scale = torch.tensor([scale], dtype=torch.float32, device=DEVICE)

out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
Expand Down

0 comments on commit 9eeddea

Please sign in to comment.