From 84d8fcdb996bfb9cb6813da4e52d07ec5760137b Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Mon, 3 May 2021 17:06:37 -0400 Subject: [PATCH] [feat] Add causal masking option to Nystrom. (#85) * add causal masking option * minor, caching the causal masks and moving them to device Co-authored-by: Benjamin Lefaudeux --- .isort.cfg | 2 +- xformers/components/attention/nystrom.py | 40 +++++++++++++++++++++--- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 0b4d779478..2161414798 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,2 +1,2 @@ [settings] -known_third_party =matplotlib,pandas,pytest,seaborn,setuptools,sklearn,torch,tqdm +known_third_party =matplotlib,numpy,pandas,pytest,seaborn,setuptools,sklearn,torch,tqdm diff --git a/xformers/components/attention/nystrom.py b/xformers/components/attention/nystrom.py index a05a9ca957..9201674ae8 100644 --- a/xformers/components/attention/nystrom.py +++ b/xformers/components/attention/nystrom.py @@ -18,6 +18,7 @@ class NystromSelfAttentionConfig(AttentionConfig): num_heads Number of heads. num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good approximation according to https://arxiv.org/pdf/2102.03902.pdf. + causal Apply a causal mask, in that the attention cannot be applied to the future. use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse, otherwise use standard torch inverse. pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using @@ -35,6 +36,7 @@ class NystromSelfAttentionConfig(AttentionConfig): num_heads: int num_landmarks: Optional[int] + causal: Optional[bool] pinverse_original_init: Optional[bool] inv_iterations: Optional[int] v_skip_connection: Optional[nn.Module] @@ -50,6 +52,7 @@ def __init__( dropout: float, num_heads: int, num_landmarks: int = 64, + causal: bool = False, use_razavi_pinverse: bool = True, pinverse_original_init: bool = False, inv_iterations: int = 6, # recommended default in paper was 6. @@ -77,6 +80,7 @@ def __init__( self.inv_iterations = inv_iterations self.attn_drop = nn.Dropout(dropout) self.skip_connection = v_skip_connection + self.causal = causal if self.skip_connection is None and conv_kernel_size is not None: self.skip_connection = nn.Conv2d( @@ -88,16 +92,21 @@ def __init__( groups=self.num_heads, ) + # Optional lower triangular masks for causal attention + self.causal_mask_1: Optional[torch.Tensor] = None + self.causal_mask_2: Optional[torch.Tensor] = None + self.causal_mask_3: Optional[torch.Tensor] = None + def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, *args, **kwargs, ): + batched_dim = k.size(0) head_dim = k.size(-1) seq_len = k.size(-2) @@ -106,7 +115,10 @@ def forward( ), "the sequence length needs to be divisible by the number of landmarks" if self.num_landmarks == seq_len: - x = scaled_dot_product_attention(q, k, v, att_mask) + mask = None + if self.causal: + mask = self._tril_mask(batched_dim, seq_len, seq_len) + x = scaled_dot_product_attention(q, k, v, mask) else: q_landmarks = q.reshape( @@ -122,9 +134,24 @@ def forward( head_dim, ).mean(dim=-2) - kernel_1 = scaled_query_key_softmax(q, k_landmarks, None) - kernel_2 = scaled_query_key_softmax(q_landmarks, k_landmarks, None) - kernel_3 = scaled_dot_product_attention(q_landmarks, k, v, None) + if self.causal and self.causal_mask_1 is None: + self.causal_mask_1 = self._tril_mask( + batched_dim, seq_len, self.num_landmarks + ).to(q.device) + self.causal_mask_2 = self._tril_mask( + batched_dim, self.num_landmarks, self.num_landmarks + ).to(q.device) + self.causal_mask_3 = self._tril_mask( + batched_dim, self.num_landmarks, seq_len + ).to(q.device) + + kernel_1 = scaled_query_key_softmax(q, k_landmarks, self.causal_mask_1) + kernel_2 = scaled_query_key_softmax( + q_landmarks, k_landmarks, self.causal_mask_2 + ) + kernel_3 = scaled_dot_product_attention( + q_landmarks, k, v, self.causal_mask_3 + ) kernel_2_inv = ( iterative_pinv( @@ -151,6 +178,9 @@ def forward( x = self.attn_drop(x) return x + def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int): + return torch.tril(torch.ones(dim_1, dim_2, dim_3, dtype=torch.bool), diagonal=0) + @classmethod def from_config(cls, config: AttentionConfig) -> "Attention": return cls(**NystromSelfAttentionConfig.as_patchy_dict(config))