File size: 5,309 Bytes
4c516c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# app.py
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import __version__ as transformers_version

MODEL_NAME = "kvn420/Tenro_V4.1"
RECOMMENDED_TRANSFORMERS_VERSION = "4.37.0"

print(f"Version de Transformers : {transformers_version}")
if transformers_version < RECOMMENDED_TRANSFORMERS_VERSION:
    print(f"Attention : Version Transformers ({transformers_version}) < Recommandée ({RECOMMENDED_TRANSFORMERS_VERSION}). Mettez à jour.")

# --- Chargement du modèle et du tokenizer (une seule fois) ---
print(f"Chargement du tokenizer pour : {MODEL_NAME}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True
    )
    print("Tokenizer chargé.")

    print(f"Chargement du modèle : {MODEL_NAME}")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16, # ou "auto"
        trust_remote_code=True,
        device_map="auto" # Utilise le GPU si disponible, sinon CPU
    )
    print(f"Modèle chargé sur {model.device}.")

    # Définir pad_token_id si manquant (important pour la génération)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        print(f"tokenizer.pad_token_id défini sur eos_token_id: {tokenizer.eos_token_id}")


except Exception as e:
    print(f"Erreur critique lors du chargement du modèle/tokenizer : {e}")
    # Lever l'erreur pour que Gradio l'affiche ou la loggue
    raise gr.Error(f"Impossible de charger le modèle ou le tokenizer: {e}. Vérifiez les logs du Space.")
# --- Fin du chargement ---

def chat_interaction(user_input, history):
    """
    Fonction appelée par Gradio pour chaque interaction de chat.
    history est une liste de paires [user_message, assistant_message]
    """
    if model is None or tokenizer is None:
        return "Erreur: Modèle ou tokenizer non initialisé."

    # Construire le prompt avec l'historique pour le modèle
    messages_for_template = []
    # Ajouter un message système par défaut si l'historique est vide et le premier message n'est pas un système
    # Ou si vous voulez toujours un message système spécifique.
    # Note: Le chat_template de Qwen ajoute déjà un message système par défaut.
    # Adaptez selon le comportement exact souhaité.
    # messages_for_template.append({"role": "system", "content": "Tu es Qwen, un assistant IA serviable."})

    for user_msg, assistant_msg in history:
        messages_for_template.append({"role": "user", "content": user_msg})
        messages_for_template.append({"role": "assistant", "content": assistant_msg})
    messages_for_template.append({"role": "user", "content": user_input})

    try:
        prompt_tokenized = tokenizer.apply_chat_template(
            messages_for_template,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            prompt_tokenized,
            max_new_tokens=512, # Augmenté pour des réponses potentiellement plus longues
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id
        )

        response_text = tokenizer.decode(outputs[0][prompt_tokenized.shape[-1]:], skip_special_tokens=True)

        # Nettoyage simple (optionnel, dépend du modèle)
        response_text = response_text.replace("<|im_end|>", "").strip()
        if response_text.startswith("assistant\n"): # Parfois Qwen ajoute cela
            response_text = response_text.split("assistant\n", 1)[-1].strip()

        return response_text

    except Exception as e:
        print(f"Erreur pendant la génération : {e}")
        return f"Désolé, une erreur est survenue : {e}"

# Création de l'interface Gradio
# Utilisation de `gr.ChatInterface` qui gère l'historique automatiquement.
iface = gr.ChatInterface(
    fn=chat_interaction,
    title=f"Chat avec {MODEL_NAME}",
    description=f"Interface de démonstration pour le modèle {MODEL_NAME}. Le modèle est hébergé sur Hugging Face et chargé ici.",
    chatbot=gr.Chatbot(height=600),
    textbox=gr.Textbox(placeholder="Posez votre question ici...", container=False, scale=7),
    retry_btn="Réessayer",
    undo_btn="Annuler",
    clear_btn="Effacer la conversation",
    submit_btn="Envoyer"
)

# Lancer l'application (pour un test local, ou si vous n'utilisez pas `if __name__ == "__main__":` dans Spaces)
# iface.launch() # Décommentez pour un test local facile

# Pour Spaces, il est souvent préférable de laisser Spaces gérer le lancement
# si vous utilisez le SDK Gradio directement dans la configuration du Space.
# Si vous exécutez ce script directement avec python app.py, il faut iface.launch().
# Dans le contexte d'un Space Gradio, le iface est généralement ce qui est "retourné" implicitement.
# Si vous voulez contrôler le lancement (ex: pour des options), utilisez :
# if __name__ == "__main__":
#     iface.launch()
# Mais pour un Space Gradio simple, juste définir `iface` peut suffire.
# La convention est de lancer si le script est exécuté directement.
if __name__ == '__main__':
     iface.launch() # Permet de tester localement `python app.py`