medical-test / app.py
CodCodingCode's picture
added more debugging code and fixed up summarizer's input
d25cc99
raw
history blame
6.93 kB
import os
import torch
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# ——— CONFIG ———
REPO_ID = "CodCodingCode/llama-3.1-8b-clinical"
SUBFOLDER = "checkpoint-45000"
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
if not HF_TOKEN:
raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env")
# ——— 1) Download the full repo ———
local_cache = snapshot_download(
repo_id=REPO_ID,
token=HF_TOKEN,
)
print("[DEBUG] snapshot_download → local_cache:", local_cache)
import pathlib
print(
"[DEBUG] MODEL root contents:",
list(pathlib.Path(local_cache).glob(f"{SUBFOLDER}/*")),
)
# ——— 2) Repo root contains tokenizer.json; model shards live in the checkpoint subfolder ———
MODEL_DIR = local_cache
MODEL_SUBFOLDER = SUBFOLDER
print("[DEBUG] MODEL_DIR:", MODEL_DIR)
print("[DEBUG] MODEL_DIR files:", os.listdir(MODEL_DIR))
print("[DEBUG] Checkpoint files:", os.listdir(os.path.join(MODEL_DIR, MODEL_SUBFOLDER)))
# ——— 3) Load tokenizer & model from disk ———
tokenizer = AutoTokenizer.from_pretrained(
MODEL_DIR,
use_fast=True,
)
print("[DEBUG] Loaded fast tokenizer object:", tokenizer, "type:", type(tokenizer))
# Confirm tokenizer files are present
import os
print("[DEBUG] Files in MODEL_DIR for tokenizer:", os.listdir(MODEL_DIR))
# Inspect tokenizer's initialization arguments
try:
print("[DEBUG] Tokenizer init_kwargs:", tokenizer.init_kwargs)
except AttributeError:
print("[DEBUG] No init_kwargs attribute on tokenizer.")
model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
subfolder=MODEL_SUBFOLDER,
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()
print(
"[DEBUG] Loaded model object:",
model.__class__.__name__,
"device:",
next(model.parameters()).device,
)
# === Role Agent with instruction/input/output format ===
class RoleAgent:
def __init__(self, role_instruction, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
self.role_instruction = role_instruction
def act(self, input_text):
prompt = (
f"Instruction: {self.role_instruction}\n"
f"Input: {input_text}\n"
f"Output:"
)
encoding = self.tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in encoding.items()}
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id,
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
thinking = ""
print(f"[RESPONSE]: {response}")
answer = response
if "Output:" in response:
# Split on the last occurrence of 'Output:' in case it's repeated
answer = response.rsplit("Output:", 1)[-1].strip()
else:
# Fallback: if thinking/answer/end tags exist, use previous logic
tags = ("THINKING:", "ANSWER:", "END")
if all(tag in response for tag in tags):
print("[FIX] tagged response detected:", response)
block = response.split("THINKING:", 1)[1].split("END", 1)[0]
thinking = block.split("ANSWER:", 1)[0].strip()
answer = block.split("ANSWER:", 1)[1].strip()
return {"thinking": thinking, "output": answer}
# === Agents ===
summarizer = RoleAgent(
role_instruction="You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues.",
tokenizer=tokenizer,
model=model,
)
diagnoser = RoleAgent(
role_instruction="You are a board-certified diagnostician that diagnoses patients.",
tokenizer=tokenizer,
model=model,
)
questioner = RoleAgent(
role_instruction="You are a physician asking questions to diagnose a patient.",
tokenizer=tokenizer,
model=model,
)
treatment_agent = RoleAgent(
role_instruction="You are a board-certified clinician. Based on the diagnosis and patient vignette provided below, suggest a concise treatment plan that could realistically be initiated by a primary care physician or psychiatrist.",
tokenizer=tokenizer,
model=model,
)
"""[DEBUG] prompt: Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues.
Input: Doctor: What brings you in today?
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee?
Previous Vignette:
Output:
Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues.
Input: Doctor: What brings you in today?
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee?
Previous Vignette:
Output: The patient is a 15-year-old male presenting with knee pain."""
# === Inference State ===
conversation_history = []
summary = ""
diagnosis = ""
# === Gradio Inference ===
def simulate_interaction(user_input, iterations=1):
history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"]
summary, diagnosis = "", ""
for i in range(iterations):
# Summarize
sum_in = "\n".join(history) + f"\nPrevious Vignette: {summary}"
sum_out = summarizer.act(sum_in)
summary = sum_out["output"]
# Diagnose
diag_out = diagnoser.act(summary)
diagnosis = diag_out["output"]
# Question
q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diag_out['thinking']} {diagnosis}"
q_out = questioner.act(q_in)
history.append(f"Doctor: {q_out['output']}")
history.append("Patient: (awaiting response)")
# Treatment
treatment_out = treatment_agent.act(
f"Diagnosis: {diagnosis}\nVignette: {summary}"
)
return {
"summary": sum_out,
"diagnosis": diag_out,
"question": q_out,
"treatment": treatment_out,
"conversation": "\n".join(history),
}
# === Gradio UI ===
def ui_fn(user_input):
res = simulate_interaction(user_input)
return f"""📋 Vignette Summary:
💭 THINKING: {res['summary']['thinking']}
ANSWER: {res['summary']['output']}
🩺 Diagnosis:
💭 THINKING: {res['diagnosis']['thinking']}
ANSWER: {res['diagnosis']['output']}
T
❓ Follow-up Question:
💭 THINKING: {res['question']['thinking']}
ANSWER: {res['question']['output']}
💊 Treatment Plan:
{res['treatment']['output']}
💬 Conversation:
{res['conversation']}
"""
demo = gr.Interface(
fn=ui_fn,
inputs=gr.Textbox(label="Patient Response"),
outputs=gr.Textbox(label="Doctor Simulation Output"),
title="🧠 AI Doctor Multi-Agent Reasoning",
)
if __name__ == "__main__":
demo.launch(share=True)