ymzhang319's picture
init
7f2690b
import os
import sys
import json
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm
from dataset import GreatestHit, AMT_test
import torch
import torch.nn as nn
from metrics import metrics
from omegaconf import OmegaConf
from model import VGGishish
from transforms import Crop, StandardNormalizeAudio, ToTensor
if __name__ == '__main__':
cfg_cli = sys.argv[1]
target_path = sys.argv[2]
model_path = sys.argv[3]
cfg_yml = OmegaConf.load(cfg_cli)
# the latter arguments are prioritized
cfg = cfg_yml
OmegaConf.set_readonly(cfg, True)
# print(OmegaConf.to_yaml(cfg))
device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')
transforms = [
StandardNormalizeAudio(cfg.mels_path),
]
if cfg.cropped_size not in [None, 'None', 'none']:
transforms.append(Crop(cfg.cropped_size))
transforms.append(ToTensor())
transforms = torchvision.transforms.transforms.Compose(transforms)
testset = AMT_test(target_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only)
loader = DataLoader(testset, batch_size=cfg.batch_size,
num_workers=cfg.num_workers, pin_memory=True)
model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(testset.label2target))
ckpt = torch.load(model_path)['model']
model.load_state_dict(ckpt, strict=True)
model = model.to(device)
model.eval()
if cfg.cls_weights_in_loss:
weights = 1 / testset.class_counts
else:
weights = torch.ones(len(testset.label2target))
preds_from_each_batch = []
file_path_from_each_batch = []
for batch in tqdm(loader):
inputs = batch['input'].to(device)
file_path = batch['file_path']
with torch.set_grad_enabled(False):
outputs = model(inputs)
# for metrics calculation later on
preds_from_each_batch += [outputs.detach().cpu()]
file_path_from_each_batch += file_path
preds_from_each_batch = torch.cat(preds_from_each_batch)
_, preds = torch.topk(preds_from_each_batch, k=1)
pred_dict = {fp: int(p.item()) for fp, p in zip(file_path_from_each_batch, preds)}
mel_parent_dir = os.path.dirname(list(pred_dict.keys())[0])
pred_list = [pred_dict[os.path.join(mel_parent_dir, f'{i}.npy')] for i in range(len(pred_dict))]
json.dump(pred_list, open(target_path + f'_{cfg.exp_name}_preds.json', 'w'))