|
|
|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers import __version__ as transformers_version |
|
|
|
MODEL_NAME = "kvn420/Tenro_V4.1" |
|
RECOMMENDED_TRANSFORMERS_VERSION = "4.37.0" |
|
|
|
print(f"Version de Transformers : {transformers_version}") |
|
if transformers_version < RECOMMENDED_TRANSFORMERS_VERSION: |
|
print(f"Attention : Version Transformers ({transformers_version}) < Recommandée ({RECOMMENDED_TRANSFORMERS_VERSION}). Mettez à jour.") |
|
|
|
|
|
print(f"Chargement du tokenizer pour : {MODEL_NAME}") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True |
|
) |
|
print("Tokenizer chargé.") |
|
|
|
print(f"Chargement du modèle : {MODEL_NAME}") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
print(f"Modèle chargé sur {model.device}.") |
|
|
|
|
|
if tokenizer.pad_token_id is None: |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
print(f"tokenizer.pad_token_id défini sur eos_token_id: {tokenizer.eos_token_id}") |
|
|
|
|
|
except Exception as e: |
|
print(f"Erreur critique lors du chargement du modèle/tokenizer : {e}") |
|
|
|
raise gr.Error(f"Impossible de charger le modèle ou le tokenizer: {e}. Vérifiez les logs du Space.") |
|
|
|
|
|
def chat_interaction(user_input, history): |
|
""" |
|
Fonction appelée par Gradio pour chaque interaction de chat. |
|
history est une liste de paires [user_message, assistant_message] |
|
""" |
|
if model is None or tokenizer is None: |
|
return "Erreur: Modèle ou tokenizer non initialisé." |
|
|
|
|
|
messages_for_template = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
for user_msg, assistant_msg in history: |
|
messages_for_template.append({"role": "user", "content": user_msg}) |
|
messages_for_template.append({"role": "assistant", "content": assistant_msg}) |
|
messages_for_template.append({"role": "user", "content": user_input}) |
|
|
|
try: |
|
prompt_tokenized = tokenizer.apply_chat_template( |
|
messages_for_template, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
outputs = model.generate( |
|
prompt_tokenized, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
response_text = tokenizer.decode(outputs[0][prompt_tokenized.shape[-1]:], skip_special_tokens=True) |
|
|
|
|
|
response_text = response_text.replace("<|im_end|>", "").strip() |
|
if response_text.startswith("assistant\n"): |
|
response_text = response_text.split("assistant\n", 1)[-1].strip() |
|
|
|
return response_text |
|
|
|
except Exception as e: |
|
print(f"Erreur pendant la génération : {e}") |
|
return f"Désolé, une erreur est survenue : {e}" |
|
|
|
|
|
|
|
iface = gr.ChatInterface( |
|
fn=chat_interaction, |
|
title=f"Chat avec {MODEL_NAME}", |
|
description=f"Interface de démonstration pour le modèle {MODEL_NAME}. Le modèle est hébergé sur Hugging Face et chargé ici.", |
|
chatbot=gr.Chatbot(height=600), |
|
textbox=gr.Textbox(placeholder="Posez votre question ici...", container=False, scale=7), |
|
retry_btn="Réessayer", |
|
undo_btn="Annuler", |
|
clear_btn="Effacer la conversation", |
|
submit_btn="Envoyer" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
iface.launch() |
|
|
|
|