Skip to content

Commit

Permalink
[fix][minor] Q/K/V needs to be divisible by heads, not more (#269)
Browse files Browse the repository at this point in the history
* Fixing #264, thanks @dnnspark
* changelog addendum
* moving the dimension check to post projection
  • Loading branch information
blefaudeux authored Apr 14, 2022
1 parent b7ca410 commit 7242042
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fix some torchscriptability [#246]
- Fix FourierMix being compatible with AMP [#258]
- Better asserts on QKV dimensions [#264]

### Added
- Simplicial Embeddings [#259]
- Mem efficient attention, FW pass [#267]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
20 changes: 10 additions & 10 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,6 @@ def __init__(
if isinstance(self.proj, nn.Linear) and self.proj.bias is not None:
constant_(self.proj.bias, 0.0)

def _check(self, t, name):
assert (
t.shape[2] % self.dim_k == 0
), f"the {name} embeddings need to be divisible by the number of heads"

def forward(
self,
query: torch.Tensor,
Expand All @@ -139,11 +134,6 @@ def forward(
if value is None:
value = query

# Check the dimensions properly
self._check(query, "query")
self._check(value, "value")
self._check(key, "key")

if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]:
max_batch = max((query.shape[0], key.shape[0], value.shape[0]))
query, key, value = map(
Expand Down Expand Up @@ -176,6 +166,16 @@ def forward(
else:
k, q, v = key, query, value

# Check the dimensions properly
def check(t, name):
assert (
t.shape[2] % self.num_heads == 0
), f"the {name} embeddings need to be divisible by the number of heads"

check(q, "projected query")
check(v, "projected value")
check(k, "projected key")

# Optional: rotary embedding, add relative positioning information
if self.rotary_embeddings:
# rotary requires the head dimension
Expand Down

0 comments on commit 7242042

Please sign in to comment.