File size: 5,745 Bytes
b1acf7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""
Audio sentiment analysis model using fine-tuned Wav2Vec2.
"""
import logging
import streamlit as st
from typing import Tuple
import torch
from PIL import Image
import os
from ..config.settings import AUDIO_MODEL_CONFIG
from ..utils.preprocessing import preprocess_audio_for_model
from ..utils.sentiment_mapping import get_sentiment_mapping
from src.utils.simple_model_manager import SimpleModelManager
logger = logging.getLogger(__name__)
@st.cache_resource
def load_audio_model():
"""Load the pre-trained Wav2Vec2 audio sentiment model from Google Drive."""
try:
manager = SimpleModelManager()
if manager is None:
logger.error("Model manager not available")
st.error("Model manager not available")
return None, None, None, None
# Load the model using the Google Drive manager
model, device = manager.load_audio_model()
if model is None:
logger.error("Failed to load audio model from Google Drive")
st.error("Failed to load audio model from Google Drive")
return None, None, None, None
# For Wav2Vec2 models, we need to determine the number of classes
# This is typically available in the model configuration
try:
num_classes = model.config.num_labels
except:
# Fallback: try to infer from the model
try:
num_classes = model.classifier.out_features
except:
num_classes = AUDIO_MODEL_CONFIG["num_classes"] # Default assumption
# Load feature extractor
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained(
AUDIO_MODEL_CONFIG["model_name"]
)
logger.info(f"Audio model loaded successfully with {num_classes} classes!")
st.success(f"Audio model loaded successfully with {num_classes} classes!")
return model, device, num_classes, feature_extractor
except Exception as e:
logger.error(f"Error loading audio model: {str(e)}")
st.error(f"Error loading audio model: {str(e)}")
return None, None, None, None
def predict_audio_sentiment(audio_bytes: bytes) -> Tuple[str, float]:
"""
Analyze audio sentiment using fine-tuned Wav2Vec2 model.
Preprocessing matches CREMA-D + RAVDESS training specifications:
- Target sampling rate: 16kHz
- Max duration: 5.0 seconds
- Feature extraction: AutoFeatureExtractor with max_length, truncation, padding
Args:
audio_bytes: Raw audio bytes
Returns:
Tuple of (sentiment, confidence)
"""
if audio_bytes is None:
return "No audio provided", 0.0
try:
# Load model if not already loaded
model, device, num_classes, feature_extractor = load_audio_model()
if model is None:
return "Model not loaded", 0.0
# Use our centralized preprocessing function
input_values = preprocess_audio_for_model(audio_bytes)
if input_values is None:
return "Preprocessing failed", 0.0
# Debug: Log the tensor shape
logger.info(f"Preprocessed audio tensor shape: {input_values.shape}")
# Ensure correct tensor shape: [batch_size, sequence_length]
if input_values.dim() == 1:
input_values = input_values.unsqueeze(0) # Add batch dimension if missing
elif input_values.dim() == 3:
# If we get [batch, sequence, channels], squeeze the channels
input_values = input_values.squeeze(-1)
logger.info(f"Final audio tensor shape: {input_values.shape}")
# Move to device
input_values = input_values.to(device)
# Run inference
with torch.no_grad():
outputs = model(input_values)
probabilities = torch.softmax(outputs.logits, dim=1)
confidence, predicted = torch.max(probabilities, 1)
# Get sentiment mapping based on number of classes
sentiment_map = get_sentiment_mapping(num_classes)
sentiment = sentiment_map[predicted.item()]
confidence_score = confidence.item()
logger.info(
f"Audio sentiment analysis completed: {sentiment} (confidence: {confidence_score:.2f})"
)
return sentiment, confidence_score
except ImportError as e:
logger.error(f"Required library not installed: {str(e)}")
st.error(f"Required library not installed: {str(e)}")
st.info("Please install: pip install librosa transformers")
return "Library not available", 0.0
except Exception as e:
logger.error(f"Error in audio sentiment prediction: {str(e)}")
st.error(f"Error in audio sentiment prediction: {str(e)}")
return "Error occurred", 0.0
def get_audio_model_info() -> dict:
"""Get information about the audio sentiment model."""
return {
"model_name": AUDIO_MODEL_CONFIG["model_name"],
"description": "Fine-tuned Wav2Vec2 for audio sentiment analysis",
"capabilities": [
"Audio sentiment classification",
"Automatic audio preprocessing",
"CREMA-D + RAVDESS dataset compatibility",
"Real-time audio analysis",
],
"input_format": "Audio files (WAV, MP3, M4A, FLAC)",
"output_format": "Sentiment label + confidence score",
"preprocessing": {
"sampling_rate": f"{AUDIO_MODEL_CONFIG['target_sampling_rate']} Hz",
"max_duration": f"{AUDIO_MODEL_CONFIG['max_duration']} seconds",
"feature_extraction": "AutoFeatureExtractor",
"normalization": "Model-specific",
},
}
|