diff --git a/praxis/layers/multi_query_attention.py b/praxis/layers/multi_query_attention.py index 23a074b6..81c844a1 100644 --- a/praxis/layers/multi_query_attention.py +++ b/praxis/layers/multi_query_attention.py @@ -32,7 +32,6 @@ from praxis.layers import embedding_softmax from praxis.layers import stochastics - WeightInit = base_layer.WeightInit WeightHParams = base_layer.WeightHParams template_field = base_layer.template_field @@ -182,9 +181,13 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): right_window_size) for each token. E.g., if local_window_size == (3, 2) and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend to [3, 4, 5, c, 7, 8]. + zero_fully_masked: if True, attention values for fully masked tokens will be + forced to zero. This is particularily useful for cross attentions when + keys are all padded. Note: dconv_qkv and ngrammer are not supported. """ + input_dim: int | dict[str, int] = 0 hidden_dim: int = 0 num_heads: int = 1 @@ -215,6 +218,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): scale_query_by_dim_per_head: bool = False chunked_attn_num_seq_split: int = 1 local_window_size: tuple[int, int] | None = None + zero_fully_masked: bool = False # SPMD partition related params. # @@ -255,6 +259,12 @@ class ActivationSharding(base_layer.BaseLayer.ActivationSharding): bld: SplitDimsMapping = None bd: SplitDimsMapping = None + def _create_rotary_position_emb( + self, layer_tpl: LayerTpl, dim_per_head: int + ) -> None: + pos_emb_p = layer_tpl.clone().set(embedding_dims=dim_per_head) + self.create_child('rotary_position_emb', pos_emb_p) + def setup(self) -> None: wp = self.weight_split_dims_mapping assert self.input_dim, 'input_dim is {}'.format(self.input_dim) @@ -323,10 +333,9 @@ def project_input_kv(input_dim, dim_per_head): ) if self.use_rotary_position_emb: - pos_emb_p = self.rotary_position_emb_tpl.clone().set( - embedding_dims=dim_per_head + self._create_rotary_position_emb( + self.rotary_position_emb_tpl, dim_per_head ) - self.create_child('rotary_position_emb', pos_emb_p) if self.relative_bias_tpl is not None: relative_bias_p = self.relative_bias_tpl.clone() @@ -528,6 +537,14 @@ def _atten_context( # Compute the attention context. encoded = self.pv_einsum('BNTS,BSH->BNTH', probs, value) encoded = encoded.transpose(0, 2, 1, 3) + + if self.zero_fully_masked: + fully_masked = jnp.all( + atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2, + axis=-1, + )[:, 0, :, jnp.newaxis, jnp.newaxis] + encoded *= 1 - fully_masked + encoded = checkpoint_name(encoded, 'context') encoded = self._shard_blnh(encoded) return encoded, probs @@ -543,9 +560,10 @@ def _atten_context_chunked_attn_seq_split( """Computes chunked attention context.""" b, t, n, _ = query.shape _, s, h = value.shape - assert ( - s % self.chunked_attn_num_seq_split == 0 - ), 'The number of attn splits must divide the sequence length' + assert s % self.chunked_attn_num_seq_split == 0, ( + f'The number of attn splits must divide the sequence length {s}:' + f' {self.chunked_attn_num_seq_split}' + ) w = s // self.chunked_attn_num_seq_split query = query.transpose(0, 2, 1, 3) full_encoded = jnp.zeros((b, n, t, h), dtype=value.dtype) @@ -602,6 +620,14 @@ def _atten_context_chunked_attn_seq_split( full_encoded = full_encoded.at[:, :, i * w : (i + 1) * w, :].set(encoded) full_encoded = full_encoded.transpose(0, 2, 1, 3) + + if self.zero_fully_masked: + fully_masked = jnp.all( + atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2, + axis=-1, + )[:, 0, :, jnp.newaxis, jnp.newaxis] + full_encoded *= 1 - fully_masked + full_encoded = checkpoint_name(full_encoded, 'context') full_encoded = self._shard_blnh(full_encoded) full_probs = None @@ -697,6 +723,12 @@ def _dot_atten_one_step( encoded, probs = self._dot_atten_one_step_from_qkv( query, key, value, atten_mask, relative_bias, time_step ) + if self.zero_fully_masked: + fully_masked = jnp.all( + atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2, + axis=-1, + )[..., jnp.newaxis] + encoded *= 1 - fully_masked return self._shard_bnh(encoded), probs else: b, n, h = query.shape @@ -719,7 +751,13 @@ def _dot_atten_one_step( )(v_q, key, value, atten_mask, v_rb, time_step) encoded = self._shard_bnh(jnp.reshape(encoded, (b, n, h))) probs = jnp.reshape(probs, (b, n, -1)) - return encoded, probs + if self.zero_fully_masked: + fully_masked = jnp.all( + atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2, + axis=-1, + )[..., jnp.newaxis] + encoded *= 1 - fully_masked + return encoded, probs def _dot_atten_one_step_from_qkv( self, @@ -731,7 +769,6 @@ def _dot_atten_one_step_from_qkv( time_step: JTensor | None = None, ) -> tuple[JTensor, JTensor]: """_dot_atten_one_step with tensors instead of state names.""" - del time_step # query is 3d. extend_one_step = len(query.shape) == 3 b, s, h = key.shape @@ -744,6 +781,29 @@ def _dot_atten_one_step_from_qkv( asserts.in_set(atten_mask.shape[0], [1, b]) base_layer.assert_has_shape(value, [b, s, h]) + + if self.local_window_size is not None: + l = self.local_window_size[0] + 1 + f = self.local_window_size[0] + 1 + self.local_window_size[1] + + minus_inf = py_utils.get_large_negative_number( + jnp.float32 if atten_mask.dtype == jnp.float64 else atten_mask.dtype + ) + + # the padding is to handle the case where there is no enough data in the + # sequence for local window size, we add padding value for the missing + # data. + key = attentions._padded_slice(key, time_step + 1 - l, f, 1, 0.0) # pylint: disable=protected-access + value = attentions._padded_slice(value, time_step + 1 - l, f, 1, 0.0) # pylint: disable=protected-access + atten_mask = attentions._padded_slice( # pylint: disable=protected-access + atten_mask, time_step + 1 - l, f, -1, minus_inf + ) + + b, f, h = key.shape + asserts.eq(f, self.local_window_size[0] + 1 + self.local_window_size[1]) + base_layer.assert_has_shape(value, [b, f, h]) + asserts.in_set(atten_mask.shape[0], [b, 1]) + query = self._scale_query(query) if extend_one_step: logits = self.qk_einsum('BNH,BSH->BNS', query, key) @@ -756,6 +816,14 @@ def _dot_atten_one_step_from_qkv( 'relative bias.' ) base_layer.assert_has_shape(relative_bias, [-1, -1, 1, s]) + if self.local_window_size is not None: + relative_bias = attentions._padded_slice( # pylint: disable=protected-access + relative_bias, + time_step - self.local_window_size[0], + self.local_window_size[0] + 1 + self.local_window_size[1], + -1, + 0.0, + ) asserts.in_set(relative_bias.shape[0], [1, b]) relative_bias = jnp.squeeze(relative_bias, axis=2) logits += relative_bias