-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
117 lines (90 loc) · 3.89 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
OpenAI CLIP model from https://github.com/openai/CLIP.
| Copyright 2017-2024, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
import logging
import os
import eta.core.web as etaw
from fiftyone.operators import types
from .zoo import TorchCLIPModelConfig, TorchCLIPModel
logger = logging.getLogger(__name__)
MODEL_URL = "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
TOKENIZER_URL = "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz"
DEFAULT_CLASSES = "aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,diningtable,dog,horse,motorbike,person,pottedplant,sheep,sofa,train,tvmonitor"
def download_model(model_name, model_path):
"""Downloads the model.
Args:
model_name: the name of the model to download, as declared by the
``base_name`` and optional ``version`` fields of the manifest
model_path: the absolute filename or directory to which to download the
model, as declared by the ``base_filename`` field of the manifest
"""
if model_name != "voxel51/clip-vit-base32-torch":
raise ValueError("Unsupported model name '%s'" % model_name)
logger.info("Downloading model...")
etaw.download_file(MODEL_URL, path=model_path)
logger.info("Downloading tokenizer...")
etaw.download_file(TOKENIZER_URL, path=_get_tokenizer_path(model_path))
def load_model(model_name, model_path, text_prompt="A photo of", classes=None):
"""Loads the model.
Args:
model_name: the name of the model to load, as declared by the
``base_name`` and optional ``version`` fields of the manifest
model_path: the absolute filename or directory to which the model was
donwloaded, as declared by the ``base_filename`` field of the
manifest
text_prompt ("A photo of"): the text prompt to use
classes (None): the list of classes to use for zero-shot prediction.
By default, the VOC classes are used
Returns:
a :class:`fiftyone.core.models.Model`
"""
if model_name != "voxel51/clip-vit-base32-torch":
raise ValueError("Unsupported model name '%s'" % model_name)
if classes is None:
classes = DEFAULT_CLASSES.split(",")
config = TorchCLIPModelConfig(
dict(
model_path=model_path,
tokenizer_path=_get_tokenizer_path(model_path),
context_length=77,
text_prompt=text_prompt,
classes=classes,
output_processor_cls="fiftyone.utils.torch.ClassifierOutputProcessor",
image_size=[224, 224],
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
embeddings_layer="visual",
)
)
return TorchCLIPModel(config)
def resolve_input(model_name, ctx):
"""Defines any necessary properties to collect the model's custom
parameters from a user during prompting.
Args:
model_name: the name of the model, as declared by the ``base_name`` and
optional ``version`` fields of the manifest
ctx: an :class:`fiftyone.operators.ExecutionContext`
Returns:
a :class:`fiftyone.operators.types.Property`, or None
"""
if model_name != "voxel51/clip-vit-base32-torch":
raise ValueError("Unsupported model name '%s'" % model_name)
inputs = types.Object()
inputs.list(
"classes",
types.String(),
required=False,
default=None,
label="Zero shot classes",
description=(
"An optional list of custom classes for zero-shot prediction"
),
view=types.AutocompleteView(),
)
return types.Property(inputs)
def _get_tokenizer_path(model_path):
model_dir = os.path.dirname(model_path)
return os.path.join(model_dir, "clip_bpe_simple_vocab_16e6.txt.gz")