-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathr_precision.py
30 lines (21 loc) · 1.14 KB
/
r_precision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import argparse
import sys
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--text', default="", type=str, help="text prompt")
parser.add_argument('--workspace', default="trial", type=str, help="text prompt")
parser.add_argument('--latest', default='ep0001', type=str, help="which epoch result you want to use for image path")
parser.add_argument('--mode', default='rgb', type=str, help="mode of result, color(rgb) or textureless()")
parser.add_argument('--clip', default="clip-ViT-B-32", type=str, help="CLIP model to encode the img and prompt")
opt = parser.parse_args()
#Load CLIP model
model = SentenceTransformer(f'{opt.clip}')
#Encode an image:
img_emb = model.encode(Image.open(f'../results/{opt.workspace}/validation/df_{opt.latest}_0005_{opt.mode}.png'))
#Encode text descriptions
text_emb = model.encode([f'{opt.text}'])
#Compute cosine similarities
cos_scores = util.cos_sim(img_emb, text_emb)
print("The final CLIP R-Precision is:", cos_scores[0][0].cpu().numpy())