Skip to content

Commit

Permalink
...enable the fully_zero_masks and fix the decoding issue for local w…
Browse files Browse the repository at this point in the history
…indow size enabled...

PiperOrigin-RevId: 706247244
  • Loading branch information
The praxis Authors committed Dec 14, 2024
1 parent edc9a6f commit 91c0996
Showing 1 changed file with 77 additions and 9 deletions.
86 changes: 77 additions & 9 deletions praxis/layers/multi_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from praxis.layers import embedding_softmax
from praxis.layers import stochastics


WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
template_field = base_layer.template_field
Expand Down Expand Up @@ -182,9 +181,13 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer):
right_window_size) for each token. E.g., if local_window_size == (3, 2)
and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend
to [3, 4, 5, c, 7, 8].
zero_fully_masked: if True, attention values for fully masked tokens will be
forced to zero. This is particularily useful for cross attentions when
keys are all padded.
Note: dconv_qkv and ngrammer are not supported.
"""

input_dim: int | dict[str, int] = 0
hidden_dim: int = 0
num_heads: int = 1
Expand Down Expand Up @@ -215,6 +218,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer):
scale_query_by_dim_per_head: bool = False
chunked_attn_num_seq_split: int = 1
local_window_size: tuple[int, int] | None = None
zero_fully_masked: bool = False

# SPMD partition related params.
#
Expand Down Expand Up @@ -255,6 +259,12 @@ class ActivationSharding(base_layer.BaseLayer.ActivationSharding):
bld: SplitDimsMapping = None
bd: SplitDimsMapping = None

def _create_rotary_position_emb(
self, layer_tpl: LayerTpl, dim_per_head: int
) -> None:
pos_emb_p = layer_tpl.clone().set(embedding_dims=dim_per_head)
self.create_child('rotary_position_emb', pos_emb_p)

def setup(self) -> None:
wp = self.weight_split_dims_mapping
assert self.input_dim, 'input_dim is {}'.format(self.input_dim)
Expand Down Expand Up @@ -323,10 +333,9 @@ def project_input_kv(input_dim, dim_per_head):
)

if self.use_rotary_position_emb:
pos_emb_p = self.rotary_position_emb_tpl.clone().set(
embedding_dims=dim_per_head
self._create_rotary_position_emb(
self.rotary_position_emb_tpl, dim_per_head
)
self.create_child('rotary_position_emb', pos_emb_p)

if self.relative_bias_tpl is not None:
relative_bias_p = self.relative_bias_tpl.clone()
Expand Down Expand Up @@ -528,6 +537,14 @@ def _atten_context(
# Compute the attention context.
encoded = self.pv_einsum('BNTS,BSH->BNTH', probs, value)
encoded = encoded.transpose(0, 2, 1, 3)

if self.zero_fully_masked:
fully_masked = jnp.all(
atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2,
axis=-1,
)[:, 0, :, jnp.newaxis, jnp.newaxis]
encoded *= 1 - fully_masked

encoded = checkpoint_name(encoded, 'context')
encoded = self._shard_blnh(encoded)
return encoded, probs
Expand All @@ -543,9 +560,10 @@ def _atten_context_chunked_attn_seq_split(
"""Computes chunked attention context."""
b, t, n, _ = query.shape
_, s, h = value.shape
assert (
s % self.chunked_attn_num_seq_split == 0
), 'The number of attn splits must divide the sequence length'
assert s % self.chunked_attn_num_seq_split == 0, (
f'The number of attn splits must divide the sequence length {s}:'
f' {self.chunked_attn_num_seq_split}'
)
w = s // self.chunked_attn_num_seq_split
query = query.transpose(0, 2, 1, 3)
full_encoded = jnp.zeros((b, n, t, h), dtype=value.dtype)
Expand Down Expand Up @@ -602,6 +620,14 @@ def _atten_context_chunked_attn_seq_split(
full_encoded = full_encoded.at[:, :, i * w : (i + 1) * w, :].set(encoded)

full_encoded = full_encoded.transpose(0, 2, 1, 3)

if self.zero_fully_masked:
fully_masked = jnp.all(
atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2,
axis=-1,
)[:, 0, :, jnp.newaxis, jnp.newaxis]
full_encoded *= 1 - fully_masked

full_encoded = checkpoint_name(full_encoded, 'context')
full_encoded = self._shard_blnh(full_encoded)
full_probs = None
Expand Down Expand Up @@ -697,6 +723,12 @@ def _dot_atten_one_step(
encoded, probs = self._dot_atten_one_step_from_qkv(
query, key, value, atten_mask, relative_bias, time_step
)
if self.zero_fully_masked:
fully_masked = jnp.all(
atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2,
axis=-1,
)[..., jnp.newaxis]
encoded *= 1 - fully_masked
return self._shard_bnh(encoded), probs
else:
b, n, h = query.shape
Expand All @@ -719,7 +751,13 @@ def _dot_atten_one_step(
)(v_q, key, value, atten_mask, v_rb, time_step)
encoded = self._shard_bnh(jnp.reshape(encoded, (b, n, h)))
probs = jnp.reshape(probs, (b, n, -1))
return encoded, probs
if self.zero_fully_masked:
fully_masked = jnp.all(
atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2,
axis=-1,
)[..., jnp.newaxis]
encoded *= 1 - fully_masked
return encoded, probs

def _dot_atten_one_step_from_qkv(
self,
Expand All @@ -731,7 +769,6 @@ def _dot_atten_one_step_from_qkv(
time_step: JTensor | None = None,
) -> tuple[JTensor, JTensor]:
"""_dot_atten_one_step with tensors instead of state names."""
del time_step
# query is 3d.
extend_one_step = len(query.shape) == 3
b, s, h = key.shape
Expand All @@ -744,6 +781,29 @@ def _dot_atten_one_step_from_qkv(
asserts.in_set(atten_mask.shape[0], [1, b])

base_layer.assert_has_shape(value, [b, s, h])

if self.local_window_size is not None:
l = self.local_window_size[0] + 1
f = self.local_window_size[0] + 1 + self.local_window_size[1]

minus_inf = py_utils.get_large_negative_number(
jnp.float32 if atten_mask.dtype == jnp.float64 else atten_mask.dtype
)

# the padding is to handle the case where there is no enough data in the
# sequence for local window size, we add padding value for the missing
# data.
key = attentions._padded_slice(key, time_step + 1 - l, f, 1, 0.0) # pylint: disable=protected-access
value = attentions._padded_slice(value, time_step + 1 - l, f, 1, 0.0) # pylint: disable=protected-access
atten_mask = attentions._padded_slice( # pylint: disable=protected-access
atten_mask, time_step + 1 - l, f, -1, minus_inf
)

b, f, h = key.shape
asserts.eq(f, self.local_window_size[0] + 1 + self.local_window_size[1])
base_layer.assert_has_shape(value, [b, f, h])
asserts.in_set(atten_mask.shape[0], [b, 1])

query = self._scale_query(query)
if extend_one_step:
logits = self.qk_einsum('BNH,BSH->BNS', query, key)
Expand All @@ -756,6 +816,14 @@ def _dot_atten_one_step_from_qkv(
'relative bias.'
)
base_layer.assert_has_shape(relative_bias, [-1, -1, 1, s])
if self.local_window_size is not None:
relative_bias = attentions._padded_slice( # pylint: disable=protected-access
relative_bias,
time_step - self.local_window_size[0],
self.local_window_size[0] + 1 + self.local_window_size[1],
-1,
0.0,
)
asserts.in_set(relative_bias.shape[0], [1, b])
relative_bias = jnp.squeeze(relative_bias, axis=2)
logits += relative_bias
Expand Down

0 comments on commit 91c0996

Please sign in to comment.