-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmodels.py
executable file
·69 lines (52 loc) · 1.84 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import copy
from dataclasses import dataclass
from typing import Any, Dict
import torch.nn as nn
from torchvision import transforms
from torchvision.transforms import ToTensor
from mlproject.decorators import configurable
@dataclass
class ModelAndTransform:
model: nn.Module
transform: Any
@configurable
def build_model(
model_name: str = "google/vit-base-patch16-224-in21k",
pretrained: bool = True,
num_classes: int = 100,
):
from transformers import ViTFeatureExtractor, ViTForImageClassification
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model: nn.Module = ViTForImageClassification.from_pretrained(
model_name, num_labels=num_classes
)
if not pretrained:
model.init_weights()
transform = lambda image: feature_extractor(
images=image, return_tensors="pt"
)
class Convert1ChannelTo3Channel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
temp = None
if hasattr(x, "pixel_values"):
temp = copy.copy(x)
x = x["pixel_values"]
x = ToTensor()(x)
if len(x.shape) == 3 and x.shape[0] == 1:
x = x.repeat([3, 1, 1])
elif len(x.shape) == 4 and x.shape[1] == 1:
x = x.repeat([1, 3, 1, 1])
if temp is not None:
temp["pixel_values"] = x
x = temp
return x
pre_transform = Convert1ChannelTo3Channel()
def transform_wrapper(input_dict: Dict):
input_dict["image"][0] = pre_transform(input_dict["image"][0])
return {
"pixel_values": transform(input_dict["image"])["pixel_values"],
"labels": input_dict["labels"],
}
return ModelAndTransform(model=model, transform=transform_wrapper)