Skip to content

Commit

Permalink
[nit] Making k/v optional to easily handle self-attention (facebookre…
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored Jul 6, 2021
1 parent fb9f160 commit 5396c9f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,20 @@ def _check(self, t, name):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
att_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expected input dimensions are [batch size, sequence length, embed dim]
Output dimensions are [batch size, sequence length, embed dim]
"""

if key is None:
key = query
if value is None:
value = query

# Check the dimensions properly
self._check(query, "query")
self._check(value, "value")
Expand Down

0 comments on commit 5396c9f

Please sign in to comment.