Spaces:
Running
Running
""" | |
Multi-Modal Knowledge Distillation Web Application | |
A FastAPI-based web application for creating new AI models through knowledge distillation | |
from multiple pre-trained models across different modalities. | |
""" | |
import os | |
import asyncio | |
import logging | |
import uuid | |
from typing import List, Dict, Any, Optional, Union | |
from pathlib import Path | |
import json | |
import shutil | |
from datetime import datetime | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import uvicorn | |
from src.model_loader import ModelLoader | |
from src.distillation import KnowledgeDistillationTrainer | |
from src.utils import setup_logging, validate_file, cleanup_temp_files, get_system_info | |
# Import new core components | |
from src.core.memory_manager import AdvancedMemoryManager | |
from src.core.chunk_loader import AdvancedChunkLoader | |
from src.core.cpu_optimizer import CPUOptimizer | |
from src.core.token_manager import TokenManager | |
# Import medical components | |
from src.medical.medical_datasets import MedicalDatasetManager | |
from src.medical.dicom_handler import DicomHandler | |
from src.medical.medical_preprocessing import MedicalPreprocessor | |
# Import database components | |
from database.database import DatabaseManager | |
from src.database_manager import DatabaseManager as PlatformDatabaseManager | |
from src.models_manager import ModelsManager | |
# Setup logging with error handling | |
try: | |
setup_logging() | |
logger = logging.getLogger(__name__) | |
except Exception as e: | |
# Fallback to basic logging if setup fails | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.warning(f"Failed to setup advanced logging: {e}") | |
# Custom JSON encoder for handling Path objects and other non-serializable types | |
class CustomJSONEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, Path): | |
return str(obj) | |
elif hasattr(obj, '__dict__'): | |
return obj.__dict__ | |
elif hasattr(obj, 'tolist'): # For numpy arrays | |
return obj.tolist() | |
elif hasattr(obj, 'detach'): # For PyTorch tensors | |
return obj.detach().cpu().numpy().tolist() | |
return super().default(obj) | |
def safe_json_serialize(data): | |
"""Safely serialize data to JSON, handling non-serializable objects""" | |
try: | |
return json.loads(json.dumps(data, cls=CustomJSONEncoder)) | |
except Exception as e: | |
logger.warning(f"Failed to serialize data: {e}") | |
# Return a safe version | |
if isinstance(data, dict): | |
safe_data = {} | |
for k, v in data.items(): | |
try: | |
json.dumps(v, cls=CustomJSONEncoder) | |
safe_data[k] = v | |
except: | |
safe_data[k] = str(v) | |
return safe_data | |
else: | |
return str(data) | |
def cleanup_training_session(session_id: str): | |
"""Clean up training session resources""" | |
try: | |
if session_id in training_sessions: | |
session = training_sessions[session_id] | |
# Clean up any temporary files | |
model_path = session.get("model_path") | |
if model_path and Path(model_path).exists(): | |
try: | |
shutil.rmtree(model_path) | |
logger.info(f"Cleaned up model files for session {session_id}") | |
except Exception as e: | |
logger.warning(f"Failed to clean up model files: {e}") | |
# Remove from active sessions | |
del training_sessions[session_id] | |
# Remove WebSocket connection if exists | |
if session_id in active_connections: | |
del active_connections[session_id] | |
logger.info(f"Cleaned up training session: {session_id}") | |
except Exception as e: | |
logger.error(f"Error cleaning up session {session_id}: {e}") | |
def cleanup_old_sessions(): | |
"""Clean up old completed or failed sessions""" | |
try: | |
current_time = datetime.now().timestamp() | |
sessions_to_remove = [] | |
for session_id, session in training_sessions.items(): | |
session_status = session.get("status", "unknown") | |
end_time = session.get("end_time") | |
# Remove sessions older than 1 hour if completed/failed | |
if session_status in ["completed", "failed", "cancelled"] and end_time: | |
if current_time - end_time > 3600: # 1 hour | |
sessions_to_remove.append(session_id) | |
for session_id in sessions_to_remove: | |
cleanup_training_session(session_id) | |
logger.info(f"Auto-cleaned old session: {session_id}") | |
except Exception as e: | |
logger.error(f"Error during automatic cleanup: {e}") | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Multi-Modal Knowledge Distillation", | |
description="Create new AI models through knowledge distillation from multiple pre-trained models", | |
version="2.1.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Mount static files and templates | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
# Global variables for tracking training sessions | |
training_sessions: Dict[str, Dict[str, Any]] = {} | |
active_connections: Dict[str, WebSocket] = {} | |
# Startup event to clean old sessions | |
async def startup_event(): | |
"""Initialize application and clean up old sessions""" | |
try: | |
logger.info("Starting Multi-Modal Knowledge Distillation Platform") | |
# Clean up any old sessions from previous runs | |
cleanup_old_sessions() | |
# Initialize core components | |
logger.info("Initializing core components...") | |
# Log system information | |
system_info = get_system_info() | |
logger.info(f"System Info: {system_info}") | |
logger.info("Application startup completed successfully") | |
except Exception as e: | |
logger.error(f"Error during startup: {e}") | |
# Shutdown event to clean up resources | |
async def shutdown_event(): | |
"""Clean up resources on shutdown""" | |
try: | |
logger.info("Shutting down application...") | |
# Clean up all active sessions | |
for session_id in list(training_sessions.keys()): | |
cleanup_training_session(session_id) | |
# Clean up temporary files | |
cleanup_temp_files() | |
logger.info("Application shutdown completed") | |
except Exception as e: | |
logger.error(f"Error during shutdown: {e}") | |
# Pydantic models for API | |
class TrainingConfig(BaseModel): | |
session_id: str = Field(..., description="Unique session identifier") | |
teacher_models: List[Union[str, Dict[str, Any]]] = Field(..., description="List of teacher model paths/URLs or model configs") | |
student_config: Dict[str, Any] = Field(default_factory=dict, description="Student model configuration") | |
training_params: Dict[str, Any] = Field(default_factory=dict, description="Training parameters") | |
distillation_strategy: str = Field(default="ensemble", description="Distillation strategy") | |
hf_token: Optional[str] = Field(default=None, description="Hugging Face token") | |
trust_remote_code: bool = Field(default=False, description="Trust remote code execution") | |
existing_student_model: Optional[str] = Field(default=None, description="Path to existing trained student model for retraining") | |
incremental_training: bool = Field(default=False, description="Whether this is incremental training") | |
class TrainingStatus(BaseModel): | |
session_id: str | |
status: str | |
progress: float | |
current_step: int | |
total_steps: int | |
loss: Optional[float] = None | |
eta: Optional[str] = None | |
message: str = "" | |
class ModelInfo(BaseModel): | |
name: str | |
size: int | |
format: str | |
modality: str | |
architecture: Optional[str] = None | |
class DatabaseInfo(BaseModel): | |
name: str | |
name_ar: Optional[str] = "" | |
dataset_id: str | |
category: str = "general" | |
description: str = "" | |
description_ar: Optional[str] = "" | |
size: Optional[str] = "Unknown" | |
language: Optional[str] = "Unknown" | |
modality: str = "text" | |
license: Optional[str] = "Unknown" | |
class DatabaseSearchRequest(BaseModel): | |
query: str | |
limit: int = 20 | |
category: Optional[str] = None | |
class DatabaseSelectionRequest(BaseModel): | |
database_ids: List[str] | |
class ModelSearchRequest(BaseModel): | |
query: str | |
limit: int = 20 | |
model_type: Optional[str] = None | |
class ModelSelectionRequest(BaseModel): | |
teacher_models: List[str] = [] | |
student_model: Optional[str] = None | |
# Initialize components | |
model_loader = ModelLoader() | |
distillation_trainer = KnowledgeDistillationTrainer() | |
# Initialize new advanced components | |
memory_manager = AdvancedMemoryManager(max_memory_gb=14.0) # 14GB for 16GB systems | |
chunk_loader = AdvancedChunkLoader(memory_manager) | |
cpu_optimizer = CPUOptimizer(memory_manager) | |
token_manager = TokenManager() | |
# Initialize database manager | |
platform_db_manager = PlatformDatabaseManager() | |
# Initialize models manager | |
models_manager = ModelsManager() | |
database_manager = DatabaseManager() | |
# Initialize medical components | |
medical_dataset_manager = MedicalDatasetManager(memory_manager) | |
dicom_handler = DicomHandler(memory_limit_mb=1000.0) | |
medical_preprocessor = MedicalPreprocessor() | |
async def startup_event(): | |
"""Initialize application on startup""" | |
logger.info("Starting Multi-Modal Knowledge Distillation application") | |
# Create necessary directories with error handling | |
for directory in ["uploads", "models", "temp", "logs"]: | |
try: | |
Path(directory).mkdir(exist_ok=True) | |
logger.info(f"Created/verified directory: {directory}") | |
except PermissionError: | |
logger.warning(f"Cannot create directory {directory}, using temp directory") | |
except Exception as e: | |
logger.warning(f"Error creating directory {directory}: {e}") | |
# Log system information | |
try: | |
system_info = get_system_info() | |
logger.info(f"System info: {system_info}") | |
except Exception as e: | |
logger.warning(f"Could not get system info: {e}") | |
async def shutdown_event(): | |
"""Cleanup on application shutdown""" | |
logger.info("Shutting down application") | |
cleanup_temp_files() | |
async def read_root(): | |
"""Serve the main web interface""" | |
return templates.TemplateResponse("index.html", {"request": {}}) | |
async def health_check(): | |
"""Health check endpoint for Docker and monitoring""" | |
try: | |
# Get system information | |
memory_info = memory_manager.get_memory_info() | |
# Check if default token is available | |
default_token = token_manager.get_token() | |
return { | |
"status": "healthy", | |
"version": "2.0.0", | |
"timestamp": datetime.now().isoformat(), | |
"memory": { | |
"usage_percent": memory_info.get("process_memory_percent", 0), | |
"available_gb": memory_info.get("system_memory_available_gb", 0), | |
"status": memory_manager.check_memory_status() | |
}, | |
"tokens": { | |
"default_available": bool(default_token), | |
"total_tokens": len(token_manager.list_tokens()) | |
}, | |
"features": { | |
"memory_management": True, | |
"chunk_loading": True, | |
"cpu_optimization": True, | |
"medical_datasets": True, | |
"token_management": True | |
}, | |
"system_info": get_system_info() | |
} | |
except Exception as e: | |
logger.error(f"Health check failed: {e}") | |
return { | |
"status": "unhealthy", | |
"error": str(e), | |
"timestamp": datetime.now().isoformat(), | |
"version": "2.0.0" | |
} | |
async def test_token(): | |
"""Test if HF token is working""" | |
hf_token = ( | |
os.getenv('HF_TOKEN') or | |
os.getenv('HUGGINGFACE_TOKEN') or | |
os.getenv('HUGGINGFACE_HUB_TOKEN') | |
) | |
if not hf_token: | |
return { | |
"token_available": False, | |
"message": "No HF token found in environment variables" | |
} | |
try: | |
# Test token by trying to access a gated model's config | |
from transformers import AutoConfig | |
config = AutoConfig.from_pretrained("google/gemma-2b", token=hf_token) | |
return { | |
"token_available": True, | |
"token_valid": True, | |
"message": "Token is working correctly" | |
} | |
except Exception as e: | |
return { | |
"token_available": True, | |
"token_valid": False, | |
"message": f"Token validation failed: {str(e)}" | |
} | |
async def test_model_loading(request: Dict[str, Any]): | |
"""Test loading a specific model""" | |
try: | |
model_path = request.get('model_path') | |
trust_remote_code = request.get('trust_remote_code', False) | |
if not model_path: | |
return {"success": False, "error": "model_path is required"} | |
# Get appropriate token based on access type | |
access_type = request.get('access_type', 'read') | |
hf_token = request.get('token') | |
if not hf_token or hf_token == 'auto': | |
# Get appropriate token for the access type | |
hf_token = token_manager.get_token_for_task(access_type) | |
if hf_token: | |
logger.info(f"Using {access_type} token for model testing") | |
else: | |
logger.warning(f"No suitable token found for {access_type} access") | |
# Fallback to environment variables | |
hf_token = ( | |
os.getenv('HF_TOKEN') or | |
os.getenv('HUGGINGFACE_TOKEN') or | |
os.getenv('HUGGINGFACE_HUB_TOKEN') | |
) | |
# Test model loading | |
model_info = await model_loader.get_model_info(model_path) | |
return { | |
"success": True, | |
"model_info": model_info, | |
"message": f"Model {model_path} can be loaded" | |
} | |
except Exception as e: | |
error_msg = str(e) | |
suggestions = [] | |
if 'trust_remote_code' in error_msg.lower(): | |
suggestions.append("فعّل 'Trust Remote Code' للنماذج التي تتطلب كود مخصص") | |
elif 'gated' in error_msg.lower(): | |
suggestions.append("النموذج يتطلب إذن وصول خاص - استخدم رمز مخصص") | |
elif 'siglip' in error_msg.lower(): | |
suggestions.append("جرب تفعيل 'Trust Remote Code' لنماذج SigLIP") | |
elif '401' in error_msg or 'authentication' in error_msg.lower(): | |
suggestions.append("تحقق من رمز Hugging Face الخاص بك") | |
suggestions.append("تأكد من أن الرمز له صلاحية الوصول لهذا النموذج") | |
elif '404' in error_msg or 'not found' in error_msg.lower(): | |
suggestions.append("تحقق من اسم مستودع النموذج") | |
suggestions.append("تأكد من وجود النموذج على Hugging Face") | |
return { | |
"success": False, | |
"error": error_msg, | |
"suggestions": suggestions | |
} | |
async def upload_model( | |
background_tasks: BackgroundTasks, | |
files: List[UploadFile] = File(...), | |
model_names: List[str] = Form(...) | |
): | |
"""Upload model files""" | |
try: | |
uploaded_models = [] | |
for file, name in zip(files, model_names): | |
# Validate file | |
validation_result = validate_file(file) | |
if not validation_result["valid"]: | |
raise HTTPException(status_code=400, detail=validation_result["error"]) | |
# Generate unique filename | |
file_id = str(uuid.uuid4()) | |
file_extension = Path(file.filename).suffix | |
safe_filename = f"{file_id}{file_extension}" | |
file_path = Path("uploads") / safe_filename | |
# Save file | |
with open(file_path, "wb") as buffer: | |
content = await file.read() | |
buffer.write(content) | |
# Get model info | |
model_info = await model_loader.get_model_info(str(file_path)) | |
uploaded_models.append({ | |
"id": file_id, | |
"name": name, | |
"filename": file.filename, | |
"path": str(file_path), | |
"size": len(content), | |
"info": model_info | |
}) | |
logger.info(f"Uploaded model: {name} ({file.filename})") | |
# Schedule cleanup of old files | |
background_tasks.add_task(cleanup_temp_files, max_age_hours=24) | |
return { | |
"success": True, | |
"models": uploaded_models, | |
"message": f"Successfully uploaded {len(uploaded_models)} model(s)" | |
} | |
except Exception as e: | |
logger.error(f"Error uploading models: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def start_training( | |
background_tasks: BackgroundTasks, | |
config: TrainingConfig | |
): | |
"""Start knowledge distillation training""" | |
try: | |
session_id = config.session_id | |
# Handle existing sessions | |
if session_id in training_sessions: | |
existing_session = training_sessions[session_id] | |
existing_status = existing_session.get("status", "unknown") | |
# Allow restarting failed or completed sessions | |
if existing_status in ["failed", "completed", "cancelled"]: | |
logger.info(f"Restarting session {session_id} (previous status: {existing_status})") | |
# Clean up old session | |
cleanup_training_session(session_id) | |
elif existing_status in ["running", "initializing"]: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Training session already running (status: {existing_status})" | |
) | |
else: | |
# Unknown status, clean up and restart | |
logger.warning(f"Unknown session status {existing_status}, cleaning up") | |
cleanup_training_session(session_id) | |
# Set HF token from environment if available | |
hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN') | |
if hf_token: | |
os.environ['HF_TOKEN'] = hf_token | |
logger.info("Using Hugging Face token from environment") | |
# Check for large models and warn | |
large_models = [] | |
for model_info in config.teacher_models: | |
model_path = model_info if isinstance(model_info, str) else model_info.get('path', '') | |
if any(size_indicator in model_path.lower() for size_indicator in ['27b', '70b', '13b']): | |
large_models.append(model_path) | |
# Initialize training session with safe config serialization | |
safe_config = safe_json_serialize(config.dict()) | |
training_sessions[session_id] = { | |
"status": "initializing", | |
"progress": 0.0, | |
"current_step": 0, | |
"total_steps": config.training_params.get("max_steps", 1000), | |
"config": safe_config, | |
"start_time": None, | |
"end_time": None, | |
"model_path": None, | |
"logs": [], | |
"large_models": large_models, | |
"message": "Initializing training session..." + ( | |
f" (Large models detected: {', '.join(large_models)})" if large_models else "" | |
) | |
} | |
# Start training in background | |
background_tasks.add_task(run_training, session_id, config) | |
logger.info(f"Started training session: {session_id}") | |
return { | |
"success": True, | |
"session_id": session_id, | |
"message": "Training started successfully" | |
} | |
except Exception as e: | |
logger.error(f"Error starting training: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def run_training(session_id: str, config: TrainingConfig): | |
"""Run knowledge distillation training in background""" | |
try: | |
session = training_sessions[session_id] | |
session["status"] = "running" | |
session["start_time"] = asyncio.get_event_loop().time() | |
# Set timeout for the entire operation (30 minutes) | |
timeout_seconds = 30 * 60 | |
# Set HF token for this session - prioritize config token | |
config_token = getattr(config, 'hf_token', None) | |
env_token = ( | |
os.getenv('HF_TOKEN') or | |
os.getenv('HUGGINGFACE_TOKEN') or | |
os.getenv('HUGGINGFACE_HUB_TOKEN') | |
) | |
hf_token = config_token or env_token | |
if hf_token: | |
logger.info(f"Using Hugging Face token from {'config' if config_token else 'environment'}") | |
# Set token in environment for this session | |
os.environ['HF_TOKEN'] = hf_token | |
else: | |
logger.warning("No Hugging Face token found - private models may fail") | |
# Handle existing student model for incremental training | |
existing_student = None | |
if config.existing_student_model and config.incremental_training: | |
try: | |
await update_training_status(session_id, "loading_student", 0.05, "Loading existing student model...") | |
# Determine student source and load accordingly | |
student_source = getattr(config, 'student_source', 'local') | |
student_path = config.existing_student_model | |
if student_source == 'huggingface' or ('/' in student_path and not Path(student_path).exists()): | |
logger.info(f"Loading student model from Hugging Face: {student_path}") | |
existing_student = await model_loader.load_trained_student(student_path) | |
elif student_source == 'space': | |
logger.info(f"Loading student model from Hugging Face Space: {student_path}") | |
# For spaces, we'll try to load from the space's models directory | |
space_model_path = f"spaces/{student_path}/models" | |
existing_student = await model_loader.load_trained_student_from_space(student_path) | |
else: | |
logger.info(f"Loading student model from local path: {student_path}") | |
existing_student = await model_loader.load_trained_student(student_path) | |
logger.info(f"Successfully loaded existing student model: {existing_student.get('type', 'unknown')}") | |
# Merge original teachers with new teachers | |
original_teachers = existing_student.get('original_teachers', []) | |
new_teachers = [ | |
model_info if isinstance(model_info, str) else model_info.get('path', '') | |
for model_info in config.teacher_models | |
] | |
# Combine teachers (avoid duplicates) | |
all_teachers = original_teachers.copy() | |
for teacher in new_teachers: | |
if teacher not in all_teachers: | |
all_teachers.append(teacher) | |
logger.info(f"Incremental training: Original teachers: {original_teachers}") | |
logger.info(f"Incremental training: New teachers: {new_teachers}") | |
logger.info(f"Incremental training: All teachers: {all_teachers}") | |
# Update config with all teachers | |
config.teacher_models = all_teachers | |
except Exception as e: | |
logger.error(f"Error loading existing student model: {e}") | |
await update_training_status(session_id, "failed", session.get("progress", 0), f"Failed to load existing student: {str(e)}") | |
return | |
# Load teacher models | |
await update_training_status(session_id, "loading_models", 0.1, "Loading teacher models...") | |
teacher_models = [] | |
trust_remote_code = config.training_params.get('trust_remote_code', False) | |
total_models = len(config.teacher_models) | |
for i, model_info in enumerate(config.teacher_models): | |
try: | |
# Handle both old format (string) and new format (dict) | |
if isinstance(model_info, str): | |
model_path = model_info | |
model_token = hf_token | |
model_trust_code = trust_remote_code | |
else: | |
model_path = model_info.get('path', model_info) | |
model_token = model_info.get('token') or hf_token | |
model_trust_code = model_info.get('trust_remote_code', trust_remote_code) | |
# Update progress | |
progress = 0.1 + (i * 0.3 / total_models) # 0.1 to 0.4 | |
await update_training_status( | |
session_id, | |
"loading_models", | |
progress, | |
f"Loading model {i+1}/{total_models}: {model_path}..." | |
) | |
logger.info(f"Loading model {model_path} with trust_remote_code={model_trust_code}") | |
# Special handling for known problematic models | |
if model_path == 'Wan-AI/Wan2.2-TI2V-5B': | |
logger.info(f"Detected ti2v model {model_path}, forcing trust_remote_code=True") | |
model_trust_code = True | |
elif model_path == 'deepseek-ai/DeepSeek-V3.1-Base': | |
logger.warning(f"Skipping {model_path}: Requires GPU with FP8 quantization support") | |
await update_training_status( | |
session_id, | |
"loading_models", | |
progress, | |
f"Skipping {model_path}: Requires GPU with FP8 quantization" | |
) | |
continue | |
model = await model_loader.load_model( | |
model_path, | |
token=model_token, | |
trust_remote_code=model_trust_code | |
) | |
teacher_models.append(model) | |
logger.info(f"Successfully loaded model: {model_path}") | |
# Update progress after successful load | |
progress = 0.1 + ((i + 1) * 0.3 / total_models) | |
await update_training_status( | |
session_id, | |
"loading_models", | |
progress, | |
f"Loaded {i+1}/{total_models} models successfully" | |
) | |
except Exception as e: | |
error_msg = f"Failed to load model {model_path}: {str(e)}" | |
logger.error(error_msg) | |
# Provide helpful suggestions based on the error | |
suggestions = [] | |
error_str = str(e).lower() | |
# Check if we should retry with trust_remote_code=True | |
if not model_trust_code and ('ti2v' in error_str or 'does not recognize this architecture' in error_str): | |
try: | |
logger.info(f"Retrying {model_path} with trust_remote_code=True") | |
await update_training_status( | |
session_id, | |
"loading_models", | |
progress, | |
f"Retrying {model_path} with trust_remote_code=True..." | |
) | |
model = await model_loader.load_model( | |
model_path, | |
token=model_token, | |
trust_remote_code=True | |
) | |
teacher_models.append(model) | |
logger.info(f"Successfully loaded model on retry: {model_path}") | |
# Update progress after successful retry | |
progress = 0.1 + ((i + 1) * 0.3 / total_models) | |
await update_training_status( | |
session_id, | |
"loading_models", | |
progress, | |
f"Loaded {i+1}/{total_models} models successfully (retry)" | |
) | |
continue | |
except Exception as retry_e: | |
logger.error(f"Retry also failed for {model_path}: {str(retry_e)}") | |
error_msg = f"Failed even with trust_remote_code=True: {str(retry_e)}" | |
if 'trust_remote_code' in error_str: | |
suggestions.append("Try enabling 'Trust Remote Code' option") | |
elif 'gated' in error_str or 'access' in error_str: | |
suggestions.append("This model requires access permission and a valid HF token") | |
elif 'siglip' in error_str or 'unknown' in error_str: | |
suggestions.append("This model may require special loading. Try enabling 'Trust Remote Code'") | |
elif 'connection' in error_str or 'network' in error_str: | |
suggestions.append("Check your internet connection") | |
elif 'ti2v' in error_str: | |
suggestions.append("This ti2v model requires trust_remote_code=True") | |
if suggestions: | |
error_msg += f". Suggestions: {'; '.join(suggestions)}" | |
await update_training_status(session_id, "failed", session.get("progress", 0), error_msg) | |
return | |
# Initialize student model | |
await update_training_status(session_id, "initializing_student", 0.2, "Initializing student model...") | |
student_model = await distillation_trainer.create_student_model( | |
teacher_models, config.student_config | |
) | |
# Run distillation training | |
await update_training_status(session_id, "training", 0.3, "Starting knowledge distillation...") | |
async def progress_callback(step: int, total_steps: int, loss: float, metrics: Dict[str, Any]): | |
progress = 0.3 + (step / total_steps) * 0.6 # 30% to 90% | |
await update_training_status( | |
session_id, "training", progress, | |
f"Training step {step}/{total_steps}, Loss: {loss:.4f}", | |
current_step=step, loss=loss | |
) | |
trained_model = await distillation_trainer.train( | |
student_model, teacher_models, config.training_params, progress_callback | |
) | |
# Save trained model with metadata | |
await update_training_status(session_id, "saving", 0.9, "Saving trained model...") | |
# Create model directory with proper structure | |
model_dir = Path("models") / f"distilled_model_{session_id}" | |
model_dir.mkdir(parents=True, exist_ok=True) | |
model_path = model_dir / "pytorch_model.safetensors" | |
# Prepare training metadata for saving | |
training_metadata = { | |
'session_id': session_id, | |
'teacher_models': [ | |
model_info if isinstance(model_info, str) else model_info.get('path', '') | |
for model_info in config.teacher_models | |
], | |
'strategy': config.distillation_strategy, | |
'training_params': config.training_params, | |
'incremental_training': config.incremental_training, | |
'existing_student_model': config.existing_student_model | |
} | |
await distillation_trainer.save_model(trained_model, str(model_path), training_metadata) | |
# Complete training | |
session["status"] = "completed" | |
session["progress"] = 1.0 | |
session["end_time"] = asyncio.get_event_loop().time() | |
session["model_path"] = model_path | |
session["training_metadata"] = training_metadata | |
await update_training_status(session_id, "completed", 1.0, "Training completed successfully!") | |
logger.info(f"Training session {session_id} completed successfully") | |
except Exception as e: | |
logger.error(f"Training session {session_id} failed: {str(e)}") | |
session = training_sessions.get(session_id, {}) | |
session["status"] = "failed" | |
session["error"] = str(e) | |
await update_training_status(session_id, "failed", session.get("progress", 0), f"Training failed: {str(e)}") | |
async def update_training_status( | |
session_id: str, | |
status: str, | |
progress: float, | |
message: str, | |
current_step: int = None, | |
loss: float = None | |
): | |
"""Update training status and notify connected clients""" | |
if session_id in training_sessions: | |
session = training_sessions[session_id] | |
session["status"] = status | |
session["progress"] = progress | |
session["message"] = message | |
if current_step is not None: | |
session["current_step"] = current_step | |
if loss is not None: | |
session["loss"] = loss | |
# Calculate ETA | |
if session.get("start_time") and progress > 0: | |
elapsed = asyncio.get_event_loop().time() - session["start_time"] | |
if progress < 1.0: | |
eta_seconds = (elapsed / progress) * (1.0 - progress) | |
eta = f"{int(eta_seconds // 60)}m {int(eta_seconds % 60)}s" | |
session["eta"] = eta | |
# Notify WebSocket clients | |
if session_id in active_connections: | |
try: | |
# Safely serialize session data | |
safe_session_data = safe_json_serialize(session) | |
await active_connections[session_id].send_json({ | |
"type": "training_update", | |
"data": safe_session_data | |
}) | |
except Exception as e: | |
logger.warning(f"Failed to send WebSocket update: {e}") | |
# Remove disconnected client | |
if session_id in active_connections: | |
del active_connections[session_id] | |
async def get_training_progress(session_id: str): | |
"""Get training progress for a session""" | |
if session_id not in training_sessions: | |
raise HTTPException(status_code=404, detail="Training session not found") | |
session = training_sessions[session_id] | |
return TrainingStatus( | |
session_id=session_id, | |
status=session["status"], | |
progress=session["progress"], | |
current_step=session["current_step"], | |
total_steps=session["total_steps"], | |
loss=session.get("loss"), | |
eta=session.get("eta"), | |
message=session.get("message", "") | |
) | |
async def download_model(session_id: str): | |
"""Download trained model""" | |
try: | |
if session_id not in training_sessions: | |
raise HTTPException(status_code=404, detail="Training session not found") | |
session = training_sessions[session_id] | |
if session["status"] != "completed": | |
raise HTTPException(status_code=400, detail="Training not completed") | |
model_path = session.get("model_path") | |
if not model_path: | |
# Try to find model in models directory | |
models_dir = Path("models") | |
possible_paths = [ | |
models_dir / f"distilled_model_{session_id}", | |
models_dir / f"distilled_model_{session_id}.safetensors", | |
models_dir / f"model_{session_id}", | |
models_dir / f"student_model_{session_id}" | |
] | |
for path in possible_paths: | |
if path.exists(): | |
model_path = str(path) | |
break | |
if not model_path or not Path(model_path).exists(): | |
raise HTTPException(status_code=404, detail="Model file not found. The model may not have been saved properly.") | |
# Create a zip file with all model files | |
import zipfile | |
import tempfile | |
model_dir = Path(model_path) | |
if model_dir.is_file(): | |
# Single file | |
return FileResponse( | |
model_path, | |
media_type="application/octet-stream", | |
filename=f"distilled_model_{session_id}.safetensors" | |
) | |
else: | |
# Directory with multiple files | |
temp_zip = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') | |
with zipfile.ZipFile(temp_zip.name, 'w') as zipf: | |
for file_path in model_dir.rglob('*'): | |
if file_path.is_file(): | |
zipf.write(file_path, file_path.relative_to(model_dir)) | |
return FileResponse( | |
temp_zip.name, | |
media_type="application/zip", | |
filename=f"distilled_model_{session_id}.zip" | |
) | |
except Exception as e: | |
logger.error(f"Error downloading model: {e}") | |
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") | |
async def upload_to_huggingface( | |
session_id: str, | |
repo_name: str = Form(...), | |
description: str = Form(""), | |
private: bool = Form(False), | |
hf_token: str = Form(...) | |
): | |
"""Upload trained model to Hugging Face Hub""" | |
try: | |
if session_id not in training_sessions: | |
raise HTTPException(status_code=404, detail="Training session not found") | |
session = training_sessions[session_id] | |
if session["status"] != "completed": | |
raise HTTPException(status_code=400, detail="Training not completed") | |
model_path = session.get("model_path") | |
if not model_path or not Path(model_path).exists(): | |
raise HTTPException(status_code=404, detail="Model file not found") | |
# Import huggingface_hub | |
try: | |
from huggingface_hub import HfApi, create_repo | |
except ImportError: | |
raise HTTPException(status_code=500, detail="huggingface_hub not installed") | |
# Initialize HF API | |
api = HfApi(token=hf_token) | |
# Validate repository name format | |
if '/' not in repo_name: | |
raise HTTPException(status_code=400, detail="Repository name must be in format 'username/model-name'") | |
username, model_name = repo_name.split('/', 1) | |
# Create repository with better error handling | |
try: | |
repo_url = create_repo( | |
repo_id=repo_name, | |
token=hf_token, | |
private=private, | |
exist_ok=True | |
) | |
logger.info(f"Created/accessed repository: {repo_url}") | |
except Exception as e: | |
error_msg = str(e) | |
if "403" in error_msg or "Forbidden" in error_msg: | |
raise HTTPException( | |
status_code=403, | |
detail=f"Permission denied. Please check: 1) Your token has 'Write' permissions, 2) You own the namespace '{username}', 3) The repository name is correct. Error: {error_msg}" | |
) | |
elif "401" in error_msg or "Unauthorized" in error_msg: | |
raise HTTPException( | |
status_code=401, | |
detail=f"Invalid token. Please check your Hugging Face token. Error: {error_msg}" | |
) | |
else: | |
raise HTTPException(status_code=400, detail=f"Failed to create repository: {error_msg}") | |
# Upload model files | |
model_path_obj = Path(model_path) | |
uploaded_files = [] | |
# Determine the model directory | |
if model_path_obj.is_file(): | |
model_dir = model_path_obj.parent | |
else: | |
model_dir = model_path_obj | |
# Upload all files in the model directory | |
essential_files = [ | |
'pytorch_model.safetensors', 'config.json', 'model.py', | |
'training_history.json', 'README.md' | |
] | |
# Upload essential files first | |
for file_name in essential_files: | |
file_path = model_dir / file_name | |
if file_path.exists(): | |
try: | |
api.upload_file( | |
path_or_fileobj=str(file_path), | |
path_in_repo=file_name, | |
repo_id=repo_name, | |
token=hf_token | |
) | |
uploaded_files.append(file_name) | |
logger.info(f"Uploaded {file_name}") | |
except Exception as e: | |
logger.warning(f"Failed to upload {file_name}: {e}") | |
# Upload any additional files | |
for file_path in model_dir.rglob('*'): | |
if file_path.is_file() and file_path.name not in essential_files: | |
try: | |
relative_path = file_path.relative_to(model_dir) | |
api.upload_file( | |
path_or_fileobj=str(file_path), | |
path_in_repo=str(relative_path), | |
repo_id=repo_name, | |
token=hf_token | |
) | |
uploaded_files.append(str(relative_path)) | |
logger.info(f"Uploaded additional file: {relative_path}") | |
except Exception as e: | |
logger.warning(f"Failed to upload {relative_path}: {e}") | |
# Create README.md | |
config_info = session.get("config", {}) | |
teacher_models_raw = config_info.get("teacher_models", []) | |
# Extract model paths from teacher_models (handle both string and dict formats) | |
teacher_models = [] | |
for model in teacher_models_raw: | |
if isinstance(model, str): | |
teacher_models.append(model) | |
elif isinstance(model, dict): | |
teacher_models.append(model.get('path', str(model))) | |
else: | |
teacher_models.append(str(model)) | |
readme_content = f"""--- | |
license: apache-2.0 | |
tags: | |
- knowledge-distillation | |
- pytorch | |
- transformers | |
base_model: {teacher_models[0] if teacher_models else 'unknown'} | |
--- | |
# {repo_name} | |
This model was created using knowledge distillation from the following teacher model(s): | |
{chr(10).join([f"- {model}" for model in teacher_models])} | |
## Model Description | |
{description if description else 'A distilled model created using multi-modal knowledge distillation.'} | |
## Training Details | |
- **Teacher Models**: {', '.join(teacher_models)} | |
- **Distillation Strategy**: {config_info.get('distillation_strategy', 'ensemble')} | |
- **Training Steps**: {config_info.get('training_params', {}).get('max_steps', 'unknown')} | |
- **Learning Rate**: {config_info.get('training_params', {}).get('learning_rate', 'unknown')} | |
## Usage | |
```python | |
from transformers import AutoModel, AutoTokenizer | |
model = AutoModel.from_pretrained("{repo_name}") | |
tokenizer = AutoTokenizer.from_pretrained("{teacher_models[0] if teacher_models else 'bert-base-uncased'}") | |
``` | |
## Created with | |
This model was created using the Multi-Modal Knowledge Distillation platform. | |
""" | |
# Upload README | |
api.upload_file( | |
path_or_fileobj=readme_content.encode(), | |
path_in_repo="README.md", | |
repo_id=repo_name, | |
token=hf_token | |
) | |
uploaded_files.append("README.md") | |
return { | |
"success": True, | |
"repo_url": f"https://huggingface.co/{repo_name}", | |
"uploaded_files": uploaded_files, | |
"message": f"Model successfully uploaded to {repo_name}" | |
} | |
except Exception as e: | |
logger.error(f"Error uploading to Hugging Face: {e}") | |
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") | |
async def validate_repo_name(request: Dict[str, Any]): | |
"""Validate repository name and check permissions""" | |
try: | |
repo_name = request.get('repo_name', '').strip() | |
hf_token = request.get('hf_token', '').strip() | |
if not repo_name or not hf_token: | |
return {"valid": False, "error": "Repository name and token are required"} | |
if '/' not in repo_name: | |
return {"valid": False, "error": "Repository name must be in format 'username/model-name'"} | |
username, model_name = repo_name.split('/', 1) | |
# Check if username matches token owner | |
try: | |
from huggingface_hub import HfApi | |
api = HfApi(token=hf_token) | |
# Try to get user info | |
user_info = api.whoami() | |
token_username = user_info.get('name', '') | |
if username != token_username: | |
return { | |
"valid": False, | |
"error": f"Username mismatch. Token belongs to '{token_username}' but trying to create repo under '{username}'. Use '{token_username}/{model_name}' instead.", | |
"suggested_name": f"{token_username}/{model_name}" | |
} | |
return { | |
"valid": True, | |
"message": f"Repository name '{repo_name}' is valid for your account", | |
"username": token_username | |
} | |
except Exception as e: | |
return {"valid": False, "error": f"Token validation failed: {str(e)}"} | |
except Exception as e: | |
return {"valid": False, "error": f"Validation error: {str(e)}"} | |
async def test_space(request: Dict[str, Any]): | |
"""Test if a Hugging Face Space exists and has trained models""" | |
try: | |
space_name = request.get('space_name', '').strip() | |
hf_token = request.get('hf_token', '').strip() | |
if not space_name: | |
return {"success": False, "error": "Space name is required"} | |
if '/' not in space_name: | |
return {"success": False, "error": "Space name must be in format 'username/space-name'"} | |
try: | |
from huggingface_hub import HfApi | |
api = HfApi(token=hf_token if hf_token else None) | |
# Check if the Space exists | |
try: | |
space_info = api.space_info(space_name) | |
logger.info(f"Found Space: {space_name}") | |
except Exception as e: | |
return {"success": False, "error": f"Space not found or not accessible: {str(e)}"} | |
# Try to list files in the Space to see if it has models | |
try: | |
files = api.list_repo_files(space_name, repo_type="space") | |
model_files = [f for f in files if f.endswith(('.safetensors', '.bin', '.pt'))] | |
# Check for models directory | |
models_dir_files = [f for f in files if f.startswith('models/')] | |
return { | |
"success": True, | |
"space_info": { | |
"name": space_name, | |
"model_files": model_files, | |
"models_directory": len(models_dir_files) > 0, | |
"total_files": len(files) | |
}, | |
"models": model_files, | |
"message": f"Space {space_name} is accessible" | |
} | |
except Exception as e: | |
# Space exists but we can't list files (might be private or no access) | |
return { | |
"success": True, | |
"space_info": {"name": space_name}, | |
"models": [], | |
"message": f"Space {space_name} exists but file listing not available (might be private)" | |
} | |
except Exception as e: | |
return {"success": False, "error": f"Error accessing Hugging Face: {str(e)}"} | |
except Exception as e: | |
logger.error(f"Error testing Space: {e}") | |
return {"success": False, "error": f"Test failed: {str(e)}"} | |
async def list_trained_students(): | |
"""List available trained student models for retraining""" | |
try: | |
models_dir = Path("models") | |
trained_students = [] | |
if models_dir.exists(): | |
for model_dir in models_dir.iterdir(): | |
if model_dir.is_dir(): | |
try: | |
# Check if it's a trained student model | |
config_files = list(model_dir.glob("*config.json")) | |
history_files = list(model_dir.glob("*training_history.json")) | |
if config_files: | |
with open(config_files[0], 'r') as f: | |
config = json.load(f) | |
if config.get('is_student_model', False): | |
history = {} | |
if history_files: | |
with open(history_files[0], 'r') as f: | |
history = json.load(f) | |
model_info = { | |
"id": model_dir.name, | |
"name": model_dir.name, | |
"path": str(model_dir), | |
"type": "trained_student", | |
"created_at": config.get('created_at', 'unknown'), | |
"architecture": config.get('architecture', 'unknown'), | |
"modalities": config.get('modalities', ['text']), | |
"can_be_retrained": config.get('can_be_retrained', True), | |
"original_teachers": history.get('retraining_info', {}).get('original_teachers', []), | |
"training_sessions": len(history.get('training_sessions', [])), | |
"last_training": history.get('training_sessions', [{}])[-1].get('timestamp', 'unknown') if history.get('training_sessions') else 'unknown' | |
} | |
trained_students.append(model_info) | |
except Exception as e: | |
logger.warning(f"Error reading model {model_dir}: {e}") | |
continue | |
return {"trained_students": trained_students} | |
except Exception as e: | |
logger.error(f"Error listing trained students: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def list_models(): | |
"""List available models""" | |
models = [] | |
# List uploaded models | |
uploads_dir = Path("uploads") | |
if uploads_dir.exists(): | |
for file_path in uploads_dir.iterdir(): | |
if file_path.is_file(): | |
try: | |
info = await model_loader.get_model_info(str(file_path)) | |
models.append(ModelInfo( | |
name=file_path.stem, | |
size=file_path.stat().st_size, | |
format=file_path.suffix[1:], | |
modality=info.get("modality", "unknown"), | |
architecture=info.get("architecture") | |
)) | |
except Exception as e: | |
logger.warning(f"Error getting info for {file_path}: {e}") | |
return models | |
async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
"""WebSocket endpoint for real-time training updates""" | |
await websocket.accept() | |
active_connections[session_id] = websocket | |
try: | |
# Send current status if session exists | |
if session_id in training_sessions: | |
await websocket.send_json({ | |
"type": "training_update", | |
"data": training_sessions[session_id] | |
}) | |
# Keep connection alive | |
while True: | |
await websocket.receive_text() | |
except WebSocketDisconnect: | |
if session_id in active_connections: | |
del active_connections[session_id] | |
except Exception as e: | |
logger.error(f"WebSocket error for session {session_id}: {e}") | |
if session_id in active_connections: | |
del active_connections[session_id] | |
# ==================== NEW ADVANCED ENDPOINTS ==================== | |
# Token Management Endpoints | |
async def token_management_page(request: Request): | |
"""Token management page""" | |
return templates.TemplateResponse("token-management.html", {"request": request}) | |
async def save_token( | |
name: str = Form(...), | |
token: str = Form(...), | |
token_type: str = Form("read"), | |
description: str = Form(""), | |
is_default: bool = Form(False) | |
): | |
"""Save HF token""" | |
try: | |
success = token_manager.save_token(name, token, token_type, description, is_default) | |
if success: | |
return {"success": True, "message": f"Token '{name}' saved successfully"} | |
else: | |
raise HTTPException(status_code=400, detail="Failed to save token") | |
except Exception as e: | |
logger.error(f"Error saving token: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def list_tokens(): | |
"""List all saved tokens""" | |
try: | |
tokens = token_manager.list_tokens() | |
return {"tokens": tokens} | |
except Exception as e: | |
logger.error(f"Error listing tokens: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def delete_token(token_name: str): | |
"""Delete a token""" | |
try: | |
success = token_manager.delete_token(token_name) | |
if success: | |
return {"success": True, "message": f"Token '{token_name}' deleted"} | |
else: | |
raise HTTPException(status_code=404, detail="Token not found") | |
except Exception as e: | |
logger.error(f"Error deleting token: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def set_default_token(token_name: str): | |
"""Set token as default""" | |
try: | |
success = token_manager.set_default_token(token_name) | |
if success: | |
return {"success": True, "message": f"Token '{token_name}' set as default"} | |
else: | |
raise HTTPException(status_code=404, detail="Token not found") | |
except Exception as e: | |
logger.error(f"Error setting default token: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def validate_token(token: str = Form(...)): | |
"""Validate HF token""" | |
try: | |
result = token_manager.validate_token(token) | |
return result | |
except Exception as e: | |
logger.error(f"Error validating token: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_token_for_task(task_type: str): | |
"""Get appropriate token for specific task""" | |
try: | |
# Get token for task | |
token = token_manager.get_token_for_task(task_type) | |
if not token: | |
raise HTTPException(status_code=404, detail=f"No suitable token found for task: {task_type}") | |
# Get token information | |
tokens = token_manager.list_tokens() | |
token_info = None | |
# Find which token was selected | |
for t in tokens: | |
test_token = token_manager.get_token(t['name']) | |
if test_token == token: | |
token_info = t | |
break | |
if not token_info: | |
# Token from environment variable | |
token_info = { | |
'name': f'{task_type}_token', | |
'type': task_type, | |
'description': f'رمز من متغيرات البيئة للمهمة: {task_type}', | |
'last_used': None, | |
'usage_count': 0 | |
} | |
# Get token type information | |
type_info = token_manager.token_types.get(token_info['type'], {}) | |
return { | |
"success": True, | |
"task_type": task_type, | |
"token_info": { | |
"token_name": token_info['name'], | |
"type": token_info['type'], | |
"type_name": type_info.get('name', token_info['type']), | |
"description": token_info['description'], | |
"security_level": type_info.get('security_level', 'medium'), | |
"recommended_for": type_info.get('recommended_for', 'general'), | |
"last_used": token_info.get('last_used'), | |
"usage_count": token_info.get('usage_count', 0) | |
} | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error getting token for task {task_type}: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Medical Dataset Endpoints | |
async def medical_datasets_page(request: Request): | |
"""Medical datasets management page""" | |
return templates.TemplateResponse("medical-datasets.html", {"request": request}) | |
async def list_medical_datasets(): | |
"""List supported medical datasets""" | |
try: | |
datasets = medical_dataset_manager.list_supported_datasets() | |
return {"datasets": datasets} | |
except Exception as e: | |
logger.error(f"Error listing medical datasets: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def load_medical_dataset( | |
dataset_name: str = Form(...), | |
streaming: bool = Form(True), | |
split: str = Form("train") | |
): | |
"""Load medical dataset""" | |
try: | |
# Get appropriate token for medical datasets (fine-grained preferred) | |
hf_token = token_manager.get_token_for_task('medical') | |
if not hf_token: | |
logger.warning("No suitable token found for medical datasets, trying default") | |
hf_token = token_manager.get_token() | |
dataset_info = await medical_dataset_manager.load_dataset( | |
dataset_name=dataset_name, | |
streaming=streaming, | |
split=split, | |
token=hf_token | |
) | |
return { | |
"success": True, | |
"dataset_info": { | |
"name": dataset_info['config']['name'], | |
"size_gb": dataset_info['config']['size_gb'], | |
"num_samples": dataset_info['config']['num_samples'], | |
"streaming": dataset_info['streaming'] | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error loading medical dataset: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Memory and Performance Endpoints | |
async def get_memory_info(): | |
"""Get current memory information""" | |
try: | |
memory_info = memory_manager.get_memory_info() | |
return memory_info | |
except Exception as e: | |
logger.error(f"Error getting memory info: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_performance_info(): | |
"""Get system performance information""" | |
try: | |
memory_info = memory_manager.get_memory_info() | |
recommendations = memory_manager.get_memory_recommendations() | |
return { | |
"memory": memory_info, | |
"recommendations": recommendations, | |
"cpu_cores": cpu_optimizer.cpu_count, | |
"optimizations_applied": cpu_optimizer.optimizations_applied | |
} | |
except Exception as e: | |
logger.error(f"Error getting performance info: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def force_memory_cleanup(): | |
"""Force memory cleanup""" | |
try: | |
memory_manager.force_cleanup() | |
return {"success": True, "message": "Memory cleanup completed"} | |
except Exception as e: | |
logger.error(f"Error during memory cleanup: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Google Models Support | |
async def list_google_models(): | |
"""List available Google models""" | |
try: | |
google_models = [ | |
{ | |
"name": "google/medsiglip-448", | |
"description": "Medical SigLIP model for medical image-text understanding", | |
"type": "vision-language", | |
"size_gb": 1.1, | |
"modality": "multimodal", | |
"medical_specialized": True | |
}, | |
{ | |
"name": "google/gemma-3n-E4B-it", | |
"description": "Gemma 3 model for instruction following", | |
"type": "language", | |
"size_gb": 8.5, | |
"modality": "text", | |
"medical_specialized": False | |
} | |
] | |
return {"models": google_models} | |
except Exception as e: | |
logger.error(f"Error listing Google models: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Database Management API Endpoints | |
async def get_all_databases(): | |
"""Get all configured databases""" | |
try: | |
databases = platform_db_manager.get_all_databases() | |
selected = platform_db_manager.get_selected_databases() | |
return { | |
"success": True, | |
"databases": databases, | |
"selected": selected, | |
"total": len(databases) | |
} | |
except Exception as e: | |
logger.error(f"Error getting databases: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def search_databases(request: DatabaseSearchRequest): | |
"""Search for databases on Hugging Face""" | |
try: | |
results = await platform_db_manager.search_huggingface_datasets( | |
query=request.query, | |
limit=request.limit | |
) | |
return { | |
"success": True, | |
"results": results, | |
"count": len(results), | |
"query": request.query | |
} | |
except Exception as e: | |
logger.error(f"Error searching databases: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def add_database(database_info: DatabaseInfo): | |
"""Add a new database to the configuration""" | |
try: | |
success = await platform_db_manager.add_database(database_info.dict()) | |
if success: | |
return { | |
"success": True, | |
"message": f"Database {database_info.dataset_id} added successfully" | |
} | |
else: | |
raise HTTPException(status_code=400, detail="Failed to add database") | |
except Exception as e: | |
logger.error(f"Error adding database: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def validate_database(dataset_id: str): | |
"""Validate a dataset""" | |
try: | |
validation_result = await platform_db_manager.validate_dataset(dataset_id) | |
return { | |
"success": True, | |
"validation": validation_result, | |
"dataset_id": dataset_id | |
} | |
except Exception as e: | |
logger.error(f"Error validating database: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def select_databases(request: DatabaseSelectionRequest): | |
"""Select databases for use""" | |
try: | |
results = [] | |
for database_id in request.database_ids: | |
success = platform_db_manager.select_database(database_id) | |
results.append({ | |
"database_id": database_id, | |
"success": success | |
}) | |
return { | |
"success": True, | |
"results": results, | |
"selected": platform_db_manager.get_selected_databases() | |
} | |
except Exception as e: | |
logger.error(f"Error selecting databases: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def remove_database(database_id: str): | |
"""Remove a database from configuration""" | |
try: | |
success = platform_db_manager.remove_database(database_id) | |
if success: | |
return { | |
"success": True, | |
"message": f"Database {database_id} removed successfully" | |
} | |
else: | |
raise HTTPException(status_code=400, detail="Failed to remove database") | |
except Exception as e: | |
logger.error(f"Error removing database: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_database_info(database_id: str): | |
"""Get detailed information about a specific database""" | |
try: | |
database_info = platform_db_manager.get_database_info(database_id) | |
if database_info: | |
return { | |
"success": True, | |
"database": database_info | |
} | |
else: | |
raise HTTPException(status_code=404, detail="Database not found") | |
except Exception as e: | |
logger.error(f"Error getting database info: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_databases_by_category(category: str): | |
"""Get databases filtered by category""" | |
try: | |
databases = platform_db_manager.get_databases_by_category(category) | |
return { | |
"success": True, | |
"databases": databases, | |
"category": category, | |
"count": len(databases) | |
} | |
except Exception as e: | |
logger.error(f"Error getting databases by category: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def load_selected_databases(max_samples: int = 1000): | |
"""Load data from selected databases""" | |
try: | |
loaded_data = await platform_db_manager.load_selected_datasets(max_samples) | |
return { | |
"success": True, | |
"loaded_datasets": loaded_data, | |
"total_datasets": len(loaded_data) | |
} | |
except Exception as e: | |
logger.error(f"Error loading selected databases: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Models Management API Endpoints | |
async def get_all_models(): | |
"""Get all configured models""" | |
try: | |
models = models_manager.get_all_models() | |
teachers = models_manager.get_selected_teachers() | |
student = models_manager.get_selected_student() | |
return { | |
"success": True, | |
"models": models, | |
"selected_teachers": teachers, | |
"selected_student": student, | |
"total": len(models) | |
} | |
except Exception as e: | |
logger.error(f"Error getting models: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_teacher_models(): | |
"""Get all teacher models""" | |
try: | |
teachers = models_manager.get_teacher_models() | |
selected = models_manager.get_selected_teachers() | |
return { | |
"success": True, | |
"teachers": teachers, | |
"selected": selected, | |
"total": len(teachers) | |
} | |
except Exception as e: | |
logger.error(f"Error getting teacher models: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_student_models(): | |
"""Get all student models""" | |
try: | |
students = models_manager.get_student_models() | |
selected = models_manager.get_selected_student() | |
return { | |
"success": True, | |
"students": students, | |
"selected": selected, | |
"total": len(students) | |
} | |
except Exception as e: | |
logger.error(f"Error getting student models: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def search_models(request: ModelSearchRequest): | |
"""Search for models on Hugging Face""" | |
try: | |
results = await models_manager.search_huggingface_models( | |
query=request.query, | |
limit=request.limit, | |
model_type=request.model_type | |
) | |
return { | |
"success": True, | |
"results": results, | |
"count": len(results), | |
"query": request.query | |
} | |
except Exception as e: | |
logger.error(f"Error searching models: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def add_model(model_info: Dict[str, Any]): | |
"""Add a new model to the configuration""" | |
try: | |
success = await models_manager.add_model(model_info) | |
if success: | |
return { | |
"success": True, | |
"message": f"Model {model_info.get('model_id')} added successfully" | |
} | |
else: | |
raise HTTPException(status_code=400, detail="Failed to add model") | |
except Exception as e: | |
logger.error(f"Error adding model: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def validate_model(model_id: str): | |
"""Validate a model""" | |
try: | |
validation_result = await models_manager.validate_model(model_id) | |
return { | |
"success": True, | |
"validation": validation_result, | |
"model_id": model_id | |
} | |
except Exception as e: | |
logger.error(f"Error validating model: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def select_models(request: ModelSelectionRequest): | |
"""Select teacher and student models""" | |
try: | |
results = [] | |
# Select teacher models | |
for teacher_id in request.teacher_models: | |
success = models_manager.select_teacher(teacher_id) | |
results.append({ | |
"model_id": teacher_id, | |
"type": "teacher", | |
"success": success | |
}) | |
# Select student model | |
if request.student_model is not None: | |
success = models_manager.select_student(request.student_model) | |
results.append({ | |
"model_id": request.student_model, | |
"type": "student", | |
"success": success | |
}) | |
return { | |
"success": True, | |
"results": results, | |
"selected_teachers": models_manager.get_selected_teachers(), | |
"selected_student": models_manager.get_selected_student() | |
} | |
except Exception as e: | |
logger.error(f"Error selecting models: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def remove_model(model_id: str): | |
"""Remove a model from configuration""" | |
try: | |
success = models_manager.remove_model(model_id) | |
if success: | |
return { | |
"success": True, | |
"message": f"Model {model_id} removed successfully" | |
} | |
else: | |
raise HTTPException(status_code=400, detail="Failed to remove model") | |
except Exception as e: | |
logger.error(f"Error removing model: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_model_info(model_id: str): | |
"""Get detailed information about a specific model""" | |
try: | |
model_info = models_manager.get_model_info(model_id) | |
if model_info: | |
return { | |
"success": True, | |
"model": model_info | |
} | |
else: | |
raise HTTPException(status_code=404, detail="Model not found") | |
except Exception as e: | |
logger.error(f"Error getting model info: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
uvicorn.run( | |
"app:app", | |
host="0.0.0.0", | |
port=int(os.getenv("PORT", 7860)), | |
reload=False, | |
log_level="info" | |
) | |