Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import json | |
from torch.utils.data import DataLoader | |
import torchvision | |
from tqdm import tqdm | |
from dataset import GreatestHit, AMT_test | |
import torch | |
import torch.nn as nn | |
from metrics import metrics | |
from omegaconf import OmegaConf | |
from model import VGGishish | |
from transforms import Crop, StandardNormalizeAudio, ToTensor | |
if __name__ == '__main__': | |
cfg_cli = sys.argv[1] | |
target_path = sys.argv[2] | |
model_path = sys.argv[3] | |
cfg_yml = OmegaConf.load(cfg_cli) | |
# the latter arguments are prioritized | |
cfg = cfg_yml | |
OmegaConf.set_readonly(cfg, True) | |
# print(OmegaConf.to_yaml(cfg)) | |
device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu') | |
transforms = [ | |
StandardNormalizeAudio(cfg.mels_path), | |
] | |
if cfg.cropped_size not in [None, 'None', 'none']: | |
transforms.append(Crop(cfg.cropped_size)) | |
transforms.append(ToTensor()) | |
transforms = torchvision.transforms.transforms.Compose(transforms) | |
testset = AMT_test(target_path, transforms, action_only=cfg.action_only, material_only=cfg.material_only) | |
loader = DataLoader(testset, batch_size=cfg.batch_size, | |
num_workers=cfg.num_workers, pin_memory=True) | |
model = VGGishish(cfg.conv_layers, cfg.use_bn, num_classes=len(testset.label2target)) | |
ckpt = torch.load(model_path)['model'] | |
model.load_state_dict(ckpt, strict=True) | |
model = model.to(device) | |
model.eval() | |
if cfg.cls_weights_in_loss: | |
weights = 1 / testset.class_counts | |
else: | |
weights = torch.ones(len(testset.label2target)) | |
preds_from_each_batch = [] | |
file_path_from_each_batch = [] | |
for batch in tqdm(loader): | |
inputs = batch['input'].to(device) | |
file_path = batch['file_path'] | |
with torch.set_grad_enabled(False): | |
outputs = model(inputs) | |
# for metrics calculation later on | |
preds_from_each_batch += [outputs.detach().cpu()] | |
file_path_from_each_batch += file_path | |
preds_from_each_batch = torch.cat(preds_from_each_batch) | |
_, preds = torch.topk(preds_from_each_batch, k=1) | |
pred_dict = {fp: int(p.item()) for fp, p in zip(file_path_from_each_batch, preds)} | |
mel_parent_dir = os.path.dirname(list(pred_dict.keys())[0]) | |
pred_list = [pred_dict[os.path.join(mel_parent_dir, f'{i}.npy')] for i in range(len(pred_dict))] | |
json.dump(pred_list, open(target_path + f'_{cfg.exp_name}_preds.json', 'w')) | |