Spaces:
Paused
Paused
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
# ——— CONFIG ——— | |
REPO_ID = "CodCodingCode/llama-3.1-8b-clinical" | |
SUBFOLDER = "checkpoint-45000" | |
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
if not HF_TOKEN: | |
raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env") | |
# ——— 1) Download the full repo ——— | |
local_cache = snapshot_download( | |
repo_id=REPO_ID, | |
token=HF_TOKEN, | |
) | |
print("[DEBUG] snapshot_download → local_cache:", local_cache) | |
import pathlib | |
print( | |
"[DEBUG] MODEL root contents:", | |
list(pathlib.Path(local_cache).glob(f"{SUBFOLDER}/*")), | |
) | |
# ——— 2) Repo root contains tokenizer.json; model shards live in the checkpoint subfolder ——— | |
MODEL_DIR = local_cache | |
MODEL_SUBFOLDER = SUBFOLDER | |
print("[DEBUG] MODEL_DIR:", MODEL_DIR) | |
print("[DEBUG] MODEL_DIR files:", os.listdir(MODEL_DIR)) | |
print("[DEBUG] Checkpoint files:", os.listdir(os.path.join(MODEL_DIR, MODEL_SUBFOLDER))) | |
# ——— 3) Load tokenizer & model from disk ——— | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_DIR, | |
use_fast=True, | |
) | |
print("[DEBUG] Loaded fast tokenizer object:", tokenizer, "type:", type(tokenizer)) | |
# Confirm tokenizer files are present | |
import os | |
print("[DEBUG] Files in MODEL_DIR for tokenizer:", os.listdir(MODEL_DIR)) | |
# Inspect tokenizer's initialization arguments | |
try: | |
print("[DEBUG] Tokenizer init_kwargs:", tokenizer.init_kwargs) | |
except AttributeError: | |
print("[DEBUG] No init_kwargs attribute on tokenizer.") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_DIR, | |
subfolder=MODEL_SUBFOLDER, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
) | |
model.eval() | |
print( | |
"[DEBUG] Loaded model object:", | |
model.__class__.__name__, | |
"device:", | |
next(model.parameters()).device, | |
) | |
# === Role Agent with instruction/input/output format === | |
class RoleAgent: | |
def __init__(self, role_instruction, tokenizer, model): | |
self.tokenizer = tokenizer | |
self.model = model | |
self.role_instruction = role_instruction | |
def act(self, input_text): | |
prompt = ( | |
f"Instruction: {self.role_instruction}\n" | |
f"Input: {input_text}\n" | |
f"Output:" | |
) | |
encoding = self.tokenizer(prompt, return_tensors="pt") | |
inputs = {k: v.to(self.model.device) for k, v in encoding.items()} | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=self.tokenizer.eos_token_id, | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
thinking = "" | |
print(f"[RESPONSE]: {response}") | |
answer = response | |
if "Output:" in response: | |
# Split on the last occurrence of 'Output:' in case it's repeated | |
answer = response.rsplit("Output:", 1)[-1].strip() | |
else: | |
# Fallback: if thinking/answer/end tags exist, use previous logic | |
tags = ("THINKING:", "ANSWER:", "END") | |
if all(tag in response for tag in tags): | |
print("[FIX] tagged response detected:", response) | |
block = response.split("THINKING:", 1)[1].split("END", 1)[0] | |
thinking = block.split("ANSWER:", 1)[0].strip() | |
answer = block.split("ANSWER:", 1)[1].strip() | |
return {"thinking": thinking, "output": answer} | |
# === Agents === | |
summarizer = RoleAgent( | |
role_instruction="You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues.", | |
tokenizer=tokenizer, | |
model=model, | |
) | |
diagnoser = RoleAgent( | |
role_instruction="You are a board-certified diagnostician that diagnoses patients.", | |
tokenizer=tokenizer, | |
model=model, | |
) | |
questioner = RoleAgent( | |
role_instruction="You are a physician asking questions to diagnose a patient.", | |
tokenizer=tokenizer, | |
model=model, | |
) | |
treatment_agent = RoleAgent( | |
role_instruction="You are a board-certified clinician. Based on the diagnosis and patient vignette provided below, suggest a concise treatment plan that could realistically be initiated by a primary care physician or psychiatrist.", | |
tokenizer=tokenizer, | |
model=model, | |
) | |
"""[DEBUG] prompt: Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues. | |
Input: Doctor: What brings you in today? | |
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? | |
Previous Vignette: | |
Output: | |
Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues. | |
Input: Doctor: What brings you in today? | |
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? | |
Previous Vignette: | |
Output: The patient is a 15-year-old male presenting with knee pain.""" | |
# === Inference State === | |
conversation_history = [] | |
summary = "" | |
diagnosis = "" | |
# === Gradio Inference === | |
def simulate_interaction(user_input, iterations=1): | |
history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"] | |
summary, diagnosis = "", "" | |
for i in range(iterations): | |
# Summarize | |
sum_in = "\n".join(history) + f"\nPrevious Vignette: {summary}" | |
sum_out = summarizer.act(sum_in) | |
summary = sum_out["output"] | |
# Diagnose | |
diag_out = diagnoser.act(summary) | |
diagnosis = diag_out["output"] | |
# Question | |
q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diag_out['thinking']} {diagnosis}" | |
q_out = questioner.act(q_in) | |
history.append(f"Doctor: {q_out['output']}") | |
history.append("Patient: (awaiting response)") | |
# Treatment | |
treatment_out = treatment_agent.act( | |
f"Diagnosis: {diagnosis}\nVignette: {summary}" | |
) | |
return { | |
"summary": sum_out, | |
"diagnosis": diag_out, | |
"question": q_out, | |
"treatment": treatment_out, | |
"conversation": "\n".join(history), | |
} | |
# === Gradio UI === | |
def ui_fn(user_input): | |
res = simulate_interaction(user_input) | |
return f"""📋 Vignette Summary: | |
💭 THINKING: {res['summary']['thinking']} | |
ANSWER: {res['summary']['output']} | |
🩺 Diagnosis: | |
💭 THINKING: {res['diagnosis']['thinking']} | |
ANSWER: {res['diagnosis']['output']} | |
T | |
❓ Follow-up Question: | |
💭 THINKING: {res['question']['thinking']} | |
ANSWER: {res['question']['output']} | |
💊 Treatment Plan: | |
{res['treatment']['output']} | |
💬 Conversation: | |
{res['conversation']} | |
""" | |
demo = gr.Interface( | |
fn=ui_fn, | |
inputs=gr.Textbox(label="Patient Response"), | |
outputs=gr.Textbox(label="Doctor Simulation Output"), | |
title="🧠 AI Doctor Multi-Agent Reasoning", | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |