diff --git a/tests/test_attention_patterns.py b/tests/test_attention_patterns.py index 0e77fac618..4169452d1f 100644 --- a/tests/test_attention_patterns.py +++ b/tests/test_attention_patterns.py @@ -142,3 +142,21 @@ def test_swin_attention_pattern(H, W, window_size): d_padded = d_padded[s, s, s, s].reshape(H * W, H * W) assert torch.all(d_padded == d_shifted) + + +@pytest.mark.parametrize("k", [2, 3]) +@pytest.mark.parametrize("W", [8, 15]) +@pytest.mark.parametrize("H", [8, 15]) +def test_dilated_2d_pattern(H, W, k): + d = AP.dilated_2d_pattern(H, W, k) + d = d.reshape(H, W, H, W) + for h, w in itertools.product(range(H), range(W)): + i = h % k + j = w % k + # every kth element is taken + assert torch.all(d[h, w][i::k, j::k]) + for ii, jj in itertools.product(range(k), range(k)): + if ii == i and jj == j: + continue + # and the other elements are discarded + assert torch.all(~d[h, w][ii::k, jj::k]) diff --git a/xformers/components/attention/attention_patterns.py b/xformers/components/attention/attention_patterns.py index ebfdafb340..9707e11315 100644 --- a/xformers/components/attention/attention_patterns.py +++ b/xformers/components/attention/attention_patterns.py @@ -141,3 +141,15 @@ def swin_attention_pattern(H, W, window_size, shift_size=0): anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1) mask = anchor_id[:, None] == anchor_id[None, :] return mask + + +def dilated_2d_pattern(H, W, k=2): + """ + Returns a 2d pattern that samples 1 every k elements in the attention mask. + Can be seen as a form of downsampling, where every pixel attends to a downsampled + version of the input. + """ + d_h = local_nd_distance(H, W, p=1, weights=(1, 0)) + d_w = local_nd_distance(H, W, p=1, weights=(0, 1)) + d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0) + return d