Changed dockerfile
Browse files- Dockerfile +2 -0
- app/services/model_service.py +113 -20
Dockerfile
CHANGED
@@ -2,6 +2,7 @@ FROM python:3.9-slim
|
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
|
|
5 |
ENV HF_HOME=/app/.cache
|
6 |
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
|
7 |
ENV TORCH_HOME=/app/.cache/torch
|
@@ -16,6 +17,7 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|
16 |
COPY app/ ./app/
|
17 |
COPY .env .
|
18 |
|
|
|
19 |
RUN mkdir -p uploads models
|
20 |
RUN chmod -R 777 uploads models
|
21 |
|
|
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
5 |
+
# Create a user-writable cache directory
|
6 |
ENV HF_HOME=/app/.cache
|
7 |
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
|
8 |
ENV TORCH_HOME=/app/.cache/torch
|
|
|
17 |
COPY app/ ./app/
|
18 |
COPY .env .
|
19 |
|
20 |
+
# Create required directories with proper permissions
|
21 |
RUN mkdir -p uploads models
|
22 |
RUN chmod -R 777 uploads models
|
23 |
|
app/services/model_service.py
CHANGED
@@ -27,15 +27,6 @@ class ModelService:
|
|
27 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
logger.info(f"Using device: {self.device}")
|
29 |
|
30 |
-
# Set up cache directories
|
31 |
-
os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/app/.cache")
|
32 |
-
os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/app/.cache/huggingface")
|
33 |
-
os.environ["TORCH_HOME"] = os.environ.get("TORCH_HOME", "/app/.cache/torch")
|
34 |
-
|
35 |
-
# Ensure cache directories exist and have proper permissions
|
36 |
-
for cache_dir in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["TORCH_HOME"]]:
|
37 |
-
os.makedirs(cache_dir, exist_ok=True)
|
38 |
-
|
39 |
# Try to login to Hugging Face if token is provided
|
40 |
if settings.HUGGINGFACE_TOKEN:
|
41 |
try:
|
@@ -54,25 +45,19 @@ class ModelService:
|
|
54 |
# Create the directory if it doesn't exist
|
55 |
os.makedirs(os.path.dirname(settings.MODEL_PATH), exist_ok=True)
|
56 |
|
57 |
-
|
58 |
-
logger.info(f"HF_HOME: {os.environ.get('HF_HOME')}")
|
59 |
-
logger.info(f"TRANSFORMERS_CACHE: {os.environ.get('TRANSFORMERS_CACHE')}")
|
60 |
-
logger.info(f"Using model repo: {settings.HF_MODEL_REPO}")
|
61 |
|
62 |
# Download the model file from Hugging Face
|
63 |
-
|
64 |
repo_id=settings.HF_MODEL_REPO,
|
65 |
filename=settings.HF_MODEL_FILENAME,
|
66 |
local_dir=os.path.dirname(settings.MODEL_PATH),
|
67 |
-
local_dir_use_symlinks=False
|
68 |
-
cache_dir=os.environ.get("HF_HOME")
|
69 |
)
|
70 |
|
71 |
-
logger.info(f"Model downloaded to: {downloaded_path}")
|
72 |
-
|
73 |
# Rename the downloaded file to match the expected path if needed
|
|
|
74 |
if downloaded_path != settings.MODEL_PATH:
|
75 |
-
logger.info(f"Renaming {downloaded_path} to {settings.MODEL_PATH}")
|
76 |
os.rename(downloaded_path, settings.MODEL_PATH)
|
77 |
|
78 |
logger.info(f"Model downloaded successfully to {settings.MODEL_PATH}")
|
@@ -81,4 +66,112 @@ class ModelService:
|
|
81 |
logger.error(f"Error downloading model from Hugging Face Hub: {e}")
|
82 |
return False
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
logger.info(f"Using device: {self.device}")
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Try to login to Hugging Face if token is provided
|
31 |
if settings.HUGGINGFACE_TOKEN:
|
32 |
try:
|
|
|
45 |
# Create the directory if it doesn't exist
|
46 |
os.makedirs(os.path.dirname(settings.MODEL_PATH), exist_ok=True)
|
47 |
|
48 |
+
logger.info(f"Downloading model from {settings.HF_MODEL_REPO} to {settings.MODEL_PATH}")
|
|
|
|
|
|
|
49 |
|
50 |
# Download the model file from Hugging Face
|
51 |
+
hf_hub_download(
|
52 |
repo_id=settings.HF_MODEL_REPO,
|
53 |
filename=settings.HF_MODEL_FILENAME,
|
54 |
local_dir=os.path.dirname(settings.MODEL_PATH),
|
55 |
+
local_dir_use_symlinks=False
|
|
|
56 |
)
|
57 |
|
|
|
|
|
58 |
# Rename the downloaded file to match the expected path if needed
|
59 |
+
downloaded_path = os.path.join(os.path.dirname(settings.MODEL_PATH), settings.HF_MODEL_FILENAME)
|
60 |
if downloaded_path != settings.MODEL_PATH:
|
|
|
61 |
os.rename(downloaded_path, settings.MODEL_PATH)
|
62 |
|
63 |
logger.info(f"Model downloaded successfully to {settings.MODEL_PATH}")
|
|
|
66 |
logger.error(f"Error downloading model from Hugging Face Hub: {e}")
|
67 |
return False
|
68 |
|
69 |
+
def load_model(self):
|
70 |
+
"""Load the VQA model from the specified path or download it if not present"""
|
71 |
+
try:
|
72 |
+
# Check if model exists locally
|
73 |
+
if not self._check_model_exists():
|
74 |
+
logger.info(f"Model not found at {settings.MODEL_PATH}")
|
75 |
+
|
76 |
+
# Download the model from Hugging Face Hub
|
77 |
+
if not self._download_model_from_hub():
|
78 |
+
logger.error("Failed to download model from Hugging Face Hub")
|
79 |
+
return False
|
80 |
+
|
81 |
+
logger.info(f"Loading model from {settings.MODEL_PATH}")
|
82 |
+
checkpoint = torch.load(settings.MODEL_PATH, map_location=self.device)
|
83 |
+
|
84 |
+
# Extract configuration
|
85 |
+
self.config = checkpoint['config']
|
86 |
+
|
87 |
+
# Get vocabulary
|
88 |
+
if 'answer_vocab' in checkpoint:
|
89 |
+
self.answer_vocab = checkpoint['answer_vocab']
|
90 |
+
logger.info("Using vocabulary from model checkpoint")
|
91 |
+
else:
|
92 |
+
logger.error("Error: No vocabulary found in model checkpoint")
|
93 |
+
raise ValueError("No vocabulary found in model checkpoint")
|
94 |
+
|
95 |
+
# Initialize model
|
96 |
+
self.model = VQAModel(self.config, len(self.answer_vocab['answer_to_idx']))
|
97 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
98 |
+
self.model.to(self.device)
|
99 |
+
self.model.eval()
|
100 |
+
|
101 |
+
# Initialize preprocessors
|
102 |
+
self.processor = ViTImageProcessor.from_pretrained(self.config['vision_model'])
|
103 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config['text_model'])
|
104 |
+
|
105 |
+
logger.info("Model loaded successfully")
|
106 |
+
return True
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(f"Error loading model: {e}")
|
110 |
+
return False
|
111 |
+
|
112 |
+
def is_model_loaded(self):
|
113 |
+
"""Check if the model is loaded"""
|
114 |
+
return self.model is not None and self.processor is not None and self.tokenizer is not None
|
115 |
+
|
116 |
+
def predict(self, image_path, question):
|
117 |
+
"""
|
118 |
+
Make a prediction for the given image and question
|
119 |
+
|
120 |
+
Args:
|
121 |
+
image_path (str): Path to the image file
|
122 |
+
question (str): Question about the image
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
dict: Prediction results
|
126 |
+
"""
|
127 |
+
if not self.is_model_loaded():
|
128 |
+
logger.error("Model not loaded")
|
129 |
+
raise RuntimeError("Model not loaded")
|
130 |
+
|
131 |
+
try:
|
132 |
+
# Preprocess image
|
133 |
+
image = Image.open(image_path).convert('RGB')
|
134 |
+
image_encoding = self.processor(images=image, return_tensors="pt")
|
135 |
+
image_encoding = {k: v.to(self.device) for k, v in image_encoding.items()}
|
136 |
+
|
137 |
+
# Preprocess question
|
138 |
+
question_encoding = self.tokenizer(
|
139 |
+
question,
|
140 |
+
padding='max_length',
|
141 |
+
truncation=True,
|
142 |
+
max_length=128,
|
143 |
+
return_tensors='pt'
|
144 |
+
)
|
145 |
+
question_encoding = {k: v.to(self.device) for k, v in question_encoding.items()}
|
146 |
+
|
147 |
+
# Get predictions
|
148 |
+
with torch.no_grad():
|
149 |
+
outputs = self.model(image_encoding, question_encoding)
|
150 |
+
|
151 |
+
answer_logits = outputs['answer_logits']
|
152 |
+
answerable_logits = outputs['answerable_logits']
|
153 |
+
|
154 |
+
answer_idx = torch.argmax(answer_logits, dim=1).item()
|
155 |
+
answerable_idx = torch.argmax(answerable_logits, dim=1).item()
|
156 |
+
|
157 |
+
# Convert string index to int for dictionary lookup
|
158 |
+
answer = self.answer_vocab['idx_to_answer'][str(answer_idx)]
|
159 |
+
is_answerable = bool(answerable_idx)
|
160 |
+
|
161 |
+
# Get confidence scores
|
162 |
+
answer_probs = torch.softmax(answer_logits, dim=1)[0]
|
163 |
+
answerable_probs = torch.softmax(answerable_logits, dim=1)[0]
|
164 |
+
|
165 |
+
answer_confidence = float(answer_probs[answer_idx].item())
|
166 |
+
answerable_confidence = float(answerable_probs[answerable_idx].item())
|
167 |
+
|
168 |
+
return {
|
169 |
+
'answer': answer,
|
170 |
+
'answer_confidence': answer_confidence,
|
171 |
+
'is_answerable': is_answerable,
|
172 |
+
'answerable_confidence': answerable_confidence
|
173 |
+
}
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f"Error during prediction: {e}")
|
177 |
+
raise
|