Skip to content

Commit

Permalink
Add dilated 2d pattern (facebookresearch#203)
Browse files Browse the repository at this point in the history
* Adding checkerboard pattern

* Adding documentation

* black reformatting

* Revert "black reformatting"

This reverts commit 032cf331495f63e1439c8334625220afae2d29ea.

* Revert changes to notebook

* Refactor implementation to handle some corner cases

Also re-uses already existing methods to simplify things

* Add tests and bugfix

* Bugfix

Co-authored-by: Marta Gazulla <martatintore@devfair0121.h2.fair>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
3 people authored Aug 18, 2021
1 parent b8fe236 commit c339af1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/test_attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,21 @@ def test_swin_attention_pattern(H, W, window_size):
d_padded = d_padded[s, s, s, s].reshape(H * W, H * W)

assert torch.all(d_padded == d_shifted)


@pytest.mark.parametrize("k", [2, 3])
@pytest.mark.parametrize("W", [8, 15])
@pytest.mark.parametrize("H", [8, 15])
def test_dilated_2d_pattern(H, W, k):
d = AP.dilated_2d_pattern(H, W, k)
d = d.reshape(H, W, H, W)
for h, w in itertools.product(range(H), range(W)):
i = h % k
j = w % k
# every kth element is taken
assert torch.all(d[h, w][i::k, j::k])
for ii, jj in itertools.product(range(k), range(k)):
if ii == i and jj == j:
continue
# and the other elements are discarded
assert torch.all(~d[h, w][ii::k, jj::k])
12 changes: 12 additions & 0 deletions xformers/components/attention/attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,15 @@ def swin_attention_pattern(H, W, window_size, shift_size=0):
anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1)
mask = anchor_id[:, None] == anchor_id[None, :]
return mask


def dilated_2d_pattern(H, W, k=2):
"""
Returns a 2d pattern that samples 1 every k elements in the attention mask.
Can be seen as a form of downsampling, where every pixel attends to a downsampled
version of the input.
"""
d_h = local_nd_distance(H, W, p=1, weights=(1, 0))
d_w = local_nd_distance(H, W, p=1, weights=(0, 1))
d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0)
return d

0 comments on commit c339af1

Please sign in to comment.