dixisouls commited on
Commit
41b88b5
·
1 Parent(s): 5e6dd38

Changed file permissions

Browse files
Files changed (2) hide show
  1. Dockerfile +7 -1
  2. app/services/model_service.py +20 -113
Dockerfile CHANGED
@@ -2,6 +2,12 @@ FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
 
 
 
 
 
 
5
  # Copy requirements first for better caching
6
  COPY requirements.txt .
7
  RUN pip install --no-cache-dir -r requirements.txt
@@ -10,8 +16,8 @@ RUN pip install --no-cache-dir -r requirements.txt
10
  COPY app/ ./app/
11
  COPY .env .
12
 
13
- # Create required directories
14
  RUN mkdir -p uploads models
 
15
 
16
  # Set environment variables
17
  ENV PYTHONPATH=/app
 
2
 
3
  WORKDIR /app
4
 
5
+ ENV HF_HOME=/app/.cache
6
+ ENV TRANSFORMERS_CACHE=/app/.cache/huggingface
7
+ ENV TORCH_HOME=/app/.cache/torch
8
+ RUN mkdir -p /app/.cache/huggingface /app/.cache/torch
9
+ RUN chmod -R 777 /app/.cache
10
+
11
  # Copy requirements first for better caching
12
  COPY requirements.txt .
13
  RUN pip install --no-cache-dir -r requirements.txt
 
16
  COPY app/ ./app/
17
  COPY .env .
18
 
 
19
  RUN mkdir -p uploads models
20
+ RUN chmod -R 777 uploads models
21
 
22
  # Set environment variables
23
  ENV PYTHONPATH=/app
app/services/model_service.py CHANGED
@@ -27,6 +27,15 @@ class ModelService:
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  logger.info(f"Using device: {self.device}")
29
 
 
 
 
 
 
 
 
 
 
30
  # Try to login to Hugging Face if token is provided
31
  if settings.HUGGINGFACE_TOKEN:
32
  try:
@@ -45,19 +54,25 @@ class ModelService:
45
  # Create the directory if it doesn't exist
46
  os.makedirs(os.path.dirname(settings.MODEL_PATH), exist_ok=True)
47
 
48
- logger.info(f"Downloading model from {settings.HF_MODEL_REPO} to {settings.MODEL_PATH}")
 
 
 
49
 
50
  # Download the model file from Hugging Face
51
- hf_hub_download(
52
  repo_id=settings.HF_MODEL_REPO,
53
  filename=settings.HF_MODEL_FILENAME,
54
  local_dir=os.path.dirname(settings.MODEL_PATH),
55
- local_dir_use_symlinks=False
 
56
  )
57
 
 
 
58
  # Rename the downloaded file to match the expected path if needed
59
- downloaded_path = os.path.join(os.path.dirname(settings.MODEL_PATH), settings.HF_MODEL_FILENAME)
60
  if downloaded_path != settings.MODEL_PATH:
 
61
  os.rename(downloaded_path, settings.MODEL_PATH)
62
 
63
  logger.info(f"Model downloaded successfully to {settings.MODEL_PATH}")
@@ -66,112 +81,4 @@ class ModelService:
66
  logger.error(f"Error downloading model from Hugging Face Hub: {e}")
67
  return False
68
 
69
- def load_model(self):
70
- """Load the VQA model from the specified path or download it if not present"""
71
- try:
72
- # Check if model exists locally
73
- if not self._check_model_exists():
74
- logger.info(f"Model not found at {settings.MODEL_PATH}")
75
-
76
- # Download the model from Hugging Face Hub
77
- if not self._download_model_from_hub():
78
- logger.error("Failed to download model from Hugging Face Hub")
79
- return False
80
-
81
- logger.info(f"Loading model from {settings.MODEL_PATH}")
82
- checkpoint = torch.load(settings.MODEL_PATH, map_location=self.device)
83
-
84
- # Extract configuration
85
- self.config = checkpoint['config']
86
-
87
- # Get vocabulary
88
- if 'answer_vocab' in checkpoint:
89
- self.answer_vocab = checkpoint['answer_vocab']
90
- logger.info("Using vocabulary from model checkpoint")
91
- else:
92
- logger.error("Error: No vocabulary found in model checkpoint")
93
- raise ValueError("No vocabulary found in model checkpoint")
94
-
95
- # Initialize model
96
- self.model = VQAModel(self.config, len(self.answer_vocab['answer_to_idx']))
97
- self.model.load_state_dict(checkpoint['model_state_dict'])
98
- self.model.to(self.device)
99
- self.model.eval()
100
-
101
- # Initialize preprocessors
102
- self.processor = ViTImageProcessor.from_pretrained(self.config['vision_model'])
103
- self.tokenizer = AutoTokenizer.from_pretrained(self.config['text_model'])
104
-
105
- logger.info("Model loaded successfully")
106
- return True
107
-
108
- except Exception as e:
109
- logger.error(f"Error loading model: {e}")
110
- return False
111
-
112
- def is_model_loaded(self):
113
- """Check if the model is loaded"""
114
- return self.model is not None and self.processor is not None and self.tokenizer is not None
115
-
116
- def predict(self, image_path, question):
117
- """
118
- Make a prediction for the given image and question
119
-
120
- Args:
121
- image_path (str): Path to the image file
122
- question (str): Question about the image
123
-
124
- Returns:
125
- dict: Prediction results
126
- """
127
- if not self.is_model_loaded():
128
- logger.error("Model not loaded")
129
- raise RuntimeError("Model not loaded")
130
-
131
- try:
132
- # Preprocess image
133
- image = Image.open(image_path).convert('RGB')
134
- image_encoding = self.processor(images=image, return_tensors="pt")
135
- image_encoding = {k: v.to(self.device) for k, v in image_encoding.items()}
136
-
137
- # Preprocess question
138
- question_encoding = self.tokenizer(
139
- question,
140
- padding='max_length',
141
- truncation=True,
142
- max_length=128,
143
- return_tensors='pt'
144
- )
145
- question_encoding = {k: v.to(self.device) for k, v in question_encoding.items()}
146
-
147
- # Get predictions
148
- with torch.no_grad():
149
- outputs = self.model(image_encoding, question_encoding)
150
-
151
- answer_logits = outputs['answer_logits']
152
- answerable_logits = outputs['answerable_logits']
153
-
154
- answer_idx = torch.argmax(answer_logits, dim=1).item()
155
- answerable_idx = torch.argmax(answerable_logits, dim=1).item()
156
-
157
- # Convert string index to int for dictionary lookup
158
- answer = self.answer_vocab['idx_to_answer'][str(answer_idx)]
159
- is_answerable = bool(answerable_idx)
160
-
161
- # Get confidence scores
162
- answer_probs = torch.softmax(answer_logits, dim=1)[0]
163
- answerable_probs = torch.softmax(answerable_logits, dim=1)[0]
164
-
165
- answer_confidence = float(answer_probs[answer_idx].item())
166
- answerable_confidence = float(answerable_probs[answerable_idx].item())
167
-
168
- return {
169
- 'answer': answer,
170
- 'answer_confidence': answer_confidence,
171
- 'is_answerable': is_answerable,
172
- 'answerable_confidence': answerable_confidence
173
- }
174
-
175
- except Exception as e:
176
- logger.error(f"Error during prediction: {e}")
177
- raise
 
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  logger.info(f"Using device: {self.device}")
29
 
30
+ # Set up cache directories
31
+ os.environ["HF_HOME"] = os.environ.get("HF_HOME", "/app/.cache")
32
+ os.environ["TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "/app/.cache/huggingface")
33
+ os.environ["TORCH_HOME"] = os.environ.get("TORCH_HOME", "/app/.cache/torch")
34
+
35
+ # Ensure cache directories exist and have proper permissions
36
+ for cache_dir in [os.environ["HF_HOME"], os.environ["TRANSFORMERS_CACHE"], os.environ["TORCH_HOME"]]:
37
+ os.makedirs(cache_dir, exist_ok=True)
38
+
39
  # Try to login to Hugging Face if token is provided
40
  if settings.HUGGINGFACE_TOKEN:
41
  try:
 
54
  # Create the directory if it doesn't exist
55
  os.makedirs(os.path.dirname(settings.MODEL_PATH), exist_ok=True)
56
 
57
+ # Log cache directories for debugging
58
+ logger.info(f"HF_HOME: {os.environ.get('HF_HOME')}")
59
+ logger.info(f"TRANSFORMERS_CACHE: {os.environ.get('TRANSFORMERS_CACHE')}")
60
+ logger.info(f"Using model repo: {settings.HF_MODEL_REPO}")
61
 
62
  # Download the model file from Hugging Face
63
+ downloaded_path = hf_hub_download(
64
  repo_id=settings.HF_MODEL_REPO,
65
  filename=settings.HF_MODEL_FILENAME,
66
  local_dir=os.path.dirname(settings.MODEL_PATH),
67
+ local_dir_use_symlinks=False,
68
+ cache_dir=os.environ.get("HF_HOME")
69
  )
70
 
71
+ logger.info(f"Model downloaded to: {downloaded_path}")
72
+
73
  # Rename the downloaded file to match the expected path if needed
 
74
  if downloaded_path != settings.MODEL_PATH:
75
+ logger.info(f"Renaming {downloaded_path} to {settings.MODEL_PATH}")
76
  os.rename(downloaded_path, settings.MODEL_PATH)
77
 
78
  logger.info(f"Model downloaded successfully to {settings.MODEL_PATH}")
 
81
  logger.error(f"Error downloading model from Hugging Face Hub: {e}")
82
  return False
83
 
84
+ # ... rest of the ModelService class remains unchanged