-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathefficient_attention.py
55 lines (47 loc) · 2.01 KB
/
efficient_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch import nn
from torch.nn import functional as f
class EfficientAttention(nn.Module):
def __init__(self, in_channels, key_channels, head_count, value_channels):
super().__init__()
self.in_channels = in_channels
self.key_channels = key_channels
self.head_count = head_count
self.value_channels = value_channels
self.keys = nn.Conv2d(in_channels, key_channels, 1)
self.queries = nn.Conv2d(in_channels, key_channels, 1)
self.values = nn.Conv2d(in_channels, value_channels, 1)
self.reprojection = nn.Conv2d(value_channels, in_channels, 1)
def forward(self, input_):
n, _, h, w = input_.size()
keys = self.keys(input_).reshape((n, self.key_channels, h * w))
queries = self.queries(input_).reshape(n, self.key_channels, h * w)
values = self.values(input_).reshape((n, self.value_channels, h * w))
head_key_channels = self.key_channels // self.head_count
head_value_channels = self.value_channels // self.head_count
attended_values = []
for i in range(self.head_count):
key = f.softmax(keys[
:,
i * head_key_channels: (i + 1) * head_key_channels,
:
], dim=2)
query = f.softmax(queries[
:,
i * head_key_channels: (i + 1) * head_key_channels,
:
], dim=1)
value = values[
:,
i * head_value_channels: (i + 1) * head_value_channels,
:
]
context = key @ value.transpose(1, 2)
attended_value = (
context.transpose(1, 2) @ query
).reshape(n, head_value_channels, h, w)
attended_values.append(attended_value)
aggregated_values = torch.cat(attended_values, dim=1)
reprojected_value = self.reprojection(aggregated_values)
attention = reprojected_value + input_
return attention