image / app.py
mgbam's picture
Update app.py
ac1cd8a verified
raw
history blame
3.65 kB
import torch
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from diffusers import AutoencoderKL
import numpy as np
import gradio as gr
# Configure device and disable FlashAttention
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"Using device: {device}")
# Initialize medical imaging components
def load_medical_models():
try:
# Load processor with medical-specific configuration
processor = VLChatProcessor.from_pretrained(
"deepseek-ai/Janus-1.3B",
medical_mode=True
)
# Load model with CPU/GPU optimization
model = MultiModalityCausalLM.from_pretrained(
"deepseek-ai/Janus-1.3B",
torch_dtype=torch_dtype,
attn_implementation="eager", # Force standard attention
low_cpu_mem_usage=True
).to(device).eval()
# Load VAE with reduced precision
vae = AutoencoderKL.from_pretrained(
"stabilityai/sdxl-vae",
torch_dtype=torch_dtype
).to(device).eval()
return processor, model, vae
except Exception as e:
print(f"Error loading medical models: {str(e)}")
raise
processor, model, vae = load_medical_models()
# Medical image analysis function
def medical_analysis(image, question, seed=42):
try:
# Set random seed for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
# Convert and validate input image
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
# Prepare medical-specific input
inputs = processor(
text=f"<medical_query>{question}</medical_query>",
images=[image],
return_tensors="pt",
max_length=512,
truncation=True
).to(device)
# Generate medical analysis
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.1,
top_p=0.95,
pad_token_id=processor.tokenizer.eos_token_id,
do_sample=True
)
# Clean and return medical report
report = processor.decode(outputs[0], skip_special_tokens=True)
return report.replace("##MEDICAL_REPORT##", "").strip()
except Exception as e:
return f"Radiology analysis error: {str(e)}"
# Medical interface
with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""# AI Radiology Assistant
**CT/MRI/X-ray Analysis System**""")
with gr.Tab("Diagnostic Imaging"):
with gr.Row():
med_image = gr.Image(label="DICOM Image", type="pil")
med_question = gr.Textbox(
label="Clinical Query",
placeholder="Describe findings in this CT scan..."
)
analysis_btn = gr.Button("Analyze", variant="primary")
report_output = gr.Textbox(label="Radiology Report", interactive=False)
# Connect components
med_question.submit(
medical_analysis,
inputs=[med_image, med_question],
outputs=report_output
)
analysis_btn.click(
medical_analysis,
inputs=[med_image, med_question],
outputs=report_output
)
# Launch with CPU optimization
demo.launch(
server_name="0.0.0.0",
server_port=7860,
enable_queue=True,
max_threads=2
)