-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathpredict.py
117 lines (102 loc) · 4.51 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import tempfile
import sys
sys.path.append('CLIP')
from pathlib import Path
import cog
import argparse
import torch
import clip
from model.ZeroCLIP import CLIPTextGenerator
def perplexity_score(text, lm_model, lm_tokenizer, device):
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
input_ids = encodings.input_ids.to(device)
target_ids = input_ids.clone()
outputs = lm_model(input_ids, labels=target_ids)
log_likelihood = outputs[0]
ll = log_likelihood.item()
return ll
class Predictor(cog.Predictor):
def setup(self):
self.args = get_args()
self.args.reset_context_delta = True
self.text_generator = CLIPTextGenerator(**vars(self.args))
@cog.input(
"image",
type=Path,
help="input image"
)
@cog.input(
"cond_text",
type=str,
default='Image of a',
help="conditional text",
)
@cog.input(
"beam_size",
type=int,
default=5, min=1, max=10,
help="Number of beams to use",
)
@cog.input(
"end_factor",
type=float,
default=1.01, min=1.0, max=1.10,
help="Higher value for shorter captions",
)
@cog.input(
"max_seq_length",
type=int,
default=15, min=1, max=20,
help="Maximum number of tokens to generate",
)
@cog.input(
"ce_loss_scale",
type=float,
default=0.2, min=0.0, max=0.6,
help="Scale of cross-entropy loss with un-shifted language model",
)
def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale):
self.args.cond_text = cond_text
self.text_generator.end_factor = end_factor
self.text_generator.target_seq_length = max_seq_length
self.text_generator.ce_scale = ce_loss_scale
image_features = self.text_generator.get_img_feature([str(image)], None)
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
# CLIP SCORE
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
# Perplexity SCORE
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
best_ppl_index = torch.tensor(ppl_scores).argmin().item()
best_clip_caption = self.args.cond_text + captions[best_clip_idx]
best_mixed = self.args.cond_text + captions[0]
best_PPL = self.args.cond_text + captions[best_ppl_index]
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
return final
# return self.args.cond_text + captions[best_clip_idx]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=5)
args = parser.parse_args('')
return args