Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
import torch | |
import librosa | |
import numpy as np | |
from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor | |
import os | |
app = Flask(__name__) | |
# Model setup | |
model_name = 'amiriparian/ExHuBERT' | |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960") | |
model = AutoModelForAudioClassification.from_pretrained(model_name, trust_remote_code=True) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Labels for emotion mapping | |
labels = ['disgust', 'neutral', 'kind', 'anger', 'surprise', 'joy'] | |
def detect_scream(): | |
try: | |
# Check if audio file is provided | |
if 'file' not in request.files: | |
return jsonify({'error': 'No audio file provided'}), 400 | |
audio_file = request.files['file'] | |
# Validate file type | |
if not audio_file.filename.endswith(('.wav', '.mp3')): | |
return jsonify({'error': 'Unsupported file format. Use WAV or MP3'}), 400 | |
# Save audio file temporarily | |
temp_path = f"/tmp/{audio_file.filename}" | |
audio_file.save(temp_path) | |
# Load and preprocess audio | |
waveform, sr = librosa.load(temp_path, sr=16000) | |
inputs = feature_extractor( | |
waveform, | |
sampling_rate=16000, | |
padding="max_length", | |
max_length=48000, | |
return_tensors="pt" | |
) | |
inputs = inputs['input_values'].to(device) | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(inputs).logits | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
confidence, predicted = torch.max(probabilities, 1) | |
# Get result | |
result = { | |
'label': labels[predicted.item()], | |
'confidence': float(confidence.item()), | |
'alert_level': 'High-Risk' if confidence.item() > 0.8 else ('Medium-Risk' if confidence.item() > 0.5 else 'None') | |
} | |
# Clean up temporary file | |
os.remove(temp_path) | |
return jsonify(result), 200 | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
def health_check(): | |
return jsonify({'status': 'healthy'}), 200 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) |