File size: 5,436 Bytes
b1acf7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
"""
Vision sentiment analysis model using fine-tuned ResNet-50.
"""
import logging
import streamlit as st
from typing import Tuple
import torch
import torch.nn.functional as F
from PIL import Image
from ..config.settings import VISION_MODEL_CONFIG
from ..utils.preprocessing import detect_and_preprocess_face, get_vision_transforms
from ..utils.sentiment_mapping import get_sentiment_mapping
from src.utils.simple_model_manager import SimpleModelManager
logger = logging.getLogger(__name__)
@st.cache_resource
def get_model_manager():
"""Get the Google Drive model manager instance."""
try:
manager = SimpleModelManager()
return manager
except Exception as e:
logger.error(f"Failed to initialize model manager: {e}")
st.error(f"Failed to initialize model manager: {e}")
return None
@st.cache_resource
def load_vision_model():
"""Load the pre-trained ResNet-50 vision sentiment model from Google Drive."""
try:
manager = get_model_manager()
if manager is None:
logger.error("Model manager not available")
st.error("Model manager not available")
return None, None, None
# Load the model using the Google Drive manager
model, device, num_classes = manager.load_vision_model()
if model is None:
logger.error("Failed to load vision model from Google Drive")
st.error("Failed to load vision model from Google Drive")
return None, None, None
logger.info(f"Vision model loaded successfully with {num_classes} classes!")
st.success(f"Vision model loaded successfully with {num_classes} classes!")
return model, device, num_classes
except Exception as e:
logger.error(f"Error loading vision model: {str(e)}")
st.error(f"Error loading vision model: {str(e)}")
return None, None, None
def predict_vision_sentiment(
image: Image.Image, crop_tightness: float = None
) -> Tuple[str, float]:
"""
Load ResNet-50 and run inference for vision sentiment analysis.
Args:
image: Input image (PIL Image or numpy array)
crop_tightness: Padding around face (0.0 = no padding, 0.3 = 30% padding)
Returns:
Tuple of (sentiment, confidence)
"""
if image is None:
return "No image provided", 0.0
try:
# Use default crop tightness if not specified
if crop_tightness is None:
crop_tightness = VISION_MODEL_CONFIG["crop_tightness"]
# Load model if not already loaded
model, device, num_classes = load_vision_model()
if model is None:
return "Model not loaded", 0.0
# Preprocess image to match FER2013 format
st.info(
"Detecting face and preprocessing image to match training data format..."
)
preprocessed_image = detect_and_preprocess_face(
image, crop_tightness=crop_tightness
)
if preprocessed_image is None:
return "Image preprocessing failed", 0.0
# Show preprocessed image
st.image(
preprocessed_image,
caption="Preprocessed Image (224x224 Grayscale → 3-channel RGB)",
width=200,
)
# Get transforms
transform = get_vision_transforms()
# Convert preprocessed image to tensor
image_tensor = transform(preprocessed_image).unsqueeze(0).to(device)
# Run inference
with torch.no_grad():
outputs = model(image_tensor)
# Debug: print output shape
st.info(f"Model output shape: {outputs.shape}")
probabilities = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
# Get sentiment mapping based on number of classes
sentiment_map = get_sentiment_mapping(num_classes)
sentiment = sentiment_map[predicted.item()]
confidence_score = confidence.item()
logger.info(
f"Vision sentiment analysis completed: {sentiment} (confidence: {confidence_score:.2f})"
)
return sentiment, confidence_score
except Exception as e:
logger.error(f"Error in vision sentiment prediction: {str(e)}")
st.error(f"Error in vision sentiment prediction: {str(e)}")
st.error(
f"Model output shape mismatch. Expected {num_classes} classes but got different."
)
return "Error occurred", 0.0
def get_vision_model_info() -> dict:
"""Get information about the vision sentiment model."""
return {
"model_name": VISION_MODEL_CONFIG["model_name"],
"description": "Fine-tuned ResNet-50 for facial expression sentiment analysis",
"capabilities": [
"Facial expression recognition",
"Automatic face detection and cropping",
"FER2013 dataset format compatibility",
"Real-time image analysis",
],
"input_format": "Images (PNG, JPG, JPEG, BMP, TIFF)",
"output_format": "Sentiment label + confidence score",
"preprocessing": {
"face_detection": "OpenCV Haar Cascade",
"image_size": f"{VISION_MODEL_CONFIG['input_size']}x{VISION_MODEL_CONFIG['input_size']}",
"color_format": "Grayscale → 3-channel RGB",
"normalization": "ImageNet standard",
},
}
|