|
import sys
|
|
sys.path.append('.')
|
|
import os
|
|
import torch
|
|
from src.utils.get_model_and_data import get_model_and_data
|
|
from src.parser.visualize import parser
|
|
from src.utils.misc import load_model_wo_clip
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader
|
|
from src.utils.tensors import collate
|
|
import clip
|
|
from src.visualize.visualize import get_gpu_device
|
|
from src.utils.action_label_to_idx import action_label_to_idx
|
|
|
|
if __name__ == '__main__':
|
|
parameters, folder, checkpointname, epoch = parser(checkpoint=True)
|
|
gpu_device = get_gpu_device()
|
|
parameters["device"] = f"cuda:{gpu_device}"
|
|
data_split = 'vald'
|
|
parameters['use_action_cat_as_text_labels'] = True
|
|
parameters['only_60_classes'] = True
|
|
|
|
TOP_K_METRIC = 5
|
|
|
|
model, datasets = get_model_and_data(parameters, split=data_split)
|
|
dataset = datasets["train"]
|
|
|
|
print("Restore weights..")
|
|
checkpointpath = os.path.join(folder, checkpointname)
|
|
state_dict = torch.load(checkpointpath, map_location=parameters["device"])
|
|
load_model_wo_clip(model, state_dict)
|
|
model.eval()
|
|
|
|
iterator = DataLoader(dataset, batch_size=parameters["batch_size"],
|
|
shuffle=False, num_workers=8, collate_fn=collate)
|
|
|
|
action_text_labels = list(action_label_to_idx.keys())
|
|
action_text_labels.sort(key=lambda x: action_label_to_idx[x])
|
|
|
|
texts = clip.tokenize(action_text_labels[:60]).to(model.device)
|
|
classes_text_emb = model.clip_model.encode_text(texts).float()
|
|
|
|
correct_preds_top_5, correct_preds_top_1 = 0,0
|
|
total_samples = 0
|
|
with torch.no_grad():
|
|
for i, batch in tqdm(enumerate(iterator), desc="Computing batch"):
|
|
if isinstance(batch['x'], list):
|
|
continue
|
|
for key in batch.keys():
|
|
if torch.is_tensor(batch[key]):
|
|
batch[key] = batch[key].to(parameters['device'])
|
|
batch = model(batch)
|
|
texts = clip.tokenize(batch['clip_text']).to(model.device)
|
|
batch['clip_text_embed'] = model.clip_model.encode_text(texts).float()
|
|
labels = list(map(lambda x: [action_label_to_idx[cat] for cat in x], batch['all_categories']))
|
|
classes_text_emb_norm = classes_text_emb / classes_text_emb.norm(dim=-1, keepdim=True)
|
|
motion_features_norm = batch['z'] / batch['z'].norm(dim=-1, keepdim=True)
|
|
scores = motion_features_norm @ classes_text_emb_norm.t()
|
|
similarity = (100.0 * motion_features_norm @ classes_text_emb_norm.t()).softmax(dim=-1)
|
|
|
|
total_samples += similarity.shape[0]
|
|
for i in range(similarity.shape[0]):
|
|
values, indices = similarity[i].topk(5)
|
|
|
|
|
|
if any([gt_cat_idx in indices for gt_cat_idx in labels[i]]):
|
|
correct_preds_top_5 += 1
|
|
|
|
|
|
values = values[:1]
|
|
indices = indices[:1]
|
|
if any([gt_cat_idx in indices for gt_cat_idx in labels[i]]):
|
|
correct_preds_top_1 += 1
|
|
|
|
|
|
|
|
print(f"Top-5 Acc. : {100 * correct_preds_top_5 / total_samples:.2f}% ({correct_preds_top_5}/{total_samples})")
|
|
print(f"Top-1 Acc. : {100 * correct_preds_top_1 / total_samples:.2f}% ({correct_preds_top_1}/{total_samples})")
|
|
|