File size: 11,811 Bytes
6135411 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e f78768c 6135411 7d5cf4e 6135411 2d5cae8 6135411 2d5cae8 7d5cf4e 2d5cae8 b96f5ef 2d5cae8 7d5cf4e 2d5cae8 6135411 7d5cf4e 6135411 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 7d5cf4e 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 6135411 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e 6135411 99e1168 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e 2d5cae8 6135411 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e 2d5cae8 6135411 2d5cae8 7d5cf4e 2d5cae8 7d5cf4e 2d5cae8 6135411 2d5cae8 7d5cf4e 2d5cae8 6135411 415506c 2d5cae8 7d5cf4e 415506c edfbc7f 6135411 7d5cf4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
import os
import torch
import torch.nn as nn
from torchvision import transforms
from typing import Dict, Any
from PIL import Image
import open_clip
from transformers import (
BioGptTokenizer,
BioGptForCausalLM,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
import gradio as gr
# NOTE: Ensure this library is installed on the Hugging Face Space
from IndicTransToolkit import IndicProcessor
from huggingface_hub import hf_hub_download # New import for HF deployment
# --- 1. CONFIGURATION (Stage 1: Report Generation) ---
# NOTE: Update this REPO_ID to the actual Hugging Face repository where you upload your .pth files!
REPO_ID = "Robinhood135/biogptm1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- MODEL/DECODING PARAMS ---
BIOMEDCLIP_MODEL_NAME = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
PREFIX_LENGTH = 10
PROMPT_TEXT = "You are a Radiologist.The chest image findings are:"
# --- BEST DECODING STRATEGY (Beam Search) ---
BEST_STRATEGY_PARAMS = {
"num_beams": 4,
"do_sample": False,
"repetition_penalty": 1.2,
"max_new_tokens": 100,
"min_new_tokens": 10,
}
# --- 2. MODEL CLASS (Stage 1) - Kept the same ---
def freeze_module(module: nn.Module):
for param in module.parameters(): param.requires_grad = False
class BiomedCLIPBioGPTGenerator(nn.Module):
def __init__(self, tokenizer, model_name=BIOMEDCLIP_MODEL_NAME, prefix_length=PREFIX_LENGTH):
super().__init__()
self.tokenizer = tokenizer
self.prefix_length = prefix_length
self.clip_model, _, _ = open_clip.create_model_and_transforms(model_name)
# Handle cases where image encoder is visual or a direct method
self.image_encoder = self.clip_model.visual if hasattr(self.clip_model, 'visual') else self.clip_model.encode_image
freeze_module(self.image_encoder)
with torch.no_grad():
dummy_features = self.image_encoder(torch.randn(1, 3, 224, 224))
if isinstance(dummy_features, tuple): dummy_features = dummy_features[0]
self.embed_dim = dummy_features.shape[-1]
config = BioGptForCausalLM.from_pretrained('microsoft/biogpt').config
self.biogpt = BioGptForCausalLM.from_pretrained('microsoft/biogpt', config=config)
self.biogpt.resize_token_embeddings(len(self.tokenizer))
self.gpt_hidden_dim = self.biogpt.config.hidden_size
self.biogpt.config.pad_token_id = self.tokenizer.pad_token_id
self.projection_head = nn.Sequential(
nn.Linear(self.embed_dim, self.prefix_length * self.gpt_hidden_dim),
nn.Tanh(),
nn.Linear(self.prefix_length * self.gpt_hidden_dim, self.prefix_length * self.gpt_hidden_dim)
)
@torch.no_grad()
def get_prefix_embeddings(self, images):
clip_features = self.image_encoder(images).float()
prefix_embeds = self.projection_head(clip_features)
return prefix_embeds.view(-1, self.prefix_length, self.gpt_hidden_dim)
def get_text_embeddings(self, input_ids):
return self.biogpt.get_input_embeddings()(input_ids)
# --- 3. INFERENCE FUNCTION (Stage 1) - Kept the same ---
@torch.no_grad()
def generate_report(model, pil_image: Image.Image, method_params: Dict[str, Any]):
model.eval()
# 3.1 Apply image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD)
])
image_tensor = transform(pil_image.convert('RGB')).unsqueeze(0).to(device)
# 3.2 Get prefix embeddings
prefix_embeds = model.get_prefix_embeddings(image_tensor)
# 3.3 Encode prompt text
prompt_data = model.tokenizer(PROMPT_TEXT, return_tensors="pt").to(device)
prompt_embeds = model.get_text_embeddings(prompt_data["input_ids"])
combined_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1)
prefix_att_mask = torch.ones(prefix_embeds.shape[:2], dtype=torch.long, device=device)
combined_att_mask = torch.cat([prefix_att_mask, prompt_data["attention_mask"]], dim=1)
# 3.4 Generation parameters
generation_args = {
"inputs_embeds": combined_embeds,
"attention_mask": combined_att_mask,
"pad_token_id": model.tokenizer.pad_token_id,
"eos_token_id": model.tokenizer.eos_token_id,
"use_cache": True,
}
generation_args.update(method_params)
# 3.5 Generate
generated_ids = model.biogpt.generate(**generation_args)
# 3.6 Decode and clean
full_text = model.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
if full_text.startswith(PROMPT_TEXT):
text = full_text[len(PROMPT_TEXT):].strip()
else:
text = full_text
return text if text.strip() else "[BLANK/FAILED GENERATION]"
# --- 4. MODEL LOADING (Stage 1) - MODIFIED FOR HF HUB ---
def load_trained_generator():
print(f"Loading Report Generator model from {REPO_ID}...")
# Load from Hugging Face Hub
try:
clip_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biomedclipp.pth")
gpt_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biogptt.pth")
proj_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="projectorr.pth")
except Exception as e:
raise FileNotFoundError(f"Failed to download one or more checkpoint files from {REPO_ID}. Error: {e}")
# Initialize tokenizer
base_tokenizer = BioGptTokenizer.from_pretrained('microsoft/biogpt')
if base_tokenizer.pad_token is None:
base_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Initialize model
model = BiomedCLIPBioGPTGenerator(base_tokenizer).to(device)
# Load CLIP encoder
clip_checkpoint = torch.load(clip_ckpt_path, map_location=device)
state_dict = clip_checkpoint.get('model_state_dict', clip_checkpoint.get('state_dict', clip_checkpoint))
# Filter state dict for the visual encoder and clean keys
visual_state = {k.replace('model.visual.', '').replace('visual.', ''): v for k, v in state_dict.items() if 'visual' in k}
model.image_encoder.load_state_dict(visual_state, strict=False)
# Load trained BioGPT and Projection weights
model.biogpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
model.projection_head.load_state_dict(torch.load(proj_ckpt_path, map_location=device))
model.eval()
print("✅ Report Generator loaded successfully.")
return model
# --- 5. MODEL LOADING (Stage 2: Translation) - Kept the same ---
def load_translator():
# IndicTrans2 models are typically loaded directly from their HF repos (ai4bharat/...)
print("Loading Translation model (IndicTrans2)...")
try:
# IndicTransToolkit library is assumed to be installed
ip = IndicProcessor(inference=True)
model_name = "ai4bharat/indictrans2-en-indic-dist-200M"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Note: If memory is an issue on the Space, you might need to use a smaller model or lower precision.
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True).to(device)
print("✅ Translator loaded successfully.")
return ip, tokenizer, model
except Exception as e:
print(f"Error loading translation model: {e}")
# Return dummy values if loading fails to prevent crash
return None, None, None
# Load models globally
GENERATOR_MODEL = load_trained_generator()
IP, TRANS_TOKENIZER, TRANS_MODEL = load_translator()
# --- 6. TRANSLATION FUNCTION (Stage 2) - Kept the same ---
@torch.no_grad()
def translate_report(english_text: str, target_lang: str = "hin_Deva") -> str:
if TRANS_MODEL is None or not english_text:
return "[Translation Model Not Available or No Text to Translate]"
# 6.1 Preprocessing
batch = IP.preprocess_batch([english_text], src_lang="eng_Latn", tgt_lang=target_lang, visualize=False)
batch = TRANS_TOKENIZER(batch, padding="longest", truncation=True, max_length=256, return_tensors="pt").to(device)
# 6.2 Generation
outputs = TRANS_MODEL.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256, use_cache=False)
# 6.3 Postprocessing
outputs = TRANS_TOKENIZER.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
translated_text = IP.postprocess_batch(outputs, lang=target_lang)[0]
return translated_text
# --- 7. GRADIO WRAPPER FUNCTION (Simplified) - Kept the same ---
def inference_wrapper(input_image: Image.Image):
if input_image is None:
return "Please upload a chest X-ray image.", "[No English Report]"
# STAGE 1: GENERATE RAW ENGLISH REPORT
try:
raw_english_report = generate_report(GENERATOR_MODEL, input_image, BEST_STRATEGY_PARAMS)
except Exception as e:
raw_english_report = f"An error occurred during generation: {e}"
return raw_english_report, "[Translation Skipped]"
# STAGE 2: TRANSLATE RAW ENGLISH REPORT
try:
hindi_report = translate_report(raw_english_report, target_lang="hin_Deva")
except Exception as e:
hindi_report = f"[Translation failed: {e}]"
return raw_english_report, hindi_report
# --- 8. GRADIO INTERFACE SETUP ---
if __name__ == "__main__":
# Define example image filenames
EXAMPLE_FILENAMES = [
"001c3589-7aed3964-f06ba8d5-03882592-d77f222c.jpg",
"004438db-4a5d6ab3-acc6c408-5dce0934-7d30b269.jpg",
"0006f2ea-d44c6b5e-aeea6fd2-a974657c-90a39211.jpg",
"0008ba07-4e43d6f4-fc692a96-c18a27a8-10eea0cd.jpg",
"001526e1-0d0b8a2d-87e74f7e-72646210-c635fee4.jpg",
"00438e51-4f75714b-943c8edd-6740491f-f8307602.jpg",
"001c78df-8ce750bd-c100a8e0-2874ea0e-09cdbd4e.jpg",
"000b9235-69b5b7e2-1ec32996-50f79b97-46f939cf.jpg",
# "0041603e-059f400f-c509c746-0da5c413-ee889ec1.jpg",
"001198e2-a2adcc23-7253eb78-0dcb5eaa-b10ed183.jpg",
"0003fc7c-3dfce751-9ff36dc3-8fa4f6d9-0515ce50.jpg",
"0018ff6b-8ad1196f-823030d0-1141b667-2a1a117a.jpg",
"00068d26-8d583659-af7de1da-fc6c0476-d94aada1.jpg",
"00196af8-50d17b31-b1b5a7be-da90b7e6-fd3a8004.jpg",
"004017bd-6506697c-3ead0e70-548114b7-2af62447.jpg",
"00059571-ade80b6c-7931ddb8-b486c6c1-1e543b22.jpg",
"00419c98-6f4860a1-3dee986d-8e2ceadc-d2fd30ae.jpg",
"000ffbff-3d93bcef-da8b17cd-fbcede53-51728df9.jpg",
"0016e39b-d0cad5f2-eecb7ae8-4db8b8f2-0b366f1a.jpg",
"00469c3d-4ebf8374-055428f7-d798daca-3e37d354.jpg",
"0013ac79-5eea664c-7ef52c71-7e5a25f3-013715fc.jpg"
]
# Create examples list with only image paths
examples = [[os.path.join("examples", f)] for f in EXAMPLE_FILENAMES]
# Interface components
input_image = gr.Image(type="pil", label="Upload Chest X-ray Image")
output_en = gr.Textbox(label="Generated Radiology Report (English)", lines=5)
output_hi = gr.Textbox(label="Generated Radiology Report (Hindi/हिन्दी)", lines=5)
# Gradio app setup
app = gr.Interface(
fn=inference_wrapper,
inputs=input_image,
outputs=[output_en, output_hi],
title="🔬 Cascading BiomedCLIP-BioGPT & IndicTrans2 Report Generator",
description="Upload a chest X-ray image to generate a radiology finding in English and automatically translate it to Hindi.",
# allow_flagging="never",
examples=examples,
cache_examples=False
# cache_examples=True
)
print("\nStarting Gradio interface...")
app.launch() # Removed share=True for typical Hugging Face Space deployment |