From d384d6a50e235d3dd89b54e96274a46bbd9a5ce2 Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Fri, 16 Feb 2024 12:22:28 -0800 Subject: [PATCH] Add flash decoding for GQA attention on GPU. PiperOrigin-RevId: 607766174 --- praxis/layers/BUILD | 1 + praxis/layers/gpu_fast_attention.py | 73 ++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/praxis/layers/BUILD b/praxis/layers/BUILD index 5347627f..90124f8c 100644 --- a/praxis/layers/BUILD +++ b/praxis/layers/BUILD @@ -111,6 +111,7 @@ pytype_strict_library( visibility = JAX_VISIBILITY, deps = [ ":attentions", + ":grouped_query_attention", ":normalizations", # Implicit jax dependency. # Implicit Pallas GPU dependency. diff --git a/praxis/layers/gpu_fast_attention.py b/praxis/layers/gpu_fast_attention.py index ee9740c7..6668fa65 100644 --- a/praxis/layers/gpu_fast_attention.py +++ b/praxis/layers/gpu_fast_attention.py @@ -20,8 +20,8 @@ import functools import logging -import math import os +from typing import Tuple import jax from jax.experimental.shard_map import shard_map @@ -30,6 +30,7 @@ from praxis import py_utils from praxis import pytypes from praxis.layers import attentions +from praxis.layers import grouped_query_attention from praxis.layers import normalizations # pylint: disable=g-import-not-at-top @@ -242,6 +243,76 @@ def sharded_decode_mha(q, k, v): return encoded, None # pytype: disable=bad-return-type # jax-ndarray +class GpuTritonFusedGroupedQueryAttention( + grouped_query_attention.GroupedQueryAttention +): + """Using flash decoding for GroupedQueryAttention.""" + + # Note that flash decoding may speedup MQA and GQA only. + # XLA MHA may still run faster than Pallas Flash decoding. + # Tune k_splits and then measure before toggling on flash decoding. + # https://crfm.stanford.edu/2023/10/12/flashdecoding.html + use_flash_decoding: bool = False + flash_decoding_k_splits: int = 16 + + def _get_mesh(self) -> jax.sharding.Mesh: + device_mesh = py_utils.create_device_mesh( + self.ici_mesh_shape, + self.dcn_mesh_shape, + contiguous_submeshes=self.contiguous_submeshes, + ) + mesh = jax.sharding.Mesh(device_mesh, self.mesh_axis_names) + return mesh + + def _atten_context( + self, + query: JTensor, + key: JTensor, + value: JTensor, + atten_mask: JTensor, + ) -> Tuple[JTensor, JTensor]: + """Computes atten context.""" + b, t, n, h = query.shape + is_decoding = t == 1 + if not is_decoding or not self.use_flash_decoding: + return super()._atten_context(query, key, value, atten_mask) + + if self.atten_dropout_prob > 0.0 and not self.do_eval: + raise NotImplementedError + if self.atten_logit_cap > 0.0: + raise NotImplementedError + + query = query * (self.dim_per_head**-0.5) / self.atten_temp + query = query.reshape([b, n, h]) + + # Assume causal self-attention mask. Not supporting cross_attention. + sh = self.activation_split_dims_mapping + bnh_pspec = jax.sharding.PartitionSpec(sh.btnh[0], sh.btnh[2], sh.btnh[3]) + blnh_pspec = jax.sharding.PartitionSpec(sh.bskh) + + @functools.partial( + shard_map, + mesh=self.get_mesh(), + in_specs=( + bnh_pspec, + blnh_pspec, + blnh_pspec, + ), + out_specs=bnh_pspec, + check_rep=False, + ) + def sharded_decode_gqa(q, k, v): + return decode_attention.gqa( + q, # [batch_size, num_q_heads, head_dim] + k, # [batch_size, k_seq_len, num_kv_heads, head_dim] + v, # [batch_size, k_seq_len, num_kv_heads, head_dim] + k_splits=self.flash_decoding_k_splits, + ) + + encoded = sharded_decode_gqa(query, key, value) + return encoded, None # pytype: disable=bad-return-type # jax-ndarray + + class GpuTritonFusedLayerNorm(normalizations.LayerNorm): def _ble_pspec(self):