Fine-tuned Swin Base Model for Gunshot Detection
Fine-tuned version of microsoft/swin-base-patch4-window7-224
for gunshot detection using spectrograms.
Model Details
- Trained by: Ranabir Saha
- Fine-tuned on: Tropical forest gunshot classification training audio dataset from Automated detection of gunshots in tropical forests using convolutional neural networks (Katsis et al. 2022)
- Dataset Source: https://doi.org/10.17632/x48cwz364j.3
- Input: 4-second .wav audio files resampled to 8000 Hz, converted to 224x224 log-mel spectrograms with 3 channels (RGB-like) during preprocessing
- Output: Binary classification (Background/Gunshot)
Usage
The model expects spectrogram images as input. Below is an example of how to preprocess an audio file and use the model:
import torchaudio
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
from transformers import pipeline, AutoImageProcessor
# Preprocess audio to spectrogram
SAMPLE_RATE = 8000
TARGET_LENGTH = 4
def preprocess_wav_to_spectrogram(file_path):
waveform, orig_sample_rate = torchaudio.load(file_path)
if orig_sample_rate != SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(orig_sample_rate, SAMPLE_RATE)
waveform = resampler(waveform)
target_length_samples = TARGET_LENGTH * SAMPLE_RATE
if waveform.shape[1] > target_length_samples:
waveform = waveform[:, :target_length_samples]
else:
padding = target_length_samples - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, (0, padding))
mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_fft=512,
hop_length=160,
n_mels=64,
power=1.0
)(waveform)
log_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)
log_spec = (log_spec + 80) / 80 # Normalize to [0,1]
log_spec = log_spec.repeat(3, 1, 1)
log_spec = F.interpolate(log_spec.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
log_spec_np = log_spec.permute(1, 2, 0).numpy()
log_spec_np = (log_spec_np * 255).astype(np.uint8)
pil_image = Image.fromarray(log_spec_np, mode='RGB')
return pil_image
# Load model and processor
image_processor = AutoImageProcessor.from_pretrained("ranvir-not-found/swin-wda_gunshot-detection")
classifier = pipeline(
"image-classification",
model="ranvir-not-found/swin-wda_gunshot-detection",
image_processor=image_processor
)
# Process audio and classify
audio_path = "path/to/your/audio.wav"
spectrogram = preprocess_wav_to_spectrogram(audio_path)
results = classifier(spectrogram)
print(results)
Training Details
- Model Checkpoint: microsoft/swin-base-patch4-window7-224
- Training Parameters:
- Batch Size: 8 (with gradient accumulation steps=2)
- Epochs: Up to 25 (with early stopping, patience=7)
- Learning Rate: 3e-5
- Scheduler: Cosine with warmup (10% of steps)
- Weight Decay: 0.05
- Downloads last month
- 11
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support