Spaces:
Sleeping
Sleeping
import argparse | |
import librosa | |
import torch | |
from data_util import audioset_classes | |
from helpers.decode import batched_decode_preds | |
from helpers.encode import ManyHotEncoder | |
from models.atstframe.ATSTF_wrapper import ATSTWrapper | |
from models.beats.BEATs_wrapper import BEATsWrapper | |
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper | |
from models.m2d.M2D_wrapper import M2DWrapper | |
from models.asit.ASIT_wrapper import ASiTWrapper | |
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper | |
from models.prediction_wrapper import PredictionsWrapper | |
from models.frame_mn.utils import NAME_TO_WIDTH | |
def sound_event_detection(args): | |
""" | |
Running Sound Event Detection on an audio clip. | |
""" | |
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') | |
model_name = args.model_name | |
if model_name == "BEATs": | |
beats = BEATsWrapper() | |
model = PredictionsWrapper(beats, checkpoint="BEATs_strong_1") | |
elif model_name == "ATST-F": | |
atst = ATSTWrapper() | |
model = PredictionsWrapper(atst, checkpoint="ATST-F_strong_1") | |
elif model_name == "fpasst": | |
fpasst = FPaSSTWrapper() | |
model = PredictionsWrapper(fpasst, checkpoint="fpasst_strong_1") | |
elif model_name == "M2D": | |
m2d = M2DWrapper() | |
model = PredictionsWrapper(m2d, checkpoint="M2D_strong_1", embed_dim=m2d.m2d.cfg.feature_d) | |
elif model_name == "ASIT": | |
asit = ASiTWrapper() | |
model = PredictionsWrapper(asit, checkpoint="ASIT_strong_1") | |
elif model_name.startswith("frame_mn"): | |
width = NAME_TO_WIDTH(model_name) | |
frame_mn = FrameMNWrapper(width) | |
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0] | |
model = PredictionsWrapper(frame_mn, checkpoint=f"{model_name}_strong_1", embed_dim=embed_dim) | |
else: | |
raise NotImplementedError(f"Model {model_name} not (yet) implemented") | |
model.eval() | |
model.to(device) | |
sample_rate = 16_000 # all our models are trained on 16 kHz audio | |
segment_duration = 10 # all models are trained on 10-second pieces | |
segment_samples = segment_duration * sample_rate | |
# load audio | |
(waveform, _) = librosa.core.load(args.audio_file, sr=sample_rate, mono=True) | |
waveform = torch.from_numpy(waveform[None, :]).to(device) | |
waveform_len = waveform.shape[1] | |
audio_len = waveform_len / sample_rate # in seconds | |
print("Audio length (seconds): ", audio_len) | |
# encoder manages decoding of model predictions into dataframes | |
# containing event labels, onsets and offsets | |
encoder = ManyHotEncoder(audioset_classes.as_strong_train_classes, audio_len=audio_len) | |
# split audio file into 10-second chunks | |
num_chunks = waveform_len // segment_samples + (waveform_len % segment_samples != 0) | |
all_predictions = [] | |
# Process each 10-second chunk | |
for i in range(num_chunks): | |
start_idx = i * segment_samples | |
end_idx = min((i + 1) * segment_samples, waveform_len) | |
waveform_chunk = waveform[:, start_idx:end_idx] | |
# Pad the last chunk if it's shorter than 10 seconds | |
if waveform_chunk.shape[1] < segment_samples: | |
pad_size = segment_samples - waveform_chunk.shape[1] | |
waveform_chunk = torch.nn.functional.pad(waveform_chunk, (0, pad_size)) | |
# Run inference for each chunk | |
with torch.no_grad(): | |
mel = model.mel_forward(waveform_chunk) | |
y_strong, _ = model(mel) | |
# Collect predictions | |
all_predictions.append(y_strong) | |
# Concatenate all predictions along the time axis | |
y_strong = torch.cat(all_predictions, dim=2) | |
# convert into probabilities | |
y_strong = torch.sigmoid(y_strong) | |
( | |
scores_unprocessed, | |
scores_postprocessed, | |
decoded_predictions | |
) = batched_decode_preds( | |
y_strong.float(), | |
[args.audio_file], | |
encoder, | |
median_filter=args.median_window, | |
thresholds=args.detection_thresholds, | |
) | |
for th in decoded_predictions: | |
print("***************************************") | |
print(f"Detected events using threshold {th}:") | |
print(decoded_predictions[th].sort_values(by="onset")) | |
print("***************************************") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Example of parser. ') | |
# model names: [BEATs, ASIT, ATST-F, fpasst, M2D] | |
parser.add_argument('--model_name', type=str, default='BEATs') | |
parser.add_argument('--audio_file', type=str, | |
default='test_files/752547__iscence__milan_metro_coming_in_station.wav') | |
parser.add_argument('--detection_thresholds', type=float, default=(0.1, 0.2, 0.5)) | |
parser.add_argument('--median_window', type=float, default=9) | |
parser.add_argument('--cuda', action='store_true', default=False) | |
args = parser.parse_args() | |
assert args.model_name in ["BEATs", "ASIT", "ATST-F", "fpasst", "M2D"] or args.model_name.startswith("frame_mn") | |
sound_event_detection(args) | |