Spaces:
Running
Running
""" | |
Model Loading Utilities | |
Provides comprehensive model loading capabilities for various formats and sources | |
including PyTorch models, Safetensors, and Hugging Face transformers. | |
""" | |
import os | |
import logging | |
import asyncio | |
from typing import Dict, Any, Optional, Union, List | |
from pathlib import Path | |
import json | |
import requests | |
from urllib.parse import urlparse | |
import tempfile | |
import shutil | |
import torch | |
import torch.nn as nn | |
from transformers import ( | |
AutoModel, AutoTokenizer, AutoConfig, AutoImageProcessor, | |
AutoFeatureExtractor, AutoProcessor, AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM | |
) | |
from safetensors import safe_open | |
from safetensors.torch import load_file as load_safetensors | |
import numpy as np | |
from PIL import Image | |
logger = logging.getLogger(__name__) | |
# Custom model configurations for special architectures | |
CUSTOM_MODEL_CONFIGS = { | |
'ti2v': { | |
'model_type': 'ti2v', | |
'architecture': 'TI2VModel', | |
'modalities': ['text', 'vision'], | |
'supports_generation': True, | |
'is_multimodal': True | |
}, | |
'diffusion': { | |
'model_type': 'diffusion', | |
'architecture': 'DiffusionModel', | |
'modalities': ['vision', 'text'], | |
'supports_generation': True, | |
'is_multimodal': True | |
} | |
} | |
class ModelLoader: | |
""" | |
Comprehensive model loader supporting multiple formats and sources | |
""" | |
def __init__(self): | |
self.supported_formats = { | |
'.pt': 'pytorch', | |
'.pth': 'pytorch', | |
'.bin': 'pytorch', | |
'.safetensors': 'safetensors', | |
'.onnx': 'onnx', | |
'.h5': 'keras', | |
'.pkl': 'pickle', | |
'.joblib': 'joblib' | |
} | |
self.modality_keywords = { | |
'text': ['bert', 'gpt', 'roberta', 'electra', 'deberta', 'xlm', 'xlnet', 't5', 'bart'], | |
'vision': ['vit', 'resnet', 'efficientnet', 'convnext', 'swin', 'deit', 'beit'], | |
'multimodal': ['clip', 'blip', 'albef', 'flava', 'layoutlm', 'donut'], | |
'audio': ['wav2vec', 'hubert', 'whisper', 'speech_t5'] | |
} | |
async def load_model(self, source: str, **kwargs) -> Dict[str, Any]: | |
""" | |
Load a model from various sources | |
Args: | |
source: Model source (file path, HF repo, URL) | |
**kwargs: Additional loading parameters | |
Returns: | |
Dictionary containing model, tokenizer/processor, and metadata | |
""" | |
try: | |
logger.info(f"Loading model from: {source}") | |
# Determine source type | |
if self._is_url(source): | |
return await self._load_from_url(source, **kwargs) | |
elif self._is_huggingface_repo(source): | |
return await self._load_from_huggingface(source, **kwargs) | |
elif Path(source).exists(): | |
return await self._load_from_file(source, **kwargs) | |
else: | |
raise ValueError(f"Invalid model source: {source}") | |
except Exception as e: | |
logger.error(f"Error loading model from {source}: {str(e)}") | |
raise | |
async def get_model_info(self, source: str) -> Dict[str, Any]: | |
""" | |
Get model information without loading the full model | |
Args: | |
source: Model source | |
Returns: | |
Model metadata and information | |
""" | |
try: | |
info = { | |
'source': source, | |
'format': 'unknown', | |
'modality': 'unknown', | |
'architecture': None, | |
'parameters': None, | |
'size_mb': None | |
} | |
if Path(source).exists(): | |
file_path = Path(source) | |
info['size_mb'] = file_path.stat().st_size / (1024 * 1024) | |
info['format'] = self.supported_formats.get(file_path.suffix, 'unknown') | |
# Try to extract more info based on format | |
if info['format'] == 'safetensors': | |
info.update(await self._get_safetensors_info(source)) | |
elif info['format'] == 'pytorch': | |
info.update(await self._get_pytorch_info(source)) | |
elif self._is_huggingface_repo(source): | |
info.update(await self._get_huggingface_info(source)) | |
# Detect modality from model name/architecture | |
info['modality'] = self._detect_modality(source, info.get('architecture', '')) | |
return info | |
except Exception as e: | |
logger.warning(f"Error getting model info for {source}: {str(e)}") | |
return {'source': source, 'error': str(e)} | |
def _is_url(self, source: str) -> bool: | |
"""Check if source is a URL""" | |
try: | |
result = urlparse(source) | |
return all([result.scheme, result.netloc]) | |
except: | |
return False | |
def _is_huggingface_repo(self, source: str) -> bool: | |
"""Check if source is a Hugging Face repository""" | |
# Simple heuristic: contains '/' but not a file extension | |
return '/' in source and not any(source.endswith(ext) for ext in self.supported_formats.keys()) | |
def _detect_modality(self, source: str, architecture: str) -> str: | |
"""Detect model modality from source and architecture""" | |
text = (source + ' ' + architecture).lower() | |
for modality, keywords in self.modality_keywords.items(): | |
if any(keyword in text for keyword in keywords): | |
return modality | |
return 'unknown' | |
async def _load_from_file(self, file_path: str, **kwargs) -> Dict[str, Any]: | |
"""Load model from local file""" | |
file_path = Path(file_path) | |
format_type = self.supported_formats.get(file_path.suffix, 'unknown') | |
if format_type == 'safetensors': | |
return await self._load_safetensors(file_path, **kwargs) | |
elif format_type == 'pytorch': | |
return await self._load_pytorch(file_path, **kwargs) | |
else: | |
raise ValueError(f"Unsupported format: {format_type}") | |
async def _load_from_url(self, url: str, **kwargs) -> Dict[str, Any]: | |
"""Load model from URL""" | |
# Download to temporary file | |
with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_file.write(chunk) | |
tmp_path = tmp_file.name | |
try: | |
# Load from temporary file | |
result = await self._load_from_file(tmp_path, **kwargs) | |
result['source_url'] = url | |
return result | |
finally: | |
# Cleanup temporary file | |
os.unlink(tmp_path) | |
async def _load_from_huggingface(self, repo_id: str, **kwargs) -> Dict[str, Any]: | |
"""Load model from Hugging Face repository""" | |
try: | |
# Get HF token from multiple sources | |
hf_token = ( | |
kwargs.get('token') or | |
os.getenv('HF_TOKEN') or | |
os.getenv('HUGGINGFACE_TOKEN') or | |
os.getenv('HUGGINGFACE_HUB_TOKEN') | |
) | |
logger.info(f"Loading model {repo_id} with token: {'Yes' if hf_token else 'No'}") | |
# Load configuration first with timeout | |
trust_remote_code = kwargs.get('trust_remote_code', False) | |
logger.info(f"Loading config for {repo_id} with trust_remote_code={trust_remote_code}") | |
try: | |
config = AutoConfig.from_pretrained( | |
repo_id, | |
trust_remote_code=trust_remote_code, | |
token=hf_token, | |
timeout=30 # 30 second timeout | |
) | |
logger.info(f"Successfully loaded config for {repo_id}") | |
except Exception as e: | |
logger.error(f"Failed to load config for {repo_id}: {e}") | |
raise ValueError(f"Could not load model configuration: {str(e)}") | |
# Load model with proper device handling | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Check if this is a large model and warn | |
model_size_gb = self._estimate_model_size(config) | |
if model_size_gb > 10: | |
logger.warning(f"Large model detected ({model_size_gb:.1f}GB estimated). This may take several minutes to load.") | |
# Check for custom architectures that need special handling | |
model_type = getattr(config, 'model_type', None) | |
# Try different loading strategies for different model types | |
model = None | |
loading_error = None | |
# Special handling for ti2v and other custom architectures | |
if model_type in CUSTOM_MODEL_CONFIGS: | |
try: | |
logger.info(f"Loading custom architecture {model_type} for {repo_id}...") | |
model = await self._load_custom_architecture(repo_id, config, hf_token, trust_remote_code, **kwargs) | |
except Exception as e: | |
logger.warning(f"Custom architecture loading failed: {e}") | |
loading_error = str(e) | |
# Strategy 1: Try AutoModel (most common) if not already loaded | |
if model is None: | |
try: | |
logger.info(f"Attempting to load {repo_id} with AutoModel...") | |
model = AutoModel.from_pretrained( | |
repo_id, | |
config=config, | |
torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
trust_remote_code=trust_remote_code, | |
token=hf_token, | |
low_cpu_mem_usage=True, | |
timeout=120 # 2 minute timeout for model loading | |
) | |
logger.info(f"Successfully loaded {repo_id} with AutoModel") | |
except Exception as e: | |
loading_error = str(e) | |
logger.warning(f"AutoModel failed for {repo_id}: {e}") | |
# Strategy 2: Try specific model classes for known types | |
if model is None: | |
model = await self._try_specific_model_classes(repo_id, config, hf_token, trust_remote_code, kwargs) | |
# Strategy 3: Try with trust_remote_code if not already enabled | |
if model is None and not trust_remote_code: | |
try: | |
logger.info(f"Trying {repo_id} with trust_remote_code=True") | |
# For Gemma 3 models, try AutoModelForCausalLM specifically | |
if 'gemma-3' in repo_id.lower() or 'gemma3' in str(config).lower(): | |
from transformers import AutoModelForCausalLM | |
model = AutoModelForCausalLM.from_pretrained( | |
repo_id, | |
config=config, | |
torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
trust_remote_code=True, | |
token=hf_token, | |
low_cpu_mem_usage=True | |
) | |
else: | |
model = AutoModel.from_pretrained( | |
repo_id, | |
config=config, | |
torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
trust_remote_code=True, | |
token=hf_token, | |
low_cpu_mem_usage=True | |
) | |
logger.info(f"Successfully loaded {repo_id} with trust_remote_code=True") | |
except Exception as e: | |
logger.warning(f"Loading with trust_remote_code=True failed: {e}") | |
if model is None: | |
raise ValueError(f"Could not load model {repo_id}. Last error: {loading_error}") | |
# Move to device manually | |
model = model.to(device) | |
# Load appropriate processor/tokenizer | |
processor = None | |
try: | |
# Try different processor types | |
for processor_class in [AutoTokenizer, AutoImageProcessor, AutoFeatureExtractor, AutoProcessor]: | |
try: | |
processor = processor_class.from_pretrained(repo_id, token=hf_token) | |
break | |
except: | |
continue | |
except Exception as e: | |
logger.warning(f"Could not load processor for {repo_id}: {e}") | |
return { | |
'model': model, | |
'processor': processor, | |
'config': config, | |
'source': repo_id, | |
'format': 'huggingface', | |
'architecture': config.architectures[0] if hasattr(config, 'architectures') and config.architectures else None, | |
'modality': self._detect_modality(repo_id, str(config.architectures) if hasattr(config, 'architectures') else ''), | |
'parameters': sum(p.numel() for p in model.parameters()) if hasattr(model, 'parameters') else None | |
} | |
except Exception as e: | |
logger.error(f"Error loading from Hugging Face repo {repo_id}: {str(e)}") | |
raise | |
async def _load_custom_architecture(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, **kwargs): | |
"""Load models with custom architectures like ti2v""" | |
try: | |
model_type = getattr(config, 'model_type', None) | |
logger.info(f"Loading custom architecture: {model_type}") | |
if model_type == 'ti2v': | |
# For ti2v models, we need to create a wrapper that can work with our distillation | |
return await self._load_ti2v_model(repo_id, config, hf_token, trust_remote_code, **kwargs) | |
else: | |
# For other custom architectures, try with trust_remote_code | |
logger.info(f"Attempting to load custom model {repo_id} with trust_remote_code=True") | |
# Try different model classes | |
model_classes = [AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM] | |
for model_class in model_classes: | |
try: | |
model = model_class.from_pretrained( | |
repo_id, | |
config=config, | |
trust_remote_code=True, # Force trust_remote_code for custom architectures | |
token=hf_token, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32 | |
) | |
logger.info(f"Successfully loaded {repo_id} with {model_class.__name__}") | |
return model | |
except Exception as e: | |
logger.warning(f"{model_class.__name__} failed for {repo_id}: {e}") | |
continue | |
raise ValueError(f"All loading strategies failed for custom architecture {model_type}") | |
except Exception as e: | |
logger.error(f"Error loading custom architecture: {e}") | |
raise | |
async def _load_ti2v_model(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, **kwargs): | |
"""Special handling for ti2v (Text-to-Image/Video) models""" | |
try: | |
logger.info(f"Loading ti2v model: {repo_id}") | |
# For ti2v models, we'll create a wrapper that extracts text features | |
# This allows us to use them in knowledge distillation | |
# Try to load with trust_remote_code=True (required for custom architectures) | |
model = AutoModel.from_pretrained( | |
repo_id, | |
config=config, | |
trust_remote_code=True, | |
token=hf_token, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32 | |
) | |
# Create a wrapper that can extract features for distillation | |
class TI2VWrapper(torch.nn.Module): | |
def __init__(self, base_model): | |
super().__init__() | |
self.base_model = base_model | |
self.config = base_model.config | |
def forward(self, input_ids=None, attention_mask=None, **kwargs): | |
# Extract text encoder features if available | |
if hasattr(self.base_model, 'text_encoder'): | |
return self.base_model.text_encoder(input_ids=input_ids, attention_mask=attention_mask) | |
elif hasattr(self.base_model, 'encoder'): | |
return self.base_model.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
else: | |
# Fallback: try to get some meaningful representation | |
return self.base_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) | |
wrapped_model = TI2VWrapper(model) | |
logger.info(f"Successfully wrapped ti2v model: {repo_id}") | |
return wrapped_model | |
except Exception as e: | |
logger.error(f"Error loading ti2v model {repo_id}: {e}") | |
raise | |
async def _load_safetensors(self, file_path: Path, **kwargs) -> Dict[str, Any]: | |
"""Load model from Safetensors format""" | |
try: | |
# Load tensors | |
tensors = {} | |
with safe_open(file_path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
tensors[key] = f.get_tensor(key) | |
# Try to reconstruct model architecture | |
model = self._reconstruct_model_from_tensors(tensors) | |
return { | |
'model': model, | |
'tensors': tensors, | |
'source': str(file_path), | |
'format': 'safetensors', | |
'parameters': sum(tensor.numel() for tensor in tensors.values()), | |
'tensor_keys': list(tensors.keys()) | |
} | |
except Exception as e: | |
logger.error(f"Error loading Safetensors file {file_path}: {str(e)}") | |
raise | |
async def _load_pytorch(self, file_path: Path, **kwargs) -> Dict[str, Any]: | |
"""Load PyTorch model""" | |
try: | |
# Load checkpoint | |
checkpoint = torch.load(file_path, map_location='cpu') | |
# Extract model and metadata | |
if isinstance(checkpoint, dict): | |
model = checkpoint.get('model', checkpoint.get('state_dict', checkpoint)) | |
metadata = {k: v for k, v in checkpoint.items() if k not in ['model', 'state_dict']} | |
else: | |
model = checkpoint | |
metadata = {} | |
return { | |
'model': model, | |
'metadata': metadata, | |
'source': str(file_path), | |
'format': 'pytorch', | |
'parameters': sum(tensor.numel() for tensor in model.values()) if isinstance(model, dict) else None | |
} | |
except Exception as e: | |
logger.error(f"Error loading PyTorch file {file_path}: {str(e)}") | |
raise | |
def _reconstruct_model_from_tensors(self, tensors: Dict[str, torch.Tensor]) -> nn.Module: | |
""" | |
Attempt to reconstruct a PyTorch model from tensor dictionary | |
This is a simplified implementation - in practice, this would need | |
more sophisticated architecture detection | |
""" | |
class GenericModel(nn.Module): | |
def __init__(self, tensors): | |
super().__init__() | |
self.tensors = nn.ParameterDict() | |
for name, tensor in tensors.items(): | |
self.tensors[name.replace('.', '_')] = nn.Parameter(tensor) | |
def forward(self, x): | |
# Placeholder forward pass | |
return x | |
return GenericModel(tensors) | |
async def _get_safetensors_info(self, file_path: str) -> Dict[str, Any]: | |
"""Get information from Safetensors file""" | |
try: | |
info = {} | |
with safe_open(file_path, framework="pt", device="cpu") as f: | |
keys = list(f.keys()) | |
info['tensor_count'] = len(keys) | |
info['tensor_keys'] = keys[:10] # First 10 keys | |
# Estimate parameters | |
total_params = 0 | |
for key in keys: | |
tensor = f.get_tensor(key) | |
total_params += tensor.numel() | |
info['parameters'] = total_params | |
return info | |
except Exception as e: | |
logger.warning(f"Error getting Safetensors info: {e}") | |
return {} | |
async def _get_pytorch_info(self, file_path: str) -> Dict[str, Any]: | |
"""Get information from PyTorch file""" | |
try: | |
checkpoint = torch.load(file_path, map_location='cpu') | |
info = {} | |
if isinstance(checkpoint, dict): | |
info['keys'] = list(checkpoint.keys()) | |
# Look for model/state_dict | |
model_data = checkpoint.get('model', checkpoint.get('state_dict', checkpoint)) | |
if isinstance(model_data, dict): | |
info['parameters'] = sum(tensor.numel() for tensor in model_data.values()) | |
info['layer_count'] = len(model_data) | |
return info | |
except Exception as e: | |
logger.warning(f"Error getting PyTorch info: {e}") | |
return {} | |
async def _get_huggingface_info(self, repo_id: str) -> Dict[str, Any]: | |
"""Get information from Hugging Face repository""" | |
try: | |
hf_token = ( | |
os.getenv('HF_TOKEN') or | |
os.getenv('HUGGINGFACE_TOKEN') or | |
os.getenv('HUGGINGFACE_HUB_TOKEN') | |
) | |
config = AutoConfig.from_pretrained(repo_id, token=hf_token) | |
info = { | |
'architecture': config.architectures[0] if hasattr(config, 'architectures') and config.architectures else None, | |
'model_type': getattr(config, 'model_type', None), | |
'hidden_size': getattr(config, 'hidden_size', None), | |
'num_layers': getattr(config, 'num_hidden_layers', getattr(config, 'num_layers', None)), | |
'vocab_size': getattr(config, 'vocab_size', None) | |
} | |
return info | |
except Exception as e: | |
logger.warning(f"Error getting Hugging Face info: {e}") | |
return {} | |
async def _try_specific_model_classes(self, repo_id: str, config, hf_token: str, trust_remote_code: bool, kwargs: Dict[str, Any]): | |
"""Try loading with specific model classes for known architectures""" | |
from transformers import ( | |
AutoModelForCausalLM, AutoModelForSequenceClassification, | |
AutoModelForTokenClassification, AutoModelForQuestionAnswering, | |
AutoModelForMaskedLM, AutoModelForImageClassification, | |
AutoModelForObjectDetection, AutoModelForSemanticSegmentation, | |
AutoModelForImageSegmentation, AutoModelForDepthEstimation, | |
AutoModelForZeroShotImageClassification | |
) | |
# Map model types to appropriate AutoModel classes | |
model_type = getattr(config, 'model_type', '').lower() | |
architecture = getattr(config, 'architectures', []) | |
arch_str = str(architecture).lower() if architecture else '' | |
model_classes_to_try = [] | |
# Determine appropriate model classes based on model type and architecture | |
if 'siglip' in model_type or 'siglip' in arch_str: | |
# SigLIP models - try vision-related classes | |
model_classes_to_try = [ | |
AutoModelForImageClassification, | |
AutoModelForZeroShotImageClassification, | |
AutoModel | |
] | |
elif 'clip' in model_type or 'clip' in arch_str: | |
model_classes_to_try = [AutoModelForZeroShotImageClassification, AutoModel] | |
elif 'vit' in model_type or 'vision' in model_type: | |
model_classes_to_try = [AutoModelForImageClassification, AutoModel] | |
elif 'bert' in model_type or 'roberta' in model_type: | |
model_classes_to_try = [AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModel] | |
elif 'gemma' in model_type or 'gemma' in arch_str: | |
# Gemma models (including Gemma 3) - try causal LM classes | |
model_classes_to_try = [AutoModelForCausalLM, AutoModel] | |
elif 'gpt' in model_type or 'llama' in model_type: | |
model_classes_to_try = [AutoModelForCausalLM, AutoModel] | |
else: | |
# Generic fallback | |
model_classes_to_try = [ | |
AutoModelForCausalLM, # Try causal LM first for newer models | |
AutoModelForSequenceClassification, | |
AutoModelForImageClassification, | |
AutoModel | |
] | |
# Try each model class | |
for model_class in model_classes_to_try: | |
try: | |
logger.info(f"Trying {repo_id} with {model_class.__name__}") | |
model = model_class.from_pretrained( | |
repo_id, | |
config=config, | |
torch_dtype=kwargs.get('torch_dtype', torch.float32), | |
trust_remote_code=trust_remote_code, | |
token=hf_token, | |
low_cpu_mem_usage=True | |
) | |
logger.info(f"Successfully loaded {repo_id} with {model_class.__name__}") | |
return model | |
except Exception as e: | |
logger.debug(f"{model_class.__name__} failed for {repo_id}: {e}") | |
continue | |
return None | |
async def load_trained_student(self, model_path: str) -> Dict[str, Any]: | |
"""Load a previously trained student model for retraining""" | |
try: | |
# Check if it's a Hugging Face model (starts with organization/) | |
if '/' in model_path and not Path(model_path).exists(): | |
# This is likely a Hugging Face repository | |
return await self._load_student_from_huggingface(model_path) | |
# Local model path | |
model_dir = Path(model_path) | |
# Check if it's a trained student model | |
config_path = model_dir / "config.json" | |
if not config_path.exists(): | |
# Try alternative naming | |
safetensors_files = list(model_dir.glob("*.safetensors")) | |
if safetensors_files: | |
config_path = safetensors_files[0].with_suffix('_config.json') | |
if not config_path.exists(): | |
raise ValueError("No configuration file found for student model") | |
# Load configuration | |
with open(config_path, 'r') as f: | |
config = json.load(f) | |
# Verify it's a student model | |
if not config.get('is_student_model', False): | |
raise ValueError("This is not a trained student model") | |
# Load training history | |
history_path = model_dir / "training_history.json" | |
if not history_path.exists(): | |
# Try alternative naming | |
safetensors_files = list(model_dir.glob("*.safetensors")) | |
if safetensors_files: | |
history_path = safetensors_files[0].with_suffix('_training_history.json') | |
training_history = {} | |
if history_path.exists(): | |
with open(history_path, 'r') as f: | |
training_history = json.load(f) | |
# Load model weights | |
model_file = None | |
for ext in ['.safetensors', '.bin', '.pt']: | |
potential_file = model_dir / f"student_model{ext}" | |
if potential_file.exists(): | |
model_file = potential_file | |
break | |
if not model_file: | |
# Look for any model file | |
for ext in ['.safetensors', '.bin', '.pt']: | |
files = list(model_dir.glob(f"*{ext}")) | |
if files: | |
model_file = files[0] | |
break | |
if not model_file: | |
raise ValueError("No model file found") | |
return { | |
'type': 'trained_student', | |
'path': str(model_path), | |
'config': config, | |
'training_history': training_history, | |
'model_file': str(model_file), | |
'can_be_retrained': config.get('can_be_retrained', True), | |
'original_teachers': training_history.get('retraining_info', {}).get('original_teachers', []), | |
'recommended_lr': training_history.get('retraining_info', {}).get('recommended_learning_rate', 1e-5), | |
'modalities': config.get('modalities', ['text']), | |
'architecture': config.get('architecture', 'unknown') | |
} | |
except Exception as e: | |
logger.error(f"Error loading trained student model: {e}") | |
raise | |
async def _load_student_from_huggingface(self, repo_id: str) -> Dict[str, Any]: | |
"""Load a student model from Hugging Face repository""" | |
try: | |
# Get HF token | |
hf_token = ( | |
os.getenv('HF_TOKEN') or | |
os.getenv('HUGGINGFACE_TOKEN') or | |
os.getenv('HUGGINGFACE_HUB_TOKEN') | |
) | |
logger.info(f"Loading student model from Hugging Face: {repo_id}") | |
# Load configuration | |
config = AutoConfig.from_pretrained(repo_id, token=hf_token) | |
# Try to load the model to verify it exists and is accessible | |
model = await self._load_from_huggingface(repo_id, token=hf_token) | |
# Check if it's marked as a student model (optional) | |
is_student = config.get('is_student_model', False) | |
return { | |
'type': 'huggingface_student', | |
'path': repo_id, | |
'config': config.__dict__ if hasattr(config, '__dict__') else {}, | |
'training_history': {}, # HF models may not have our training history | |
'model_file': repo_id, # For HF models, this is the repo ID | |
'can_be_retrained': True, | |
'original_teachers': [], # Unknown for external models | |
'recommended_lr': 1e-5, # Default learning rate | |
'modalities': ['text'], # Default, could be enhanced | |
'architecture': getattr(config, 'architectures', ['unknown'])[0] if hasattr(config, 'architectures') else 'unknown', | |
'is_huggingface': True | |
} | |
except Exception as e: | |
logger.error(f"Error loading student model from Hugging Face: {e}") | |
raise ValueError(f"Could not load student model from Hugging Face: {str(e)}") | |
async def load_trained_student_from_space(self, space_name: str) -> Dict[str, Any]: | |
"""Load a student model from a Hugging Face Space""" | |
try: | |
# Get HF token | |
hf_token = ( | |
os.getenv('HF_TOKEN') or | |
os.getenv('HUGGINGFACE_TOKEN') or | |
os.getenv('HUGGINGFACE_HUB_TOKEN') | |
) | |
logger.info(f"Loading student model from Hugging Face Space: {space_name}") | |
from huggingface_hub import HfApi | |
api = HfApi(token=hf_token) | |
# List files in the Space to find model files | |
try: | |
files = api.list_repo_files(space_name, repo_type="space") | |
# Look for model files in models directory | |
model_files = [f for f in files if f.startswith('models/') and f.endswith(('.safetensors', '.bin', '.pt'))] | |
if not model_files: | |
# Look for model files in root | |
model_files = [f for f in files if f.endswith(('.safetensors', '.bin', '.pt'))] | |
if not model_files: | |
raise ValueError(f"No model files found in Space {space_name}") | |
# Use the first model file found | |
model_file = model_files[0] | |
logger.info(f"Found model file in Space: {model_file}") | |
# For now, we'll treat Space models as external HF models | |
# In the future, we could download and cache them locally | |
return { | |
'type': 'space_student', | |
'path': space_name, | |
'config': {}, # Space models may not have our config format | |
'training_history': {}, # Unknown for space models | |
'model_file': model_file, | |
'can_be_retrained': True, | |
'original_teachers': [], # Unknown for external models | |
'recommended_lr': 1e-5, # Default learning rate | |
'modalities': ['text'], # Default, could be enhanced | |
'architecture': 'unknown', | |
'is_space': True, | |
'space_name': space_name, | |
'available_models': model_files | |
} | |
except Exception as e: | |
logger.error(f"Error accessing Space files: {e}") | |
# Fallback: treat as a regular HF model | |
return await self._load_student_from_huggingface(space_name) | |
except Exception as e: | |
logger.error(f"Error loading student model from Space: {e}") | |
raise ValueError(f"Could not load student model from Space: {str(e)}") | |
def _estimate_model_size(self, config) -> float: | |
"""Estimate model size in GB based on configuration""" | |
try: | |
# Get basic parameters | |
hidden_size = getattr(config, 'hidden_size', 768) | |
num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'num_layers', 12)) | |
vocab_size = getattr(config, 'vocab_size', 50000) | |
# Rough estimation: parameters * 4 bytes (float32) / 1GB | |
# This is a very rough estimate | |
embedding_params = vocab_size * hidden_size | |
layer_params = num_layers * (hidden_size * hidden_size * 4) # Simplified | |
total_params = embedding_params + layer_params | |
# Convert to GB (4 bytes per parameter for float32) | |
size_gb = (total_params * 4) / (1024 ** 3) | |
return max(size_gb, 0.1) # Minimum 0.1GB | |
except Exception: | |
return 1.0 # Default 1GB if estimation fails | |
def validate_model_compatibility(self, models: List[Dict[str, Any]]) -> Dict[str, Any]: | |
""" | |
Validate that multiple models are compatible for knowledge distillation | |
Args: | |
models: List of loaded model dictionaries | |
Returns: | |
Validation result with compatibility information | |
""" | |
if not models: | |
return {'compatible': False, 'reason': 'No models provided'} | |
if len(models) < 2: | |
return {'compatible': False, 'reason': 'At least 2 models required for distillation'} | |
# Check modality compatibility | |
modalities = [model.get('modality', 'unknown') for model in models] | |
unique_modalities = set(modalities) | |
# Allow same modality or multimodal combinations | |
if len(unique_modalities) == 1 and 'unknown' not in unique_modalities: | |
compatibility_type = 'same_modality' | |
elif 'multimodal' in unique_modalities or len(unique_modalities) > 1: | |
compatibility_type = 'cross_modal' | |
else: | |
return {'compatible': False, 'reason': 'Unknown modalities detected'} | |
return { | |
'compatible': True, | |
'type': compatibility_type, | |
'modalities': list(unique_modalities), | |
'model_count': len(models), | |
'total_parameters': sum(model.get('parameters', 0) for model in models if model.get('parameters')) | |
} | |