File size: 2,454 Bytes
7f2690b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import sys
import json
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm
from dataset import GreatestHit, AMT_test
import torch
import torch.nn as nn
from metrics import metrics
from omegaconf import OmegaConf
from model import VGGishish
from transforms import Crop, StandardNormalizeAudio, ToTensor


if __name__ == '__main__':
    cfg_cli = sys.argv[1]
    target_path = sys.argv[2]
    model_path = sys.argv[3]
    cfg_yml = OmegaConf.load(cfg_cli)
    # the latter arguments are prioritized
    cfg = cfg_yml
    OmegaConf.set_readonly(cfg, True)
    # print(OmegaConf.to_yaml(cfg))

    device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
    transforms = [
        StandardNormalizeAudio(cfg.mels_path),
    ]
    if cfg.cropped_size not in [None, 'None', 'none']:
        transforms.append(Crop(cfg.cropped_size))
    transforms.append(ToTensor())
    transforms = torchvision.transforms.transforms.Compose(transforms)

    testset = AMT_test(target_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only)
    loader = DataLoader(testset, batch_size=cfg.batch_size,
                        num_workers=cfg.num_workers, pin_memory=True)

    model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(testset.label2target))
    ckpt = torch.load(model_path)['model']
    model.load_state_dict(ckpt, strict=True)
    model = model.to(device)

    model.eval()

    if cfg.cls_weights_in_loss:
        weights = 1 / testset.class_counts
    else:
        weights = torch.ones(len(testset.label2target))

    preds_from_each_batch = []
    file_path_from_each_batch = []
    for batch in tqdm(loader):
        inputs = batch['input'].to(device)
        file_path = batch['file_path']
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
        # for metrics calculation later on
        preds_from_each_batch += [outputs.detach().cpu()]
        file_path_from_each_batch += file_path
    preds_from_each_batch = torch.cat(preds_from_each_batch)
    _, preds = torch.topk(preds_from_each_batch, k=1)
    pred_dict = {fp: int(p.item()) for fp, p in zip(file_path_from_each_batch, preds)}
    mel_parent_dir = os.path.dirname(list(pred_dict.keys())[0])
    pred_list = [pred_dict[os.path.join(mel_parent_dir, f'{i}.npy')] for i in range(len(pred_dict))]
    json.dump(pred_list, open(target_path + f'_{cfg.exp_name}_preds.json', 'w'))