Spaces:
Running
Running
File size: 4,366 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import csv
import torch
import numpy as np
import logging
# from torch_mir_eval.separation import bss_eval_sources
from ..losses import (
PITLossWrapper,
pairwise_neg_sisdr,
pairwise_neg_snr,
singlesrc_neg_sisdr,
)
logger = logging.getLogger(__name__)
class SPlitMetricsTracker:
def __init__(self, save_file: str = ""):
self.one_all_snrs = []
self.one_all_snrs_i = []
self.one_all_sisnrs = []
self.one_all_sisnrs_i = []
self.two_all_snrs = []
self.two_all_snrs_i = []
self.two_all_sisnrs = []
self.two_all_sisnrs_i = []
csv_columns = [
"snt_id",
"one_snr",
"one_snr_i",
"one_si-snr",
"one_si-snr_i",
"two_snr",
"two_snr_i",
"two_si-snr",
"two_si-snr_i",
]
self.results_csv = open(save_file, "w")
self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
self.writer.writeheader()
self.pit_sisnr = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
self.pit_snr = PITLossWrapper(pairwise_neg_snr, pit_from="pw_mtx")
def __call__(self, mix, clean, estimate, key):
_, ests_np = self.pit_snr(
estimate.unsqueeze(0), clean.unsqueeze(0), return_ests=True
)
# sisnr
two_sisnr = self.pit_sisnr(ests_np[:, 0:2], clean.unsqueeze(0)[:, 0:2])
one_sisnr = self.pit_sisnr(
ests_np[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
)
mix = torch.stack([mix] * clean.shape[0], dim=0)
two_sisnr_baseline = self.pit_sisnr(
mix.unsqueeze(0)[:, 0:2], clean.unsqueeze(0)[:, 0:2]
)
one_sisnr_baseline = self.pit_sisnr(
mix.unsqueeze(0)[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
)
two_sisnr_i = two_sisnr - two_sisnr_baseline
one_sisnr_i = one_sisnr - one_sisnr_baseline
# sdr
two_snr = self.pit_snr(ests_np[:, 0:2], clean.unsqueeze(0)[:, 0:2])
one_snr = self.pit_snr(
ests_np[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
)
two_snr_baseline = self.pit_snr(
mix.unsqueeze(0)[:, 0:2], clean.unsqueeze(0)[:, 0:2]
)
one_snr_baseline = self.pit_snr(
mix.unsqueeze(0)[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
)
two_snr_i = two_snr - two_snr_baseline
one_snr_i = one_snr - one_snr_baseline
row = {
"snt_id": key,
"one_snr": -one_snr.item(),
"one_snr_i": -one_snr_i.item(),
"one_si-snr": -one_sisnr.item(),
"one_si-snr_i": -one_sisnr_i.item(),
"two_snr": -two_snr.item(),
"two_snr_i": -two_snr_i.item(),
"two_si-snr": -two_sisnr.item(),
"two_si-snr_i": -two_sisnr_i.item(),
}
self.writer.writerow(row)
# Metric Accumulation
self.one_all_snrs.append(-one_snr.item())
self.one_all_snrs_i.append(-one_snr_i.item())
self.one_all_sisnrs.append(-one_sisnr.item())
self.one_all_sisnrs_i.append(-one_sisnr_i.item())
self.two_all_snrs.append(-two_snr.item())
self.two_all_snrs_i.append(-two_snr_i.item())
self.two_all_sisnrs.append(-two_sisnr.item())
self.two_all_sisnrs_i.append(-two_sisnr_i.item())
def final(self,):
row = {
"snt_id": "avg",
"one_snr": np.array(self.one_all_snrs).mean(),
"one_snr_i": np.array(self.one_all_snrs_i).mean(),
"one_si-snr": np.array(self.one_all_sisnrs).mean(),
"one_si-snr_i": np.array(self.one_all_sisnrs_i).mean(),
"two_snr": np.array(self.two_all_snrs).mean(),
"two_snr_i": np.array(self.two_all_snrs_i).mean(),
"two_si-snr": np.array(self.two_all_sisnrs).mean(),
"two_si-snr_i": np.array(self.two_all_sisnrs_i).mean(),
}
self.writer.writerow(row)
# logger.info("Mean SISNR is {}".format(row["si-snr"]))
# logger.info("Mean SISNRi is {}".format(row["si-snr_i"]))
# logger.info("Mean SDR is {}".format(row["sdr"]))
# logger.info("Mean SDRi is {}".format(row["sdr_i"]))
self.results_csv.close()
|