kvn420 commited on
Commit
f142553
·
verified ·
1 Parent(s): 5bba009

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -16
app.py CHANGED
@@ -246,20 +246,96 @@ class MultimodalTrainer:
246
  try:
247
  logger.info(f"Chargement du modèle: {model_name}")
248
 
249
- if model_type == "causal":
250
- self.current_model = AutoModelForCausalLM.from_pretrained(
251
- model_name,
252
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
253
- device_map="auto" if torch.cuda.is_available() else None,
254
- trust_remote_code=True
255
- )
256
- else:
257
- self.current_model = AutoModel.from_pretrained(
258
- model_name,
259
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
260
- device_map="auto" if torch.cuda.is_available() else None,
261
- trust_remote_code=True
262
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  # Charge le tokenizer
265
  try:
@@ -270,6 +346,14 @@ class MultimodalTrainer:
270
  self.current_tokenizer.pad_token = self.current_tokenizer.eos_token
271
  except Exception as e:
272
  logger.warning(f"Tokenizer non trouvé: {e}")
 
 
 
 
 
 
 
 
273
 
274
  # Charge le processor
275
  try:
@@ -279,7 +363,7 @@ class MultimodalTrainer:
279
  except Exception as e:
280
  logger.warning(f"Processor non trouvé: {e}")
281
 
282
- return f"✅ Modèle {model_name} chargé avec succès!\nType: {type(self.current_model).__name__}"
283
 
284
  except Exception as e:
285
  error_msg = f"❌ Erreur lors du chargement: {str(e)}"
@@ -354,7 +438,55 @@ class MultimodalTrainer:
354
  info += f"📈 Exemples: {len(self.training_data):,}\n"
355
  info += f"📝 Colonnes: {list(self.training_data.column_names)}\n"
356
 
357
- return info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  # Initialisation
360
  trainer = MultimodalTrainer()
@@ -428,6 +560,7 @@ def create_interface():
428
  value="causal"
429
  )
430
  load_model_btn = gr.Button("🔄 Charger le modèle", variant="primary")
 
431
 
432
  with gr.Column():
433
  model_status = gr.Textbox(
@@ -449,6 +582,12 @@ def create_interface():
449
  outputs=model_status
450
  )
451
 
 
 
 
 
 
 
452
  info_btn.click(trainer.get_model_info, outputs=model_info)
453
 
454
  with gr.Tab("📊 Données"):
 
246
  try:
247
  logger.info(f"Chargement du modèle: {model_name}")
248
 
249
+ # Stratégies de chargement multiples
250
+ model_loaded = False
251
+ error_messages = []
252
+
253
+ # Stratégie 1: AutoModelForCausalLM avec trust_remote_code
254
+ if model_type == "causal" and not model_loaded:
255
+ try:
256
+ self.current_model = AutoModelForCausalLM.from_pretrained(
257
+ model_name,
258
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
259
+ device_map="auto" if torch.cuda.is_available() else None,
260
+ trust_remote_code=True
261
+ )
262
+ model_loaded = True
263
+ except Exception as e:
264
+ error_messages.append(f"AutoModelForCausalLM: {str(e)}")
265
+
266
+ # Stratégie 2: AutoModel générique
267
+ if not model_loaded:
268
+ try:
269
+ self.current_model = AutoModel.from_pretrained(
270
+ model_name,
271
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
272
+ device_map="auto" if torch.cuda.is_available() else None,
273
+ trust_remote_code=True
274
+ )
275
+ model_loaded = True
276
+ except Exception as e:
277
+ error_messages.append(f"AutoModel: {str(e)}")
278
+
279
+ # Stratégie 3: Détection automatique basée sur le nom
280
+ if not model_loaded and any(x in model_name.lower() for x in ['llama', 'mistral', 'qwen', 'phi']):
281
+ try:
282
+ # Pour les modèles de type LLaMA/Mistral/Qwen
283
+ from transformers import LlamaForCausalLM, MistralForCausalLM
284
+
285
+ if 'llama' in model_name.lower():
286
+ self.current_model = LlamaForCausalLM.from_pretrained(
287
+ model_name,
288
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
289
+ device_map="auto" if torch.cuda.is_available() else None,
290
+ trust_remote_code=True
291
+ )
292
+ elif 'mistral' in model_name.lower():
293
+ self.current_model = MistralForCausalLM.from_pretrained(
294
+ model_name,
295
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
296
+ device_map="auto" if torch.cuda.is_available() else None,
297
+ trust_remote_code=True
298
+ )
299
+ model_loaded = True
300
+ except Exception as e:
301
+ error_messages.append(f"Modèle spécifique: {str(e)}")
302
+
303
+ # Stratégie 4: Configuration manuelle
304
+ if not model_loaded:
305
+ try:
306
+ # Télécharge la configuration d'abord
307
+ from transformers import AutoConfig
308
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
309
+
310
+ # Force le model_type si manquant
311
+ if not hasattr(config, 'model_type') or config.model_type is None:
312
+ # Détection basée sur l'architecture
313
+ if hasattr(config, 'architectures') and config.architectures:
314
+ arch = config.architectures[0].lower()
315
+ if 'llama' in arch:
316
+ config.model_type = 'llama'
317
+ elif 'mistral' in arch:
318
+ config.model_type = 'mistral'
319
+ elif 'qwen' in arch:
320
+ config.model_type = 'qwen2'
321
+ elif 'phi' in arch:
322
+ config.model_type = 'phi'
323
+ else:
324
+ config.model_type = 'llama' # Par défaut
325
+
326
+ self.current_model = AutoModelForCausalLM.from_pretrained(
327
+ model_name,
328
+ config=config,
329
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
330
+ device_map="auto" if torch.cuda.is_available() else None,
331
+ trust_remote_code=True
332
+ )
333
+ model_loaded = True
334
+ except Exception as e:
335
+ error_messages.append(f"Configuration manuelle: {str(e)}")
336
+
337
+ if not model_loaded:
338
+ return f"❌ Impossible de charger le modèle. Erreurs:\n" + "\n".join(error_messages)
339
 
340
  # Charge le tokenizer
341
  try:
 
346
  self.current_tokenizer.pad_token = self.current_tokenizer.eos_token
347
  except Exception as e:
348
  logger.warning(f"Tokenizer non trouvé: {e}")
349
+ try:
350
+ # Essaye avec un tokenizer générique
351
+ from transformers import LlamaTokenizer
352
+ self.current_tokenizer = LlamaTokenizer.from_pretrained(
353
+ model_name, trust_remote_code=True
354
+ )
355
+ except:
356
+ logger.warning("Aucun tokenizer trouvé")
357
 
358
  # Charge le processor
359
  try:
 
363
  except Exception as e:
364
  logger.warning(f"Processor non trouvé: {e}")
365
 
366
+ return f"✅ Modèle {model_name} chargé avec succès!\nType: {type(self.current_model).__name__}\nArchitecture: {getattr(self.current_model.config, 'architectures', ['Inconnue'])[0] if hasattr(self.current_model, 'config') else 'Inconnue'}"
367
 
368
  except Exception as e:
369
  error_msg = f"❌ Erreur lors du chargement: {str(e)}"
 
438
  info += f"📈 Exemples: {len(self.training_data):,}\n"
439
  info += f"📝 Colonnes: {list(self.training_data.column_names)}\n"
440
 
441
+ def diagnose_model(self, model_name: str):
442
+ """Diagnostique un modèle avant chargement"""
443
+ if not model_name.strip():
444
+ return "❌ Veuillez entrer un nom de modèle"
445
+
446
+ try:
447
+ from transformers import AutoConfig
448
+ import requests
449
+
450
+ result = f"🔍 DIAGNOSTIC DU MODÈLE: {model_name}\n\n"
451
+
452
+ # Vérification de l'existence
453
+ try:
454
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
455
+ result += "✅ Modèle accessible\n"
456
+
457
+ # Informations sur la configuration
458
+ result += f"📋 Type de modèle: {getattr(config, 'model_type', 'Non défini')}\n"
459
+ result += f"🏗️ Architecture: {getattr(config, 'architectures', ['Inconnue'])}\n"
460
+ result += f"📚 Vocabulaire: {getattr(config, 'vocab_size', 'Inconnu')}\n"
461
+ result += f"🧠 Couches cachées: {getattr(config, 'hidden_size', 'Inconnu')}\n"
462
+ result += f"🔢 Nombre de couches: {getattr(config, 'num_hidden_layers', 'Inconnu')}\n"
463
+
464
+ # Recommandations
465
+ if not hasattr(config, 'model_type') or config.model_type is None:
466
+ result += "\n⚠️ PROBLÈME: model_type manquant\n"
467
+ result += "💡 SOLUTION: Le chargeur essaiera de détecter automatiquement\n"
468
+
469
+ if hasattr(config, 'architectures') and config.architectures:
470
+ arch = config.architectures[0].lower()
471
+ if 'llama' in arch:
472
+ result += "🎯 Type détecté: LLaMA\n"
473
+ elif 'mistral' in arch:
474
+ result += "🎯 Type détecté: Mistral\n"
475
+ elif 'qwen' in arch:
476
+ result += "🎯 Type détecté: Qwen\n"
477
+ elif 'phi' in arch:
478
+ result += "🎯 Type détecté: Phi\n"
479
+
480
+ result += "\n✅ Chargement possible avec les stratégies multiples"
481
+
482
+ except Exception as e:
483
+ result += f"❌ Erreur d'accès: {str(e)}\n"
484
+ result += "💡 Vérifiez que le modèle existe et est public\n"
485
+
486
+ return result
487
+
488
+ except Exception as e:
489
+ return f"❌ Erreur diagnostic: {str(e)}"
490
 
491
  # Initialisation
492
  trainer = MultimodalTrainer()
 
560
  value="causal"
561
  )
562
  load_model_btn = gr.Button("🔄 Charger le modèle", variant="primary")
563
+ diagnose_btn = gr.Button("🔍 Diagnostiquer le modèle", variant="secondary")
564
 
565
  with gr.Column():
566
  model_status = gr.Textbox(
 
582
  outputs=model_status
583
  )
584
 
585
+ diagnose_btn.click(
586
+ trainer.diagnose_model,
587
+ inputs=[model_input],
588
+ outputs=model_status
589
+ )
590
+
591
  info_btn.click(trainer.get_model_info, outputs=model_info)
592
 
593
  with gr.Tab("📊 Données"):