import os import random import torch import torchaudio import torchaudio.transforms as AT import csv import numpy as np import librosa import pandas as pd import laion_clap import soundfile as sf from model.CLAPSep import LightningModule from model.CLAPSep_decoder import HTSAT_Decoder import argparse import pytorch_lightning as pl from helpers import utils as local_utils class AudioCapsTest(torch.utils.data.Dataset): # type: ignore def __init__(self, eval_csv, input_dir, sr=32000, resample_rate=48000): self.data_path = input_dir self.data_names = [] self.data_caps = [] self.noise_names = [] self.noise_caps = [] with open(eval_csv, 'r') as d: reader = csv.reader(d, skipinitialspace=True) next(reader) for row in reader: self.data_names.append(row[0]) self.data_caps.append(row[1]) self.noise_names.append(row[2]) self.noise_caps.append(row[3]) if resample_rate is not None: self.resampler = AT.Resample(sr, resample_rate) self.sr = sr self.resample_rate = resample_rate else: self.sr = sr def __len__(self): return len(self.data_names) def load_wav(self, path): max_length = self.sr * 10 wav = librosa.core.load(path, sr=self.sr)[0] if len(wav) > max_length: wav = wav[0:max_length] # pad audio to max length, 10s for AudioCaps if len(wav) < max_length: # audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant') wav = np.pad(wav, (0, max_length - len(wav)), 'constant') return wav def __getitem__(self, idx): tgt_name = self.data_names[idx] noise_name = self.noise_names[idx] tgt_cap = self.data_caps[idx] neg_cap = self.noise_caps[idx] assert noise_name != tgt_name snr = torch.ones((1,)) * 0 tgt = torch.tensor(self.load_wav(os.path.join(self.data_path, tgt_name))).unsqueeze(0) noise = torch.tensor(self.load_wav(os.path.join(self.data_path, noise_name))).unsqueeze(0) mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) max_value = torch.max(torch.abs(mixed)) if max_value > 1: tgt *= 0.9 / max_value mixed *= 0.9 / max_value tgt = tgt.squeeze() mixed = mixed.squeeze() return mixed, self.resampler(mixed), tgt_cap, neg_cap, tgt def main(args): torch.set_float32_matmul_precision('highest') # Load dataset data_test = AudioCapsTest(eval_csv=args.eval_csv, input_dir=args.input_dir, sr=args.sample_rate, resample_rate=48000) test_loader = torch.utils.data.DataLoader(data_test, batch_size=1, num_workers=1, pin_memory=True, shuffle=False) clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') clap_model.load_ckpt(args.clap_path) decoder = HTSAT_Decoder(**args.model) lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'], use_lora=args.lora, rank=args.lora_rank, nfft=args.nfft) distributed_backend = "ddp" trainer = pl.Trainer( default_root_dir=os.path.join(args.exp_dir, 'checkpoint'), devices=args.gpu_ids if args.use_cuda else "auto", accelerator="gpu" if args.use_cuda else "cpu", benchmark=False, gradient_clip_val=5.0, precision='bf16-mixed', limit_train_batches=1.0, max_epochs=args.epochs, strategy=distributed_backend, logger=False ) # weight = torch.load(args.ckpt_path, map_location="cpu") # lightning_module.load_state_dict(weight, strict=False) # trainer.test(model=lightning_module, dataloaders=test_loader) trainer.test(model=lightning_module, dataloaders=test_loader, ckpt_path=args.ckpt_path) if __name__ == '__main__': parser = argparse.ArgumentParser() # Data Params parser.add_argument('exp_dir', type=str, default='experiments', help="Path to save checkpoints and logs.") parser.add_argument('--sample_rate', type=int, default=32000) parser.add_argument('--ckpt_path', type=str, default='') parser.add_argument('--eval_csv', type=str, default='') parser.add_argument('--input_dir', type=str, default='') parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', help="Whether to use cuda") parser.add_argument('--gpu_ids', nargs='+', type=int, default=None, help="List of GPU ids used for training. " "Eg., --gpu_ids 2 4. All GPUs are used by default.") args = parser.parse_args() # Set the random seed for reproducible experiments pl.seed_everything(114514) # Set up checkpoints if not os.path.exists(args.exp_dir): os.makedirs(args.exp_dir) # Load model and training params params = local_utils.Params(os.path.join(args.exp_dir, 'config.json')) for k, v in params.__dict__.items(): vars(args)[k] = v main(args)