import os import torch from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr # ——— CONFIG ——— REPO_ID = "CodCodingCode/llama-3.1-8b-clinical" SUBFOLDER = "checkpoint-45000" HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") if not HF_TOKEN: raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env") # ——— 1) Download the full repo ——— local_cache = snapshot_download( repo_id=REPO_ID, token=HF_TOKEN, ) print("[DEBUG] snapshot_download → local_cache:", local_cache) import pathlib print( "[DEBUG] MODEL root contents:", list(pathlib.Path(local_cache).glob(f"{SUBFOLDER}/*")), ) # ——— 2) Repo root contains tokenizer.json; model shards live in the checkpoint subfolder ——— MODEL_DIR = local_cache MODEL_SUBFOLDER = SUBFOLDER print("[DEBUG] MODEL_DIR:", MODEL_DIR) print("[DEBUG] MODEL_DIR files:", os.listdir(MODEL_DIR)) print("[DEBUG] Checkpoint files:", os.listdir(os.path.join(MODEL_DIR, MODEL_SUBFOLDER))) # ——— 3) Load tokenizer & model from disk ——— tokenizer = AutoTokenizer.from_pretrained( MODEL_DIR, use_fast=True, ) print("[DEBUG] Loaded fast tokenizer object:", tokenizer, "type:", type(tokenizer)) # Confirm tokenizer files are present import os print("[DEBUG] Files in MODEL_DIR for tokenizer:", os.listdir(MODEL_DIR)) # Inspect tokenizer's initialization arguments try: print("[DEBUG] Tokenizer init_kwargs:", tokenizer.init_kwargs) except AttributeError: print("[DEBUG] No init_kwargs attribute on tokenizer.") model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, subfolder=MODEL_SUBFOLDER, device_map="auto", torch_dtype=torch.float16, ) model.eval() print( "[DEBUG] Loaded model object:", model.__class__.__name__, "device:", next(model.parameters()).device, ) # === Role Agent with instruction/input/output format === class RoleAgent: def __init__(self, role_instruction, tokenizer, model): self.tokenizer = tokenizer self.model = model self.role_instruction = role_instruction def act(self, input_text): prompt = ( f"Instruction: {self.role_instruction}\n" f"Input: {input_text}\n" f"Output:" ) print("[DEBUG] prompt:", prompt) encoding = self.tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(self.model.device) for k, v in encoding.items()} outputs = self.model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id, ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) thinking = "" print(response) answer = response if all(tag in response for tag in ("THINKING:", "ANSWER:", "END")): print("[FIX] response:", response) block = response.split("THINKING:")[1].split("END")[0] thinking = block.split("ANSWER:")[0].strip() answer = block.split("ANSWER:")[1].strip() return {"thinking": thinking, "output": answer} # === Agents === summarizer = RoleAgent( role_instruction="You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues.", tokenizer=tokenizer, model=model, ) diagnoser = RoleAgent( role_instruction="You are a board-certified diagnostician that diagnoses patients.", tokenizer=tokenizer, model=model, ) questioner = RoleAgent( role_instruction="You are a physician asking questions to diagnose a patient.", tokenizer=tokenizer, model=model, ) treatment_agent = RoleAgent( role_instruction="You are a board-certified clinician. Based on the diagnosis and patient vignette provided below, suggest a concise treatment plan that could realistically be initiated by a primary care physician or psychiatrist.", tokenizer=tokenizer, model=model, ) # === Inference State === conversation_history = [] summary = "" diagnosis = "" # === Gradio Inference === def simulate_interaction(user_input, iterations=1): history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"] summary, diagnosis = "", "" for i in range(iterations): # Summarize sum_in = "\n".join(history) + f"\nPrevious Vignette: {summary}" sum_out = summarizer.act(sum_in) summary = sum_out["output"] # Diagnose diag_out = diagnoser.act(summary) diagnosis = diag_out["output"] # Question q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diag_out['thinking']} {diagnosis}" q_out = questioner.act(q_in) history.append(f"Doctor: {q_out['output']}") history.append("Patient: (awaiting response)") # Treatment treatment_out = treatment_agent.act( f"Diagnosis: {diagnosis}\nVignette: {summary}" ) return { "summary": sum_out, "diagnosis": diag_out, "question": q_out, "treatment": treatment_out, "conversation": "\n".join(history), } # === Gradio UI === def ui_fn(user_input): res = simulate_interaction(user_input) return f"""📋 Vignette Summary: 💭 THINKING: {res['summary']['thinking']} ANSWER: {res['summary']['output']} 🩺 Diagnosis: 💭 THINKING: {res['diagnosis']['thinking']} ANSWER: {res['diagnosis']['output']} T ❓ Follow-up Question: 💭 THINKING: {res['question']['thinking']} ANSWER: {res['question']['output']} 💊 Treatment Plan: {res['treatment']['output']} 💬 Conversation: {res['conversation']} """ demo = gr.Interface( fn=ui_fn, inputs=gr.Textbox(label="Patient Response"), outputs=gr.Textbox(label="Doctor Simulation Output"), title="🧠 AI Doctor Multi-Agent Reasoning", ) if __name__ == "__main__": demo.launch(share=True)