Fine-tuned Swin Base Model for Gunshot Detection

Fine-tuned version of microsoft/swin-base-patch4-window7-224 for gunshot detection using spectrograms.

Model Details

  • Trained by: Ranabir Saha
  • Fine-tuned on: Tropical forest gunshot classification training audio dataset from Automated detection of gunshots in tropical forests using convolutional neural networks (Katsis et al. 2022)
  • Dataset Source: https://doi.org/10.17632/x48cwz364j.3
  • Input: 4-second .wav audio files resampled to 8000 Hz, converted to 224x224 log-mel spectrograms with 3 channels (RGB-like) during preprocessing
  • Output: Binary classification (Background/Gunshot)

Usage

The model expects spectrogram images as input. Below is an example of how to preprocess an audio file and use the model:

import torchaudio
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
from transformers import pipeline, AutoImageProcessor

# Preprocess audio to spectrogram
SAMPLE_RATE = 8000
TARGET_LENGTH = 4
def preprocess_wav_to_spectrogram(file_path):
    waveform, orig_sample_rate = torchaudio.load(file_path)
    if orig_sample_rate != SAMPLE_RATE:
        resampler = torchaudio.transforms.Resample(orig_sample_rate, SAMPLE_RATE)
        waveform = resampler(waveform)
    target_length_samples = TARGET_LENGTH * SAMPLE_RATE
    if waveform.shape[1] > target_length_samples:
        waveform = waveform[:, :target_length_samples]
    else:
        padding = target_length_samples - waveform.shape[1]
        waveform = torch.nn.functional.pad(waveform, (0, padding))
    mel_spec = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=512,
        hop_length=160,
        n_mels=64,
        power=1.0
    )(waveform)
    log_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)
    log_spec = (log_spec + 80) / 80  # Normalize to [0,1]
    log_spec = log_spec.repeat(3, 1, 1)
    log_spec = F.interpolate(log_spec.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
    log_spec_np = log_spec.permute(1, 2, 0).numpy()
    log_spec_np = (log_spec_np * 255).astype(np.uint8)
    pil_image = Image.fromarray(log_spec_np, mode='RGB')
    return pil_image

# Load model and processor
image_processor = AutoImageProcessor.from_pretrained("ranvir-not-found/swin-wda_gunshot-detection")
classifier = pipeline(
    "image-classification",
    model="ranvir-not-found/swin-wda_gunshot-detection",
    image_processor=image_processor
)

# Process audio and classify
audio_path = "path/to/your/audio.wav"
spectrogram = preprocess_wav_to_spectrogram(audio_path)
results = classifier(spectrogram)
print(results)

Training Details

  • Model Checkpoint: microsoft/swin-base-patch4-window7-224
  • Training Parameters:
    • Batch Size: 8 (with gradient accumulation steps=2)
    • Epochs: Up to 25 (with early stopping, patience=7)
    • Learning Rate: 3e-5
    • Scheduler: Cosine with warmup (10% of steps)
    • Weight Decay: 0.05
Downloads last month
11
Safetensors
Model size
86.8M params
Tensor type
I64
·
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support