|
""" |
|
Vision sentiment analysis model using fine-tuned ResNet-50. |
|
""" |
|
|
|
import logging |
|
import streamlit as st |
|
from typing import Tuple |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
|
|
from ..config.settings import VISION_MODEL_CONFIG |
|
from ..utils.preprocessing import detect_and_preprocess_face, get_vision_transforms |
|
from ..utils.sentiment_mapping import get_sentiment_mapping |
|
from src.utils.simple_model_manager import SimpleModelManager |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@st.cache_resource |
|
def get_model_manager(): |
|
"""Get the Google Drive model manager instance.""" |
|
try: |
|
manager = SimpleModelManager() |
|
return manager |
|
except Exception as e: |
|
logger.error(f"Failed to initialize model manager: {e}") |
|
st.error(f"Failed to initialize model manager: {e}") |
|
return None |
|
|
|
|
|
@st.cache_resource |
|
def load_vision_model(): |
|
"""Load the pre-trained ResNet-50 vision sentiment model from Google Drive.""" |
|
try: |
|
manager = get_model_manager() |
|
if manager is None: |
|
logger.error("Model manager not available") |
|
st.error("Model manager not available") |
|
return None, None, None |
|
|
|
|
|
model, device, num_classes = manager.load_vision_model() |
|
|
|
if model is None: |
|
logger.error("Failed to load vision model from Google Drive") |
|
st.error("Failed to load vision model from Google Drive") |
|
return None, None, None |
|
|
|
logger.info(f"Vision model loaded successfully with {num_classes} classes!") |
|
st.success(f"Vision model loaded successfully with {num_classes} classes!") |
|
return model, device, num_classes |
|
except Exception as e: |
|
logger.error(f"Error loading vision model: {str(e)}") |
|
st.error(f"Error loading vision model: {str(e)}") |
|
return None, None, None |
|
|
|
|
|
def predict_vision_sentiment( |
|
image: Image.Image, crop_tightness: float = None |
|
) -> Tuple[str, float]: |
|
""" |
|
Load ResNet-50 and run inference for vision sentiment analysis. |
|
|
|
Args: |
|
image: Input image (PIL Image or numpy array) |
|
crop_tightness: Padding around face (0.0 = no padding, 0.3 = 30% padding) |
|
|
|
Returns: |
|
Tuple of (sentiment, confidence) |
|
""" |
|
if image is None: |
|
return "No image provided", 0.0 |
|
|
|
try: |
|
|
|
if crop_tightness is None: |
|
crop_tightness = VISION_MODEL_CONFIG["crop_tightness"] |
|
|
|
|
|
model, device, num_classes = load_vision_model() |
|
if model is None: |
|
return "Model not loaded", 0.0 |
|
|
|
|
|
st.info( |
|
"Detecting face and preprocessing image to match training data format..." |
|
) |
|
preprocessed_image = detect_and_preprocess_face( |
|
image, crop_tightness=crop_tightness |
|
) |
|
|
|
if preprocessed_image is None: |
|
return "Image preprocessing failed", 0.0 |
|
|
|
|
|
st.image( |
|
preprocessed_image, |
|
caption="Preprocessed Image (224x224 Grayscale → 3-channel RGB)", |
|
width=200, |
|
) |
|
|
|
|
|
transform = get_vision_transforms() |
|
|
|
|
|
image_tensor = transform(preprocessed_image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(image_tensor) |
|
|
|
|
|
st.info(f"Model output shape: {outputs.shape}") |
|
|
|
probabilities = F.softmax(outputs, dim=1) |
|
confidence, predicted = torch.max(probabilities, 1) |
|
|
|
|
|
sentiment_map = get_sentiment_mapping(num_classes) |
|
sentiment = sentiment_map[predicted.item()] |
|
confidence_score = confidence.item() |
|
|
|
logger.info( |
|
f"Vision sentiment analysis completed: {sentiment} (confidence: {confidence_score:.2f})" |
|
) |
|
return sentiment, confidence_score |
|
|
|
except Exception as e: |
|
logger.error(f"Error in vision sentiment prediction: {str(e)}") |
|
st.error(f"Error in vision sentiment prediction: {str(e)}") |
|
st.error( |
|
f"Model output shape mismatch. Expected {num_classes} classes but got different." |
|
) |
|
return "Error occurred", 0.0 |
|
|
|
|
|
def get_vision_model_info() -> dict: |
|
"""Get information about the vision sentiment model.""" |
|
return { |
|
"model_name": VISION_MODEL_CONFIG["model_name"], |
|
"description": "Fine-tuned ResNet-50 for facial expression sentiment analysis", |
|
"capabilities": [ |
|
"Facial expression recognition", |
|
"Automatic face detection and cropping", |
|
"FER2013 dataset format compatibility", |
|
"Real-time image analysis", |
|
], |
|
"input_format": "Images (PNG, JPG, JPEG, BMP, TIFF)", |
|
"output_format": "Sentiment label + confidence score", |
|
"preprocessing": { |
|
"face_detection": "OpenCV Haar Cascade", |
|
"image_size": f"{VISION_MODEL_CONFIG['input_size']}x{VISION_MODEL_CONFIG['input_size']}", |
|
"color_format": "Grayscale → 3-channel RGB", |
|
"normalization": "ImageNet standard", |
|
}, |
|
} |
|
|