File size: 3,504 Bytes
fe64bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import sys
sys.path.append('.')
import os
import torch
from src.utils.get_model_and_data import get_model_and_data
from src.parser.visualize import parser
from src.utils.misc import load_model_wo_clip
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.utils.tensors import collate
import clip
from src.visualize.visualize import get_gpu_device
from src.utils.action_label_to_idx import action_label_to_idx

if __name__ == '__main__':
    parameters, folder, checkpointname, epoch = parser(checkpoint=True)
    gpu_device = get_gpu_device()
    parameters["device"] = f"cuda:{gpu_device}"
    data_split = 'vald'  # Hardcoded
    parameters['use_action_cat_as_text_labels'] = True
    parameters['only_60_classes'] = True

    TOP_K_METRIC = 5

    model, datasets = get_model_and_data(parameters, split=data_split)
    dataset = datasets["train"]

    print("Restore weights..")
    checkpointpath = os.path.join(folder, checkpointname)
    state_dict = torch.load(checkpointpath, map_location=parameters["device"])
    load_model_wo_clip(model, state_dict)
    model.eval()

    iterator = DataLoader(dataset, batch_size=parameters["batch_size"],
                          shuffle=False, num_workers=8, collate_fn=collate)

    action_text_labels = list(action_label_to_idx.keys())
    action_text_labels.sort(key=lambda x: action_label_to_idx[x])

    texts = clip.tokenize(action_text_labels[:60]).to(model.device)
    classes_text_emb = model.clip_model.encode_text(texts).float()

    correct_preds_top_5, correct_preds_top_1 = 0,0
    total_samples = 0
    with torch.no_grad():
        for i, batch in tqdm(enumerate(iterator), desc="Computing batch"):
            if isinstance(batch['x'], list):
                continue
            for key in batch.keys():
                if torch.is_tensor(batch[key]):
                    batch[key] = batch[key].to(parameters['device'])
            batch = model(batch)
            texts = clip.tokenize(batch['clip_text']).to(model.device)
            batch['clip_text_embed'] = model.clip_model.encode_text(texts).float()
            labels = list(map(lambda x: [action_label_to_idx[cat] for cat in x], batch['all_categories']))
            classes_text_emb_norm = classes_text_emb / classes_text_emb.norm(dim=-1, keepdim=True)
            motion_features_norm = batch['z'] / batch['z'].norm(dim=-1, keepdim=True)
            scores = motion_features_norm @ classes_text_emb_norm.t()
            similarity = (100.0 * motion_features_norm @ classes_text_emb_norm.t()).softmax(dim=-1)

            total_samples += similarity.shape[0]
            for i in range(similarity.shape[0]):
                values, indices = similarity[i].topk(5)

                # TOP-5 CHECK
                if any([gt_cat_idx in indices for gt_cat_idx in labels[i]]):
                    correct_preds_top_5 += 1

                # TOP-1 CHECK
                values = values[:1]
                indices = indices[:1]
                if any([gt_cat_idx in indices for gt_cat_idx in labels[i]]):
                    correct_preds_top_1 += 1

            # print(f"Current Top-5 Acc. : {100 * correct_preds_top_5 / total_samples:.2f}%")

        print(f"Top-5 Acc. : {100 * correct_preds_top_5 / total_samples:.2f}%  ({correct_preds_top_5}/{total_samples})")
        print(f"Top-1 Acc. : {100 * correct_preds_top_1 / total_samples:.2f}%  ({correct_preds_top_1}/{total_samples})")