|
import gradio as gr |
|
import os |
|
import subprocess |
|
import sys |
|
import requests |
|
import json |
|
import logging |
|
from typing import Dict, List, Optional, Union |
|
import time |
|
import tempfile |
|
import shutil |
|
import importlib |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def install_package(package_name): |
|
"""Installe un package Python""" |
|
try: |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name, "--quiet"]) |
|
logger.info(f"✅ {package_name} installé avec succès") |
|
return True |
|
except subprocess.CalledProcessError as e: |
|
logger.error(f"❌ Erreur installation {package_name}: {e}") |
|
return False |
|
|
|
|
|
def reload_module(module_name): |
|
"""Recharge un module après installation""" |
|
try: |
|
if module_name in sys.modules: |
|
importlib.reload(sys.modules[module_name]) |
|
else: |
|
__import__(module_name) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Erreur rechargement {module_name}: {e}") |
|
return False |
|
|
|
|
|
def check_and_import_dependencies(): |
|
"""Vérifie et importe toutes les dépendances""" |
|
global numpy, torch, NUMPY_AVAILABLE, TORCH_AVAILABLE, TRANSFORMERS_AVAILABLE |
|
global DATASETS_AVAILABLE, HF_HUB_AVAILABLE, PIL_AVAILABLE, LIBROSA_AVAILABLE, CV2_AVAILABLE |
|
global AutoTokenizer, AutoModel, AutoProcessor, AutoModelForCausalLM, AutoConfig |
|
global TrainingArguments, Trainer, DataCollatorForLanguageModeling |
|
global Dataset, load_dataset, concatenate_datasets, HfApi, Image, librosa, cv2 |
|
|
|
|
|
try: |
|
import numpy |
|
NUMPY_AVAILABLE = True |
|
except ImportError: |
|
numpy = None |
|
NUMPY_AVAILABLE = False |
|
|
|
|
|
try: |
|
import torch |
|
TORCH_AVAILABLE = True |
|
except ImportError: |
|
torch = None |
|
TORCH_AVAILABLE = False |
|
|
|
|
|
try: |
|
from transformers import ( |
|
AutoTokenizer, AutoModel, AutoProcessor, AutoConfig, |
|
AutoModelForCausalLM, TrainingArguments, Trainer, |
|
DataCollatorForLanguageModeling |
|
) |
|
TRANSFORMERS_AVAILABLE = True |
|
except ImportError: |
|
TRANSFORMERS_AVAILABLE = False |
|
AutoTokenizer = AutoModel = AutoProcessor = AutoConfig = None |
|
AutoModelForCausalLM = TrainingArguments = Trainer = None |
|
DataCollatorForLanguageModeling = None |
|
|
|
|
|
try: |
|
from datasets import Dataset, load_dataset, concatenate_datasets |
|
DATASETS_AVAILABLE = True |
|
except ImportError: |
|
DATASETS_AVAILABLE = False |
|
Dataset = load_dataset = concatenate_datasets = None |
|
|
|
|
|
try: |
|
from huggingface_hub import HfApi |
|
HF_HUB_AVAILABLE = True |
|
except ImportError: |
|
HF_HUB_AVAILABLE = False |
|
HfApi = None |
|
|
|
|
|
try: |
|
from PIL import Image |
|
PIL_AVAILABLE = True |
|
except ImportError: |
|
PIL_AVAILABLE = False |
|
Image = None |
|
|
|
|
|
try: |
|
import librosa |
|
LIBROSA_AVAILABLE = True |
|
except ImportError: |
|
LIBROSA_AVAILABLE = False |
|
librosa = None |
|
|
|
|
|
try: |
|
import cv2 |
|
CV2_AVAILABLE = True |
|
except ImportError: |
|
CV2_AVAILABLE = False |
|
cv2 = None |
|
|
|
|
|
check_and_import_dependencies() |
|
|
|
class MultimodalTrainer: |
|
def __init__(self): |
|
self.current_model = None |
|
self.current_tokenizer = None |
|
self.current_processor = None |
|
self.training_data = [] |
|
|
|
|
|
if TORCH_AVAILABLE and torch and torch.cuda.is_available(): |
|
self.device = torch.device("cuda") |
|
else: |
|
self.device = "cpu" |
|
|
|
|
|
if HF_HUB_AVAILABLE and HfApi: |
|
self.hf_api = HfApi() |
|
else: |
|
self.hf_api = None |
|
|
|
def install_dependencies(self, packages_to_install): |
|
"""Installe les dépendances manquantes""" |
|
installation_results = [] |
|
|
|
|
|
package_mapping = { |
|
"torch": "torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cpu", |
|
"transformers": "transformers>=4.46.2", |
|
"datasets": "datasets>=2.21.0", |
|
"accelerate": "accelerate>=1.1.0", |
|
"pillow": "pillow>=10.1.0", |
|
"librosa": "librosa>=0.10.1", |
|
"opencv": "opencv-python-headless>=4.8.1.78", |
|
"huggingface_hub": "huggingface_hub>=0.26.0", |
|
"qwen": "qwen-vl-utils>=0.0.8" |
|
} |
|
|
|
for package in packages_to_install: |
|
installation_results.append(f"📦 Installation de {package}...") |
|
|
|
|
|
install_cmd = package_mapping.get(package.lower(), package) |
|
|
|
if package.lower() == "torch": |
|
|
|
try: |
|
subprocess.check_call([ |
|
sys.executable, "-m", "pip", "install", |
|
"torch==2.1.0", "torchvision==0.16.0", "torchaudio==2.1.0", |
|
"--index-url", "https://download.pytorch.org/whl/cpu", |
|
"--quiet" |
|
]) |
|
success = True |
|
except subprocess.CalledProcessError: |
|
success = False |
|
else: |
|
success = install_package(install_cmd) |
|
|
|
if success: |
|
installation_results.append(f"✅ {package} installé avec succès!") |
|
else: |
|
installation_results.append(f"❌ Échec installation {package}") |
|
|
|
|
|
installation_results.append("\n🔄 Rechargement des modules...") |
|
check_and_import_dependencies() |
|
self.__init__() |
|
|
|
installation_results.append("✅ Modules rechargés!") |
|
return "\n".join(installation_results) |
|
|
|
def check_dependencies(self): |
|
"""Vérifie et affiche l'état des dépendances""" |
|
|
|
check_and_import_dependencies() |
|
|
|
deps = { |
|
"PyTorch": TORCH_AVAILABLE, |
|
"Transformers": TRANSFORMERS_AVAILABLE, |
|
"Datasets": DATASETS_AVAILABLE, |
|
"NumPy": NUMPY_AVAILABLE, |
|
"HuggingFace Hub": HF_HUB_AVAILABLE, |
|
"PIL": PIL_AVAILABLE, |
|
"Librosa": LIBROSA_AVAILABLE, |
|
"OpenCV": CV2_AVAILABLE |
|
} |
|
|
|
status = "📦 État des dépendances:\n\n" |
|
|
|
|
|
critical_deps = ["PyTorch", "Transformers", "Datasets"] |
|
status += "🔥 CRITIQUES:\n" |
|
for dep in critical_deps: |
|
icon = "✅" if deps.get(dep) else "❌" |
|
status += f"{icon} {dep}\n" |
|
|
|
status += "\n🔧 OPTIONNELLES:\n" |
|
optional_deps = ["NumPy", "HuggingFace Hub", "PIL", "Librosa", "OpenCV"] |
|
for dep in optional_deps: |
|
icon = "✅" if deps.get(dep) else "⚠️" |
|
status += f"{icon} {dep}\n" |
|
|
|
|
|
status += f"\n💻 SYSTÈME:\n" |
|
status += f"🐍 Python: {sys.version.split()[0]}\n" |
|
status += f"💾 Device: {self.device}\n" |
|
|
|
if TORCH_AVAILABLE and torch and torch.cuda.is_available(): |
|
status += f"🚀 GPU: {torch.cuda.get_device_name()}\n" |
|
status += f"🔋 VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB\n" |
|
|
|
|
|
if TRANSFORMERS_AVAILABLE: |
|
import transformers |
|
status += f"🤗 Transformers: {transformers.__version__}\n" |
|
|
|
return status |
|
|
|
def load_model_safe(self, model_name: str): |
|
"""Chargement sécurisé du modèle avec gestion d'erreurs avancée""" |
|
if not TRANSFORMERS_AVAILABLE: |
|
return "❌ Transformers non installé! Utilisez l'outil d'installation.", None, None |
|
|
|
if not TORCH_AVAILABLE or not torch: |
|
return "❌ PyTorch non installé! Utilisez l'outil d'installation.", None, None |
|
|
|
try: |
|
logger.info(f"Chargement sécurisé du modèle: {model_name}") |
|
|
|
|
|
try: |
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
|
logger.info(f"Configuration chargée: {config.model_type}") |
|
except Exception as e: |
|
return f"❌ Erreur configuration: {str(e)}", None, None |
|
|
|
|
|
tokenizer = None |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
use_fast=False |
|
) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
logger.info("Tokenizer chargé avec succès") |
|
except Exception as e: |
|
logger.warning(f"Tokenizer non trouvé: {e}") |
|
return f"❌ Erreur tokenizer: {str(e)}", None, None |
|
|
|
|
|
model = None |
|
loading_strategies = [ |
|
{ |
|
"name": "AutoModelForCausalLM standard", |
|
"loader": lambda: AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True |
|
) |
|
}, |
|
{ |
|
"name": "AutoModelForCausalLM avec config explicite", |
|
"loader": lambda: AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
config=config, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True, |
|
attn_implementation="eager" |
|
) |
|
}, |
|
{ |
|
"name": "AutoModel générique", |
|
"loader": lambda: AutoModel.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True |
|
) |
|
} |
|
] |
|
|
|
last_error = None |
|
for strategy in loading_strategies: |
|
try: |
|
logger.info(f"Tentative: {strategy['name']}") |
|
model = strategy["loader"]() |
|
logger.info(f"✅ Succès avec: {strategy['name']}") |
|
break |
|
except Exception as e: |
|
last_error = str(e) |
|
logger.warning(f"❌ Échec {strategy['name']}: {e}") |
|
continue |
|
|
|
if model is None: |
|
return f"❌ Toutes les stratégies ont échoué. Dernière erreur: {last_error}", None, None |
|
|
|
|
|
processor = None |
|
try: |
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) |
|
logger.info("Processor chargé avec succès") |
|
except Exception as e: |
|
logger.warning(f"Processor non disponible: {e}") |
|
|
|
return "✅ Modèle chargé avec succès!", model, tokenizer, processor |
|
|
|
except Exception as e: |
|
error_msg = f"❌ Erreur critique: {str(e)}" |
|
logger.error(error_msg) |
|
return error_msg, None, None |
|
|
|
def load_model(self, model_name: str, model_type: str = "causal"): |
|
"""Charge un modèle depuis Hugging Face avec gestion d'erreurs améliorée""" |
|
if not model_name.strip(): |
|
return "❌ Veuillez entrer un nom de modèle" |
|
|
|
|
|
result = self.load_model_safe(model_name) |
|
|
|
if len(result) == 4: |
|
message, model, tokenizer, processor = result |
|
self.current_model = model |
|
self.current_tokenizer = tokenizer |
|
self.current_processor = processor |
|
|
|
|
|
info = f"{message}\n" |
|
info += f"🏷️ Type: {type(model).__name__}\n" |
|
if hasattr(model, 'config'): |
|
info += f"🏗️ Architecture: {getattr(model.config, 'architectures', ['Inconnue'])[0] if hasattr(model.config, 'architectures') else 'Inconnue'}\n" |
|
info += f"📋 Model type: {getattr(model.config, 'model_type', 'Non défini')}\n" |
|
|
|
if TORCH_AVAILABLE and torch: |
|
info += f"💾 Device: {next(model.parameters()).device}\n" |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
info += f"🔢 Paramètres: {total_params:,}\n" |
|
|
|
return info |
|
else: |
|
|
|
return result[0] |
|
|
|
def diagnose_model(self, model_name: str): |
|
"""Diagnostique avancé d'un modèle""" |
|
if not model_name.strip(): |
|
return "❌ Veuillez entrer un nom de modèle" |
|
|
|
try: |
|
result = f"🔍 DIAGNOSTIC APPROFONDI: {model_name}\n\n" |
|
|
|
|
|
try: |
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
|
result += "✅ Modèle accessible sur Hugging Face\n\n" |
|
|
|
|
|
result += "📋 CONFIGURATION:\n" |
|
result += f"🏷️ Model type: {getattr(config, 'model_type', '❌ NON DÉFINI')}\n" |
|
result += f"🏗️ Architectures: {getattr(config, 'architectures', ['❌ NON DÉFINI'])}\n" |
|
result += f"📚 Vocab size: {getattr(config, 'vocab_size', 'Inconnu'):,}\n" |
|
result += f"🧠 Hidden size: {getattr(config, 'hidden_size', 'Inconnu')}\n" |
|
result += f"🔢 Layers: {getattr(config, 'num_hidden_layers', 'Inconnu')}\n" |
|
result += f"🎯 Attention heads: {getattr(config, 'num_attention_heads', 'Inconnu')}\n" |
|
|
|
|
|
result += "\n🔧 ANALYSE DES PROBLÈMES:\n" |
|
|
|
if not hasattr(config, 'model_type') or config.model_type is None: |
|
result += "⚠️ PROBLÈME: model_type manquant\n" |
|
if hasattr(config, 'architectures') and config.architectures: |
|
arch = config.architectures[0].lower() |
|
suggested_type = None |
|
if 'qwen' in arch: |
|
suggested_type = 'qwen2' if 'qwen2' in arch else 'qwen' |
|
elif 'llama' in arch: |
|
suggested_type = 'llama' |
|
elif 'mistral' in arch: |
|
suggested_type = 'mistral' |
|
elif 'phi' in arch: |
|
suggested_type = 'phi' |
|
|
|
if suggested_type: |
|
result += f"💡 Type suggéré: {suggested_type}\n" |
|
else: |
|
result += f"✅ Model type défini: {config.model_type}\n" |
|
|
|
|
|
if hasattr(config, 'architectures') and config.architectures: |
|
arch = config.architectures[0] |
|
if 'Qwen2_5OmniForCausalLM' in arch: |
|
result += "⚠️ Architecture Qwen2.5-Omni détectée\n" |
|
result += "💡 Nécessite Transformers >= 4.45.0\n" |
|
if TRANSFORMERS_AVAILABLE: |
|
import transformers |
|
current_version = transformers.__version__ |
|
result += f"📦 Version actuelle: {current_version}\n" |
|
|
|
|
|
result += "\n🎯 STRATÉGIES DE CHARGEMENT:\n" |
|
result += "1️⃣ AutoModelForCausalLM avec trust_remote_code=True\n" |
|
result += "2️⃣ Configuration explicite si model_type manquant\n" |
|
result += "3️⃣ Fallback vers AutoModel générique\n" |
|
|
|
result += "\n✅ Diagnostic terminé - Chargement possible avec adaptations" |
|
|
|
except Exception as e: |
|
result += f"❌ Erreur d'accès: {str(e)}\n" |
|
|
|
|
|
if "404" in str(e): |
|
result += "💡 Le modèle n'existe pas ou n'est pas public\n" |
|
elif "token" in str(e).lower(): |
|
result += "💡 Un token d'authentification pourrait être nécessaire\n" |
|
else: |
|
result += "💡 Vérifiez le nom du modèle et votre connexion\n" |
|
|
|
return result |
|
|
|
except Exception as e: |
|
return f"❌ Erreur diagnostic: {str(e)}" |
|
|
|
def load_single_dataset(self, dataset_name: str, split: str = "train"): |
|
"""Charge un dataset individuel""" |
|
if not DATASETS_AVAILABLE or not load_dataset: |
|
return "❌ Datasets non installé! Utilisez l'outil d'installation." |
|
|
|
if not dataset_name.strip(): |
|
return "❌ Veuillez entrer un nom de dataset" |
|
|
|
try: |
|
dataset = load_dataset(dataset_name, split=split) |
|
|
|
if hasattr(self, 'training_data') and self.training_data: |
|
self.training_data = concatenate_datasets([self.training_data, dataset]) |
|
else: |
|
self.training_data = dataset |
|
|
|
return f"✅ Dataset {dataset_name} ajouté!\n📊 Total: {len(self.training_data)} exemples\n🔍 Colonnes: {list(self.training_data.column_names)}" |
|
|
|
except Exception as e: |
|
error_msg = f"❌ Erreur dataset: {str(e)}" |
|
logger.error(error_msg) |
|
return error_msg |
|
|
|
def simulate_training(self, output_dir: str, num_epochs: int, learning_rate: float, batch_size: int): |
|
"""Simulation d'entraînement (mode démo)""" |
|
if not self.current_model and not self.training_data: |
|
return "❌ Aucun modèle ou donnée chargé!" |
|
|
|
|
|
steps = ["🏗️ Préparation des données", "🔧 Configuration du modèle", "🚀 Début entraînement"] |
|
result = "📋 SIMULATION D'ENTRAÎNEMENT:\n\n" |
|
result += f"📂 Sortie: {output_dir}\n" |
|
result += f"🔄 Époques: {num_epochs}\n" |
|
result += f"📚 Learning rate: {learning_rate}\n" |
|
result += f"📦 Batch size: {batch_size}\n\n" |
|
|
|
for i, step in enumerate(steps): |
|
result += f"Étape {i+1}: {step} ✅\n" |
|
|
|
if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE: |
|
result += "\n✅ Prêt pour un vrai entraînement!" |
|
else: |
|
result += "\n⚠️ MODE DÉMO - Installez PyTorch + Transformers pour un vrai entraînement" |
|
return result |
|
|
|
def get_model_info(self): |
|
"""Retourne les informations du modèle actuel""" |
|
if not self.current_model: |
|
return "❌ Aucun modèle chargé" |
|
|
|
info = f"📋 INFORMATIONS DU MODÈLE:\n\n" |
|
info += f"🏷️ Type: {type(self.current_model).__name__}\n" |
|
|
|
if TORCH_AVAILABLE and torch: |
|
info += f"💾 Device: {next(self.current_model.parameters()).device}\n" |
|
|
|
|
|
total_params = sum(p.numel() for p in self.current_model.parameters()) |
|
trainable_params = sum(p.numel() for p in self.current_model.parameters() if p.requires_grad) |
|
|
|
info += f"🔢 Paramètres totaux: {total_params:,}\n" |
|
info += f"🎯 Paramètres entraînables: {trainable_params:,}\n" |
|
|
|
if hasattr(self, 'training_data') and self.training_data: |
|
info += f"\n📊 DONNÉES:\n" |
|
info += f"📈 Exemples: {len(self.training_data):,}\n" |
|
info += f"📝 Colonnes: {list(self.training_data.column_names)}\n" |
|
|
|
return info |
|
|
|
|
|
trainer = MultimodalTrainer() |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="🔥 Multimodal Training Hub", theme=gr.themes.Soft()) as app: |
|
|
|
gr.Markdown(""" |
|
# 🔥 Multimodal Training Hub |
|
### Plateforme d'entraînement de modèles multimodaux optimisée pour Qwen2.5-Omni |
|
|
|
🤖 Modèles • 📊 Datasets • 🏋️ Training • 🛠️ Outils |
|
""") |
|
|
|
with gr.Tab("🔧 Diagnostic"): |
|
gr.Markdown("### 🩺 Vérification du système et installation") |
|
|
|
with gr.Row(): |
|
check_deps_btn = gr.Button("🔍 Vérifier dépendances", variant="primary") |
|
install_core_btn = gr.Button("📦 Installer packages critiques", variant="secondary") |
|
install_qwen_btn = gr.Button("🎯 Support Qwen2.5", variant="secondary") |
|
|
|
deps_status = gr.Textbox( |
|
label="État des dépendances", |
|
lines=15, |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
install_transformers_btn = gr.Button("🤗 Installer Transformers") |
|
install_torch_btn = gr.Button("🔥 Installer PyTorch") |
|
install_datasets_btn = gr.Button("📊 Installer Datasets") |
|
|
|
install_status = gr.Textbox( |
|
label="Status d'installation", |
|
lines=8, |
|
interactive=False |
|
) |
|
|
|
|
|
check_deps_btn.click(trainer.check_dependencies, outputs=deps_status) |
|
|
|
install_transformers_btn.click( |
|
lambda: trainer.install_dependencies(["transformers"]), |
|
outputs=install_status |
|
) |
|
install_torch_btn.click( |
|
lambda: trainer.install_dependencies(["torch"]), |
|
outputs=install_status |
|
) |
|
install_datasets_btn.click( |
|
lambda: trainer.install_dependencies(["datasets"]), |
|
outputs=install_status |
|
) |
|
install_core_btn.click( |
|
lambda: trainer.install_dependencies(["torch", "transformers", "datasets", "accelerate"]), |
|
outputs=install_status |
|
) |
|
install_qwen_btn.click( |
|
lambda: trainer.install_dependencies(["transformers", "qwen"]), |
|
outputs=install_status |
|
) |
|
|
|
with gr.Tab("🤖 Modèle"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
model_input = gr.Textbox( |
|
label="Nom du modèle HuggingFace", |
|
placeholder="kvn420/Tenro_V4.1", |
|
value="kvn420/Tenro_V4.1" |
|
) |
|
model_type = gr.Dropdown( |
|
label="Type de modèle", |
|
choices=["causal", "base"], |
|
value="causal" |
|
) |
|
|
|
with gr.Row(): |
|
load_model_btn = gr.Button("🔄 Charger le modèle", variant="primary") |
|
diagnose_btn = gr.Button("🔍 Diagnostiquer", variant="secondary") |
|
|
|
gr.Markdown(""" |
|
💡 **Modèles testés:** |
|
- `kvn420/Tenro_V4.1` (Qwen2.5-Omni) |
|
- `Qwen/Qwen2.5-7B-Instruct` |
|
- `microsoft/DialoGPT-medium` |
|
""") |
|
|
|
with gr.Column(): |
|
model_status = gr.Textbox( |
|
label="Status du modèle", |
|
interactive=False, |
|
lines=10 |
|
) |
|
|
|
info_btn = gr.Button("ℹ️ Info modèle") |
|
model_info = gr.Textbox( |
|
label="Informations détaillées", |
|
interactive=False, |
|
lines=8 |
|
) |
|
|
|
load_model_btn.click( |
|
trainer.load_model, |
|
inputs=[model_input, model_type], |
|
outputs=model_status |
|
) |
|
|
|
diagnose_btn.click( |
|
trainer.diagnose_model, |
|
inputs=[model_input], |
|
outputs=model_status |
|
) |
|
|
|
info_btn.click(trainer.get_model_info, outputs=model_info) |
|
|
|
with gr.Tab("📊 Données"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### 📝 Dataset individuel") |
|
dataset_input = gr.Textbox( |
|
label="Nom du dataset", |
|
placeholder="wikitext", |
|
value="wikitext" |
|
) |
|
dataset_config = gr.Textbox( |
|
label="Configuration (optionnel)", |
|
placeholder="wikitext-2-raw-v1" |
|
) |
|
dataset_split = gr.Textbox( |
|
label="Split", |
|
value="train" |
|
) |
|
load_dataset_btn = gr.Button("➕ Ajouter dataset", variant="primary") |
|
|
|
with gr.Column(): |
|
data_status = gr.Textbox( |
|
label="Status des données", |
|
interactive=False, |
|
lines=12 |
|
) |
|
|
|
def load_dataset_with_config(dataset_name, config_name, split): |
|
if config_name.strip(): |
|
full_name = f"{dataset_name}/{config_name}" if "/" not in config_name else config_name |
|
else: |
|
full_name = dataset_name |
|
return trainer.load_single_dataset(full_name, split) |
|
|
|
load_dataset_btn.click( |
|
load_dataset_with_config, |
|
inputs=[dataset_input, dataset_config, dataset_split], |
|
outputs=data_status |
|
) |
|
|
|
with gr.Tab("🏋️ Entraînement"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
output_dir = gr.Textbox( |
|
label="Dossier de sortie", |
|
value="./trained_model" |
|
) |
|
|
|
with gr.Row(): |
|
num_epochs = gr.Number( |
|
label="Époques", |
|
value=3, |
|
minimum=1 |
|
) |
|
batch_size = gr.Number( |
|
label="Batch size", |
|
value=4, |
|
minimum=1 |
|
) |
|
|
|
learning_rate = gr.Number( |
|
label="Learning rate", |
|
value=5e-5, |
|
step=1e-6 |
|
) |
|
|
|
train_btn = gr.Button("🚀 Simuler entraînement", variant="primary", size="lg") |
|
|
|
with gr.Column(): |
|
training_status = gr.Textbox( |
|
label="Status d'entraînement", |
|
interactive=False, |
|
lines=15 |
|
) |
|
|
|
train_btn.click( |
|
trainer.simulate_training, |
|
inputs=[output_dir, num_epochs, learning_rate, batch_size], |
|
outputs=training_status |
|
) |
|
|
|
with gr.Tab("📈 Monitoring"): |
|
gr.Markdown("### 📊 Suivi de l'entraînement") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("#### 🎯 Métriques") |
|
metrics_display = gr.Textbox( |
|
label="Métriques actuelles", |
|
value="📊 Aucun entraînement en cours", |
|
interactive=False, |
|
lines=8 |
|
) |
|
|
|
refresh_metrics_btn = gr.Button("🔄 Actualiser métriques") |
|
|
|
with gr.Column(): |
|
gr.Markdown("#### 📝 Logs") |
|
logs_display = gr.Textbox( |
|
label="Logs d'entraînement", |
|
value="📋 Aucun log disponible", |
|
interactive=False, |
|
lines=8 |
|
) |
|
|
|
clear_logs_btn = gr.Button("🧹 Nettoyer logs") |
|
|
|
def get_dummy_metrics(): |
|
return "📊 MÉTRIQUES (SIMULATION):\n\n🔥 Loss: 2.34\n📈 Accuracy: 0.78\n⚡ Speed: 1.2 steps/sec\n💾 Memory: 4.2GB" |
|
|
|
def clear_logs(): |
|
return "📋 Logs nettoyés" |
|
|
|
refresh_metrics_btn.click(get_dummy_metrics, outputs=metrics_display) |
|
clear_logs_btn.click(clear_logs, outputs=logs_display) |
|
|
|
with gr.Tab("🛠️ Outils"): |
|
gr.Markdown("### 🔧 Utilitaires et outils avancés") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("#### 💾 Gestion des modèles") |
|
|
|
model_path = gr.Textbox( |
|
label="Chemin du modèle local", |
|
placeholder="/path/to/model" |
|
) |
|
|
|
with gr.Row(): |
|
save_model_btn = gr.Button("💾 Sauvegarder modèle") |
|
load_local_btn = gr.Button("📂 Charger local") |
|
|
|
gr.Markdown("#### 🧹 Nettoyage") |
|
with gr.Row(): |
|
clear_cache_btn = gr.Button("🗑️ Vider cache") |
|
reset_all_btn = gr.Button("🔄 Reset complet", variant="stop") |
|
|
|
with gr.Column(): |
|
tools_status = gr.Textbox( |
|
label="Status des outils", |
|
interactive=False, |
|
lines=12 |
|
) |
|
|
|
def save_model_placeholder(): |
|
return "💾 Fonction de sauvegarde (implémentation requise)" |
|
|
|
def load_local_placeholder(): |
|
return "📂 Fonction de chargement local (implémentation requise)" |
|
|
|
def clear_cache(): |
|
return "🗑️ Cache vidé (simulation)" |
|
|
|
def reset_all(): |
|
return "🔄 Système réinitialisé (simulation)" |
|
|
|
save_model_btn.click(save_model_placeholder, outputs=tools_status) |
|
load_local_btn.click(load_local_placeholder, outputs=tools_status) |
|
clear_cache_btn.click(clear_cache, outputs=tools_status) |
|
reset_all_btn.click(reset_all, outputs=tools_status) |
|
|
|
|
|
gr.Markdown(""" |
|
--- |
|
🔥 **Multimodal Training Hub** | Optimisé pour Qwen2.5-Omni et modèles multimodaux |
|
|
|
💡 **Conseils:** |
|
- Vérifiez les dépendances avant de commencer |
|
- Utilisez le diagnostic pour analyser les modèles |
|
- Les entraînements sont simulés sans GPU adapté |
|
""") |
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
app = create_interface() |
|
|
|
|
|
launch_kwargs = { |
|
"share": False, |
|
"server_name": "0.0.0.0", |
|
"server_port": 7860, |
|
"show_error": True, |
|
"quiet": False |
|
} |
|
|
|
|
|
print("\n" + "="*60) |
|
print("🔥 MULTIMODAL TRAINING HUB") |
|
print("="*60) |
|
print(trainer.check_dependencies()) |
|
print("="*60) |
|
print("🚀 Lancement de l'interface...") |
|
|
|
try: |
|
app.launch(**launch_kwargs) |
|
except Exception as e: |
|
print(f"❌ Erreur de lancement: {e}") |
|
print("💡 Essayez de changer le port ou les paramètres réseau") |