Upload zero_shot_classification.py
Browse files- zero_shot_classification.py +339 -0
zero_shot_classification.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import clip
|
3 |
+
import torch
|
4 |
+
import open_clip
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from torchvision.datasets import CIFAR100
|
8 |
+
from tqdm import tqdm
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
import warnings
|
11 |
+
with warnings.catch_warnings():
|
12 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
13 |
+
with warnings.catch_warnings():
|
14 |
+
warnings.simplefilter(action='ignore', category=UserWarning)
|
15 |
+
import torchvision
|
16 |
+
|
17 |
+
import pandas as pd
|
18 |
+
from pathlib import Path
|
19 |
+
from PIL import Image
|
20 |
+
from torch.utils.data import Dataset, DataLoader
|
21 |
+
import pickle
|
22 |
+
|
23 |
+
|
24 |
+
class FACET(Dataset):
|
25 |
+
"""Face Landmarks dataset."""
|
26 |
+
|
27 |
+
def __init__(self, paths, labels, root_dir, file_extension=".jpg", transform=None):
|
28 |
+
"""
|
29 |
+
Arguments:
|
30 |
+
csv_file (string): Path to the csv file with annotations.
|
31 |
+
root_dir (string): Directory with all the images.
|
32 |
+
transform (callable, optional): Optional transform to be applied
|
33 |
+
on a sample.
|
34 |
+
"""
|
35 |
+
self.fpaths = paths
|
36 |
+
self.extension = file_extension
|
37 |
+
self.labels = labels
|
38 |
+
self.root_dir = root_dir
|
39 |
+
self.transform = transform
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.fpaths)
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
if torch.is_tensor(idx):
|
46 |
+
idx = idx.tolist()
|
47 |
+
|
48 |
+
img_name = os.path.join(self.root_dir,
|
49 |
+
str(self.fpaths[idx])+self.extension)
|
50 |
+
image = self.transform(Image.open(img_name).convert('RGB'))
|
51 |
+
label = self.labels[idx]
|
52 |
+
|
53 |
+
return image, label
|
54 |
+
|
55 |
+
imagenet_templates = [
|
56 |
+
'a bad photo of a {}.',
|
57 |
+
'a photo of many {}.',
|
58 |
+
'a sculpture of a {}.',
|
59 |
+
'a photo of the hard to see {}.',
|
60 |
+
'a low resolution photo of the {}.',
|
61 |
+
'a rendering of a {}.',
|
62 |
+
'graffiti of a {}.',
|
63 |
+
'a bad photo of the {}.',
|
64 |
+
'a cropped photo of the {}.',
|
65 |
+
'a tattoo of a {}.',
|
66 |
+
'the embroidered {}.',
|
67 |
+
'a photo of a hard to see {}.',
|
68 |
+
'a bright photo of a {}.',
|
69 |
+
'a photo of a clean {}.',
|
70 |
+
'a photo of a dirty {}.',
|
71 |
+
'a dark photo of the {}.',
|
72 |
+
'a drawing of a {}.',
|
73 |
+
'a photo of my {}.',
|
74 |
+
'the plastic {}.',
|
75 |
+
'a photo of the cool {}.',
|
76 |
+
'a close-up photo of a {}.',
|
77 |
+
'a black and white photo of the {}.',
|
78 |
+
'a painting of the {}.',
|
79 |
+
'a painting of a {}.',
|
80 |
+
'a pixelated photo of the {}.',
|
81 |
+
'a sculpture of the {}.',
|
82 |
+
'a bright photo of the {}.',
|
83 |
+
'a cropped photo of a {}.',
|
84 |
+
'a plastic {}.',
|
85 |
+
'a photo of the dirty {}.',
|
86 |
+
'a jpeg corrupted photo of a {}.',
|
87 |
+
'a blurry photo of the {}.',
|
88 |
+
'a photo of the {}.',
|
89 |
+
'a good photo of the {}.',
|
90 |
+
'a rendering of the {}.',
|
91 |
+
'a {} in a video game.',
|
92 |
+
'a photo of one {}.',
|
93 |
+
'a doodle of a {}.',
|
94 |
+
'a close-up photo of the {}.',
|
95 |
+
'a photo of a {}.',
|
96 |
+
'the origami {}.',
|
97 |
+
'the {} in a video game.',
|
98 |
+
'a sketch of a {}.',
|
99 |
+
'a doodle of the {}.',
|
100 |
+
'a origami {}.',
|
101 |
+
'a low resolution photo of a {}.',
|
102 |
+
'the toy {}.',
|
103 |
+
'a rendition of the {}.',
|
104 |
+
'a photo of the clean {}.',
|
105 |
+
'a photo of a large {}.',
|
106 |
+
'a rendition of a {}.',
|
107 |
+
'a photo of a nice {}.',
|
108 |
+
'a photo of a weird {}.',
|
109 |
+
'a blurry photo of a {}.',
|
110 |
+
'a cartoon {}.',
|
111 |
+
'art of a {}.',
|
112 |
+
'a sketch of the {}.',
|
113 |
+
'a embroidered {}.',
|
114 |
+
'a pixelated photo of a {}.',
|
115 |
+
'itap of the {}.',
|
116 |
+
'a jpeg corrupted photo of the {}.',
|
117 |
+
'a good photo of a {}.',
|
118 |
+
'a plushie {}.',
|
119 |
+
'a photo of the nice {}.',
|
120 |
+
'a photo of the small {}.',
|
121 |
+
'a photo of the weird {}.',
|
122 |
+
'the cartoon {}.',
|
123 |
+
'art of the {}.',
|
124 |
+
'a drawing of the {}.',
|
125 |
+
'a photo of the large {}.',
|
126 |
+
'a black and white photo of a {}.',
|
127 |
+
'the plushie {}.',
|
128 |
+
'a dark photo of a {}.',
|
129 |
+
'itap of a {}.',
|
130 |
+
'graffiti of the {}.',
|
131 |
+
'a toy {}.',
|
132 |
+
'itap of my {}.',
|
133 |
+
'a photo of a cool {}.',
|
134 |
+
'a photo of a small {}.',
|
135 |
+
'a tattoo of the {}.',
|
136 |
+
]
|
137 |
+
|
138 |
+
models = (
|
139 |
+
# CLIP OpenAI
|
140 |
+
"ViT-B/16",
|
141 |
+
"ViT-B/32",
|
142 |
+
"ViT-L/14",
|
143 |
+
"RN50",
|
144 |
+
"RN101",
|
145 |
+
|
146 |
+
# CLIP OpenCLIP
|
147 |
+
"vit_b_16_400m",
|
148 |
+
"vit_b_16_2b",
|
149 |
+
"vit_l_14_400m",
|
150 |
+
"vit_l_14_2b",
|
151 |
+
"vit_b_32_400m",
|
152 |
+
"vit_b_32_2b",
|
153 |
+
)
|
154 |
+
weights = (
|
155 |
+
# CLIP OpenAI
|
156 |
+
"OpenAI hub",
|
157 |
+
"OpenAI hub",
|
158 |
+
"OpenAI hub",
|
159 |
+
"OpenAI hub",
|
160 |
+
"OpenAI hub",
|
161 |
+
# CLIP OpenCLIP
|
162 |
+
"OpenCLIP hub",
|
163 |
+
"OpenCLIP hub",
|
164 |
+
"OpenCLIP hub",
|
165 |
+
"OpenCLIP hub",
|
166 |
+
"OpenCLIP hub",
|
167 |
+
"OpenCLIP hub",
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
facet_annotations_file_path = "INSERT_HERE/annotations.csv"
|
172 |
+
facet_root = "?????" # where the in-painted images are stored, the following structure is expected:
|
173 |
+
# facet_root/
|
174 |
+
# facet_paper_skin_ours/
|
175 |
+
# facet_paper_clothes_only/
|
176 |
+
# facet_paper_skin_ours_occupation_prompt/
|
177 |
+
# facet_paper_clothes_only_occupation_prompt/
|
178 |
+
# facet_paper_whole_body/
|
179 |
+
# facet_paper_whole_body_occupation_prompt/
|
180 |
+
|
181 |
+
facet = pd.read_csv(facet_annotations_file_path, header=0).rename(columns={'Unnamed: 0': 'sample_idx'})
|
182 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
183 |
+
|
184 |
+
|
185 |
+
experiments = ["facet_paper_skin_ours", "facet_paper_clothes_only", "facet_paper_skin_ours_occupation_prompt", "facet_paper_clothes_only_occupation_prompt",
|
186 |
+
"facet_paper_whole_body", "facet_paper_whole_body_occupation_prompt"
|
187 |
+
]
|
188 |
+
|
189 |
+
|
190 |
+
for experiment in experiments:
|
191 |
+
for model_name, weight in zip(models, weights):
|
192 |
+
print( "\n\n",model_name, experiment)
|
193 |
+
preprocess = None
|
194 |
+
if model_name == "vit_b_16_400m":
|
195 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
|
196 |
+
clip_src = "OpenCLIP"
|
197 |
+
elif model_name == "vit_b_16_2b":
|
198 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')
|
199 |
+
clip_src = "OpenCLIP"
|
200 |
+
elif model_name == "vit_b_32_400m":
|
201 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
|
202 |
+
clip_src = "OpenCLIP"
|
203 |
+
elif model_name == "vit_b_32_2b":
|
204 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
|
205 |
+
clip_src = "OpenCLIP"
|
206 |
+
elif model_name == "vit_l_14_400m":
|
207 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion400m_e32')
|
208 |
+
clip_src = "OpenCLIP"
|
209 |
+
elif model_name == "vit_l_14_2b":
|
210 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
|
211 |
+
clip_src = "OpenCLIP"
|
212 |
+
elif "ViT" in model_name:
|
213 |
+
model, preprocess = clip.load(model_name, device)
|
214 |
+
clip_src = "OpenAI"
|
215 |
+
elif "RN" in model_name:
|
216 |
+
model, preprocess = clip.load(model_name, device)
|
217 |
+
clip_src = "OpenAI"
|
218 |
+
else:
|
219 |
+
raise NotImplementedError
|
220 |
+
|
221 |
+
model.cuda()
|
222 |
+
|
223 |
+
occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter']
|
224 |
+
|
225 |
+
tokens_occupations = clip.tokenize([f"A photo of a {occupation}" for occupation in occupations]).cuda()
|
226 |
+
|
227 |
+
facet_img_root = facet_save_root + "/" experiment + "/"
|
228 |
+
out_dir = experiment + "_zero_shot"
|
229 |
+
if not os.path.exists(out_dir):
|
230 |
+
os.makedirs(out_dir)
|
231 |
+
|
232 |
+
|
233 |
+
fnames = list(os.listdir(facet_img_root))
|
234 |
+
|
235 |
+
for attribute_value in ["only_original_male", "only_original_female", "original", "male_to_female", "male_to_male", "female_to_female", "female_to_male"]:
|
236 |
+
print(f"----{attribute_value}----")
|
237 |
+
facet = pd.read_csv("../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'}) # Bounding boxes
|
238 |
+
extension = ".png"
|
239 |
+
|
240 |
+
processed_synthetic_samples = set()
|
241 |
+
|
242 |
+
for fname in fnames:
|
243 |
+
bbid, target_attr = fname.split("_")[0], "_".join(fname.split("_")[1:]).split(".")[0]
|
244 |
+
|
245 |
+
if "only" in attribute_value:
|
246 |
+
if target_attr=="original" and bbid not in processed_synthetic_samples:
|
247 |
+
processed_synthetic_samples.add(int(bbid))
|
248 |
+
elif target_attr==attribute_value and bbid not in processed_synthetic_samples:
|
249 |
+
processed_synthetic_samples.add(int(bbid))
|
250 |
+
|
251 |
+
if attribute_value == "only_original_male":
|
252 |
+
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
|
253 |
+
facet = facet[facet.gender_presentation_na != 1]
|
254 |
+
facet = facet[facet.gender_presentation_non_binary != 1]
|
255 |
+
facet = facet[(facet.gender_presentation_masc == 1)]
|
256 |
+
elif attribute_value == "only_original_female":
|
257 |
+
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
|
258 |
+
facet = facet[facet.gender_presentation_na != 1]
|
259 |
+
facet = facet[facet.gender_presentation_non_binary != 1]
|
260 |
+
facet = facet[(facet.gender_presentation_fem == 1)]
|
261 |
+
else:
|
262 |
+
facet = facet[facet.person_id.isin(processed_synthetic_samples)]
|
263 |
+
facet = facet[facet.gender_presentation_na != 1]
|
264 |
+
facet = facet[facet.gender_presentation_non_binary != 1]
|
265 |
+
facet = facet[(facet.gender_presentation_masc == 1) | (facet.gender_presentation_fem == 1)]
|
266 |
+
|
267 |
+
|
268 |
+
facet["class1"] = facet["class1"].apply(lambda val: int(occupations.index(val)))
|
269 |
+
|
270 |
+
bsize = 512
|
271 |
+
predictions = []
|
272 |
+
acc = my_acc = 0
|
273 |
+
n_batches = 0
|
274 |
+
|
275 |
+
def zeroshot_classifier(classnames, templates):
|
276 |
+
with torch.no_grad():
|
277 |
+
zeroshot_weights = []
|
278 |
+
for classname in tqdm(classnames):
|
279 |
+
texts = [template.format(classname) for template in templates] #format with class
|
280 |
+
texts = clip.tokenize(texts).cuda() #tokenize
|
281 |
+
class_embeddings = model.encode_text(texts) #embed with text encoder
|
282 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
283 |
+
class_embedding = class_embeddings.mean(dim=0)
|
284 |
+
class_embedding /= class_embedding.norm()
|
285 |
+
zeroshot_weights.append(class_embedding)
|
286 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
287 |
+
return zeroshot_weights
|
288 |
+
|
289 |
+
|
290 |
+
if "only" in attribute_value:
|
291 |
+
dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_original.png")
|
292 |
+
else:
|
293 |
+
dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_{attribute_value}.png")
|
294 |
+
dataloader = DataLoader(dataset, batch_size=bsize, shuffle=False, num_workers=6, drop_last=False,)
|
295 |
+
|
296 |
+
zeroshot_weights = zeroshot_classifier(occupations[:39], imagenet_templates)
|
297 |
+
|
298 |
+
|
299 |
+
for imgs, labels in tqdm(dataloader):
|
300 |
+
|
301 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
302 |
+
if clip_src == "OpenAI":
|
303 |
+
# CLIP
|
304 |
+
image_features = model.encode_image(imgs.half().cuda())
|
305 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
306 |
+
logits = 100. * image_features @ zeroshot_weights
|
307 |
+
probs = logits.softmax(dim=-1).cpu().numpy()
|
308 |
+
else:
|
309 |
+
# OpenCLIP
|
310 |
+
image_features = model.encode_image(imgs.half().cuda())
|
311 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
312 |
+
probs = (100. * image_features @ zeroshot_weights).softmax(dim=-1).cpu().numpy()
|
313 |
+
|
314 |
+
preds_batch = np.argmax(probs, axis=-1)
|
315 |
+
predictions += preds_batch.tolist()
|
316 |
+
acc += torch.sum(torch.tensor(preds_batch).cuda()==labels.cuda()) / preds_batch.shape[0]
|
317 |
+
n_batches += 1
|
318 |
+
|
319 |
+
|
320 |
+
print(model_name, "acc: ", acc / n_batches, "%")
|
321 |
+
|
322 |
+
results = pd.DataFrame({"person_id": facet.person_id.values,
|
323 |
+
"inpainted_attribute": attribute_value,
|
324 |
+
"age_presentation_young": facet.age_presentation_young.values,
|
325 |
+
"age_presentation_middle": facet.age_presentation_middle.values,
|
326 |
+
"age_presentation_older": facet.age_presentation_older.values,
|
327 |
+
"gender_presentation_fem": facet.gender_presentation_fem.values,
|
328 |
+
"gender_presentation_masc": facet.gender_presentation_masc.values,
|
329 |
+
"gt_class_label": facet.class1.values,
|
330 |
+
"class_predictions": predictions
|
331 |
+
})
|
332 |
+
|
333 |
+
results.to_csv(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_predictions.csv')
|
334 |
+
|
335 |
+
with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_accuracy.txt', "w") as of:
|
336 |
+
of.write(str((acc/n_batches).item()))
|
337 |
+
|
338 |
+
with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}.pkl', "wb") as f:
|
339 |
+
pickle.dump(predictions, f)
|