fffiloni's picture
Migrated from GitHub
406f22d verified
import csv
import torch
import numpy as np
import logging
from torch_mir_eval.separation import bss_eval_sources
import fast_bss_eval
from ..losses import (
PITLossWrapper,
pairwise_neg_sisdr,
pairwise_neg_snr,
singlesrc_neg_sisdr,
PairwiseNegSDR,
)
logger = logging.getLogger(__name__)
class MetricsTracker:
def __init__(self, save_file: str = ""):
self.all_sdrs = []
self.all_sdrs_i = []
self.all_sisnrs = []
self.all_sisnrs_i = []
csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]
self.results_csv = open(save_file, "w")
self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
self.writer.writeheader()
self.pit_sisnr = PITLossWrapper(
PairwiseNegSDR("sisdr", zero_mean=False), pit_from="pw_mtx"
)
self.pit_snr = PITLossWrapper(
PairwiseNegSDR("snr", zero_mean=False), pit_from="pw_mtx"
)
def __call__(self, mix, clean, estimate, key):
# sisnr
sisnr = self.pit_sisnr(estimate.unsqueeze(0), clean.unsqueeze(0))
mix = torch.stack([mix] * clean.shape[0], dim=0)
sisnr_baseline = self.pit_sisnr(mix.unsqueeze(0), clean.unsqueeze(0))
sisnr_i = sisnr - sisnr_baseline
# sdr
sdr = -fast_bss_eval.sdr_pit_loss(estimate, clean).mean()
sdr_baseline = -fast_bss_eval.sdr_pit_loss(mix, clean).mean()
sdr_i = sdr - sdr_baseline
# import pdb; pdb.set_trace()
row = {
"snt_id": key,
"sdr": sdr.item(),
"sdr_i": sdr_i.item(),
"si-snr": -sisnr.item(),
"si-snr_i": -sisnr_i.item(),
}
self.writer.writerow(row)
# Metric Accumulation
self.all_sdrs.append(sdr.item())
self.all_sdrs_i.append(sdr_i.item())
self.all_sisnrs.append(-sisnr.item())
self.all_sisnrs_i.append(-sisnr_i.item())
def update(self, ):
return {"sdr_i": np.array(self.all_sdrs_i).mean(),
"si-snr_i": np.array(self.all_sisnrs_i).mean()
}
def final(self,):
row = {
"snt_id": "avg",
"sdr": np.array(self.all_sdrs).mean(),
"sdr_i": np.array(self.all_sdrs_i).mean(),
"si-snr": np.array(self.all_sisnrs).mean(),
"si-snr_i": np.array(self.all_sisnrs_i).mean(),
}
self.writer.writerow(row)
row = {
"snt_id": "std",
"sdr": np.array(self.all_sdrs).std(),
"sdr_i": np.array(self.all_sdrs_i).std(),
"si-snr": np.array(self.all_sisnrs).std(),
"si-snr_i": np.array(self.all_sisnrs_i).std(),
}
self.writer.writerow(row)
self.results_csv.close()