File size: 15,467 Bytes
d919881
 
 
 
 
 
 
 
59d4479
d919881
59d4479
d919881
 
 
 
 
 
 
 
 
b1acf7e
d919881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472585e
d919881
 
472585e
d919881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472585e
d919881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472585e
d919881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472585e
d919881
472585e
d919881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472585e
 
d919881
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import os
import gdown
from pathlib import Path
import logging
from typing import Tuple, Any
import torch
import torch.nn as nn
from torchvision import models
from dotenv import load_dotenv

load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SimpleModelManager:
    """Simple model manager that downloads models from Google Drive using gdown"""

    def __init__(self, model_dir: str = "model_weights", cache_models: bool = True):
        """
        Initialize simple model manager

        Args:
            model_dir: Local directory to store models
            cache_models: Whether to cache models locally
        """
        self.model_dir = Path(model_dir)
        self.model_dir.mkdir(exist_ok=True)
        self.cache_models = cache_models

        # Load model links from environment variables
        self.model_links = {
            "vision": {
                "url": os.getenv("VISION_MODEL_DRIVE_ID", ""),
                "filename": os.getenv("VISION_MODEL_FILENAME", "resnet50_model.pth"),
                "description": "Vision sentiment analysis model",
            },
            "audio": {
                "url": os.getenv("AUDIO_MODEL_DRIVE_ID", ""),
                "filename": os.getenv("AUDIO_MODEL_FILENAME", "wav2vec2_model.pth"),
                "description": "Audio sentiment analysis model",
            },
        }

        # Validate that environment variables are set
        self._validate_environment()

    def _validate_environment(self):
        """Validate that required environment variables are set"""
        missing_vars = []

        if not self.model_links["vision"]["url"]:
            missing_vars.append("VISION_MODEL_DRIVE_ID")

        if not self.model_links["audio"]["url"]:
            missing_vars.append("AUDIO_MODEL_DRIVE_ID")

        if missing_vars:
            logger.warning(f"Missing environment variables: {', '.join(missing_vars)}")
            logger.warning("Please set these in your .env file or environment")
            logger.warning("Models will not be available until these are configured")

    def download_from_google_drive(self, share_url: str, filename: str) -> str:
        """
        Download file from Google Drive share link using gdown

        Args:
            share_url: Google Drive share link
            filename: Name to save the file as

        Returns:
            Path to downloaded file
        """
        try:
            local_path = self.model_dir / filename

            if local_path.exists() and self.cache_models:
                logger.info(f"Model already cached: {local_path}")
                return str(local_path)

            logger.info(f"Downloading {filename} from Google Drive using gdown...")

            # Use gdown to download the file
            # gdown automatically handles virus scan warnings and other Google Drive issues
            output_path = str(local_path)

            # Download with progress bar
            gdown.download(
                id=share_url,
                output=output_path,
                quiet=False,  # Show progress bar
                fuzzy=True,  # Handle various Google Drive URL formats
            )

            # Verify the file was downloaded
            if not Path(output_path).exists():
                raise FileNotFoundError(f"Download failed: {output_path} not found")

            file_size = Path(output_path).stat().st_size
            if file_size == 0:
                raise ValueError(f"Downloaded file is empty: {output_path}")

            logger.info(f"Successfully downloaded {filename} ({file_size} bytes)")
            return output_path

        except Exception as e:
            logger.error(f"Failed to download {filename}: {e}")
            raise

    def load_vision_model(self) -> Tuple[Any, torch.device, int]:
        """Load vision sentiment model"""
        try:
            model_info = self.model_links["vision"]

            # Check if URL is configured
            if not model_info["url"]:
                raise ValueError("VISION_MODEL_DRIVE_ID environment variable not set")

            model_path = self.download_from_google_drive(
                model_info["url"], model_info["filename"]
            )

            # Validate the downloaded file
            if not Path(model_path).exists():
                raise FileNotFoundError(f"Model file not found at {model_path}")

            file_size = Path(model_path).stat().st_size
            if file_size == 0:
                raise ValueError(f"Model file is empty: {model_path}")

            # Check file header to see what type of file it is
            with open(model_path, "rb") as f:
                header = f.read(100)  # Read first 100 bytes

            logger.info(f"File size: {file_size} bytes")
            logger.info(f"File header (first 100 bytes): {header[:50]}...")

            # Try to detect file type
            if header.startswith(b"<"):
                raise ValueError(
                    f"File appears to be HTML/XML, not a PyTorch model: {model_path}"
                )
            elif header.startswith(b"\x89PNG"):
                raise ValueError(f"File appears to be a PNG image: {model_path}")
            elif header.startswith(b"\xff\xd8\xff"):
                raise ValueError(f"File appears to be a JPEG image: {model_path}")

            # For any other file type (including ZIP), try to load it directly as a PyTorch model
            logger.info(
                f"File appears to be a PyTorch model file, attempting to load directly..."
            )

            # Load the model
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            try:
                # Try loading the file directly as a PyTorch model
                checkpoint = torch.load(
                    model_path, map_location=device, weights_only=False
                )
                logger.info("Successfully loaded model file directly")
            except Exception as load_error:
                logger.error(f"Failed to load model directly: {load_error}")
                try:
                    # Try with weights only as fallback
                    checkpoint = torch.load(
                        model_path, map_location=device, weights_only=True
                    )
                    logger.info("Loaded with weights_only=True (weights only)")
                except Exception as fallback_error:
                    logger.error(
                        f"Failed to load with weights_only=True: {fallback_error}"
                    )
                    raise ValueError(
                        f"Cannot load model file {model_path}. File may be corrupted or in wrong format."
                    )

            # Initialize ResNet-50 model
            model = models.resnet50(weights=None)
            num_ftrs = model.fc.in_features

            # Determine number of classes from checkpoint
            if "fc.weight" in checkpoint:
                num_classes = checkpoint["fc.weight"].shape[0]
            else:
                num_classes = 3  # Default fallback

            model.fc = nn.Linear(num_ftrs, num_classes)
            model.load_state_dict(checkpoint)
            model.to(device)
            model.eval()

            logger.info(f"Vision model loaded successfully with {num_classes} classes!")
            return model, device, num_classes

        except Exception as e:
            logger.error(f"Failed to load vision model: {e}")
            raise

    def load_audio_model(self) -> Tuple[Any, torch.device]:
        """Load audio sentiment model"""
        try:
            model_info = self.model_links["audio"]

            # Check if URL is configured
            if not model_info["url"]:
                raise ValueError("AUDIO_MODEL_DRIVE_ID environment variable not set")

            model_path = self.download_from_google_drive(
                model_info["url"], model_info["filename"]
            )

            # Validate the downloaded file
            if not Path(model_path).exists():
                raise FileNotFoundError(f"Model file not found at {model_path}")

            file_size = Path(model_path).stat().st_size
            if file_size == 0:
                raise ValueError(f"Model file is empty: {model_path}")

            # Check file header to see what type of file it is
            with open(model_path, "rb") as f:
                header = f.read(100)  # Read first 100 bytes

            logger.info(f"File size: {file_size} bytes")
            logger.info(f"File header (first 100 bytes): {header[:50]}...")

            # Try to detect file type
            if header.startswith(b"<"):
                raise ValueError(
                    f"File appears to be HTML/XML, not a PyTorch model: {model_path}"
                )
            elif header.startswith(b"\x89PNG"):
                raise ValueError(f"File appears to be a PNG image: {model_path}")
            elif header.startswith(b"\xff\xd8\xff"):
                raise ValueError(f"File appears to be a JPEG image: {model_path}")

            # For any other file type (including ZIP), try to load it directly as a PyTorch model
            logger.info(
                f"File appears to be a PyTorch model file, attempting to load directly..."
            )

            # Load the model
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            try:
                # Try loading the file directly as a PyTorch model
                checkpoint = torch.load(
                    model_path, map_location=device, weights_only=False
                )
                logger.info("Successfully loaded model file directly")
            except Exception as load_error:
                logger.error(f"Failed to load model directly: {load_error}")
                try:
                    # Try with weights only as fallback
                    checkpoint = torch.load(
                        model_path, map_location=device, weights_only=True
                    )
                    logger.info("Loaded with weights_only=True (weights only)")
                except Exception as fallback_error:
                    logger.error(
                        f"Failed to load with weights_only=True: {fallback_error}"
                    )
                    raise ValueError(
                        f"Cannot load model file {model_path}. File may be corrupted or in wrong format."
                    )

            # Check if we have a state dict or a full model
            if isinstance(checkpoint, dict) and "classifier.weight" in checkpoint:
                # This is a state dictionary - we need to initialize the model first
                from transformers import AutoModelForAudioClassification

                # Determine number of classes from checkpoint
                if "classifier.weight" in checkpoint:
                    num_classes = checkpoint["classifier.weight"].shape[0]
                else:
                    num_classes = 3  # Default fallback

                # Initialize Wav2Vec2 model with the correct number of classes
                model = AutoModelForAudioClassification.from_pretrained(
                    "facebook/wav2vec2-base", num_labels=num_classes
                )

                # Load the state dictionary
                model.load_state_dict(checkpoint)
                model.to(device)
                model.eval()

                logger.info(
                    f"Audio model loaded successfully with {num_classes} classes!"
                )
                return model, device
            else:
                # This is a full model object
                model = checkpoint
                model.to(device)
                model.eval()

                logger.info("Audio model loaded successfully!")
                return model, device

        except Exception as e:
            logger.error(f"Failed to load audio model: {e}")
            raise

    def update_model_links(self, vision_url: str = None, audio_url: str = None):
        """Update Google Drive URLs for models (optional override)"""
        if vision_url:
            self.model_links["vision"]["url"] = vision_url
        if audio_url:
            self.model_links["audio"]["url"] = audio_url

        # Update environment variables if provided
        if vision_url:
            os.environ["VISION_MODEL_DRIVE_ID"] = vision_url
        if audio_url:
            os.environ["AUDIO_MODEL_DRIVE_ID"] = audio_url

        logger.info("Model links updated!")

    def list_cached_models(self) -> list:
        """List all cached models"""
        cached_models = []
        for file_path in self.model_dir.glob("*.pth"):
            cached_models.append(file_path.name)
        return cached_models

    def clear_cache(self):
        """Clear all cached models"""
        for file_path in self.model_dir.glob("*.pth"):
            file_path.unlink()
        logger.info("Cache cleared!")

    def get_model_status(self) -> dict:
        """Get status of all models"""
        status = {}
        for model_type, info in self.model_links.items():
            status[model_type] = {
                "configured": bool(info["url"]),
                "filename": info["filename"],
                "cached": (self.model_dir / info["filename"]).exists(),
                "url": info["url"] if info["url"] else "Not configured",
            }
        return status


# Example usage
if __name__ == "__main__":
    # Initialize manager
    manager = SimpleModelManager()

    # Check model status
    status = manager.get_model_status()
    print("Model Status:")
    for model_type, info in status.items():
        print(f"  {model_type}: {'βœ…' if info['configured'] else '❌'} {info['url']}")
        if info["cached"]:
            print(f"    πŸ“ Cached: {info['filename']}")

    # Load models if configured
    try:
        if status["vision"]["configured"]:
            vision_model, device, num_classes = manager.load_vision_model()
            print(f"βœ… Vision model loaded: {num_classes} classes")
        else:
            print("❌ Vision model not configured")

        if status["audio"]["configured"]:
            audio_model, device = manager.load_audio_model()
            print("βœ… Audio model loaded")
        else:
            print("❌ Audio model not configured")

        if status["vision"]["configured"] and status["audio"]["configured"]:
            print("\nπŸŽ‰ All models loaded successfully!")
        else:
            print("\n⚠️  Some models are not configured")
            print("Please set the following environment variables:")
            print("  VISION_MODEL_DRIVE_ID")
            print("  AUDIO_MODEL_DRIVE_ID")

    except Exception as e:
        print(f"Error loading models: {e}")
        print("\nFor folder structures:")
        print("   1. Navigate to each subfolder (Audio/Vision)")
        print("   2. Right-click on each .pth file")
        print("   3. Share -> Copy link")
        print("   4. Use those direct file links instead of folder links")
        print("\nNote: Downloaded files are used directly as PyTorch models.")
        print("\nOr set environment variables in your .env file:")
        print("  VISION_MODEL_DRIVE_ID=your_vision_model_file_id")
        print("  AUDIO_MODEL_DRIVE_ID=your_audio_model_file_id")