Skip to content

Commit

Permalink
[Kernel] Change interface to Mamba selective_state_update for continu…
Browse files Browse the repository at this point in the history
…ous batching (vllm-project#8039)
  • Loading branch information
tlrmchlsmth authored and MengqingCao committed Sep 30, 2024
1 parent 50847a3 commit 7ed9b8b
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 3 deletions.
146 changes: 146 additions & 0 deletions tests/kernels/test_mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype):

assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 7e-2, 7e-2
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 16

total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)

x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)

assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
dim, dstate, ngroups, has_z, tie_hdim, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 16
headdim = 64
nheads = dim // headdim

total_entries = 10 * batch_size
state = torch.randn(total_entries,
nheads,
headdim,
dstate,
dtype=itype,
device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)

x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
if not tie_hdim:
dt = torch.randn(batch_size,
nheads,
headdim,
device=device,
dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(torch.randn(batch_size, nheads, device=device,
dtype=itype),
"b h -> b h p",
p=headdim)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
"h -> h p",
p=headdim)
A = repeat(-torch.rand(nheads, device=device) - 1.0,
"h -> h p n",
p=headdim,
n=dstate)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)

print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
31 changes: 28 additions & 3 deletions vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py

import torch
import triton
Expand Down Expand Up @@ -27,6 +28,10 @@ def softplus(dt):
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({
"HAS_STATE_BATCH_INDICES":
lambda args: args["state_batch_indices_ptr"] is not None
})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
Expand All @@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
D_ptr,
z_ptr,
out_ptr,
state_batch_indices_ptr,
# Matrix dimensions
batch,
nheads,
Expand Down Expand Up @@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head

# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr)
state_ptr += (state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
else:
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head

x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
Expand Down Expand Up @@ -177,7 +195,8 @@ def selective_state_update(state,
D=None,
z=None,
dt_bias=None,
dt_softplus=False):
dt_softplus=False,
state_batch_indices=None):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
Expand Down Expand Up @@ -211,7 +230,10 @@ def selective_state_update(state,
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape

_, nheads, dim, dstate = state.shape
batch = x.shape[0]

assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
Expand All @@ -225,6 +247,8 @@ def selective_state_update(state,
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch, )
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
Expand All @@ -249,6 +273,7 @@ def selective_state_update(state,
D,
z,
out,
state_batch_indices,
batch,
nheads,
dim,
Expand Down

0 comments on commit 7ed9b8b

Please sign in to comment.