biogpt / app.py
Robinhood135's picture
Update app.py
edfbc7f verified
import os
import torch
import torch.nn as nn
from torchvision import transforms
from typing import Dict, Any
from PIL import Image
import open_clip
from transformers import (
BioGptTokenizer,
BioGptForCausalLM,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
import gradio as gr
# NOTE: Ensure this library is installed on the Hugging Face Space
from IndicTransToolkit import IndicProcessor
from huggingface_hub import hf_hub_download # New import for HF deployment
# --- 1. CONFIGURATION (Stage 1: Report Generation) ---
# NOTE: Update this REPO_ID to the actual Hugging Face repository where you upload your .pth files!
REPO_ID = "Robinhood135/biogptm1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- MODEL/DECODING PARAMS ---
BIOMEDCLIP_MODEL_NAME = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
PREFIX_LENGTH = 10
PROMPT_TEXT = "You are a Radiologist.The chest image findings are:"
# --- BEST DECODING STRATEGY (Beam Search) ---
BEST_STRATEGY_PARAMS = {
"num_beams": 4,
"do_sample": False,
"repetition_penalty": 1.2,
"max_new_tokens": 100,
"min_new_tokens": 10,
}
# --- 2. MODEL CLASS (Stage 1) - Kept the same ---
def freeze_module(module: nn.Module):
for param in module.parameters(): param.requires_grad = False
class BiomedCLIPBioGPTGenerator(nn.Module):
def __init__(self, tokenizer, model_name=BIOMEDCLIP_MODEL_NAME, prefix_length=PREFIX_LENGTH):
super().__init__()
self.tokenizer = tokenizer
self.prefix_length = prefix_length
self.clip_model, _, _ = open_clip.create_model_and_transforms(model_name)
# Handle cases where image encoder is visual or a direct method
self.image_encoder = self.clip_model.visual if hasattr(self.clip_model, 'visual') else self.clip_model.encode_image
freeze_module(self.image_encoder)
with torch.no_grad():
dummy_features = self.image_encoder(torch.randn(1, 3, 224, 224))
if isinstance(dummy_features, tuple): dummy_features = dummy_features[0]
self.embed_dim = dummy_features.shape[-1]
config = BioGptForCausalLM.from_pretrained('microsoft/biogpt').config
self.biogpt = BioGptForCausalLM.from_pretrained('microsoft/biogpt', config=config)
self.biogpt.resize_token_embeddings(len(self.tokenizer))
self.gpt_hidden_dim = self.biogpt.config.hidden_size
self.biogpt.config.pad_token_id = self.tokenizer.pad_token_id
self.projection_head = nn.Sequential(
nn.Linear(self.embed_dim, self.prefix_length * self.gpt_hidden_dim),
nn.Tanh(),
nn.Linear(self.prefix_length * self.gpt_hidden_dim, self.prefix_length * self.gpt_hidden_dim)
)
@torch.no_grad()
def get_prefix_embeddings(self, images):
clip_features = self.image_encoder(images).float()
prefix_embeds = self.projection_head(clip_features)
return prefix_embeds.view(-1, self.prefix_length, self.gpt_hidden_dim)
def get_text_embeddings(self, input_ids):
return self.biogpt.get_input_embeddings()(input_ids)
# --- 3. INFERENCE FUNCTION (Stage 1) - Kept the same ---
@torch.no_grad()
def generate_report(model, pil_image: Image.Image, method_params: Dict[str, Any]):
model.eval()
# 3.1 Apply image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD)
])
image_tensor = transform(pil_image.convert('RGB')).unsqueeze(0).to(device)
# 3.2 Get prefix embeddings
prefix_embeds = model.get_prefix_embeddings(image_tensor)
# 3.3 Encode prompt text
prompt_data = model.tokenizer(PROMPT_TEXT, return_tensors="pt").to(device)
prompt_embeds = model.get_text_embeddings(prompt_data["input_ids"])
combined_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1)
prefix_att_mask = torch.ones(prefix_embeds.shape[:2], dtype=torch.long, device=device)
combined_att_mask = torch.cat([prefix_att_mask, prompt_data["attention_mask"]], dim=1)
# 3.4 Generation parameters
generation_args = {
"inputs_embeds": combined_embeds,
"attention_mask": combined_att_mask,
"pad_token_id": model.tokenizer.pad_token_id,
"eos_token_id": model.tokenizer.eos_token_id,
"use_cache": True,
}
generation_args.update(method_params)
# 3.5 Generate
generated_ids = model.biogpt.generate(**generation_args)
# 3.6 Decode and clean
full_text = model.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
if full_text.startswith(PROMPT_TEXT):
text = full_text[len(PROMPT_TEXT):].strip()
else:
text = full_text
return text if text.strip() else "[BLANK/FAILED GENERATION]"
# --- 4. MODEL LOADING (Stage 1) - MODIFIED FOR HF HUB ---
def load_trained_generator():
print(f"Loading Report Generator model from {REPO_ID}...")
# Load from Hugging Face Hub
try:
clip_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biomedclipp.pth")
gpt_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biogptt.pth")
proj_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="projectorr.pth")
except Exception as e:
raise FileNotFoundError(f"Failed to download one or more checkpoint files from {REPO_ID}. Error: {e}")
# Initialize tokenizer
base_tokenizer = BioGptTokenizer.from_pretrained('microsoft/biogpt')
if base_tokenizer.pad_token is None:
base_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Initialize model
model = BiomedCLIPBioGPTGenerator(base_tokenizer).to(device)
# Load CLIP encoder
clip_checkpoint = torch.load(clip_ckpt_path, map_location=device)
state_dict = clip_checkpoint.get('model_state_dict', clip_checkpoint.get('state_dict', clip_checkpoint))
# Filter state dict for the visual encoder and clean keys
visual_state = {k.replace('model.visual.', '').replace('visual.', ''): v for k, v in state_dict.items() if 'visual' in k}
model.image_encoder.load_state_dict(visual_state, strict=False)
# Load trained BioGPT and Projection weights
model.biogpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
model.projection_head.load_state_dict(torch.load(proj_ckpt_path, map_location=device))
model.eval()
print("✅ Report Generator loaded successfully.")
return model
# --- 5. MODEL LOADING (Stage 2: Translation) - Kept the same ---
def load_translator():
# IndicTrans2 models are typically loaded directly from their HF repos (ai4bharat/...)
print("Loading Translation model (IndicTrans2)...")
try:
# IndicTransToolkit library is assumed to be installed
ip = IndicProcessor(inference=True)
model_name = "ai4bharat/indictrans2-en-indic-dist-200M"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Note: If memory is an issue on the Space, you might need to use a smaller model or lower precision.
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True).to(device)
print("✅ Translator loaded successfully.")
return ip, tokenizer, model
except Exception as e:
print(f"Error loading translation model: {e}")
# Return dummy values if loading fails to prevent crash
return None, None, None
# Load models globally
GENERATOR_MODEL = load_trained_generator()
IP, TRANS_TOKENIZER, TRANS_MODEL = load_translator()
# --- 6. TRANSLATION FUNCTION (Stage 2) - Kept the same ---
@torch.no_grad()
def translate_report(english_text: str, target_lang: str = "hin_Deva") -> str:
if TRANS_MODEL is None or not english_text:
return "[Translation Model Not Available or No Text to Translate]"
# 6.1 Preprocessing
batch = IP.preprocess_batch([english_text], src_lang="eng_Latn", tgt_lang=target_lang, visualize=False)
batch = TRANS_TOKENIZER(batch, padding="longest", truncation=True, max_length=256, return_tensors="pt").to(device)
# 6.2 Generation
outputs = TRANS_MODEL.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256, use_cache=False)
# 6.3 Postprocessing
outputs = TRANS_TOKENIZER.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
translated_text = IP.postprocess_batch(outputs, lang=target_lang)[0]
return translated_text
# --- 7. GRADIO WRAPPER FUNCTION (Simplified) - Kept the same ---
def inference_wrapper(input_image: Image.Image):
if input_image is None:
return "Please upload a chest X-ray image.", "[No English Report]"
# STAGE 1: GENERATE RAW ENGLISH REPORT
try:
raw_english_report = generate_report(GENERATOR_MODEL, input_image, BEST_STRATEGY_PARAMS)
except Exception as e:
raw_english_report = f"An error occurred during generation: {e}"
return raw_english_report, "[Translation Skipped]"
# STAGE 2: TRANSLATE RAW ENGLISH REPORT
try:
hindi_report = translate_report(raw_english_report, target_lang="hin_Deva")
except Exception as e:
hindi_report = f"[Translation failed: {e}]"
return raw_english_report, hindi_report
# --- 8. GRADIO INTERFACE SETUP ---
if __name__ == "__main__":
# Define example image filenames
EXAMPLE_FILENAMES = [
"001c3589-7aed3964-f06ba8d5-03882592-d77f222c.jpg",
"004438db-4a5d6ab3-acc6c408-5dce0934-7d30b269.jpg",
"0006f2ea-d44c6b5e-aeea6fd2-a974657c-90a39211.jpg",
"0008ba07-4e43d6f4-fc692a96-c18a27a8-10eea0cd.jpg",
"001526e1-0d0b8a2d-87e74f7e-72646210-c635fee4.jpg",
"00438e51-4f75714b-943c8edd-6740491f-f8307602.jpg",
"001c78df-8ce750bd-c100a8e0-2874ea0e-09cdbd4e.jpg",
"000b9235-69b5b7e2-1ec32996-50f79b97-46f939cf.jpg",
# "0041603e-059f400f-c509c746-0da5c413-ee889ec1.jpg",
"001198e2-a2adcc23-7253eb78-0dcb5eaa-b10ed183.jpg",
"0003fc7c-3dfce751-9ff36dc3-8fa4f6d9-0515ce50.jpg",
"0018ff6b-8ad1196f-823030d0-1141b667-2a1a117a.jpg",
"00068d26-8d583659-af7de1da-fc6c0476-d94aada1.jpg",
"00196af8-50d17b31-b1b5a7be-da90b7e6-fd3a8004.jpg",
"004017bd-6506697c-3ead0e70-548114b7-2af62447.jpg",
"00059571-ade80b6c-7931ddb8-b486c6c1-1e543b22.jpg",
"00419c98-6f4860a1-3dee986d-8e2ceadc-d2fd30ae.jpg",
"000ffbff-3d93bcef-da8b17cd-fbcede53-51728df9.jpg",
"0016e39b-d0cad5f2-eecb7ae8-4db8b8f2-0b366f1a.jpg",
"00469c3d-4ebf8374-055428f7-d798daca-3e37d354.jpg",
"0013ac79-5eea664c-7ef52c71-7e5a25f3-013715fc.jpg"
]
# Create examples list with only image paths
examples = [[os.path.join("examples", f)] for f in EXAMPLE_FILENAMES]
# Interface components
input_image = gr.Image(type="pil", label="Upload Chest X-ray Image")
output_en = gr.Textbox(label="Generated Radiology Report (English)", lines=5)
output_hi = gr.Textbox(label="Generated Radiology Report (Hindi/हिन्दी)", lines=5)
# Gradio app setup
app = gr.Interface(
fn=inference_wrapper,
inputs=input_image,
outputs=[output_en, output_hi],
title="🔬 Cascading BiomedCLIP-BioGPT & IndicTrans2 Report Generator",
description="Upload a chest X-ray image to generate a radiology finding in English and automatically translate it to Hindi.",
# allow_flagging="never",
examples=examples,
cache_examples=False
# cache_examples=True
)
print("\nStarting Gradio interface...")
app.launch() # Removed share=True for typical Hugging Face Space deployment