import os import torch from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr # ——— CONFIG ——— REPO_ID = "CodCodingCode/llama-3.1-8b-clinical" SUBFOLDER = "checkpoint-45000" HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") if not HF_TOKEN: raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env") # ——— 1) Download only the files in checkpoint-45000/ ——— local_cache = snapshot_download( repo_id=REPO_ID, token=HF_TOKEN, allow_patterns=[ f"{SUBFOLDER}/*.json", f"{SUBFOLDER}/*.safetensors", ], ) # ——— 2) Point MODEL_DIR at that subfolder ——— MODEL_DIR = os.path.join(local_cache, SUBFOLDER) # ——— 3) Load tokenizer & model from disk ——— tokenizer = AutoTokenizer.from_pretrained( MODEL_DIR, use_fast=False, trust_remote_code=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, ) model.eval() # === Role Agent with instruction/input/output format === class RoleAgent: def __init__(self, role_instruction): self.role_instruction = role_instruction def act(self, input_text): prompt = ( f"Instruction: {self.role_instruction}\n" f"Input: {input_text}\n" f"Output:" ) encoding = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in encoding.items()} outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) thinking = "" answer = response if all(tag in response for tag in ("THINKING:", "ANSWER:", "END")): block = response.split("THINKING:")[1].split("END")[0] thinking = block.split("ANSWER:")[0].strip() answer = block.split("ANSWER:")[1].strip() return {"thinking": thinking, "output": answer} # === Agents === summarizer = RoleAgent( "You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues." ) diagnoser = RoleAgent( "You are a board-certified diagnostician that diagnoses patients." ) questioner = RoleAgent("You are a physician asking questions to diagnose a patient.") treatment_agent = RoleAgent( "You are a board-certified clinician. Based on the diagnosis and patient vignette provided below, suggest a concise treatment plan that could realistically be initiated by a primary care physician or psychiatrist." ) # === Inference State === conversation_history = [] summary = "" diagnosis = "" # === Gradio Inference === def simulate_interaction(user_input, iterations=1): history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"] summary, diagnosis = "", "" for i in range(iterations): # Summarize sum_in = "\n".join(history) + f"\nPrevious Vignette: {summary}" sum_out = summarizer.act(sum_in) summary = sum_out["output"] # Diagnose diag_out = diagnoser.act(summary) diagnosis = diag_out["output"] # Question q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diag_out['thinking']} {diagnosis}" q_out = questioner.act(q_in) history.append(f"Doctor: {q_out['output']}") history.append("Patient: (awaiting response)") # Treatment treatment_out = treatment_agent.act( f"Diagnosis: {diagnosis}\nVignette: {summary}" ) return { "summary": sum_out, "diagnosis": diag_out, "question": q_out, "treatment": treatment_out, "conversation": "\n".join(history), } # === Gradio UI === def ui_fn(user_input): res = simulate_interaction(user_input) return f"""📋 Vignette Summary: 💭 THINKING: {res['summary']['thinking']} ANSWER: {res['summary']['output']} 🩺 Diagnosis: 💭 THINKING: {res['diagnosis']['thinking']} ANSWER: {res['diagnosis']['output']} T ❓ Follow-up Question: 💭 THINKING: {res['question']['thinking']} ANSWER: {res['question']['output']} 💊 Treatment Plan: {res['treatment']['output']} 💬 Conversation: {res['conversation']} """ demo = gr.Interface( fn=ui_fn, inputs=gr.Textbox(label="Patient Response"), outputs=gr.Textbox(label="Doctor Simulation Output"), title="🧠 AI Doctor Multi-Agent Reasoning", ) if __name__ == "__main__": demo.launch(share=True)