Skip to content

Commit

Permalink
[CI][Bugfix] set VllmConfig before calling CustomOP
Browse files Browse the repository at this point in the history
Signed-off-by: Mengqing Cao <cmq0113@163.com>
  • Loading branch information
MengqingCao committed Nov 18, 2024
1 parent 8796fbc commit 4be3739
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 34 deletions.
36 changes: 21 additions & 15 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import torch

from tests.kernels.utils import opcheck
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU,
QuickGELU, SiluAndMul)
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config

from .allclose_default import get_default_atol, get_default_rtol

Expand Down Expand Up @@ -40,19 +42,21 @@ def test_act_and_mul(
current_platform.seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu":
layer = SiluAndMul()
fn = torch.ops._C.silu_and_mul
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
fn = torch.ops._C.gelu_and_mul
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
fn = torch.ops._C.gelu_tanh_and_mul
elif activation == "fatrelu":
threshold = random.uniform(0, 1)
layer = FatreluAndMul(threshold)
fn = torch.ops._C.fatrelu_and_mul
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
if activation == "silu":
layer = SiluAndMul()
fn = torch.ops._C.silu_and_mul
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
fn = torch.ops._C.gelu_and_mul
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
fn = torch.ops._C.gelu_tanh_and_mul
elif activation == "fatrelu":
threshold = random.uniform(0, 1)
layer = FatreluAndMul(threshold)
fn = torch.ops._C.fatrelu_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
Expand Down Expand Up @@ -88,8 +92,10 @@ def test_activation(
current_platform.seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation[0]()
fn = activation[1]
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
layer = activation[0]()
fn = activation[1]
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(out,
Expand Down
6 changes: 5 additions & 1 deletion tests/kernels/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from tests.kernels.quant_utils import FP8_DTYPE
from tests.kernels.utils import opcheck
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
Expand Down Expand Up @@ -34,7 +36,9 @@ def test_rms_norm(
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
Expand Down
47 changes: 31 additions & 16 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config

from .allclose_default import get_default_atol, get_default_rtol

Expand Down Expand Up @@ -52,8 +54,11 @@ def test_rotary_embedding(
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype=dtype)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style)
rope = rope.to(dtype=dtype).to(device)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
Expand Down Expand Up @@ -104,11 +109,14 @@ def test_batched_rotary_embedding(
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear",
"factor": (1, )
})
rope = rope.to(dtype=dtype)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"rope_type": "linear",
"factor": (1, )
})
rope = rope.to(dtype=dtype).to(device)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
Expand Down Expand Up @@ -165,11 +173,14 @@ def test_batched_rotary_embedding_multi_lora(
if rotary_dim is None:
rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear",
"factor": tuple(scaling_factors)
})
rope = rope.to(dtype=dtype)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"rope_type": "linear",
"factor": tuple(scaling_factors)
})
rope = rope.to(dtype=dtype).to(device)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
Expand Down Expand Up @@ -225,8 +236,10 @@ def test_rope_module_cache():
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# different settings cannot share the same rope module
assert id(rope) not in rope_setting_id_map.values()
assert all(x.dtype == dtype for x in rope.buffers())
Expand All @@ -238,7 +251,9 @@ def test_rope_module_cache():
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# check if cache take effect
assert id(rope) == rope_setting_id_map[str(setting)]
8 changes: 6 additions & 2 deletions tests/kernels/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import torch

from tests.kernels.utils import opcheck
from vllm.config import VllmConfig
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.plugins import set_current_vllm_config


def rotary_embedding_opcheck(rot,
Expand Down Expand Up @@ -42,8 +44,10 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
batch_size = 1
base = 0
num_heads = 7
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)

positions = torch.randint(0,
max_position, (batch_size, seq_len),
Expand Down

0 comments on commit 4be3739

Please sign in to comment.