File size: 5,072 Bytes
9b0d6c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import librosa
import torch

from data_util import audioset_classes
from helpers.decode import batched_decode_preds
from helpers.encode import ManyHotEncoder
from models.atstframe.ATSTF_wrapper import ATSTWrapper
from models.beats.BEATs_wrapper import BEATsWrapper
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
from models.m2d.M2D_wrapper import M2DWrapper
from models.asit.ASIT_wrapper import ASiTWrapper
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
from models.prediction_wrapper import PredictionsWrapper
from models.frame_mn.utils import NAME_TO_WIDTH


def sound_event_detection(args):
    """
    Running Sound Event Detection on an audio clip.
    """
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
    model_name = args.model_name

    if model_name == "BEATs":
        beats = BEATsWrapper()
        model = PredictionsWrapper(beats, checkpoint="BEATs_strong_1")
    elif model_name == "ATST-F":
        atst = ATSTWrapper()
        model = PredictionsWrapper(atst, checkpoint="ATST-F_strong_1")
    elif model_name == "fpasst":
        fpasst = FPaSSTWrapper()
        model = PredictionsWrapper(fpasst, checkpoint="fpasst_strong_1")
    elif model_name == "M2D":
        m2d = M2DWrapper()
        model = PredictionsWrapper(m2d, checkpoint="M2D_strong_1", embed_dim=m2d.m2d.cfg.feature_d)
    elif model_name == "ASIT":
        asit = ASiTWrapper()
        model = PredictionsWrapper(asit, checkpoint="ASIT_strong_1")
    elif model_name.startswith("frame_mn"):
        width = NAME_TO_WIDTH(model_name)
        frame_mn = FrameMNWrapper(width)
        embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
        model = PredictionsWrapper(frame_mn, checkpoint=f"{model_name}_strong_1", embed_dim=embed_dim)
    else:
        raise NotImplementedError(f"Model {model_name} not (yet) implemented")

    model.eval()
    model.to(device)

    sample_rate = 16_000  # all our models are trained on 16 kHz audio
    segment_duration = 10  # all models are trained on 10-second pieces
    segment_samples = segment_duration * sample_rate

    # load audio
    (waveform, _) = librosa.core.load(args.audio_file, sr=sample_rate, mono=True)
    waveform = torch.from_numpy(waveform[None, :]).to(device)
    waveform_len = waveform.shape[1]

    audio_len = waveform_len / sample_rate  # in seconds
    print("Audio length (seconds): ", audio_len)

    # encoder manages decoding of model predictions into dataframes
    # containing event labels, onsets and offsets
    encoder = ManyHotEncoder(audioset_classes.as_strong_train_classes, audio_len=audio_len)

    # split audio file into 10-second chunks
    num_chunks = waveform_len // segment_samples + (waveform_len % segment_samples != 0)
    all_predictions = []

    # Process each 10-second chunk
    for i in range(num_chunks):
        start_idx = i * segment_samples
        end_idx = min((i + 1) * segment_samples, waveform_len)
        waveform_chunk = waveform[:, start_idx:end_idx]

        # Pad the last chunk if it's shorter than 10 seconds
        if waveform_chunk.shape[1] < segment_samples:
            pad_size = segment_samples - waveform_chunk.shape[1]
            waveform_chunk = torch.nn.functional.pad(waveform_chunk, (0, pad_size))

        # Run inference for each chunk
        with torch.no_grad():
            mel = model.mel_forward(waveform_chunk)
            y_strong, _ = model(mel)

        # Collect predictions
        all_predictions.append(y_strong)

    # Concatenate all predictions along the time axis
    y_strong = torch.cat(all_predictions, dim=2)
    # convert into probabilities
    y_strong = torch.sigmoid(y_strong)

    (
        scores_unprocessed,
        scores_postprocessed,
        decoded_predictions
    ) = batched_decode_preds(
        y_strong.float(),
        [args.audio_file],
        encoder,
        median_filter=args.median_window,
        thresholds=args.detection_thresholds,
    )

    for th in decoded_predictions:
        print("***************************************")
        print(f"Detected events using threshold {th}:")
        print(decoded_predictions[th].sort_values(by="onset"))
        print("***************************************")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Example of parser. ')
    # model names: [BEATs, ASIT, ATST-F, fpasst, M2D]
    parser.add_argument('--model_name', type=str, default='BEATs')
    parser.add_argument('--audio_file', type=str,
                        default='test_files/752547__iscence__milan_metro_coming_in_station.wav')
    parser.add_argument('--detection_thresholds', type=float, default=(0.1, 0.2, 0.5))
    parser.add_argument('--median_window', type=float, default=9)
    parser.add_argument('--cuda', action='store_true', default=False)
    args = parser.parse_args()

    assert args.model_name in ["BEATs", "ASIT", "ATST-F", "fpasst", "M2D"] or args.model_name.startswith("frame_mn")
    sound_event_detection(args)