train-modle / src /distillation.py
fokan's picture
Force Space rebuild v2.1.0 with incremental training
cca1fa9
"""
Knowledge Distillation Engine
Implements multi-modal knowledge distillation algorithms for creating new AI models
from multiple pre-trained teacher models across different modalities.
"""
import logging
import asyncio
from typing import Dict, Any, List, Optional, Callable, Union
import math
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from transformers import get_linear_schedule_with_warmup
from safetensors.torch import save_file
logger = logging.getLogger(__name__)
# Known problematic models and their error messages
PROBLEMATIC_MODELS = {
'deepseek-ai/DeepSeek-V3.1-Base': 'Requires GPU with FP8 quantization support. Try using a smaller model or different hardware.',
'Wan-AI/Wan2.2-TI2V-5B': 'Uses ti2v architecture. Will attempt to load with trust_remote_code=True.',
'stabilityai/stable-diffusion': 'Diffusion models require special handling. Consider using text encoders only.',
'runwayml/stable-diffusion': 'Diffusion models require special handling. Consider using text encoders only.',
}
class RealMultiModalDataset(Dataset):
"""
Real multi-modal dataset using actual data from Hugging Face or realistic synthetic data
"""
def __init__(self, size: int = 1000, modalities: List[str] = None, dataset_name: str = None, split: str = "train"):
self.size = size
self.modalities = modalities or ['text', 'vision']
self.dataset_name = dataset_name
self.split = split
self.data = self._load_real_data()
def _load_real_data(self):
"""Load real dataset from Hugging Face or create meaningful synthetic data"""
try:
if self.dataset_name:
# Try to load real dataset from Hugging Face
from datasets import load_dataset
dataset = load_dataset(self.dataset_name, split=self.split, streaming=True)
return list(dataset.take(self.size))
else:
# Create more realistic synthetic data with patterns
return self._create_realistic_synthetic_data()
except Exception as e:
logger.warning(f"Failed to load real dataset: {e}, using realistic synthetic data")
return self._create_realistic_synthetic_data()
def _create_realistic_synthetic_data(self):
"""Create realistic synthetic data with learnable patterns"""
data = []
for i in range(self.size):
# Create data with learnable patterns instead of pure random
base_pattern = torch.sin(torch.linspace(0, 2*3.14159, 512)) * (i % 10 + 1) / 10
noise = torch.randn(512) * 0.1
item = {}
if 'text' in self.modalities:
# Create text embeddings with learnable patterns
text_embedding = base_pattern + noise
item['text'] = text_embedding
if 'vision' in self.modalities:
# Create image data with patterns
image_pattern = base_pattern.unsqueeze(0).unsqueeze(0).repeat(3, 224, 224) + torch.randn(3, 224, 224) * 0.1
item['vision'] = image_pattern
if 'audio' in self.modalities:
# Create audio data with patterns
audio_pattern = base_pattern.repeat(2) + torch.randn(1024) * 0.1
item['audio'] = audio_pattern
# Add labels for supervised learning
item['labels'] = torch.tensor([i % 10], dtype=torch.float32)
data.append(item)
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if idx >= len(self.data):
idx = idx % len(self.data)
return self.data[idx]
class MultiModalDataset(RealMultiModalDataset):
"""
Backward compatibility wrapper for existing code
"""
def __init__(self, size: int = 1000, modalities: List[str] = None):
super().__init__(size=size, modalities=modalities, dataset_name=None)
class StudentModel(nn.Module):
"""
Configurable student model for knowledge distillation
"""
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.config = config
self.modalities = config.get('modalities', ['text'])
self.hidden_size = config.get('hidden_size', 768)
self.num_layers = config.get('num_layers', 6)
self.output_size = config.get('output_size', 768)
# Build modality-specific encoders
self.encoders = nn.ModuleDict()
if 'text' in self.modalities:
self.encoders['text'] = nn.Sequential(
nn.Linear(512, self.hidden_size),
nn.ReLU(),
*[nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.1)
) for _ in range(self.num_layers - 1)]
)
if 'vision' in self.modalities:
self.encoders['vision'] = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(64, self.hidden_size),
*[nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.1)
) for _ in range(self.num_layers - 1)]
)
if 'audio' in self.modalities:
self.encoders['audio'] = nn.Sequential(
nn.Linear(1024, self.hidden_size),
nn.ReLU(),
*[nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.1)
) for _ in range(self.num_layers - 1)]
)
# Fusion layer
self.fusion = nn.Sequential(
nn.Linear(self.hidden_size * len(self.modalities), self.hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(self.hidden_size, self.output_size)
)
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward pass through student model"""
encoded = []
for modality in self.modalities:
if modality in inputs and modality in self.encoders:
encoded.append(self.encoders[modality](inputs[modality]))
if not encoded:
raise ValueError("No valid modality inputs found")
# Concatenate and fuse
if len(encoded) == 1:
fused = encoded[0]
else:
fused = torch.cat(encoded, dim=-1)
fused = self.fusion(fused)
return fused
class KnowledgeDistillationTrainer:
"""
Multi-modal knowledge distillation trainer
"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
async def create_student_model(
self,
teacher_models: List[Dict[str, Any]],
config: Dict[str, Any]
) -> StudentModel:
"""
Create a student model based on teacher models and configuration
Args:
teacher_models: List of loaded teacher models
config: Student model configuration
Returns:
Initialized student model
"""
try:
# Analyze teacher models to determine student architecture
modalities = set()
total_params = 0
for teacher in teacher_models:
modality = teacher.get('modality', 'unknown')
if modality != 'unknown':
modalities.add(modality)
total_params += teacher.get('parameters', 0)
# Configure student model
student_config = {
'modalities': list(modalities) if modalities else ['text'],
'hidden_size': config.get('hidden_size', 768),
'num_layers': config.get('num_layers', 6),
'output_size': config.get('output_size', 768)
}
# Adjust size based on teacher complexity
if total_params > 1e9: # Large teachers
student_config['hidden_size'] = min(1024, student_config['hidden_size'])
student_config['num_layers'] = min(12, student_config['num_layers'])
elif total_params < 1e8: # Small teachers
student_config['hidden_size'] = max(256, student_config['hidden_size'])
student_config['num_layers'] = max(3, student_config['num_layers'])
student = StudentModel(student_config)
student.to(self.device)
logger.info(f"Created student model with config: {student_config}")
logger.info(f"Student parameters: {sum(p.numel() for p in student.parameters()):,}")
return student
except Exception as e:
logger.error(f"Error creating student model: {str(e)}")
raise
async def train(
self,
student_model: StudentModel,
teacher_models: List[Dict[str, Any]],
training_params: Dict[str, Any],
progress_callback: Optional[Callable] = None
) -> StudentModel:
"""
Train student model using knowledge distillation
Args:
student_model: Student model to train
teacher_models: List of teacher models
training_params: Training configuration
progress_callback: Callback for progress updates
Returns:
Trained student model
"""
try:
# Extract training parameters
max_steps = training_params.get('max_steps', 1000)
learning_rate = training_params.get('learning_rate', 1e-4)
batch_size = training_params.get('batch_size', 8)
temperature = training_params.get('temperature', 4.0)
alpha = training_params.get('alpha', 0.7) # Distillation loss weight
warmup_steps = training_params.get('warmup_steps', max_steps // 10)
# Prepare teachers
teacher_models_prepared = await self._prepare_teachers(teacher_models)
# Create dataset and dataloader
modalities = list(student_model.modalities)
dataset = MultiModalDataset(size=max_steps * batch_size, modalities=modalities)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Setup optimizer and scheduler
optimizer = optim.AdamW(student_model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
)
# Training loop
student_model.train()
total_loss = 0.0
step = 0
for batch_idx, batch in enumerate(dataloader):
if step >= max_steps:
break
# Move batch to device
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass through student
student_output = student_model(batch)
# Get teacher outputs
teacher_outputs = []
for teacher_data in teacher_models_prepared:
with torch.no_grad():
teacher_output = await self._get_teacher_output(teacher_data, batch)
teacher_outputs.append(teacher_output)
# Calculate distillation loss
distillation_loss = self._calculate_distillation_loss(
student_output, teacher_outputs, temperature, alpha
)
# Backward pass
optimizer.zero_grad()
distillation_loss.backward()
torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
optimizer.step()
scheduler.step()
# Update metrics
total_loss += distillation_loss.item()
step += 1
# Progress callback
if progress_callback and step % 10 == 0:
avg_loss = total_loss / step
await progress_callback(step, max_steps, avg_loss, {
'learning_rate': scheduler.get_last_lr()[0],
'temperature': temperature
})
# Log progress
if step % 100 == 0:
avg_loss = total_loss / step
logger.info(f"Step {step}/{max_steps}, Loss: {avg_loss:.4f}")
logger.info(f"Training completed. Final loss: {total_loss / max_steps:.4f}")
return student_model
except Exception as e:
logger.error(f"Error during training: {str(e)}")
raise
async def _prepare_teachers(self, teacher_models: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare teacher models for inference"""
prepared = []
for teacher_data in teacher_models:
model = teacher_data.get('model')
if model is not None:
if hasattr(model, 'eval'):
model.eval()
if hasattr(model, 'to'):
model.to(self.device)
prepared.append(teacher_data)
return prepared
async def _get_teacher_output(
self,
teacher_data: Dict[str, Any],
batch: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Get output from a teacher model with improved handling"""
try:
model = teacher_data.get('model')
modality = teacher_data.get('modality', 'text')
model_name = teacher_data.get('name', 'unknown')
logger.debug(f"Getting output from teacher model: {model_name} (modality: {modality})")
# Determine batch size
batch_size = next(iter(batch.values())).size(0) if batch else 1
if model is None:
logger.warning(f"Teacher model {model_name} is None, using synthetic output")
return self._create_synthetic_teacher_output(batch_size, modality)
# Try to get real output from the model
if modality == 'text' and 'text' in batch:
input_tensor = batch['text']
output = self._process_text_model(model, input_tensor, model_name)
elif modality == 'vision' and 'vision' in batch:
input_tensor = batch['vision']
output = self._process_vision_model(model, input_tensor, model_name)
elif modality == 'audio' and 'audio' in batch:
input_tensor = batch['audio']
output = self._process_audio_model(model, input_tensor, model_name)
else:
logger.warning(f"No matching modality for {model_name}, using synthetic output")
output = self._create_synthetic_teacher_output(batch_size, modality)
# Ensure output is 2D (batch_size, features)
if output.dim() > 2:
output = output.view(output.size(0), -1)
elif output.dim() == 1:
output = output.unsqueeze(0)
return output
except Exception as e:
logger.error(f"Error getting teacher output from {model_name}: {e}")
batch_size = next(iter(batch.values())).size(0) if batch else 1
return self._create_synthetic_teacher_output(batch_size, modality)
def _process_text_model(self, model, input_tensor: torch.Tensor, model_name: str) -> torch.Tensor:
"""Process text model with proper error handling"""
try:
# Ensure proper input shape
if input_tensor.dim() == 1:
input_tensor = input_tensor.unsqueeze(0)
# Try different model interfaces
if hasattr(model, 'encode'):
# For sentence transformers
output = model.encode(input_tensor)
elif hasattr(model, 'forward'):
# For standard PyTorch models
with torch.no_grad():
output = model(input_tensor)
elif callable(model):
# For callable models
output = model(input_tensor)
else:
raise ValueError(f"Model {model_name} is not callable")
# Handle different output types
if isinstance(output, dict):
# For models that return dict (like transformers)
if 'last_hidden_state' in output:
output = output['last_hidden_state'].mean(dim=1) # Average pooling
elif 'pooler_output' in output:
output = output['pooler_output']
else:
# Take first tensor value
output = next(iter(output.values()))
return output.to(self.device)
except Exception as e:
logger.warning(f"Failed to process text model {model_name}: {e}")
batch_size = input_tensor.size(0)
return self._create_synthetic_teacher_output(batch_size, 'text')
def _process_vision_model(self, model, input_tensor: torch.Tensor, model_name: str) -> torch.Tensor:
"""Process vision model with proper error handling"""
try:
# Ensure proper input shape (batch_size, channels, height, width)
if input_tensor.dim() == 3:
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
if hasattr(model, 'forward'):
output = model(input_tensor)
elif callable(model):
output = model(input_tensor)
else:
raise ValueError(f"Vision model {model_name} is not callable")
# Handle different output types
if isinstance(output, dict):
if 'last_hidden_state' in output:
output = output['last_hidden_state'].mean(dim=1)
elif 'pooler_output' in output:
output = output['pooler_output']
else:
output = next(iter(output.values()))
return output.to(self.device)
except Exception as e:
logger.warning(f"Failed to process vision model {model_name}: {e}")
batch_size = input_tensor.size(0)
return self._create_synthetic_teacher_output(batch_size, 'vision')
def _process_audio_model(self, model, input_tensor: torch.Tensor, model_name: str) -> torch.Tensor:
"""Process audio model with proper error handling"""
try:
if input_tensor.dim() == 1:
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
if hasattr(model, 'forward'):
output = model(input_tensor)
elif callable(model):
output = model(input_tensor)
else:
raise ValueError(f"Audio model {model_name} is not callable")
if isinstance(output, dict):
if 'last_hidden_state' in output:
output = output['last_hidden_state'].mean(dim=1)
elif 'pooler_output' in output:
output = output['pooler_output']
else:
output = next(iter(output.values()))
return output.to(self.device)
except Exception as e:
logger.warning(f"Failed to process audio model {model_name}: {e}")
batch_size = input_tensor.size(0)
return self._create_synthetic_teacher_output(batch_size, 'audio')
def _create_synthetic_teacher_output(self, batch_size: int, modality: str) -> torch.Tensor:
"""Create synthetic teacher output with some structure"""
# Create output with some pattern instead of pure random
if modality == 'text':
# Text-like embeddings
base = torch.linspace(0, 1, 768).unsqueeze(0).repeat(batch_size, 1)
noise = torch.randn(batch_size, 768) * 0.1
output = base + noise
elif modality == 'vision':
# Vision-like features
base = torch.linspace(0, 1, 768).unsqueeze(0).repeat(batch_size, 1)
noise = torch.randn(batch_size, 768) * 0.15
output = base * 0.8 + noise
elif modality == 'audio':
# Audio-like features
base = torch.sin(torch.linspace(0, 10, 768)).unsqueeze(0).repeat(batch_size, 1)
noise = torch.randn(batch_size, 768) * 0.1
output = base + noise
else:
# Default output
output = torch.randn(batch_size, 768)
return output.to(self.device)
def _calculate_distillation_loss(
self,
student_output: torch.Tensor,
teacher_outputs: List[torch.Tensor],
temperature: float,
alpha: float
) -> torch.Tensor:
"""
Calculate knowledge distillation loss
Args:
student_output: Student model output
teacher_outputs: List of teacher outputs
temperature: Temperature for softmax
alpha: Weight for distillation loss
Returns:
Combined distillation loss
"""
if not teacher_outputs:
return torch.tensor(0.0, device=self.device, requires_grad=True)
# Ensemble teacher outputs (average)
teacher_ensemble = torch.stack(teacher_outputs).mean(dim=0)
# Ensure same dimensions
min_dim = min(student_output.size(-1), teacher_ensemble.size(-1))
student_logits = student_output[..., :min_dim]
teacher_logits = teacher_ensemble[..., :min_dim]
# Temperature-scaled softmax
student_soft = F.log_softmax(student_logits / temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
# KL divergence loss
distillation_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
# Optional: Add MSE loss for feature matching
feature_loss = F.mse_loss(student_logits, teacher_logits)
# Combine losses
total_loss = alpha * distillation_loss + (1 - alpha) * feature_loss
return total_loss
async def save_model(self, model: StudentModel, save_path: str, training_metadata: Dict[str, Any] = None) -> None:
"""
Save trained model with complete files for HF compatibility
Args:
model: Trained student model
save_path: Path to save the model (should be .safetensors file)
training_metadata: Additional training information
"""
try:
from datetime import datetime
from pathlib import Path
import json
# Get save directory and create it
save_path = Path(save_path)
save_dir = save_path.parent
save_dir.mkdir(parents=True, exist_ok=True)
# Prepare state dict
state_dict = model.state_dict()
# Convert to CPU and ensure contiguous
cpu_state_dict = {}
for key, tensor in state_dict.items():
cpu_state_dict[key] = tensor.cpu().contiguous()
# Save model weights using safetensors
save_file(cpu_state_dict, str(save_path))
# Create comprehensive config.json (HF compatible)
config_path = save_dir / "config.json"
model_config = {
"architectures": [str(type(model).__name__)],
"model_type": "distilled_student",
"hidden_size": getattr(model, 'hidden_size', 768),
"num_hidden_layers": getattr(model, 'num_layers', 12),
"num_attention_heads": getattr(model, 'num_attention_heads', 12),
"intermediate_size": getattr(model, 'intermediate_size', 3072),
"vocab_size": getattr(model, 'vocab_size', 30522),
"max_position_embeddings": getattr(model, 'max_position_embeddings', 512),
"modalities": list(model.modalities) if hasattr(model, 'modalities') else ["text"],
"torch_dtype": "float32",
"transformers_version": "4.45.2",
"created_at": datetime.now().isoformat(),
"framework": "pytorch",
"can_be_retrained": True,
"is_student_model": True,
"supports_incremental_training": True,
"auto_map": {
"AutoModel": "model.StudentModel"
}
}
# Add original model config if available
if hasattr(model, 'config') and model.config:
model_config.update(model.config)
with open(config_path, 'w') as f:
json.dump(model_config, f, indent=2)
# Save model.py file for custom architecture
model_py_path = save_dir / "model.py"
model_py_content = '''"""
Custom Student Model for Knowledge Distillation
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from typing import Dict, Any, List, Optional
class StudentModelConfig(PretrainedConfig):
model_type = "distilled_student"
def __init__(
self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
intermediate_size=3072,
vocab_size=30522,
max_position_embeddings=512,
modalities=["text"],
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.modalities = modalities
class StudentModel(PreTrainedModel):
config_class = StudentModelConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.modalities = config.modalities
# Build model layers based on config
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
batch_first=True
) for _ in range(config.num_layers)
])
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
if input_ids is not None:
embeddings = self.embeddings(input_ids)
else:
# Handle other modalities
embeddings = kwargs.get('inputs_embeds')
for layer in self.layers:
embeddings = layer(embeddings, src_key_padding_mask=attention_mask)
pooled = self.pooler(embeddings.mean(dim=1))
return {
'last_hidden_state': embeddings,
'pooler_output': pooled
}
'''
with open(model_py_path, 'w') as f:
f.write(model_py_content)
# Save training history
training_history_path = save_dir / "training_history.json"
training_history = {
"model_info": {
"type": "student",
"architecture": str(type(model).__name__),
"modalities": list(model.modalities) if hasattr(model, 'modalities') else ["text"],
"hidden_size": getattr(model, 'hidden_size', 768),
"num_layers": getattr(model, 'num_layers', 12)
},
"training_sessions": [
{
"session_id": training_metadata.get('session_id') if training_metadata else None,
"timestamp": datetime.now().isoformat(),
"teacher_models": training_metadata.get('teacher_models', []) if training_metadata else [],
"distillation_strategy": training_metadata.get('strategy', 'ensemble') if training_metadata else 'ensemble',
"training_params": training_metadata.get('training_params', {}) if training_metadata else {},
"final_loss": getattr(self, 'final_loss', None)
}
],
"retraining_info": {
"can_be_used_as_student": True,
"can_accept_new_teachers": True,
"original_teachers": training_metadata.get('teacher_models', []) if training_metadata else [],
"recommended_learning_rate": training_metadata.get('training_params', {}).get('learning_rate', 1e-4) * 0.1 if training_metadata else 1e-5,
"supports_teacher_addition": True
}
}
with open(training_history_path, 'w') as f:
json.dump(training_history, f, indent=2)
# Create README.md
readme_path = save_dir / "README.md"
teacher_models = training_metadata.get('teacher_models', []) if training_metadata else []
readme_content = f'''---
license: apache-2.0
tags:
- knowledge-distillation
- pytorch
- transformers
- student-model
base_model: {teacher_models[0] if teacher_models else 'unknown'}
---
# Distilled Student Model
This is a student model created through knowledge distillation.
## Model Details
- **Architecture**: {str(type(model).__name__)}
- **Hidden Size**: {getattr(model, 'hidden_size', 768)}
- **Number of Layers**: {getattr(model, 'num_layers', 12)}
- **Modalities**: {list(model.modalities) if hasattr(model, 'modalities') else ["text"]}
- **Created**: {datetime.now().isoformat()}
## Teacher Models
{chr(10).join([f"- {teacher}" for teacher in teacher_models])}
## Training Details
- **Strategy**: {training_metadata.get('strategy', 'ensemble') if training_metadata else 'ensemble'}
- **Training Steps**: {training_metadata.get('training_params', {}).get('max_steps', 'unknown') if training_metadata else 'unknown'}
- **Learning Rate**: {training_metadata.get('training_params', {}).get('learning_rate', 'unknown') if training_metadata else 'unknown'}
## Usage
```python
from transformers import AutoModel, AutoConfig
# Load the model
model = AutoModel.from_pretrained("path/to/model", trust_remote_code=True)
config = AutoConfig.from_pretrained("path/to/model")
# Use for inference or further training
outputs = model(input_ids)
```
## Retraining
This model can be used as a student model for incremental training:
```python
# Load as existing student for further distillation
existing_student = "path/to/this/model"
# Add new teachers and continue training
```
## Files
- `pytorch_model.safetensors`: Model weights
- `config.json`: Model configuration
- `model.py`: Custom model architecture
- `training_history.json`: Complete training history
- `README.md`: This file
'''
with open(readme_path, 'w') as f:
f.write(readme_content)
logger.info(f"Complete model package saved to {save_dir}")
except Exception as e:
logger.error(f"Error saving model: {str(e)}")
raise
def _is_problematic_model(self, model_path: str) -> bool:
"""Check if a model is known to be problematic"""
return model_path in PROBLEMATIC_MODELS
def _get_model_error_message(self, model_path: str) -> str:
"""Get error message for problematic models"""
return PROBLEMATIC_MODELS.get(model_path, "Unknown compatibility issue")
def _should_retry_with_trust_remote_code(self, model_path: str, error_msg: str) -> bool:
"""Determine if we should retry loading with trust_remote_code=True"""
trust_indicators = [
'ti2v', 'does not recognize this architecture',
'trust_remote_code', 'custom architecture'
]
return any(indicator in error_msg.lower() for indicator in trust_indicators)