-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_clip.py
135 lines (106 loc) · 4.15 KB
/
compute_clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from torch.nn.functional import normalize
import clip
from PIL import Image
import numpy as np
from torchmetrics.multimodal.clip_score import CLIPScore
from sketch2prototype import load_prompt
import os
import pandas as pd
import gc
def get_image(filepath):
return np.asarray(Image.open(filepath))
def compute_clip(image, text, metric):
if not text:
return 0
image_tensor = torch.from_numpy(image)
res = metric(image_tensor, text).detach().item()
del image_tensor
return res
def get_features(image, model, preprocess):
image = preprocess(Image.open(image)).unsqueeze(0).to(device)
image_features = model.encode_image(image)
image_features = normalize(image_features, p=2.0, dim=-1)
return image_features.cpu().detach().numpy()
def compute_image_clip(image1, image2, model, preprocess, device):
image_1_features = get_features(image1, model, preprocess)
image_2_features = get_features(image2, model, preprocess)
similarity = image_2_features @ image_1_features.T
res = similarity.item() * 100
return res
def compute_clip_from_fp(image_fp, text_fp):
image = get_image(image_fp)
image_name = image_fp.split("/")[1]
text = load_prompt(image_name, text_fp)
if type(text) != str:
return 0
return compute_clip(image, text, metric)
def compute_clip_from_sample(sample_fp, text_fp, model=None, preprocess=None):
original_clip = compute_clip_from_fp(f"{sample_fp}/original.png", text_fp)
synthetic_clip = 0
for i in range(4):
synthetic_clip += compute_clip_from_fp(f"{sample_fp}/images/image_{i}.png", text_fp)
# synthetic_clip += compute_image_clip(
# f"{sample_fp}/images/image_{i}.png", f"{sample_fp}/original.png")
synthetic_clip /= 4
return [original_clip, synthetic_clip]
def pairwise_clip(sample_fp, text_fp, model, preprocess):
clip_vals = 0
clip_samples = 0
for i in range(4):
for j in range(i+1, 4):
clip_samples += 1
clip_vals += compute_image_clip(
f"{sample_fp}/images/image_{i}.png",
f"{sample_fp}/images/image_{j}.png",
model, preprocess, device
)
clip_vals /= clip_samples
return [clip_vals]
def compute_clip_from_dataset(dataset_fp, model, preprocess, ref_file=""):
original_clips = []
sample_clips = []
clip_log = dict()
if not ref_file:
ref_df = set()
else:
ref_df = set(pd.read_csv(ref_file)['Unnamed: 0'])
for fp in os.listdir(dataset_fp):
if 'csv' in fp:
continue
if fp in ref_df:
continue
clips = pairwise_clip(f"{dataset_fp}/{fp}", f"{dataset_fp}/sketch_drawing.csv", model, preprocess)
# clips = [1]
synthetic_clips = compute_clip_from_sample(f"{dataset_fp}/{fp}", f"{dataset_fp}/sketch_drawing.csv")
if synthetic_clips[0] == 0:
continue
if clips[0]:
print (f"{fp}: Pairwise clip: {clips[0]} Original clip: {synthetic_clips[0]} Synthetic clip: {synthetic_clips[1]}")
res = [clips[0], synthetic_clips[0], synthetic_clips[1]]
clip_log[fp] = res
return clip_log
def check_dirs_have_image(input_dir):
no_img = False
for val in os.listdir(input_dir):
if "csv" in val:
continue
if "dalle_response.json" not in os.listdir(f"{input_dir}/{val}"):
print (val)
no_img = True
if no_img:
raise Exception("A directory does not have images")
if __name__ == "__main__":
metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
input_dir = "dataset_full"
device = "cuda" if torch.cuda.is_available() else "cpu"
save_dir = "clip_scores"
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
model, preprocess = clip.load("ViT-B/32", device=device)
check_dirs_have_image(input_dir)
clip_log = compute_clip_from_dataset(input_dir, model, preprocess)
df = pd.DataFrame.from_dict(clip_log, orient="index")
df.columns = ["Pairwise clip", "Original clip", "Synthetic clip"]
df.to_csv(f"{save_dir}/results.csv")
print ("Done")