Skip to content

Commit

Permalink
Implement local version of relative bias positional encoding, includi…
Browse files Browse the repository at this point in the history
…ng both trainable (DIET-REL) and non-trainable (ALiBi). Note that this change only supports non-autoregressive decoding.

PiperOrigin-RevId: 607519136
  • Loading branch information
puchinc authored and pax authors committed Feb 16, 2024
1 parent 54d6d86 commit dff8206
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 0 deletions.
2 changes: 2 additions & 0 deletions praxis/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from praxis.layers.attentions import DotProductAttention
from praxis.layers.attentions import DotProductAttentionXL
from praxis.layers.attentions import LocalSelfAttention
from praxis.layers.attentions import LocalSelfAttentionAlibi
from praxis.layers.attentions import LocalSelfAttentionRelativeBias
from praxis.layers.attentions import LocalSelfAttentionXL
from praxis.layers.attentions import PerDimScale
from praxis.layers.attentions import RelativeBias
Expand Down
99 changes: 99 additions & 0 deletions praxis/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Attention layers."""

import functools
from functools import partial
import math
import string
from typing import Any, Callable, Mapping, Sequence
Expand Down Expand Up @@ -3511,3 +3512,101 @@ def extend_step(
raise NotImplementedError(
'extend_step is not implemented for %s' % self.__name__
)


class LocalSelfAttentionRelativeBias(LocalSelfAttention):
"""Local version of trainable relative bias position encoding.
See DIET-REL in https://arxiv.org/abs/2104.08698.
"""

def setup(self) -> None:
"""Constructs a LocalSelfAttentionRelativeBias object with fixed pos_emb."""
super().setup()
# Number of possible relative positional distance indices =
# C + (W - 1) =
# [(L - 1) + W + R] + (W - 1) =
# L + R + 2 * (W - 1).
#
# Conceptually, num_positions =
# [- (L - 1) - (W - 1), ..., 0, ..., (W - 1) + R]
w = self.block_size
c = w + self.left_context + self.right_context - 1
num_positions = c + w - 1

pc = WeightHParams(shape=[self.num_heads, num_positions])
self.create_variable('pos_emb_compressed', pc)

def _atten_logits(self, query, key):
b, u, w, n, _ = query.shape[:5]
c = w + self.left_context + self.right_context - 1

# reconstruct the Toeplitz matrix
# -> [N, W, C]
pos_emb_compressed = self.theta.pos_emb_compressed
self.add_summary('pos_emb_compressed', jnp.sum(pos_emb_compressed))
pos_bias = jnp.tile(pos_emb_compressed, [1, w])[:, : w * (w + c - 2)]
pos_bias = jnp.reshape(pos_bias, [n, w, w + c - 2])
pos_bias = pos_bias[..., w - 2 :]

# -> [B, N, U, W, C]
pos_bias = pos_bias[jnp.newaxis, :, jnp.newaxis, :, :]
pos_bias = jnp.broadcast_to(pos_bias, (b, n, u, w, c))

logits = jnp.einsum('buwnh,bucnh->bnuwc', query, key)
logits += pos_bias

return logits


class LocalSelfAttentionAlibi(LocalSelfAttention):
"""Local version of non-trainable relative bias position encoding.
See ALiBi in https://arxiv.org/abs/2108.12409.
"""

def setup(self) -> None:
"""Constructs a LocalSelfAttentionAlibi object with fixed pos_emb."""
super().setup()

def _atten_logits(self, query, key):
b, u, w, n, _ = query.shape[:5]
c = w + self.left_context + self.right_context - 1

# -> [N, W, C]
# reconstruct the Toeplitz matrix
num_pos = c + w - 1
# Assume this will be replaced with variables
abs_pos_indices = jnp.arange(num_pos + 1)
# broadcast to each head to represent "shared" indices value
abs_pos_indices = jnp.broadcast_to(abs_pos_indices, [n, w + c])
pos_bias = jnp.tile(abs_pos_indices, [1, w])[:, : w * (w + c - 1)]
pos_bias = jnp.reshape(pos_bias, [n, w, w + c - 1])
pos_bias = pos_bias[..., w - 1 :]

# Construct ALiBi
# -> [N, W, C]
@partial(jax.jit, static_argnums=0)
def _get_slopes(n_heads):
n = 2 ** np.floor(np.log2(n_heads))
m_0 = 2.0 ** (-8.0 / n)
m = m_0 ** jnp.arange(1, 1 + n_heads)

if n < n_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = m_hat_0 ** jnp.arange(1, 1 + 2 * (n_heads - n), 2)
m = jnp.concatenate([m, m_hat])

return m

m = _get_slopes(n)
alibi = pos_bias * m[:, None, None]

# -> [B, N, U, W, C]
pos_bias = alibi[None, :, None, :, :]
pos_bias = jnp.broadcast_to(pos_bias, (b, n, u, w, c))

logits = jnp.einsum('buwnh,bucnh->bnuwc', query, key)
logits += pos_bias

return logits
166 changes: 166 additions & 0 deletions praxis/layers/attentions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,172 @@ def test_local_attention_fully_masked(self):
]
self.assertNotEqual(np.amin(np.abs(test_utils.to_np(non_masked_out))), 0)

@parameterized.parameters([
(4, 2, 1, True, True),
(4, 2, 1, False, True),
(8, 3, 5, True, False),
(8, 3, 5, False, False),
(5, 4, 0, False, True),
(5, 4, 0, True, True),
])
def test_local_attention_rel_bias(
self,
block_size,
left_context,
right_context,
is_full,
zero_fully_masked,
):
mdl_dim = 16
hidden_dim = 32
num_heads = 4
test_layer_p = pax_fiddle.Config(
attentions.LocalSelfAttentionRelativeBias,
name='rel_bias',
input_dim=mdl_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
block_size=block_size,
left_context=left_context,
right_context=right_context,
zero_fully_masked=zero_fully_masked,
)
layer = instantiate(test_layer_p)

target_batch_size = 3
source_max_length = 16

query_vec = np.random.normal(
size=[target_batch_size, source_max_length, mdl_dim]
).astype(np.float32)
key_vec = np.random.normal(
size=[target_batch_size, source_max_length, mdl_dim]
).astype(np.float32)
value_vec = np.random.normal(
size=[target_batch_size, source_max_length, mdl_dim]
).astype(np.float32)

paddings = range(source_max_length)[-target_batch_size:]
paddings = [[0] * l + [1] * (source_max_length - l) for l in paddings]
paddings = np.array(paddings)
atten_mask = attentions.convert_paddings_to_mask(paddings, np.float32)
if is_full:
atten_mask = jnp.tile(atten_mask, [1, 1, source_max_length, 1])

with base_layer.JaxContext.new_context():
prng_key = jax.random.PRNGKey(seed=123)
prng_key, init_key = jax.random.split(prng_key)
initial_vars = layer.init(
init_key, query_vec, key_vec, value_vec, atten_mask
)
jax_fprop_out, jax_atten_prob = layer.apply(
initial_vars, query_vec, key_vec, value_vec, atten_mask
)

self.assertEqual(
jax_fprop_out.shape, (target_batch_size, source_max_length, mdl_dim)
)

# -> [B, U, C, ...]
key_block_context = attentions.extract_block_context(
key_vec,
block_size=block_size,
left_context=left_context,
right_context=right_context,
)
_, u, c, _ = key_block_context.shape

# -> [B, U, W, ...]
query_blocks = attentions.convert_to_block(query_vec, block_size=block_size)
_, _, w, _ = query_blocks.shape

self.assertEqual(
jax_atten_prob.shape, (target_batch_size, num_heads, u, w, c)
)

@parameterized.parameters([
(4, 2, 1, True, True),
(4, 2, 1, False, True),
(8, 3, 5, True, False),
(8, 3, 5, False, False),
(5, 4, 0, False, True),
(5, 4, 0, True, True),
])
def test_local_attention_alibi(
self,
block_size,
left_context,
right_context,
is_full,
zero_fully_masked,
):
mdl_dim = 16
hidden_dim = 32
num_heads = 4
test_layer_p = pax_fiddle.Config(
attentions.LocalSelfAttentionAlibi,
name='alibi',
input_dim=mdl_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
block_size=block_size,
left_context=left_context,
right_context=right_context,
zero_fully_masked=zero_fully_masked,
)
layer = instantiate(test_layer_p)

target_batch_size = 3
source_max_length = 16

query_vec = np.random.normal(
size=[target_batch_size, source_max_length, mdl_dim]
).astype(np.float32)
key_vec = np.random.normal(
size=[target_batch_size, source_max_length, mdl_dim]
).astype(np.float32)
value_vec = np.random.normal(
size=[target_batch_size, source_max_length, mdl_dim]
).astype(np.float32)

paddings = range(source_max_length)[-target_batch_size:]
paddings = [[0] * l + [1] * (source_max_length - l) for l in paddings]
paddings = np.array(paddings)
atten_mask = attentions.convert_paddings_to_mask(paddings, np.float32)
if is_full:
atten_mask = jnp.tile(atten_mask, [1, 1, source_max_length, 1])

with base_layer.JaxContext.new_context():
prng_key = jax.random.PRNGKey(seed=123)
prng_key, init_key = jax.random.split(prng_key)
initial_vars = layer.init(
init_key, query_vec, key_vec, value_vec, atten_mask
)
jax_fprop_out, jax_atten_prob = layer.apply(
initial_vars, query_vec, key_vec, value_vec, atten_mask
)

self.assertEqual(
jax_fprop_out.shape, (target_batch_size, source_max_length, mdl_dim)
)

# -> [B, U, C, ...]
key_block_context = attentions.extract_block_context(
key_vec,
block_size=block_size,
left_context=left_context,
right_context=right_context,
)
_, u, c, _ = key_block_context.shape

# -> [B, U, W, ...]
query_blocks = attentions.convert_to_block(query_vec, block_size=block_size)
_, _, w, _ = query_blocks.shape

self.assertEqual(
jax_atten_prob.shape, (target_batch_size, num_heads, u, w, c)
)

@parameterized.parameters(
([1, 2, 3, 4, 5], 1, 0, [0, 1, 2, 3, 4]),
([1, 2, 3, 4, 5], -1, 0, [2, 3, 4, 5, 0]),
Expand Down

0 comments on commit dff8206

Please sign in to comment.