forked from OPPO-Mente-Lab/Subject-Diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodules.py
79 lines (69 loc) · 2.87 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class FourierEmbedder():
def __init__(self, num_freqs=64, temperature=100):
self.num_freqs = num_freqs
self.temperature = temperature
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
@ torch.no_grad()
def __call__(self, x, cat_dim=-1):
"x: arbitrary shape of tensor. dim: cat dim"
out = []
for freq in self.freq_bands:
out.append(torch.sin(freq*x))
out.append(torch.cos(freq*x))
return torch.cat(out, cat_dim)
class GroundingNet(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
hidden_dim: int,
fourier_freqs=8,
num_token=256,
use_bbox=True
):
super(GroundingNet, self).__init__()
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs*2*4
self.linears_image = MLP(
in_dim=input_dim + self.position_dim, out_dim=output_dim, hidden_dim=hidden_dim, use_residual=False)
self.null_image_feature = torch.nn.Parameter(
torch.zeros([1, 1, num_token, input_dim]))
self.null_position_feature = torch.nn.Parameter(
torch.zeros([1, 1, num_token, self.position_dim]))
self.use_bbox = use_bbox
def forward(self, image_embeddings, image_token_idx_mask, bboxes):
bsz, num_of_objects, _, dim = image_embeddings.size()
image_embeddings = image_embeddings*image_token_idx_mask + \
(~image_token_idx_mask)*self.null_image_feature
xyxy_embedding = self.fourier_embedder(
bboxes).unsqueeze(-2) # B*N*4 --> B*N*C
if not self.use_bbox:
image_token_idx_mask = image_token_idx_mask.sum(
1, keepdim=True) > 1
xyxy_embedding = xyxy_embedding*image_token_idx_mask + \
(~image_token_idx_mask)*self.null_position_feature
xyxy_embedding = xyxy_embedding.reshape(bsz, -1, self.position_dim)
image_embeddings = image_embeddings.reshape(bsz, -1, dim)
objs_image = self.linears_image(
torch.cat([image_embeddings, xyxy_embedding], dim=-1))
return objs_image