VQA / app /services /model_service.py
dixisouls's picture
Download model using HTTP only
90dd904
"""
Model service for handling VQA model operations
"""
import os
import json
import logging
import torch
from PIL import Image
from transformers import AutoTokenizer, ViTImageProcessor
from huggingface_hub import login
import requests
from huggingface_hub.utils import build_hf_headers
from app.config import settings
from app.models.vqa_model import VQAModel
logger = logging.getLogger(__name__)
class ModelService:
"""Service for loading and running the VQA model"""
def __init__(self):
"""Initialize the model service"""
self.model = None
self.processor = None
self.tokenizer = None
self.config = None
self.answer_vocab = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Try to login to Hugging Face if token is provided
if settings.HUGGINGFACE_TOKEN:
try:
login(token=settings.HUGGINGFACE_TOKEN)
logger.info("Successfully logged in to Hugging Face Hub")
except Exception as e:
logger.error(f"Error logging in to Hugging Face Hub: {e}")
def _check_model_exists(self):
"""Check if the model file exists locally"""
return os.path.exists(settings.MODEL_PATH)
def _download_model_from_hub(self):
"""Download the model from Hugging Face Hub if not present locally"""
try:
# Fallback method: direct download using requests
logger.info("Downlaoding model from Hugging Face Hub")
# Get Hugging Face token from settings
token = settings.HUGGINGFACE_TOKEN
# Build proper URL for the model file
url = f"https://huggingface.co/{settings.HF_MODEL_REPO}/resolve/main/{settings.HF_MODEL_FILENAME}"
logger.info(f"Downloading from URL: {url}")
# Download with proper headers
headers = build_hf_headers(token=token)
response = requests.get(url, headers=headers, stream=True)
response.raise_for_status()
# Write the file in chunks to avoid memory issues
logger.info(f"Writing downloaded content to {settings.MODEL_PATH}")
with open(settings.MODEL_PATH, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Model downloaded successfully")
return True
except Exception as e:
logger.error(f"Error downloading model from Hugging Face Hub: {e}")
return False
def load_model(self):
"""Load the VQA model from the specified path or download it if not present"""
try:
# Check if model exists locally
if not self._check_model_exists():
logger.info(f"Model not found at {settings.MODEL_PATH}")
# Download the model from Hugging Face Hub
if not self._download_model_from_hub():
logger.error("Failed to download model from Hugging Face Hub")
return False
logger.info(f"Loading model from {settings.MODEL_PATH}")
checkpoint = torch.load(settings.MODEL_PATH, map_location=self.device)
# Extract configuration
self.config = checkpoint['config']
# Get vocabulary
if 'answer_vocab' in checkpoint:
self.answer_vocab = checkpoint['answer_vocab']
logger.info("Using vocabulary from model checkpoint")
else:
logger.error("Error: No vocabulary found in model checkpoint")
raise ValueError("No vocabulary found in model checkpoint")
# Initialize model
self.model = VQAModel(self.config, len(self.answer_vocab['answer_to_idx']))
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# Initialize preprocessors
self.processor = ViTImageProcessor.from_pretrained(self.config['vision_model'])
self.tokenizer = AutoTokenizer.from_pretrained(self.config['text_model'])
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def is_model_loaded(self):
"""Check if the model is loaded"""
return self.model is not None and self.processor is not None and self.tokenizer is not None
def predict(self, image_path, question):
"""
Make a prediction for the given image and question
Args:
image_path (str): Path to the image file
question (str): Question about the image
Returns:
dict: Prediction results
"""
if not self.is_model_loaded():
logger.error("Model not loaded")
raise RuntimeError("Model not loaded")
try:
# Preprocess image
image = Image.open(image_path).convert('RGB')
image_encoding = self.processor(images=image, return_tensors="pt")
image_encoding = {k: v.to(self.device) for k, v in image_encoding.items()}
# Preprocess question
question_encoding = self.tokenizer(
question,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
)
question_encoding = {k: v.to(self.device) for k, v in question_encoding.items()}
# Get predictions
with torch.no_grad():
outputs = self.model(image_encoding, question_encoding)
answer_logits = outputs['answer_logits']
answerable_logits = outputs['answerable_logits']
answer_idx = torch.argmax(answer_logits, dim=1).item()
answerable_idx = torch.argmax(answerable_logits, dim=1).item()
# Convert string index to int for dictionary lookup
answer = self.answer_vocab['idx_to_answer'][str(answer_idx)]
is_answerable = bool(answerable_idx)
# Get confidence scores
answer_probs = torch.softmax(answer_logits, dim=1)[0]
answerable_probs = torch.softmax(answerable_logits, dim=1)[0]
answer_confidence = float(answer_probs[answer_idx].item())
answerable_confidence = float(answerable_probs[answerable_idx].item())
return {
'answer': answer,
'answer_confidence': answer_confidence,
'is_answerable': is_answerable,
'answerable_confidence': answerable_confidence
}
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise