File size: 3,504 Bytes
fe64bad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import sys
sys.path.append('.')
import os
import torch
from src.utils.get_model_and_data import get_model_and_data
from src.parser.visualize import parser
from src.utils.misc import load_model_wo_clip
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.utils.tensors import collate
import clip
from src.visualize.visualize import get_gpu_device
from src.utils.action_label_to_idx import action_label_to_idx
if __name__ == '__main__':
parameters, folder, checkpointname, epoch = parser(checkpoint=True)
gpu_device = get_gpu_device()
parameters["device"] = f"cuda:{gpu_device}"
data_split = 'vald' # Hardcoded
parameters['use_action_cat_as_text_labels'] = True
parameters['only_60_classes'] = True
TOP_K_METRIC = 5
model, datasets = get_model_and_data(parameters, split=data_split)
dataset = datasets["train"]
print("Restore weights..")
checkpointpath = os.path.join(folder, checkpointname)
state_dict = torch.load(checkpointpath, map_location=parameters["device"])
load_model_wo_clip(model, state_dict)
model.eval()
iterator = DataLoader(dataset, batch_size=parameters["batch_size"],
shuffle=False, num_workers=8, collate_fn=collate)
action_text_labels = list(action_label_to_idx.keys())
action_text_labels.sort(key=lambda x: action_label_to_idx[x])
texts = clip.tokenize(action_text_labels[:60]).to(model.device)
classes_text_emb = model.clip_model.encode_text(texts).float()
correct_preds_top_5, correct_preds_top_1 = 0,0
total_samples = 0
with torch.no_grad():
for i, batch in tqdm(enumerate(iterator), desc="Computing batch"):
if isinstance(batch['x'], list):
continue
for key in batch.keys():
if torch.is_tensor(batch[key]):
batch[key] = batch[key].to(parameters['device'])
batch = model(batch)
texts = clip.tokenize(batch['clip_text']).to(model.device)
batch['clip_text_embed'] = model.clip_model.encode_text(texts).float()
labels = list(map(lambda x: [action_label_to_idx[cat] for cat in x], batch['all_categories']))
classes_text_emb_norm = classes_text_emb / classes_text_emb.norm(dim=-1, keepdim=True)
motion_features_norm = batch['z'] / batch['z'].norm(dim=-1, keepdim=True)
scores = motion_features_norm @ classes_text_emb_norm.t()
similarity = (100.0 * motion_features_norm @ classes_text_emb_norm.t()).softmax(dim=-1)
total_samples += similarity.shape[0]
for i in range(similarity.shape[0]):
values, indices = similarity[i].topk(5)
# TOP-5 CHECK
if any([gt_cat_idx in indices for gt_cat_idx in labels[i]]):
correct_preds_top_5 += 1
# TOP-1 CHECK
values = values[:1]
indices = indices[:1]
if any([gt_cat_idx in indices for gt_cat_idx in labels[i]]):
correct_preds_top_1 += 1
# print(f"Current Top-5 Acc. : {100 * correct_preds_top_5 / total_samples:.2f}%")
print(f"Top-5 Acc. : {100 * correct_preds_top_5 / total_samples:.2f}% ({correct_preds_top_5}/{total_samples})")
print(f"Top-1 Acc. : {100 * correct_preds_top_1 / total_samples:.2f}% ({correct_preds_top_1}/{total_samples})")
|