Spaces:
Running
Running
""" | |
Models Management System for Knowledge Distillation Platform | |
نظام إدارة النماذج لمنصة تقطير المعرفة | |
""" | |
import json | |
import logging | |
import os | |
from pathlib import Path | |
from typing import Dict, List, Any, Optional | |
from datetime import datetime | |
import asyncio | |
from huggingface_hub import list_models, model_info | |
logger = logging.getLogger(__name__) | |
class ModelsManager: | |
""" | |
Comprehensive models management system for the platform | |
نظام إدارة النماذج الشامل للمنصة | |
""" | |
def __init__(self, storage_path: str = "data/models"): | |
self.storage_path = Path(storage_path) | |
self.storage_path.mkdir(parents=True, exist_ok=True) | |
self.config_file = self.storage_path / "models_config.json" | |
self.selected_teachers_file = self.storage_path / "selected_teachers.json" | |
self.selected_student_file = self.storage_path / "selected_student.json" | |
# Load existing configuration | |
self.models_config = self._load_config() | |
self.selected_teachers = self._load_selected_teachers() | |
self.selected_student = self._load_selected_student() | |
logger.info(f"Models Manager initialized with {len(self.models_config)} configured models") | |
def _load_config(self) -> Dict[str, Any]: | |
"""Load models configuration""" | |
try: | |
if self.config_file.exists(): | |
with open(self.config_file, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
else: | |
# Initialize with default models | |
default_config = self._get_default_models() | |
self._save_config(default_config) | |
return default_config | |
except Exception as e: | |
logger.error(f"Error loading models config: {e}") | |
return {} | |
def _save_config(self, config: Dict[str, Any]): | |
"""Save models configuration""" | |
try: | |
with open(self.config_file, 'w', encoding='utf-8') as f: | |
json.dump(config, f, indent=2, ensure_ascii=False) | |
except Exception as e: | |
logger.error(f"Error saving models config: {e}") | |
def _load_selected_teachers(self) -> List[str]: | |
"""Load selected teacher models list""" | |
try: | |
if self.selected_teachers_file.exists(): | |
with open(self.selected_teachers_file, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
else: | |
return [] | |
except Exception as e: | |
logger.error(f"Error loading selected teachers: {e}") | |
return [] | |
def _save_selected_teachers(self): | |
"""Save selected teacher models list""" | |
try: | |
with open(self.selected_teachers_file, 'w', encoding='utf-8') as f: | |
json.dump(self.selected_teachers, f, indent=2, ensure_ascii=False) | |
except Exception as e: | |
logger.error(f"Error saving selected teachers: {e}") | |
def _load_selected_student(self) -> Optional[str]: | |
"""Load selected student model""" | |
try: | |
if self.selected_student_file.exists(): | |
with open(self.selected_student_file, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
return data.get('student_model') | |
else: | |
return None | |
except Exception as e: | |
logger.error(f"Error loading selected student: {e}") | |
return None | |
def _save_selected_student(self): | |
"""Save selected student model""" | |
try: | |
with open(self.selected_student_file, 'w', encoding='utf-8') as f: | |
json.dump({'student_model': self.selected_student}, f, indent=2, ensure_ascii=False) | |
except Exception as e: | |
logger.error(f"Error saving selected student: {e}") | |
def _get_default_models(self) -> Dict[str, Any]: | |
"""Get default models configuration""" | |
return { | |
"google/bert-base-uncased": { | |
"name": "BERT Base Uncased", | |
"name_ar": "بيرت الأساسي", | |
"model_id": "google/bert-base-uncased", | |
"category": "text", | |
"type": "teacher", | |
"description": "BERT base model for text understanding", | |
"description_ar": "نموذج بيرت الأساسي لفهم النصوص", | |
"size": "~440MB", | |
"language": "English", | |
"modality": "text", | |
"architecture": "transformer", | |
"license": "Apache 2.0", | |
"added_date": datetime.now().isoformat(), | |
"status": "available", | |
"parameters": "110M" | |
}, | |
"microsoft/DialoGPT-medium": { | |
"name": "DialoGPT Medium", | |
"name_ar": "ديالو جي بي تي متوسط", | |
"model_id": "microsoft/DialoGPT-medium", | |
"category": "text", | |
"type": "teacher", | |
"description": "Conversational AI model", | |
"description_ar": "نموذج ذكاء اصطناعي للمحادثة", | |
"size": "~1.2GB", | |
"language": "English", | |
"modality": "text", | |
"architecture": "gpt", | |
"license": "MIT", | |
"added_date": datetime.now().isoformat(), | |
"status": "available", | |
"parameters": "345M" | |
}, | |
"google/vit-base-patch16-224": { | |
"name": "Vision Transformer Base", | |
"name_ar": "محول الرؤية الأساسي", | |
"model_id": "google/vit-base-patch16-224", | |
"category": "vision", | |
"type": "teacher", | |
"description": "Vision Transformer for image classification", | |
"description_ar": "محول الرؤية لتصنيف الصور", | |
"size": "~330MB", | |
"language": "Universal", | |
"modality": "vision", | |
"architecture": "transformer", | |
"license": "Apache 2.0", | |
"added_date": datetime.now().isoformat(), | |
"status": "available", | |
"parameters": "86M" | |
} | |
} | |
async def search_huggingface_models(self, query: str, limit: int = 20, model_type: str = None) -> List[Dict[str, Any]]: | |
"""Search for models on Hugging Face""" | |
try: | |
logger.info(f"Searching Hugging Face for models: {query}") | |
# Search models | |
models = list_models(search=query, limit=limit) | |
results = [] | |
for model in models: | |
try: | |
# Get model info | |
info = model_info(model.modelId) | |
model_data = { | |
"id": model.modelId, | |
"name": model.modelId.split('/')[-1], | |
"author": model.modelId.split('/')[0] if '/' in model.modelId else 'unknown', | |
"description": getattr(info, 'description', 'No description available'), | |
"tags": getattr(info, 'tags', []), | |
"downloads": getattr(info, 'downloads', 0), | |
"likes": getattr(info, 'likes', 0), | |
"created_at": getattr(info, 'created_at', None), | |
"last_modified": getattr(info, 'last_modified', None), | |
"pipeline_tag": getattr(info, 'pipeline_tag', 'unknown'), | |
"library_name": getattr(info, 'library_name', 'unknown') | |
} | |
# Filter by model type if specified | |
if model_type: | |
pipeline_tag = model_data.get('pipeline_tag', '').lower() | |
if model_type == 'text' and pipeline_tag not in ['text-classification', 'text-generation', 'fill-mask', 'question-answering']: | |
continue | |
elif model_type == 'vision' and pipeline_tag not in ['image-classification', 'object-detection', 'image-segmentation']: | |
continue | |
elif model_type == 'audio' and pipeline_tag not in ['automatic-speech-recognition', 'audio-classification']: | |
continue | |
results.append(model_data) | |
except Exception as e: | |
logger.warning(f"Error processing model {model.modelId}: {e}") | |
continue | |
logger.info(f"Found {len(results)} models") | |
return results | |
except Exception as e: | |
logger.error(f"Error searching Hugging Face models: {e}") | |
return [] | |
async def add_model(self, model_info: Dict[str, Any]) -> bool: | |
"""Add a new model to the configuration""" | |
try: | |
model_id = model_info.get('model_id') or model_info.get('id') | |
if not model_id: | |
raise ValueError("Model ID is required") | |
# Validate model exists and is accessible | |
validation_result = await self.validate_model(model_id) | |
if not validation_result['valid']: | |
raise ValueError(f"Model validation failed: {validation_result['error']}") | |
# Prepare model configuration | |
config = { | |
"name": model_info.get('name', model_id.split('/')[-1]), | |
"name_ar": model_info.get('name_ar', ''), | |
"model_id": model_id, | |
"category": model_info.get('category', 'text'), | |
"type": model_info.get('type', 'teacher'), | |
"description": model_info.get('description', ''), | |
"description_ar": model_info.get('description_ar', ''), | |
"size": model_info.get('size', 'Unknown'), | |
"language": model_info.get('language', 'Unknown'), | |
"modality": model_info.get('modality', 'text'), | |
"architecture": model_info.get('architecture', 'unknown'), | |
"license": model_info.get('license', 'Unknown'), | |
"added_date": datetime.now().isoformat(), | |
"status": "available", | |
"parameters": model_info.get('parameters', 'Unknown'), | |
"validation": validation_result | |
} | |
# Add to configuration | |
self.models_config[model_id] = config | |
self._save_config(self.models_config) | |
logger.info(f"Added model: {model_id}") | |
return True | |
except Exception as e: | |
logger.error(f"Error adding model: {e}") | |
return False | |
async def validate_model(self, model_id: str) -> Dict[str, Any]: | |
"""Validate that a model exists and is accessible""" | |
try: | |
logger.info(f"Validating model: {model_id}") | |
# Try to get model info | |
info = model_info(model_id) | |
return { | |
"valid": True, | |
"pipeline_tag": getattr(info, 'pipeline_tag', 'unknown'), | |
"library_name": getattr(info, 'library_name', 'unknown'), | |
"accessible": True, | |
"error": None | |
} | |
except Exception as e: | |
logger.warning(f"Model validation failed for {model_id}: {e}") | |
return { | |
"valid": False, | |
"pipeline_tag": None, | |
"library_name": None, | |
"accessible": False, | |
"error": str(e) | |
} | |
def get_all_models(self) -> Dict[str, Any]: | |
"""Get all configured models""" | |
return self.models_config | |
def get_teacher_models(self) -> Dict[str, Any]: | |
"""Get all teacher models""" | |
return { | |
model_id: model_info | |
for model_id, model_info in self.models_config.items() | |
if model_info.get('type') == 'teacher' | |
} | |
def get_student_models(self) -> Dict[str, Any]: | |
"""Get all student models""" | |
return { | |
model_id: model_info | |
for model_id, model_info in self.models_config.items() | |
if model_info.get('type') == 'student' | |
} | |
def get_selected_teachers(self) -> List[str]: | |
"""Get list of selected teacher model IDs""" | |
return self.selected_teachers | |
def get_selected_student(self) -> Optional[str]: | |
"""Get selected student model ID""" | |
return self.selected_student | |
def select_teacher(self, model_id: str) -> bool: | |
"""Select a teacher model""" | |
try: | |
if model_id not in self.models_config: | |
raise ValueError(f"Model {model_id} not found in configuration") | |
model_info = self.models_config[model_id] | |
if model_info.get('type') != 'teacher': | |
raise ValueError(f"Model {model_id} is not a teacher model") | |
if model_id not in self.selected_teachers: | |
self.selected_teachers.append(model_id) | |
self._save_selected_teachers() | |
logger.info(f"Selected teacher model: {model_id}") | |
return True | |
except Exception as e: | |
logger.error(f"Error selecting teacher model: {e}") | |
return False | |
def deselect_teacher(self, model_id: str) -> bool: | |
"""Deselect a teacher model""" | |
try: | |
if model_id in self.selected_teachers: | |
self.selected_teachers.remove(model_id) | |
self._save_selected_teachers() | |
logger.info(f"Deselected teacher model: {model_id}") | |
return True | |
except Exception as e: | |
logger.error(f"Error deselecting teacher model: {e}") | |
return False | |
def select_student(self, model_id: str = None) -> bool: | |
"""Select a student model (None for training from scratch)""" | |
try: | |
if model_id and model_id not in self.models_config: | |
raise ValueError(f"Model {model_id} not found in configuration") | |
if model_id: | |
model_info = self.models_config[model_id] | |
if model_info.get('type') not in ['student', 'teacher']: # Teachers can be used as base for students | |
raise ValueError(f"Model {model_id} cannot be used as student model") | |
self.selected_student = model_id | |
self._save_selected_student() | |
if model_id: | |
logger.info(f"Selected student model: {model_id}") | |
else: | |
logger.info("Selected training from scratch (no base student model)") | |
return True | |
except Exception as e: | |
logger.error(f"Error selecting student model: {e}") | |
return False | |
def remove_model(self, model_id: str) -> bool: | |
"""Remove a model from configuration""" | |
try: | |
if model_id in self.models_config: | |
del self.models_config[model_id] | |
self._save_config(self.models_config) | |
if model_id in self.selected_teachers: | |
self.selected_teachers.remove(model_id) | |
self._save_selected_teachers() | |
if self.selected_student == model_id: | |
self.selected_student = None | |
self._save_selected_student() | |
logger.info(f"Removed model: {model_id}") | |
return True | |
except Exception as e: | |
logger.error(f"Error removing model: {e}") | |
return False | |
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]: | |
"""Get detailed information about a specific model""" | |
return self.models_config.get(model_id) | |
def get_models_by_category(self, category: str) -> Dict[str, Any]: | |
"""Get models filtered by category""" | |
return { | |
model_id: model_info | |
for model_id, model_info in self.models_config.items() | |
if model_info.get('category') == category | |
} | |
def get_models_by_modality(self, modality: str) -> Dict[str, Any]: | |
"""Get models filtered by modality""" | |
return { | |
model_id: model_info | |
for model_id, model_info in self.models_config.items() | |
if model_info.get('modality') == modality | |
} | |