Spaces:
Runtime error
Runtime error
import torch | |
from janus.models import MultiModalityCausalLM, VLChatProcessor | |
from PIL import Image | |
from diffusers import AutoencoderKL | |
import numpy as np | |
import gradio as gr | |
# Configure device and attention implementation | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
attn_implementation = "flash_attention_2" if device == "cuda" else "eager" | |
print(f"Using device: {device} with {attn_implementation}") | |
# Initialize medical imaging components | |
def load_medical_models(): | |
try: | |
processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B") | |
model = MultiModalityCausalLM.from_pretrained( | |
"deepseek-ai/Janus-1.3B", | |
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, | |
attn_implementation=attn_implementation, | |
use_flash_attention_2=(attn_implementation == "flash_attention_2") | |
).to(device).eval() | |
vae = AutoencoderKL.from_pretrained( | |
"stabilityai/sdxl-vae", | |
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32 | |
).to(device).eval() | |
return processor, model, vae | |
except Exception as e: | |
print(f"Error loading medical models: {str(e)}") | |
raise | |
processor, model, vae = load_medical_models() | |
# Medical image analysis function with attention control | |
def medical_analysis(image, question, seed=42): | |
try: | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert("RGB") | |
inputs = processor( | |
text=f"<medical_query>{question}</medical_query>", | |
images=[image], | |
return_tensors="pt" | |
).to(device) | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.1, | |
top_p=0.95, | |
pad_token_id=processor.tokenizer.eos_token_id | |
) | |
return processor.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
return f"Radiology analysis error: {str(e)}" | |
# Medical interface | |
with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("""# AI Radiology Assistant | |
**CT/MRI/X-ray Analysis System**""") | |
with gr.Tab("Diagnostic Imaging"): | |
with gr.Row(): | |
med_image = gr.Image(label="DICOM Image", type="pil") | |
med_question = gr.Textbox(label="Clinical Query", | |
placeholder="Describe findings in this CT scan...") | |
analysis_btn = gr.Button("Analyze", variant="primary") | |
report_output = gr.Textbox(label="Radiology Report", interactive=False) | |
med_question.submit( | |
medical_analysis, | |
inputs=[med_image, med_question], | |
outputs=report_output | |
) | |
analysis_btn.click( | |
medical_analysis, | |
inputs=[med_image, med_question], | |
outputs=report_output | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) |