""" Model Loading Utilities Provides comprehensive model loading capabilities for various formats and sources including PyTorch models, Safetensors, and Hugging Face transformers. """ import os import logging import asyncio from typing import Dict, Any, Optional, Union, List from pathlib import Path import json import requests from urllib.parse import urlparse import tempfile import shutil import torch import torch.nn as nn from transformers import ( AutoModel, AutoTokenizer, AutoConfig, AutoImageProcessor, AutoFeatureExtractor, AutoProcessor, AutoModelForCausalLM, AutoModelForSeq2SeqLM ) from safetensors import safe_open from safetensors.torch import load_file as load_safetensors import numpy as np from PIL import Image logger = logging.getLogger(__name__) # Custom model configurations for special architectures CUSTOM_MODEL_CONFIGS = { 'ti2v': { 'model_type': 'ti2v', 'architecture': 'TI2VModel', 'modalities': ['text', 'vision'], 'supports_generation': True, 'is_multimodal': True }, 'diffusion': { 'model_type': 'diffusion', 'architecture': 'DiffusionModel', 'modalities': ['vision', 'text'], 'supports_generation': True, 'is_multimodal': True } } class ModelLoader: """ Comprehensive model loader supporting multiple formats and sources """ def __init__(self): self.supported_formats = { '.pt': 'pytorch', '.pth': 'pytorch', '.bin': 'pytorch', '.safetensors': 'safetensors', '.onnx': 'onnx', '.h5': 'keras', '.pkl': 'pickle', '.joblib': 'joblib' } self.modality_keywords = { 'text': ['bert', 'gpt', 'roberta', 'electra', 'deberta', 'xlm', 'xlnet', 't5', 'bart'], 'vision': ['vit', 'resnet', 'efficientnet', 'convnext', 'swin', 'deit', 'beit'], 'multimodal': ['clip', 'blip', 'albef', 'flava', 'layoutlm', 'donut'], 'audio': ['wav2vec', 'hubert', 'whisper', 'speech_t5'] } async def load_model(self, source: str, **kwargs) -> Dict[str, Any]: """ Load a model from various sources Args: source: Model source (file path, HF repo, URL) **kwargs: Additional loading parameters Returns: Dictionary containing model, tokenizer/processor, and metadata """ try: logger.info(f"Loading model from: {source}") # Determine source type if self._is_url(source): return await self._load_from_url(source, **kwargs) elif self._is_huggingface_repo(source): return await self._load_from_huggingface(source, **kwargs) elif Path(source).exists(): return await self._load_from_file(source, **kwargs) else: raise ValueError(f"Invalid model source: {source}") except Exception as e: logger.error(f"Error loading model from {source}: {str(e)}") raise async def get_model_info(self, source: str) -> Dict[str, Any]: """ Get model information without loading the full model Args: source: Model source Returns: Model metadata and information """ try: info = { 'source': source, 'format': 'unknown', 'modality': 'unknown', 'architecture': None, 'parameters': None, 'size_mb': None } if Path(source).exists(): file_path = Path(source) info['size_mb'] = file_path.stat().st_size / (1024 * 1024) info['format'] = self.supported_formats.get(file_path.suffix, 'unknown') # Try to extract more info based on format if info['format'] == 'safetensors': info.update(await self._get_safetensors_info(source)) elif info['format'] == 'pytorch': info.update(await self._get_pytorch_info(source)) elif self._is_huggingface_repo(source): info.update(await self._get_huggingface_info(source)) # Detect modality from model name/architecture info['modality'] = self._detect_modality(source, info.get('architecture', '')) return info except Exception as e: logger.warning(f"Error getting model info for {source}: {str(e)}") return {'source': source, 'error': str(e)} def _is_url(self, source: str) -> bool: """Check if source is a URL""" try: result = urlparse(source) return all([result.scheme, result.netloc]) except: return False def _is_huggingface_repo(self, source: str) -> bool: """Check if source is a Hugging Face repository""" # Simple heuristic: contains '/' but not a file extension return '/' in source and not any(source.endswith(ext) for ext in self.supported_formats.keys()) def _detect_modality(self, source: str, architecture: str) -> str: """Detect model modality from source and architecture""" text = (source + ' ' + architecture).lower() for modality, keywords in self.modality_keywords.items(): if any(keyword in text for keyword in keywords): return modality return 'unknown' async def _load_from_file(self, file_path: str, **kwargs) -> Dict[str, Any]: """Load model from local file""" file_path = Path(file_path) format_type = self.supported_formats.get(file_path.suffix, 'unknown') if format_type == 'safetensors': return await self._load_safetensors(file_path, **kwargs) elif format_type == 'pytorch': return await self._load_pytorch(file_path, **kwargs) else: raise ValueError(f"Unsupported format: {format_type}") async def _load_from_url(self, url: str, **kwargs) -> Dict[str, Any]: """Load model from URL""" # Download to temporary file with tempfile.NamedTemporaryFile(delete=False) as tmp_file: response = requests.get(url, stream=True) response.raise_for_status() for chunk in response.iter_content(chunk_size=8192): tmp_file.write(chunk) tmp_path = tmp_file.name try: # Load from temporary file result = await self._load_from_file(tmp_path, **kwargs) result['source_url'] = url return result finally: # Cleanup temporary file os.unlink(tmp_path) async def _load_from_huggingface(self, repo_id: str, **kwargs) -> Dict[str, Any]: """Load model from Hugging Face repository""" try: # Get HF token from multiple sources hf_token = ( kwargs.get('token') or os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') ) logger.info(f"Loading model {repo_id} with token: {'Yes' if hf_token else 'No'}") # Load configuration first with timeout trust_remote_code = kwargs.get('trust_remote_code', False) logger.info(f"Loading config for {repo_id} with trust_remote_code={trust_remote_code}") try: config = AutoConfig.from_pretrained( repo_id, trust_remote_code=trust_remote_code, token=hf_token, timeout=30 # 30 second timeout ) logger.info(f"Successfully loaded config for {repo_id}") except Exception as e: logger.error(f"Failed to load config for {repo_id}: {e}") raise ValueError(f"Could not load model configuration: {str(e)}") # Load model with proper device handling device = 'cuda' if torch.cuda.is_available() else 'cpu' # Check if this is a large model and warn model_size_gb = self._estimate_model_size(config) if model_size_gb > 10: logger.warning(f"Large model detected ({model_size_gb:.1f}GB estimated). This may take several minutes to load.") # Check for custom architectures that need special handling model_type = getattr(config, 'model_type', None) # Try different loading strategies for different model types model = None loading_error = None # Special handling for ti2v and other custom architectures if model_type in CUSTOM_MODEL_CONFIGS: try: logger.info(f"Loading custom architecture {model_type} for {repo_id}...") model = await self._load_custom_architecture(repo_id, config, hf_token, trust_remote_code, **kwargs) except Exception as e: logger.warning(f"Custom architecture loading failed: {e}") loading_error = str(e) # Strategy 1: Try AutoModel (most common) if not already loaded if model is None: try: logger.info(f"Attempting to load {repo_id} with AutoModel...") model = AutoModel.from_pretrained( repo_id, config=config, torch_dtype=kwargs.get('torch_dtype', torch.float32), trust_remote_code=trust_remote_code, token=hf_token, low_cpu_mem_usage=True, timeout=120 # 2 minute timeout for model loading ) logger.info(f"Successfully loaded {repo_id} with AutoModel") except Exception as e: loading_error = str(e) logger.warning(f"AutoModel failed for {repo_id}: {e}") # Strategy 2: Try specific model classes for known types if model is None: model = await self._try_specific_model_classes(repo_id, config, hf_token, trust_remote_code, kwargs) # Strategy 3: Try with trust_remote_code if not already enabled if model is None and not trust_remote_code: try: logger.info(f"Trying {repo_id} with trust_remote_code=True") # For Gemma 3 models, try AutoModelForCausalLM specifically if 'gemma-3' in repo_id.lower() or 'gemma3' in str(config).lower(): from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( repo_id, config=config, torch_dtype=kwargs.get('torch_dtype', torch.float32), trust_remote_code=True, token=hf_token, low_cpu_mem_usage=True ) else: model = AutoModel.from_pretrained( repo_id, config=config, torch_dtype=kwargs.get('torch_dtype', torch.float32), trust_remote_code=True, token=hf_token, low_cpu_mem_usage=True ) logger.info(f"Successfully loaded {repo_id} with trust_remote_code=True") except Exception as e: logger.warning(f"Loading with trust_remote_code=True failed: {e}") if model is None: raise ValueError(f"Could not load model {repo_id}. Last error: {loading_error}") # Move to device manually model = model.to(device) # Load appropriate processor/tokenizer processor = None try: # Try different processor types for processor_class in [AutoTokenizer, AutoImageProcessor, AutoFeatureExtractor, AutoProcessor]: try: processor = processor_class.from_pretrained(repo_id, token=hf_token) break except: continue except Exception as e: logger.warning(f"Could not load processor for {repo_id}: {e}") return { 'model': model, 'processor': processor, 'config': config, 'source': repo_id, 'format': 'huggingface', 'architecture': config.architectures[0] if hasattr(config, 'architectures') and config.architectures else None, 'modality': self._detect_modality(repo_id, str(config.architectures) if hasattr(config, 'architectures') else ''), 'parameters': sum(p.numel() for p in model.parameters()) if hasattr(model, 'parameters') else None } except Exception as e: logger.error(f"Error loading from Hugging Face repo {repo_id}: {str(e)}") raise async def _load_custom_architecture(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, **kwargs): """Load models with custom architectures like ti2v""" try: model_type = getattr(config, 'model_type', None) logger.info(f"Loading custom architecture: {model_type}") if model_type == 'ti2v': # For ti2v models, we need to create a wrapper that can work with our distillation return await self._load_ti2v_model(repo_id, config, hf_token, trust_remote_code, **kwargs) else: # For other custom architectures, try with trust_remote_code logger.info(f"Attempting to load custom model {repo_id} with trust_remote_code=True") # Try different model classes model_classes = [AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM] for model_class in model_classes: try: model = model_class.from_pretrained( repo_id, config=config, trust_remote_code=True, # Force trust_remote_code for custom architectures token=hf_token, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) logger.info(f"Successfully loaded {repo_id} with {model_class.__name__}") return model except Exception as e: logger.warning(f"{model_class.__name__} failed for {repo_id}: {e}") continue raise ValueError(f"All loading strategies failed for custom architecture {model_type}") except Exception as e: logger.error(f"Error loading custom architecture: {e}") raise async def _load_ti2v_model(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, **kwargs): """Special handling for ti2v (Text-to-Image/Video) models""" try: logger.info(f"Loading ti2v model: {repo_id}") # For ti2v models, we'll create a wrapper that extracts text features # This allows us to use them in knowledge distillation # Try to load with trust_remote_code=True (required for custom architectures) model = AutoModel.from_pretrained( repo_id, config=config, trust_remote_code=True, token=hf_token, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) # Create a wrapper that can extract features for distillation class TI2VWrapper(torch.nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model self.config = base_model.config def forward(self, input_ids=None, attention_mask=None, **kwargs): # Extract text encoder features if available if hasattr(self.base_model, 'text_encoder'): return self.base_model.text_encoder(input_ids=input_ids, attention_mask=attention_mask) elif hasattr(self.base_model, 'encoder'): return self.base_model.encoder(input_ids=input_ids, attention_mask=attention_mask) else: # Fallback: try to get some meaningful representation return self.base_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) wrapped_model = TI2VWrapper(model) logger.info(f"Successfully wrapped ti2v model: {repo_id}") return wrapped_model except Exception as e: logger.error(f"Error loading ti2v model {repo_id}: {e}") raise async def _load_safetensors(self, file_path: Path, **kwargs) -> Dict[str, Any]: """Load model from Safetensors format""" try: # Load tensors tensors = {} with safe_open(file_path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) # Try to reconstruct model architecture model = self._reconstruct_model_from_tensors(tensors) return { 'model': model, 'tensors': tensors, 'source': str(file_path), 'format': 'safetensors', 'parameters': sum(tensor.numel() for tensor in tensors.values()), 'tensor_keys': list(tensors.keys()) } except Exception as e: logger.error(f"Error loading Safetensors file {file_path}: {str(e)}") raise async def _load_pytorch(self, file_path: Path, **kwargs) -> Dict[str, Any]: """Load PyTorch model""" try: # Load checkpoint checkpoint = torch.load(file_path, map_location='cpu') # Extract model and metadata if isinstance(checkpoint, dict): model = checkpoint.get('model', checkpoint.get('state_dict', checkpoint)) metadata = {k: v for k, v in checkpoint.items() if k not in ['model', 'state_dict']} else: model = checkpoint metadata = {} return { 'model': model, 'metadata': metadata, 'source': str(file_path), 'format': 'pytorch', 'parameters': sum(tensor.numel() for tensor in model.values()) if isinstance(model, dict) else None } except Exception as e: logger.error(f"Error loading PyTorch file {file_path}: {str(e)}") raise def _reconstruct_model_from_tensors(self, tensors: Dict[str, torch.Tensor]) -> nn.Module: """ Attempt to reconstruct a PyTorch model from tensor dictionary This is a simplified implementation - in practice, this would need more sophisticated architecture detection """ class GenericModel(nn.Module): def __init__(self, tensors): super().__init__() self.tensors = nn.ParameterDict() for name, tensor in tensors.items(): self.tensors[name.replace('.', '_')] = nn.Parameter(tensor) def forward(self, x): # Placeholder forward pass return x return GenericModel(tensors) async def _get_safetensors_info(self, file_path: str) -> Dict[str, Any]: """Get information from Safetensors file""" try: info = {} with safe_open(file_path, framework="pt", device="cpu") as f: keys = list(f.keys()) info['tensor_count'] = len(keys) info['tensor_keys'] = keys[:10] # First 10 keys # Estimate parameters total_params = 0 for key in keys: tensor = f.get_tensor(key) total_params += tensor.numel() info['parameters'] = total_params return info except Exception as e: logger.warning(f"Error getting Safetensors info: {e}") return {} async def _get_pytorch_info(self, file_path: str) -> Dict[str, Any]: """Get information from PyTorch file""" try: checkpoint = torch.load(file_path, map_location='cpu') info = {} if isinstance(checkpoint, dict): info['keys'] = list(checkpoint.keys()) # Look for model/state_dict model_data = checkpoint.get('model', checkpoint.get('state_dict', checkpoint)) if isinstance(model_data, dict): info['parameters'] = sum(tensor.numel() for tensor in model_data.values()) info['layer_count'] = len(model_data) return info except Exception as e: logger.warning(f"Error getting PyTorch info: {e}") return {} async def _get_huggingface_info(self, repo_id: str) -> Dict[str, Any]: """Get information from Hugging Face repository""" try: hf_token = ( os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') ) config = AutoConfig.from_pretrained(repo_id, token=hf_token) info = { 'architecture': config.architectures[0] if hasattr(config, 'architectures') and config.architectures else None, 'model_type': getattr(config, 'model_type', None), 'hidden_size': getattr(config, 'hidden_size', None), 'num_layers': getattr(config, 'num_hidden_layers', getattr(config, 'num_layers', None)), 'vocab_size': getattr(config, 'vocab_size', None) } return info except Exception as e: logger.warning(f"Error getting Hugging Face info: {e}") return {} async def _try_specific_model_classes(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, kwargs: Dict[str, Any]): """Try loading with specific model classes for known architectures""" from transformers import ( AutoModelForCausalLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering, AutoModelForMaskedLM, AutoModelForImageClassification, AutoModelForObjectDetection, AutoModelForSemanticSegmentation, AutoModelForImageSegmentation, AutoModelForDepthEstimation, AutoModelForZeroShotImageClassification ) # Map model types to appropriate AutoModel classes model_type = getattr(config, 'model_type', '').lower() architecture = getattr(config, 'architectures', []) arch_str = str(architecture).lower() if architecture else '' model_classes_to_try = [] # Determine appropriate model classes based on model type and architecture if 'siglip' in model_type or 'siglip' in arch_str: # SigLIP models - try vision-related classes model_classes_to_try = [ AutoModelForImageClassification, AutoModelForZeroShotImageClassification, AutoModel ] elif 'clip' in model_type or 'clip' in arch_str: model_classes_to_try = [AutoModelForZeroShotImageClassification, AutoModel] elif 'vit' in model_type or 'vision' in model_type: model_classes_to_try = [AutoModelForImageClassification, AutoModel] elif 'bert' in model_type or 'roberta' in model_type: model_classes_to_try = [AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModel] elif 'gemma' in model_type or 'gemma' in arch_str: # Gemma models (including Gemma 3) - try causal LM classes model_classes_to_try = [AutoModelForCausalLM, AutoModel] elif 'gpt' in model_type or 'llama' in model_type: model_classes_to_try = [AutoModelForCausalLM, AutoModel] else: # Generic fallback model_classes_to_try = [ AutoModelForCausalLM, # Try causal LM first for newer models AutoModelForSequenceClassification, AutoModelForImageClassification, AutoModel ] # Try each model class for model_class in model_classes_to_try: try: logger.info(f"Trying {repo_id} with {model_class.__name__}") model = model_class.from_pretrained( repo_id, config=config, torch_dtype=kwargs.get('torch_dtype', torch.float32), trust_remote_code=trust_remote_code, token=hf_token, low_cpu_mem_usage=True ) logger.info(f"Successfully loaded {repo_id} with {model_class.__name__}") return model except Exception as e: logger.debug(f"{model_class.__name__} failed for {repo_id}: {e}") continue return None async def load_trained_student(self, model_path: str) -> Dict[str, Any]: """Load a previously trained student model for retraining""" try: # Check if it's a Hugging Face model (starts with organization/) if '/' in model_path and not Path(model_path).exists(): # This is likely a Hugging Face repository return await self._load_student_from_huggingface(model_path) # Local model path model_dir = Path(model_path) # Check if it's a trained student model config_path = model_dir / "config.json" if not config_path.exists(): # Try alternative naming safetensors_files = list(model_dir.glob("*.safetensors")) if safetensors_files: config_path = safetensors_files[0].with_suffix('_config.json') if not config_path.exists(): raise ValueError("No configuration file found for student model") # Load configuration with open(config_path, 'r') as f: config = json.load(f) # Verify it's a student model if not config.get('is_student_model', False): raise ValueError("This is not a trained student model") # Load training history history_path = model_dir / "training_history.json" if not history_path.exists(): # Try alternative naming safetensors_files = list(model_dir.glob("*.safetensors")) if safetensors_files: history_path = safetensors_files[0].with_suffix('_training_history.json') training_history = {} if history_path.exists(): with open(history_path, 'r') as f: training_history = json.load(f) # Load model weights model_file = None for ext in ['.safetensors', '.bin', '.pt']: potential_file = model_dir / f"student_model{ext}" if potential_file.exists(): model_file = potential_file break if not model_file: # Look for any model file for ext in ['.safetensors', '.bin', '.pt']: files = list(model_dir.glob(f"*{ext}")) if files: model_file = files[0] break if not model_file: raise ValueError("No model file found") return { 'type': 'trained_student', 'path': str(model_path), 'config': config, 'training_history': training_history, 'model_file': str(model_file), 'can_be_retrained': config.get('can_be_retrained', True), 'original_teachers': training_history.get('retraining_info', {}).get('original_teachers', []), 'recommended_lr': training_history.get('retraining_info', {}).get('recommended_learning_rate', 1e-5), 'modalities': config.get('modalities', ['text']), 'architecture': config.get('architecture', 'unknown') } except Exception as e: logger.error(f"Error loading trained student model: {e}") raise async def _load_student_from_huggingface(self, repo_id: str) -> Dict[str, Any]: """Load a student model from Hugging Face repository""" try: # Get HF token hf_token = ( os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') ) logger.info(f"Loading student model from Hugging Face: {repo_id}") # Load configuration config = AutoConfig.from_pretrained(repo_id, token=hf_token) # Try to load the model to verify it exists and is accessible model = await self._load_from_huggingface(repo_id, token=hf_token) # Check if it's marked as a student model (optional) is_student = config.get('is_student_model', False) return { 'type': 'huggingface_student', 'path': repo_id, 'config': config.__dict__ if hasattr(config, '__dict__') else {}, 'training_history': {}, # HF models may not have our training history 'model_file': repo_id, # For HF models, this is the repo ID 'can_be_retrained': True, 'original_teachers': [], # Unknown for external models 'recommended_lr': 1e-5, # Default learning rate 'modalities': ['text'], # Default, could be enhanced 'architecture': getattr(config, 'architectures', ['unknown'])[0] if hasattr(config, 'architectures') else 'unknown', 'is_huggingface': True } except Exception as e: logger.error(f"Error loading student model from Hugging Face: {e}") raise ValueError(f"Could not load student model from Hugging Face: {str(e)}") async def load_trained_student_from_space(self, space_name: str) -> Dict[str, Any]: """Load a student model from a Hugging Face Space""" try: # Get HF token hf_token = ( os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') ) logger.info(f"Loading student model from Hugging Face Space: {space_name}") from huggingface_hub import HfApi api = HfApi(token=hf_token) # List files in the Space to find model files try: files = api.list_repo_files(space_name, repo_type="space") # Look for model files in models directory model_files = [f for f in files if f.startswith('models/') and f.endswith(('.safetensors', '.bin', '.pt'))] if not model_files: # Look for model files in root model_files = [f for f in files if f.endswith(('.safetensors', '.bin', '.pt'))] if not model_files: raise ValueError(f"No model files found in Space {space_name}") # Use the first model file found model_file = model_files[0] logger.info(f"Found model file in Space: {model_file}") # For now, we'll treat Space models as external HF models # In the future, we could download and cache them locally return { 'type': 'space_student', 'path': space_name, 'config': {}, # Space models may not have our config format 'training_history': {}, # Unknown for space models 'model_file': model_file, 'can_be_retrained': True, 'original_teachers': [], # Unknown for external models 'recommended_lr': 1e-5, # Default learning rate 'modalities': ['text'], # Default, could be enhanced 'architecture': 'unknown', 'is_space': True, 'space_name': space_name, 'available_models': model_files } except Exception as e: logger.error(f"Error accessing Space files: {e}") # Fallback: treat as a regular HF model return await self._load_student_from_huggingface(space_name) except Exception as e: logger.error(f"Error loading student model from Space: {e}") raise ValueError(f"Could not load student model from Space: {str(e)}") def _estimate_model_size(self, config) -> float: """Estimate model size in GB based on configuration""" try: # Get basic parameters hidden_size = getattr(config, 'hidden_size', 768) num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'num_layers', 12)) vocab_size = getattr(config, 'vocab_size', 50000) # Rough estimation: parameters * 4 bytes (float32) / 1GB # This is a very rough estimate embedding_params = vocab_size * hidden_size layer_params = num_layers * (hidden_size * hidden_size * 4) # Simplified total_params = embedding_params + layer_params # Convert to GB (4 bytes per parameter for float32) size_gb = (total_params * 4) / (1024 ** 3) return max(size_gb, 0.1) # Minimum 0.1GB except Exception: return 1.0 # Default 1GB if estimation fails def validate_model_compatibility(self, models: List[Dict[str, Any]]) -> Dict[str, Any]: """ Validate that multiple models are compatible for knowledge distillation Args: models: List of loaded model dictionaries Returns: Validation result with compatibility information """ if not models: return {'compatible': False, 'reason': 'No models provided'} if len(models) < 2: return {'compatible': False, 'reason': 'At least 2 models required for distillation'} # Check modality compatibility modalities = [model.get('modality', 'unknown') for model in models] unique_modalities = set(modalities) # Allow same modality or multimodal combinations if len(unique_modalities) == 1 and 'unknown' not in unique_modalities: compatibility_type = 'same_modality' elif 'multimodal' in unique_modalities or len(unique_modalities) > 1: compatibility_type = 'cross_modal' else: return {'compatible': False, 'reason': 'Unknown modalities detected'} return { 'compatible': True, 'type': compatibility_type, 'modalities': list(unique_modalities), 'model_count': len(models), 'total_parameters': sum(model.get('parameters', 0) for model in models if model.get('parameters')) }