import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images from PIL import Image import numpy as np import os import time import spaces # Load medical imaging-optimized model and processor model_path = "deepseek-ai/Janus-Pro-1B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' # Initialize model with medical imaging parameters vl_gpt = AutoModelForCausalLM.from_pretrained( model_path, language_config=language_config, trust_remote_code=True, medical_head=True # Assuming custom medical imaging head ).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16) if torch.cuda.is_available(): vl_gpt = vl_gpt.cuda() vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.inference_mode() @spaces.GPU(duration=120) def medical_image_analysis(medical_image, clinical_question, seed, top_p, temperature): """Analyze medical images (CT, MRI, X-ray, histopathology) with clinical context.""" torch.cuda.empty_cache() torch.manual_seed(seed) # Medical-specific conversation template conversation = [{ "role": "<|Radiologist|>", "content": f"\nClinical Context: {clinical_question}", "images": [medical_image], }, {"role": "<|AI_Assistant|>", "content": ""}] processed_image = [Image.fromarray(medical_image)] inputs = vl_chat_processor( conversations=conversation, images=processed_image, force_batchify=True ).to(cuda_device, dtype=torch.bfloat16) inputs_embeds = vl_gpt.prepare_inputs_embeds(**inputs) # Medical-optimized generation parameters outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=inputs.attention_mask, max_new_tokens=512, temperature=0.2, # Lower for clinical precision top_p=0.9, repetition_penalty=1.2, # Reduce hallucination medical_mode=True ) findings = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return f"Clinical Findings:\n{findings}" @torch.inference_mode() @spaces.GPU(duration=120) def generate_medical_image(prompt, seed=None, guidance=5, t2i_temperature=0.5): """Generate synthetic medical images for educational/research purposes.""" torch.cuda.empty_cache() if seed is not None: torch.manual_seed(seed) # Medical image generation parameters medical_config = { 'width': 512, 'height': 512, 'parallel_size': 3, 'modality': 'mri', # Can specify CT, X-ray, etc. 'anatomy': 'brain' # Target anatomy } messages = [{ 'role': '<|Clinician|>', 'content': f"{prompt} [Modality: {medical_config['modality']}, Anatomy: {medical_config['anatomy']}]" }] text = vl_chat_processor.apply_medical_template( messages, system_prompt='Generate education-quality medical imaging data' ) input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device) generated_tokens, patches = vl_gpt.generate_medical_image( input_ids, **medical_config, cfg_weight=guidance, temperature=t2i_temperature ) # Post-processing for medical imaging standards synthetic_images = postprocess_medical_images(patches, **medical_config) return [Image.fromarray(img).resize((512, 512)) for img in synthetic_images] # Medical-optimized Gradio interface with gr.Blocks(title="Medical Imaging AI Suite") as demo: gr.Markdown("""## Medical Image Analysis Suite v2.1 *For research use only - not for clinical diagnosis*""") with gr.Tab("Clinical Image Analysis"): with gr.Row(): medical_image_input = gr.Image(label="Upload Medical Scan") clinical_question = gr.Textbox(label="Clinical Query", placeholder="E.g.: 'Assess tumor progression in this MRI series'") with gr.Accordion("Advanced Parameters", open=False): und_seed = gr.Number(42, label="Reproducibility Seed") analysis_top_p = gr.Slider(0.8, 1.0, 0.95, label="Diagnostic Certainty") analysis_temp = gr.Slider(0.1, 0.5, 0.2, label="Analysis Precision") analysis_btn = gr.Button("Analyze Scan", variant="primary") clinical_report = gr.Textbox(label="AI Analysis Report", interactive=False) gr.Examples( examples=[ ["Identify pulmonary nodules in this CT scan", "ct_chest.png"], ["Assess MRI for multiple sclerosis lesions", "brain_mri.jpg"], ["Histopathology analysis: tumor grading", "biopsy_slide.png"] ], inputs=[clinical_question, medical_image_input] ) with gr.Tab("Medical Imaging Synthesis"): gr.Markdown("**Educational Image Generation**") synth_prompt = gr.Textbox(label="Synthesis Prompt", placeholder="E.g.: 'Synthetic brain MRI showing glioblastoma multiforme'") with gr.Row(): synth_guidance = gr.Slider(3, 7, 5, label="Anatomical Accuracy") synth_temp = gr.Slider(0.3, 1.0, 0.6, label="Synthesis Variability") synth_btn = gr.Button("Generate Educational Images", variant="secondary") synthetic_gallery = gr.Gallery(label="Synthetic Medical Images", columns=3, object_fit="contain") gr.Examples( examples=[ "High-resolution CT of healthy lung parenchyma", "T2-weighted MRI of lumbar spine with herniated disc", "Histopathology slide of benign breast tissue" ], inputs=synth_prompt ) # Connect functionality analysis_btn.click( medical_image_analysis, inputs=[medical_image_input, clinical_question, und_seed, analysis_top_p, analysis_temp], outputs=clinical_report ) synth_btn.click( generate_medical_image, inputs=[synth_prompt, und_seed, synth_guidance, synth_temp], outputs=synthetic_gallery ) demo.launch(share=True, server_port=7860)