File size: 7,360 Bytes
eacbbc9
 
 
 
 
 
 
 
 
90dd904
 
 
eacbbc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90dd904
 
 
eacbbc9
90dd904
 
eacbbc9
90dd904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eacbbc9
 
 
 
 
a7f1017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
Model service for handling VQA model operations
"""
import os
import json
import logging
import torch
from PIL import Image
from transformers import AutoTokenizer, ViTImageProcessor
from huggingface_hub import login
import requests
from huggingface_hub.utils import build_hf_headers
from app.config import settings
from app.models.vqa_model import VQAModel

logger = logging.getLogger(__name__)

class ModelService:
    """Service for loading and running the VQA model"""
    
    def __init__(self):
        """Initialize the model service"""
        self.model = None
        self.processor = None
        self.tokenizer = None
        self.config = None
        self.answer_vocab = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")
        
        # Try to login to Hugging Face if token is provided
        if settings.HUGGINGFACE_TOKEN:
            try:
                login(token=settings.HUGGINGFACE_TOKEN)
                logger.info("Successfully logged in to Hugging Face Hub")
            except Exception as e:
                logger.error(f"Error logging in to Hugging Face Hub: {e}")
    
    def _check_model_exists(self):
        """Check if the model file exists locally"""
        return os.path.exists(settings.MODEL_PATH)
    
    def _download_model_from_hub(self):
        """Download the model from Hugging Face Hub if not present locally"""
        try:
            
            # Fallback method: direct download using requests
            logger.info("Downlaoding model from Hugging Face Hub")

            
            # Get Hugging Face token from settings
            token = settings.HUGGINGFACE_TOKEN
            
            # Build proper URL for the model file
            url = f"https://huggingface.co/{settings.HF_MODEL_REPO}/resolve/main/{settings.HF_MODEL_FILENAME}"
            logger.info(f"Downloading from URL: {url}")
            
            # Download with proper headers
            headers = build_hf_headers(token=token)
            response = requests.get(url, headers=headers, stream=True)
            response.raise_for_status()
            
            # Write the file in chunks to avoid memory issues
            logger.info(f"Writing downloaded content to {settings.MODEL_PATH}")
            with open(settings.MODEL_PATH, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            
            logger.info(f"Model downloaded successfully")
            return True
                
        except Exception as e:
            logger.error(f"Error downloading model from Hugging Face Hub: {e}")
            return False
    
    def load_model(self):
        """Load the VQA model from the specified path or download it if not present"""
        try:
            # Check if model exists locally
            if not self._check_model_exists():
                logger.info(f"Model not found at {settings.MODEL_PATH}")
                
                # Download the model from Hugging Face Hub
                if not self._download_model_from_hub():
                    logger.error("Failed to download model from Hugging Face Hub")
                    return False
            
            logger.info(f"Loading model from {settings.MODEL_PATH}")
            checkpoint = torch.load(settings.MODEL_PATH, map_location=self.device)
            
            # Extract configuration
            self.config = checkpoint['config']
            
            # Get vocabulary
            if 'answer_vocab' in checkpoint:
                self.answer_vocab = checkpoint['answer_vocab']
                logger.info("Using vocabulary from model checkpoint")
            else:
                logger.error("Error: No vocabulary found in model checkpoint")
                raise ValueError("No vocabulary found in model checkpoint")
            
            # Initialize model
            self.model = VQAModel(self.config, len(self.answer_vocab['answer_to_idx']))
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.model.to(self.device)
            self.model.eval()
            
            # Initialize preprocessors
            self.processor = ViTImageProcessor.from_pretrained(self.config['vision_model'])
            self.tokenizer = AutoTokenizer.from_pretrained(self.config['text_model'])
            
            logger.info("Model loaded successfully")
            return True
            
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            return False
    
    def is_model_loaded(self):
        """Check if the model is loaded"""
        return self.model is not None and self.processor is not None and self.tokenizer is not None
    
    def predict(self, image_path, question):
        """
        Make a prediction for the given image and question
        
        Args:
            image_path (str): Path to the image file
            question (str): Question about the image
            
        Returns:
            dict: Prediction results
        """
        if not self.is_model_loaded():
            logger.error("Model not loaded")
            raise RuntimeError("Model not loaded")
        
        try:
            # Preprocess image
            image = Image.open(image_path).convert('RGB')
            image_encoding = self.processor(images=image, return_tensors="pt")
            image_encoding = {k: v.to(self.device) for k, v in image_encoding.items()}
            
            # Preprocess question
            question_encoding = self.tokenizer(
                question,
                padding='max_length',
                truncation=True,
                max_length=128,
                return_tensors='pt'
            )
            question_encoding = {k: v.to(self.device) for k, v in question_encoding.items()}
            
            # Get predictions
            with torch.no_grad():
                outputs = self.model(image_encoding, question_encoding)
                
                answer_logits = outputs['answer_logits']
                answerable_logits = outputs['answerable_logits']
                
                answer_idx = torch.argmax(answer_logits, dim=1).item()
                answerable_idx = torch.argmax(answerable_logits, dim=1).item()
                
                # Convert string index to int for dictionary lookup
                answer = self.answer_vocab['idx_to_answer'][str(answer_idx)]
                is_answerable = bool(answerable_idx)
                
                # Get confidence scores
                answer_probs = torch.softmax(answer_logits, dim=1)[0]
                answerable_probs = torch.softmax(answerable_logits, dim=1)[0]
                
                answer_confidence = float(answer_probs[answer_idx].item())
                answerable_confidence = float(answerable_probs[answerable_idx].item())
            
            return {
                'answer': answer,
                'answer_confidence': answer_confidence,
                'is_answerable': is_answerable,
                'answerable_confidence': answerable_confidence
            }
            
        except Exception as e:
            logger.error(f"Error during prediction: {e}")
            raise