File size: 5,745 Bytes
b1acf7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Audio sentiment analysis model using fine-tuned Wav2Vec2.
"""

import logging
import streamlit as st
from typing import Tuple
import torch
from PIL import Image
import os
from ..config.settings import AUDIO_MODEL_CONFIG
from ..utils.preprocessing import preprocess_audio_for_model
from ..utils.sentiment_mapping import get_sentiment_mapping
from src.utils.simple_model_manager import SimpleModelManager

logger = logging.getLogger(__name__)


@st.cache_resource
def load_audio_model():
    """Load the pre-trained Wav2Vec2 audio sentiment model from Google Drive."""
    try:
        manager = SimpleModelManager()
        if manager is None:
            logger.error("Model manager not available")
            st.error("Model manager not available")
            return None, None, None, None

        # Load the model using the Google Drive manager
        model, device = manager.load_audio_model()

        if model is None:
            logger.error("Failed to load audio model from Google Drive")
            st.error("Failed to load audio model from Google Drive")
            return None, None, None, None

        # For Wav2Vec2 models, we need to determine the number of classes
        # This is typically available in the model configuration
        try:
            num_classes = model.config.num_labels
        except:
            # Fallback: try to infer from the model
            try:
                num_classes = model.classifier.out_features
            except:
                num_classes = AUDIO_MODEL_CONFIG["num_classes"]  # Default assumption

        # Load feature extractor
        from transformers import AutoFeatureExtractor

        feature_extractor = AutoFeatureExtractor.from_pretrained(
            AUDIO_MODEL_CONFIG["model_name"]
        )

        logger.info(f"Audio model loaded successfully with {num_classes} classes!")
        st.success(f"Audio model loaded successfully with {num_classes} classes!")
        return model, device, num_classes, feature_extractor
    except Exception as e:
        logger.error(f"Error loading audio model: {str(e)}")
        st.error(f"Error loading audio model: {str(e)}")
        return None, None, None, None


def predict_audio_sentiment(audio_bytes: bytes) -> Tuple[str, float]:
    """
    Analyze audio sentiment using fine-tuned Wav2Vec2 model.

    Preprocessing matches CREMA-D + RAVDESS training specifications:
    - Target sampling rate: 16kHz
    - Max duration: 5.0 seconds
    - Feature extraction: AutoFeatureExtractor with max_length, truncation, padding

    Args:
        audio_bytes: Raw audio bytes

    Returns:
        Tuple of (sentiment, confidence)
    """
    if audio_bytes is None:
        return "No audio provided", 0.0

    try:
        # Load model if not already loaded
        model, device, num_classes, feature_extractor = load_audio_model()
        if model is None:
            return "Model not loaded", 0.0

        # Use our centralized preprocessing function
        input_values = preprocess_audio_for_model(audio_bytes)
        if input_values is None:
            return "Preprocessing failed", 0.0

        # Debug: Log the tensor shape
        logger.info(f"Preprocessed audio tensor shape: {input_values.shape}")

        # Ensure correct tensor shape: [batch_size, sequence_length]
        if input_values.dim() == 1:
            input_values = input_values.unsqueeze(0)  # Add batch dimension if missing
        elif input_values.dim() == 3:
            # If we get [batch, sequence, channels], squeeze the channels
            input_values = input_values.squeeze(-1)

        logger.info(f"Final audio tensor shape: {input_values.shape}")

        # Move to device
        input_values = input_values.to(device)

        # Run inference
        with torch.no_grad():
            outputs = model(input_values)
            probabilities = torch.softmax(outputs.logits, dim=1)
            confidence, predicted = torch.max(probabilities, 1)

            # Get sentiment mapping based on number of classes
            sentiment_map = get_sentiment_mapping(num_classes)
            sentiment = sentiment_map[predicted.item()]
            confidence_score = confidence.item()

        logger.info(
            f"Audio sentiment analysis completed: {sentiment} (confidence: {confidence_score:.2f})"
        )
        return sentiment, confidence_score

    except ImportError as e:
        logger.error(f"Required library not installed: {str(e)}")
        st.error(f"Required library not installed: {str(e)}")
        st.info("Please install: pip install librosa transformers")
        return "Library not available", 0.0
    except Exception as e:
        logger.error(f"Error in audio sentiment prediction: {str(e)}")
        st.error(f"Error in audio sentiment prediction: {str(e)}")
        return "Error occurred", 0.0


def get_audio_model_info() -> dict:
    """Get information about the audio sentiment model."""
    return {
        "model_name": AUDIO_MODEL_CONFIG["model_name"],
        "description": "Fine-tuned Wav2Vec2 for audio sentiment analysis",
        "capabilities": [
            "Audio sentiment classification",
            "Automatic audio preprocessing",
            "CREMA-D + RAVDESS dataset compatibility",
            "Real-time audio analysis",
        ],
        "input_format": "Audio files (WAV, MP3, M4A, FLAC)",
        "output_format": "Sentiment label + confidence score",
        "preprocessing": {
            "sampling_rate": f"{AUDIO_MODEL_CONFIG['target_sampling_rate']} Hz",
            "max_duration": f"{AUDIO_MODEL_CONFIG['max_duration']} seconds",
            "feature_extraction": "AutoFeatureExtractor",
            "normalization": "Model-specific",
        },
    }