Spaces:
Sleeping
Sleeping
File size: 2,660 Bytes
9b0d6c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
"""
Code from:
https://github.com/DCASE-REPO/DESED_task
"""
from pathlib import Path
import numpy as np
import pandas as pd
import scipy
from sed_scores_eval.base_modules.scores import create_score_dataframe
def batched_decode_preds(
strong_preds,
filenames,
encoder,
thresholds=[0.5],
median_filter=None,
pad_indx=None,
):
"""Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a
dictionary
Args:
strong_preds: torch.Tensor, batch of strong predictions.
filenames: list, the list of filenames of the current batch.
encoder: ManyHotEncoder object, object used to decode predictions.
thresholds: list, the list of thresholds to be used for predictions.
median_filter: int, the number of frames for which to apply median window (smoothing).
pad_indx: list, the list of indexes which have been used for padding.
Returns:
dict of predictions, each keys is a threshold and the value is the DataFrame of predictions.
"""
# Init a dataframe per threshold
scores_raw = {}
scores_postprocessed = {}
prediction_dfs = {}
for threshold in thresholds:
prediction_dfs[threshold] = pd.DataFrame()
for j in range(strong_preds.shape[0]): # over batches
audio_id = Path(filenames[j]).stem
filename = audio_id + ".wav"
c_scores = strong_preds[j]
if pad_indx is not None:
true_len = int(c_scores.shape[-1] * pad_indx[j].item())
c_scores = c_scores[:true_len]
c_scores = c_scores.transpose(0, 1).detach().cpu().numpy()
scores_raw[audio_id] = create_score_dataframe(
scores=c_scores,
timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
event_classes=encoder.labels,
)
if median_filter is not None:
c_scores = scipy.ndimage.filters.median_filter(c_scores, (median_filter, 1))
scores_postprocessed[audio_id] = create_score_dataframe(
scores=c_scores,
timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
event_classes=encoder.labels,
)
for c_th in thresholds:
pred = c_scores > c_th
pred = encoder.decode_strong(pred)
pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"])
pred["filename"] = filename
prediction_dfs[c_th] = pd.concat(
[prediction_dfs[c_th], pred], ignore_index=True
)
return scores_raw, scores_postprocessed, prediction_dfs
|