import os import torch import gradio as gr from PIL import Image from huggingface_hub import login from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from quanto import quantize_model # === [AUTHENTICATION] === hf_token = os.getenv("hf_token") if hf_token is None: raise ValueError("Please set HF_TOKEN environment variable with your Hugging Face token") login(token=hf_token) # === [TRANSLATOR] === translator = pipeline("translation", model="facebook/nllb-200-distilled-600M") # === [LOAD & QUANTIZE MODEL] === model_name = "ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025" tokenizer = AutoTokenizer.from_pretrained(model_name) print("Loading base model...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) print("Quantizing model...") quantized_model = quantize_model(model, bits=4) print("Initializing pipeline...") text_gen_pipeline = pipeline( "text-generation", model=quantized_model, tokenizer=tokenizer, torch_dtype=torch.float16, device_map="auto" ) # === [SYSTEM MESSAGE] === system_message = { "role": "system", "content": ( "You are a helpful, respectful, and knowledgeable medical assistant developed by the AI team at AfriAI Solutions, Senegal. " "Provide brief, clear definitions when answering medical questions. After giving a concise response, ask the user if they would like more information about symptoms, causes, or treatments. " "Always encourage users to consult healthcare professionals for personalized advice." ) } messages = [system_message] max_history = 10 salutations = ["bonjour", "salut", "bonsoir", "coucou"] remerciements = ["merci", "je vous remercie", "thanks"] au_revoir = ["au revoir", "à bientôt", "bye", "bonne journée", "à la prochaine"] def detect_smalltalk(user_input): lower_input = user_input.lower().strip() if any(phrase in lower_input for phrase in salutations): return "Bonjour ! Comment puis-je vous aider aujourd'hui ?", True if any(phrase in lower_input for phrase in remerciements): return "Avec plaisir ! Souhaitez-vous poser une autre question médicale ?", True if any(phrase in lower_input for phrase in au_revoir): return "Au revoir ! Prenez soin de votre santé et n'hésitez pas à revenir si besoin.", True return "", False def medical_chatbot(user_input): global messages smalltalk_response, handled = detect_smalltalk(user_input) if handled: return smalltalk_response translated = translator(user_input, src_lang="fra_Latn", tgt_lang="eng_Latn")[0]['translation_text'] messages.append({"role": "user", "content": translated}) if len(messages) > max_history * 2: messages = [system_message] + messages[-max_history * 2:] prompt = text_gen_pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) response = text_gen_pipeline( prompt, max_new_tokens=1024, do_sample=True, temperature=0.4, top_k=150, top_p=0.75, eos_token_id=[ text_gen_pipeline.tokenizer.eos_token_id, text_gen_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] ) output = response[0]['generated_text'][len(prompt):].strip() translated_back = translator(output, src_lang="eng_Latn", tgt_lang="fra_Latn")[0]['translation_text'] messages.append({"role": "assistant", "content": translated_back}) return translated_back # === [LOGO LOAD] === logo = Image.open("AfriAI Solutions.jpg") # === [GRADIO UI] === with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo")) as demo: with gr.Row(): gr.Image(value=logo, show_label=False, show_download_button=False, interactive=False, height=150) gr.Markdown(""" # 🤖 Chatbot Médical AfriAI Solutions **Posez votre question médicale en français.** Le chatbot vous répondra brièvement et avec bienveillance, puis vous demandera si vous souhaitez plus de détails. """, elem_id="title") chatbot = gr.Chatbot(label="Chat avec le Médecin Virtuel") msg = gr.Textbox(label="Votre question", placeholder="Exemple : Quels sont les symptômes du paludisme ?") clear = gr.Button("Effacer la conversation", variant="secondary") def respond(message, history): response = medical_chatbot(message) history = history or [] history.append((message, response)) return "", history msg.submit(respond, [msg, chatbot], [msg, chatbot]) clear.click(lambda: ("", []), None, [msg, chatbot]) demo.launch()