Skip to content

Commit

Permalink
Add flash decoding for GQA attention on GPU.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607766174
  • Loading branch information
zhangqiaorjc authored and pax authors committed Feb 16, 2024
1 parent c8d96d0 commit d384d6a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
1 change: 1 addition & 0 deletions praxis/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ pytype_strict_library(
visibility = JAX_VISIBILITY,
deps = [
":attentions",
":grouped_query_attention",
":normalizations",
# Implicit jax dependency.
# Implicit Pallas GPU dependency.
Expand Down
73 changes: 72 additions & 1 deletion praxis/layers/gpu_fast_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import functools
import logging
import math
import os
from typing import Tuple

import jax
from jax.experimental.shard_map import shard_map
Expand All @@ -30,6 +30,7 @@
from praxis import py_utils
from praxis import pytypes
from praxis.layers import attentions
from praxis.layers import grouped_query_attention
from praxis.layers import normalizations

# pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -242,6 +243,76 @@ def sharded_decode_mha(q, k, v):
return encoded, None # pytype: disable=bad-return-type # jax-ndarray


class GpuTritonFusedGroupedQueryAttention(
grouped_query_attention.GroupedQueryAttention
):
"""Using flash decoding for GroupedQueryAttention."""

# Note that flash decoding may speedup MQA and GQA only.
# XLA MHA may still run faster than Pallas Flash decoding.
# Tune k_splits and then measure before toggling on flash decoding.
# https://crfm.stanford.edu/2023/10/12/flashdecoding.html
use_flash_decoding: bool = False
flash_decoding_k_splits: int = 16

def _get_mesh(self) -> jax.sharding.Mesh:
device_mesh = py_utils.create_device_mesh(
self.ici_mesh_shape,
self.dcn_mesh_shape,
contiguous_submeshes=self.contiguous_submeshes,
)
mesh = jax.sharding.Mesh(device_mesh, self.mesh_axis_names)
return mesh

def _atten_context(
self,
query: JTensor,
key: JTensor,
value: JTensor,
atten_mask: JTensor,
) -> Tuple[JTensor, JTensor]:
"""Computes atten context."""
b, t, n, h = query.shape
is_decoding = t == 1
if not is_decoding or not self.use_flash_decoding:
return super()._atten_context(query, key, value, atten_mask)

if self.atten_dropout_prob > 0.0 and not self.do_eval:
raise NotImplementedError
if self.atten_logit_cap > 0.0:
raise NotImplementedError

query = query * (self.dim_per_head**-0.5) / self.atten_temp
query = query.reshape([b, n, h])

# Assume causal self-attention mask. Not supporting cross_attention.
sh = self.activation_split_dims_mapping
bnh_pspec = jax.sharding.PartitionSpec(sh.btnh[0], sh.btnh[2], sh.btnh[3])
blnh_pspec = jax.sharding.PartitionSpec(sh.bskh)

@functools.partial(
shard_map,
mesh=self.get_mesh(),
in_specs=(
bnh_pspec,
blnh_pspec,
blnh_pspec,
),
out_specs=bnh_pspec,
check_rep=False,
)
def sharded_decode_gqa(q, k, v):
return decode_attention.gqa(
q, # [batch_size, num_q_heads, head_dim]
k, # [batch_size, k_seq_len, num_kv_heads, head_dim]
v, # [batch_size, k_seq_len, num_kv_heads, head_dim]
k_splits=self.flash_decoding_k_splits,
)

encoded = sharded_decode_gqa(query, key, value)
return encoded, None # pytype: disable=bad-return-type # jax-ndarray


class GpuTritonFusedLayerNorm(normalizations.LayerNorm):

def _ble_pspec(self):
Expand Down

0 comments on commit d384d6a

Please sign in to comment.