Skip to content

Commit

Permalink
fp8: add calibration scale for decode attention operators (flashinfer…
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored Jun 1, 2024
1 parent 64e935a commit 041b63a
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 8 deletions.
126 changes: 119 additions & 7 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def single_decode_with_kv_cache(
v: torch.Tensor,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand All @@ -80,6 +83,12 @@ def single_decode_with_kv_cache(
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -120,11 +129,15 @@ def single_decode_with_kv_cache(
if sm_scale is None:
head_dim = q.shape[-1]
sm_scale = 1.0 / math.sqrt(head_dim)
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
return _kernels.single_decode_with_kv_cache(
out = _kernels.single_decode_with_kv_cache(
q,
k,
v,
Expand All @@ -135,6 +148,9 @@ def single_decode_with_kv_cache(
rope_scale,
rope_theta,
)
if v_scale is not None:
out *= v_scale
return out


def batch_decode_with_padded_kv_cache(
Expand All @@ -143,6 +159,9 @@ def batch_decode_with_padded_kv_cache(
v_padded: torch.Tensor,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand All @@ -169,6 +188,12 @@ def batch_decode_with_padded_kv_cache(
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -208,11 +233,15 @@ def batch_decode_with_padded_kv_cache(
if sm_scale is None:
head_dim = q.shape[-1]
sm_scale = 1.0 / math.sqrt(head_dim)
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
return _kernels.batch_decode_with_padded_kv_cache(
out = _kernels.batch_decode_with_padded_kv_cache(
q,
k_padded,
v_padded,
Expand All @@ -223,6 +252,9 @@ def batch_decode_with_padded_kv_cache(
rope_theta,
False,
)[0]
if v_scale is not None:
out *= v_scale
return out


def batch_decode_with_padded_kv_cache_return_lse(
Expand All @@ -231,6 +263,9 @@ def batch_decode_with_padded_kv_cache_return_lse(
v_padded: torch.Tensor,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand Down Expand Up @@ -258,6 +293,12 @@ def batch_decode_with_padded_kv_cache_return_lse(
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -304,11 +345,15 @@ def batch_decode_with_padded_kv_cache_return_lse(
if sm_scale is None:
head_dim = q.shape[-1]
sm_scale = 1.0 / math.sqrt(head_dim)
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
return _kernels.batch_decode_with_padded_kv_cache(
V, s = _kernels.batch_decode_with_padded_kv_cache(
q,
k_padded,
v_padded,
Expand All @@ -319,6 +364,9 @@ def batch_decode_with_padded_kv_cache_return_lse(
rope_theta,
True,
)
if v_scale is not None:
V *= v_scale
return V, s


class BatchDecodeWithPagedKVCacheWrapper:
Expand Down Expand Up @@ -508,6 +556,9 @@ def forward(
q: torch.Tensor,
paged_kv_data: torch.Tensor,
pos_encoding_mode: str = "NONE",
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand All @@ -527,6 +578,12 @@ def forward(
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand All @@ -544,13 +601,17 @@ def forward(
if sm_scale is None:
head_dim = q.shape[-1]
sm_scale = 1.0 / math.sqrt(head_dim)
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4

paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
out = self._wrapper.forward(
q,
paged_kv_data,
self._paged_kv_indptr,
Expand All @@ -562,12 +623,18 @@ def forward(
rope_theta,
False,
)[0]
if v_scale is not None:
out *= v_scale
return out

def forward_return_lse(
self,
q: torch.Tensor,
paged_kv_data: torch.Tensor,
pos_encoding_mode: str = "NONE",
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand All @@ -588,6 +655,12 @@ def forward_return_lse(
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand All @@ -612,12 +685,16 @@ def forward_return_lse(
if sm_scale is None:
head_dim = q.shape[-1]
sm_scale = 1.0 / math.sqrt(head_dim)
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
V, s = self._wrapper.forward(
q,
paged_kv_data,
self._paged_kv_indptr,
Expand All @@ -629,6 +706,9 @@ def forward_return_lse(
rope_theta,
True,
)
if v_scale is not None:
V *= v_scale
return V, s


class CUDAGraphBatchDecodeWithPagedKVCacheWrapper:
Expand Down Expand Up @@ -788,6 +868,9 @@ def forward(
q: torch.Tensor,
paged_kv_data: torch.Tensor,
pos_encoding_mode: str = "NONE",
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand All @@ -807,6 +890,12 @@ def forward(
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand All @@ -824,13 +913,17 @@ def forward(
if sm_scale is None:
head_dim = q.shape[-1]
sm_scale = 1.0 / math.sqrt(head_dim)
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4

paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
out = self._wrapper.forward(
q,
paged_kv_data,
self._paged_kv_indptr_buf,
Expand All @@ -842,12 +935,18 @@ def forward(
rope_theta,
False,
)[0]
if v_scale is not None:
out *= v_scale
return out

def forward_return_lse(
self,
q: torch.Tensor,
paged_kv_data: torch.Tensor,
pos_encoding_mode: str = "NONE",
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
Expand All @@ -868,6 +967,12 @@ def forward_return_lse(
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand All @@ -892,12 +997,16 @@ def forward_return_lse(
if sm_scale is None:
head_dim = q.shape[-1]
sm_scale = 1.0 / math.sqrt(head_dim)
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
rope_theta = 1e4
paged_kv_data = expand_5d(paged_kv_data, self._kv_layout)
return self._wrapper.forward(
V, s = self._wrapper.forward(
q,
paged_kv_data,
self._paged_kv_indptr_buf,
Expand All @@ -911,3 +1020,6 @@ def forward_return_lse(
rope_theta,
True,
)
if v_scale is not None:
V *= v_scale
return V, s
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_instantiation_cu() -> List[str]:
","
)
allow_fp16_qk_reduction_options = os.environ.get(
"FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0,1"
"FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0"
).split(",")
mask_modes = os.environ.get("FLASHINFER_MASK_MODES", "0,1,2").split(",")
# dispatch.inc
Expand Down
Loading

0 comments on commit 041b63a

Please sign in to comment.