""" Code from: https://github.com/DCASE-REPO/DESED_task """ from pathlib import Path import numpy as np import pandas as pd import scipy from sed_scores_eval.base_modules.scores import create_score_dataframe def batched_decode_preds( strong_preds, filenames, encoder, thresholds=[0.5], median_filter=None, pad_indx=None, ): """Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a dictionary Args: strong_preds: torch.Tensor, batch of strong predictions. filenames: list, the list of filenames of the current batch. encoder: ManyHotEncoder object, object used to decode predictions. thresholds: list, the list of thresholds to be used for predictions. median_filter: int, the number of frames for which to apply median window (smoothing). pad_indx: list, the list of indexes which have been used for padding. Returns: dict of predictions, each keys is a threshold and the value is the DataFrame of predictions. """ # Init a dataframe per threshold scores_raw = {} scores_postprocessed = {} prediction_dfs = {} for threshold in thresholds: prediction_dfs[threshold] = pd.DataFrame() for j in range(strong_preds.shape[0]): # over batches audio_id = Path(filenames[j]).stem filename = audio_id + ".wav" c_scores = strong_preds[j] if pad_indx is not None: true_len = int(c_scores.shape[-1] * pad_indx[j].item()) c_scores = c_scores[:true_len] c_scores = c_scores.transpose(0, 1).detach().cpu().numpy() scores_raw[audio_id] = create_score_dataframe( scores=c_scores, timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)), event_classes=encoder.labels, ) if median_filter is not None: c_scores = scipy.ndimage.filters.median_filter(c_scores, (median_filter, 1)) scores_postprocessed[audio_id] = create_score_dataframe( scores=c_scores, timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)), event_classes=encoder.labels, ) for c_th in thresholds: pred = c_scores > c_th pred = encoder.decode_strong(pred) pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) pred["filename"] = filename prediction_dfs[c_th] = pd.concat( [prediction_dfs[c_th], pred], ignore_index=True ) return scores_raw, scores_postprocessed, prediction_dfs