dixisouls commited on
Commit
a7f1017
·
1 Parent(s): 41b88b5

Changed dockerfile

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -0
  2. app/services/model_service.py +113 -20
Dockerfile CHANGED
@@ -2,6 +2,7 @@ FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
 
5
  ENV HF_HOME=/app/.cache
6
  ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
7
  ENV TORCH_HOME=/app/.cache/torch
@@ -16,6 +17,7 @@ RUN pip install --no-cache-dir -r requirements.txt
16
  COPY app/ ./app/
17
  COPY .env .
18
 
 
19
  RUN mkdir -p uploads models
20
  RUN chmod -R 777 uploads models
21
 
 
2
 
3
  WORKDIR /app
4
 
5
+ # Create a user-writable cache directory
6
  ENV HF_HOME=/app/.cache
7
  ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
8
  ENV TORCH_HOME=/app/.cache/torch
 
17
  COPY app/ ./app/
18
  COPY .env .
19
 
20
+ # Create required directories with proper permissions
21
  RUN mkdir -p uploads models
22
  RUN chmod -R 777 uploads models
23
 
app/services/model_service.py CHANGED
@@ -27,15 +27,6 @@ class ModelService:
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  logger.info(f"Using device: {self.device}")
29
 
30
- # Set up cache directories
31
- os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/app/.cache")
32
- os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/app/.cache/huggingface")
33
- os.environ["TORCH_HOME"] = os.environ.get("TORCH_HOME", "/app/.cache/torch")
34
-
35
- # Ensure cache directories exist and have proper permissions
36
- for cache_dir in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["TORCH_HOME"]]:
37
- os.makedirs(cache_dir, exist_ok=True)
38
-
39
  # Try to login to Hugging Face if token is provided
40
  if settings.HUGGINGFACE_TOKEN:
41
  try:
@@ -54,25 +45,19 @@ class ModelService:
54
  # Create the directory if it doesn't exist
55
  os.makedirs(os.path.dirname(settings.MODEL_PATH), exist_ok=True)
56
 
57
- # Log cache directories for debugging
58
- logger.info(f"HF_HOME: {os.environ.get('HF_HOME')}")
59
- logger.info(f"TRANSFORMERS_CACHE: {os.environ.get('TRANSFORMERS_CACHE')}")
60
- logger.info(f"Using model repo: {settings.HF_MODEL_REPO}")
61
 
62
  # Download the model file from Hugging Face
63
- downloaded_path = hf_hub_download(
64
  repo_id=settings.HF_MODEL_REPO,
65
  filename=settings.HF_MODEL_FILENAME,
66
  local_dir=os.path.dirname(settings.MODEL_PATH),
67
- local_dir_use_symlinks=False,
68
- cache_dir=os.environ.get("HF_HOME")
69
  )
70
 
71
- logger.info(f"Model downloaded to: {downloaded_path}")
72
-
73
  # Rename the downloaded file to match the expected path if needed
 
74
  if downloaded_path != settings.MODEL_PATH:
75
- logger.info(f"Renaming {downloaded_path} to {settings.MODEL_PATH}")
76
  os.rename(downloaded_path, settings.MODEL_PATH)
77
 
78
  logger.info(f"Model downloaded successfully to {settings.MODEL_PATH}")
@@ -81,4 +66,112 @@ class ModelService:
81
  logger.error(f"Error downloading model from Hugging Face Hub: {e}")
82
  return False
83
 
84
- # ... rest of the ModelService class remains unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  logger.info(f"Using device: {self.device}")
29
 
 
 
 
 
 
 
 
 
 
30
  # Try to login to Hugging Face if token is provided
31
  if settings.HUGGINGFACE_TOKEN:
32
  try:
 
45
  # Create the directory if it doesn't exist
46
  os.makedirs(os.path.dirname(settings.MODEL_PATH), exist_ok=True)
47
 
48
+ logger.info(f"Downloading model from {settings.HF_MODEL_REPO} to {settings.MODEL_PATH}")
 
 
 
49
 
50
  # Download the model file from Hugging Face
51
+ hf_hub_download(
52
  repo_id=settings.HF_MODEL_REPO,
53
  filename=settings.HF_MODEL_FILENAME,
54
  local_dir=os.path.dirname(settings.MODEL_PATH),
55
+ local_dir_use_symlinks=False
 
56
  )
57
 
 
 
58
  # Rename the downloaded file to match the expected path if needed
59
+ downloaded_path = os.path.join(os.path.dirname(settings.MODEL_PATH), settings.HF_MODEL_FILENAME)
60
  if downloaded_path != settings.MODEL_PATH:
 
61
  os.rename(downloaded_path, settings.MODEL_PATH)
62
 
63
  logger.info(f"Model downloaded successfully to {settings.MODEL_PATH}")
 
66
  logger.error(f"Error downloading model from Hugging Face Hub: {e}")
67
  return False
68
 
69
+ def load_model(self):
70
+ """Load the VQA model from the specified path or download it if not present"""
71
+ try:
72
+ # Check if model exists locally
73
+ if not self._check_model_exists():
74
+ logger.info(f"Model not found at {settings.MODEL_PATH}")
75
+
76
+ # Download the model from Hugging Face Hub
77
+ if not self._download_model_from_hub():
78
+ logger.error("Failed to download model from Hugging Face Hub")
79
+ return False
80
+
81
+ logger.info(f"Loading model from {settings.MODEL_PATH}")
82
+ checkpoint = torch.load(settings.MODEL_PATH, map_location=self.device)
83
+
84
+ # Extract configuration
85
+ self.config = checkpoint['config']
86
+
87
+ # Get vocabulary
88
+ if 'answer_vocab' in checkpoint:
89
+ self.answer_vocab = checkpoint['answer_vocab']
90
+ logger.info("Using vocabulary from model checkpoint")
91
+ else:
92
+ logger.error("Error: No vocabulary found in model checkpoint")
93
+ raise ValueError("No vocabulary found in model checkpoint")
94
+
95
+ # Initialize model
96
+ self.model = VQAModel(self.config, len(self.answer_vocab['answer_to_idx']))
97
+ self.model.load_state_dict(checkpoint['model_state_dict'])
98
+ self.model.to(self.device)
99
+ self.model.eval()
100
+
101
+ # Initialize preprocessors
102
+ self.processor = ViTImageProcessor.from_pretrained(self.config['vision_model'])
103
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config['text_model'])
104
+
105
+ logger.info("Model loaded successfully")
106
+ return True
107
+
108
+ except Exception as e:
109
+ logger.error(f"Error loading model: {e}")
110
+ return False
111
+
112
+ def is_model_loaded(self):
113
+ """Check if the model is loaded"""
114
+ return self.model is not None and self.processor is not None and self.tokenizer is not None
115
+
116
+ def predict(self, image_path, question):
117
+ """
118
+ Make a prediction for the given image and question
119
+
120
+ Args:
121
+ image_path (str): Path to the image file
122
+ question (str): Question about the image
123
+
124
+ Returns:
125
+ dict: Prediction results
126
+ """
127
+ if not self.is_model_loaded():
128
+ logger.error("Model not loaded")
129
+ raise RuntimeError("Model not loaded")
130
+
131
+ try:
132
+ # Preprocess image
133
+ image = Image.open(image_path).convert('RGB')
134
+ image_encoding = self.processor(images=image, return_tensors="pt")
135
+ image_encoding = {k: v.to(self.device) for k, v in image_encoding.items()}
136
+
137
+ # Preprocess question
138
+ question_encoding = self.tokenizer(
139
+ question,
140
+ padding='max_length',
141
+ truncation=True,
142
+ max_length=128,
143
+ return_tensors='pt'
144
+ )
145
+ question_encoding = {k: v.to(self.device) for k, v in question_encoding.items()}
146
+
147
+ # Get predictions
148
+ with torch.no_grad():
149
+ outputs = self.model(image_encoding, question_encoding)
150
+
151
+ answer_logits = outputs['answer_logits']
152
+ answerable_logits = outputs['answerable_logits']
153
+
154
+ answer_idx = torch.argmax(answer_logits, dim=1).item()
155
+ answerable_idx = torch.argmax(answerable_logits, dim=1).item()
156
+
157
+ # Convert string index to int for dictionary lookup
158
+ answer = self.answer_vocab['idx_to_answer'][str(answer_idx)]
159
+ is_answerable = bool(answerable_idx)
160
+
161
+ # Get confidence scores
162
+ answer_probs = torch.softmax(answer_logits, dim=1)[0]
163
+ answerable_probs = torch.softmax(answerable_logits, dim=1)[0]
164
+
165
+ answer_confidence = float(answer_probs[answer_idx].item())
166
+ answerable_confidence = float(answerable_probs[answerable_idx].item())
167
+
168
+ return {
169
+ 'answer': answer,
170
+ 'answer_confidence': answer_confidence,
171
+ 'is_answerable': is_answerable,
172
+ 'answerable_confidence': answerable_confidence
173
+ }
174
+
175
+ except Exception as e:
176
+ logger.error(f"Error during prediction: {e}")
177
+ raise