kvn420 commited on
Commit
4c516c0
·
verified ·
1 Parent(s): bf13d26

Create run_qwen_model.py

Browse files
Files changed (1) hide show
  1. run_qwen_model.py +123 -0
run_qwen_model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers import __version__ as transformers_version
6
+
7
+ MODEL_NAME = "kvn420/Tenro_V4.1"
8
+ RECOMMENDED_TRANSFORMERS_VERSION = "4.37.0"
9
+
10
+ print(f"Version de Transformers : {transformers_version}")
11
+ if transformers_version < RECOMMENDED_TRANSFORMERS_VERSION:
12
+ print(f"Attention : Version Transformers ({transformers_version}) < Recommandée ({RECOMMENDED_TRANSFORMERS_VERSION}). Mettez à jour.")
13
+
14
+ # --- Chargement du modèle et du tokenizer (une seule fois) ---
15
+ print(f"Chargement du tokenizer pour : {MODEL_NAME}")
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(
18
+ MODEL_NAME,
19
+ trust_remote_code=True
20
+ )
21
+ print("Tokenizer chargé.")
22
+
23
+ print(f"Chargement du modèle : {MODEL_NAME}")
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ MODEL_NAME,
26
+ torch_dtype=torch.bfloat16, # ou "auto"
27
+ trust_remote_code=True,
28
+ device_map="auto" # Utilise le GPU si disponible, sinon CPU
29
+ )
30
+ print(f"Modèle chargé sur {model.device}.")
31
+
32
+ # Définir pad_token_id si manquant (important pour la génération)
33
+ if tokenizer.pad_token_id is None:
34
+ tokenizer.pad_token_id = tokenizer.eos_token_id
35
+ print(f"tokenizer.pad_token_id défini sur eos_token_id: {tokenizer.eos_token_id}")
36
+
37
+
38
+ except Exception as e:
39
+ print(f"Erreur critique lors du chargement du modèle/tokenizer : {e}")
40
+ # Lever l'erreur pour que Gradio l'affiche ou la loggue
41
+ raise gr.Error(f"Impossible de charger le modèle ou le tokenizer: {e}. Vérifiez les logs du Space.")
42
+ # --- Fin du chargement ---
43
+
44
+ def chat_interaction(user_input, history):
45
+ """
46
+ Fonction appelée par Gradio pour chaque interaction de chat.
47
+ history est une liste de paires [user_message, assistant_message]
48
+ """
49
+ if model is None or tokenizer is None:
50
+ return "Erreur: Modèle ou tokenizer non initialisé."
51
+
52
+ # Construire le prompt avec l'historique pour le modèle
53
+ messages_for_template = []
54
+ # Ajouter un message système par défaut si l'historique est vide et le premier message n'est pas un système
55
+ # Ou si vous voulez toujours un message système spécifique.
56
+ # Note: Le chat_template de Qwen ajoute déjà un message système par défaut.
57
+ # Adaptez selon le comportement exact souhaité.
58
+ # messages_for_template.append({"role": "system", "content": "Tu es Qwen, un assistant IA serviable."})
59
+
60
+ for user_msg, assistant_msg in history:
61
+ messages_for_template.append({"role": "user", "content": user_msg})
62
+ messages_for_template.append({"role": "assistant", "content": assistant_msg})
63
+ messages_for_template.append({"role": "user", "content": user_input})
64
+
65
+ try:
66
+ prompt_tokenized = tokenizer.apply_chat_template(
67
+ messages_for_template,
68
+ tokenize=True,
69
+ add_generation_prompt=True,
70
+ return_tensors="pt"
71
+ ).to(model.device)
72
+
73
+ outputs = model.generate(
74
+ prompt_tokenized,
75
+ max_new_tokens=512, # Augmenté pour des réponses potentiellement plus longues
76
+ do_sample=True,
77
+ temperature=0.7,
78
+ top_p=0.9,
79
+ pad_token_id=tokenizer.pad_token_id
80
+ )
81
+
82
+ response_text = tokenizer.decode(outputs[0][prompt_tokenized.shape[-1]:], skip_special_tokens=True)
83
+
84
+ # Nettoyage simple (optionnel, dépend du modèle)
85
+ response_text = response_text.replace("<|im_end|>", "").strip()
86
+ if response_text.startswith("assistant\n"): # Parfois Qwen ajoute cela
87
+ response_text = response_text.split("assistant\n", 1)[-1].strip()
88
+
89
+ return response_text
90
+
91
+ except Exception as e:
92
+ print(f"Erreur pendant la génération : {e}")
93
+ return f"Désolé, une erreur est survenue : {e}"
94
+
95
+ # Création de l'interface Gradio
96
+ # Utilisation de `gr.ChatInterface` qui gère l'historique automatiquement.
97
+ iface = gr.ChatInterface(
98
+ fn=chat_interaction,
99
+ title=f"Chat avec {MODEL_NAME}",
100
+ description=f"Interface de démonstration pour le modèle {MODEL_NAME}. Le modèle est hébergé sur Hugging Face et chargé ici.",
101
+ chatbot=gr.Chatbot(height=600),
102
+ textbox=gr.Textbox(placeholder="Posez votre question ici...", container=False, scale=7),
103
+ retry_btn="Réessayer",
104
+ undo_btn="Annuler",
105
+ clear_btn="Effacer la conversation",
106
+ submit_btn="Envoyer"
107
+ )
108
+
109
+ # Lancer l'application (pour un test local, ou si vous n'utilisez pas `if __name__ == "__main__":` dans Spaces)
110
+ # iface.launch() # Décommentez pour un test local facile
111
+
112
+ # Pour Spaces, il est souvent préférable de laisser Spaces gérer le lancement
113
+ # si vous utilisez le SDK Gradio directement dans la configuration du Space.
114
+ # Si vous exécutez ce script directement avec python app.py, il faut iface.launch().
115
+ # Dans le contexte d'un Space Gradio, le iface est généralement ce qui est "retourné" implicitement.
116
+ # Si vous voulez contrôler le lancement (ex: pour des options), utilisez :
117
+ # if __name__ == "__main__":
118
+ # iface.launch()
119
+ # Mais pour un Space Gradio simple, juste définir `iface` peut suffire.
120
+ # La convention est de lancer si le script est exécuté directement.
121
+ if __name__ == '__main__':
122
+ iface.launch() # Permet de tester localement `python app.py`
123
+