-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathimage.py
66 lines (54 loc) · 2.13 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from dataclasses import dataclass
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers.models.vit.modeling_vit import ViTModel
from utils import BaseModule
class DINOSingleImageTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: str = "facebook/dino-vitb16"
enable_gradient_checkpointing: bool = False
cfg: Config
def configure(self) -> None:
self.model: ViTModel = ViTModel(
ViTModel.config_class.from_pretrained(
hf_hub_download(
repo_id=self.cfg.pretrained_model_name_or_path,
filename="config.json",
)
)
)
if self.cfg.enable_gradient_checkpointing:
self.model.encoder.gradient_checkpointing = True
self.register_buffer(
"image_mean",
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
self.register_buffer(
"image_std",
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
packed = False
if images.ndim == 4:
packed = True
images = images.unsqueeze(1)
batch_size, n_input_views = images.shape[:2]
images = (images - self.image_mean) / self.image_std
out = self.model(
rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
)
local_features, global_features = out.last_hidden_state, out.pooler_output
local_features = local_features.permute(0, 2, 1)
local_features = rearrange(
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
)
if packed:
local_features = local_features.squeeze(1)
return local_features
def detokenize(self, *args, **kwargs):
raise NotImplementedError