Skip to content

Commit 23aaa58

Browse files
authored
[fix] FourierMix + AMP (#258)
1 parent 90785c2 commit 23aaa58

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [0.0.x] - TBD
88
### Fixed
99
- Fix some torchscriptability [#246]
10+
- Fix FourierMix being compatible with AMP [#258]
1011

1112
## [0.0.10] - 2022-03-14
1213
### Fixed

tests/test_attentions.py

+5
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ def test_order_invariance(
124124
att_2 = multi_head(inputs, inputs_shuffled, inputs)
125125
assert (att != att_2).any()
126126

127+
# Test AMP, if available
128+
if device.type == "cuda":
129+
with torch.cuda.amp.autocast(enabled=True):
130+
_ = multi_head(inputs, inputs_shuffled, inputs)
131+
127132

128133
@pytest.mark.parametrize("heads", [1, 4])
129134
@pytest.mark.parametrize("attention_name", ["scaled_dot_product"])

xformers/components/attention/fourier_mix.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
from torch.cuda.amp import autocast
78

89
from xformers.components.attention import Attention, AttentionConfig, register_attention
910

@@ -22,7 +23,9 @@ def __init__(self, dropout: float, *_, **__):
2223
self.requires_input_projection = False
2324

2425
def forward(self, q: torch.Tensor, *_, **__):
25-
att = torch.fft.fft2(q).real
26+
# Guard against autocast / fp16, not supported by torch.fft.fft2
27+
with autocast(enabled=False):
28+
att = torch.fft.fft2(q).real
2629

2730
att = self.attn_drop(att)
2831

0 commit comments

Comments
 (0)