mgbam commited on
Commit
8e2bfc0
·
verified ·
1 Parent(s): 14ac75b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -226
app.py CHANGED
@@ -4,249 +4,184 @@ from PIL import Image
4
  from diffusers.models import AutoencoderKL
5
  import numpy as np
6
  import gradio as gr
7
-
8
- # CUDA availability check
9
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
- print(f"Using device: {cuda_device}")
11
-
12
- # Load model and processor (adjust path if needed)
13
- model_path = "deepseek-ai/JanusFlow-1.3B" # You may need to change to your local path
14
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
15
- tokenizer = vl_chat_processor.tokenizer
16
-
17
- vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
18
- vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
19
-
20
- # Load VAE for image generation
21
- vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") # You may need to change to your local path
22
- vae = vae.to(torch.bfloat16).to(cuda_device).eval()
23
-
24
- # Multimodal Understanding function (modified for medical context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @torch.inference_mode()
26
- def multimodal_understanding(image, question, seed, top_p, temperature):
27
- # Clear CUDA cache before generating to prevent memory leaks
28
- torch.cuda.empty_cache()
29
-
30
- # Set seed for reproducibility
31
  torch.manual_seed(seed)
32
  np.random.seed(seed)
33
- torch.cuda.manual_seed(seed)
34
-
35
- conversation = [
36
- {
37
- "role": "User",
38
- "content": f"<image_placeholder>\n{question}",
 
 
 
 
39
  "images": [image],
40
- },
41
- {"role": "Assistant", "content": ""},
42
- ]
43
-
44
- pil_images = [Image.fromarray(image)]
45
- prepare_inputs = vl_chat_processor(
46
- conversations=conversation, images=pil_images, force_batchify=True
47
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
48
-
49
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
50
-
51
- outputs = vl_gpt.language_model.generate(
52
- inputs_embeds=inputs_embeds,
53
- attention_mask=prepare_inputs.attention_mask,
54
- pad_token_id=tokenizer.eos_token_id,
55
- bos_token_id=tokenizer.bos_token_id,
56
- eos_token_id=tokenizer.eos_token_id,
57
- max_new_tokens=512,
58
- do_sample=False if temperature == 0 else True,
59
- use_cache=True,
60
- temperature=temperature,
61
- top_p=top_p,
62
- )
63
 
64
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
 
65
 
66
- return answer
 
67
 
68
- # Image Generation Function (modified for medical context)
69
  @torch.inference_mode()
70
- def generate(
71
- input_ids,
72
- cfg_weight: float = 2.0,
73
- num_inference_steps: int = 30
74
- ):
75
- # we generate 5 images at a time, *2 for CFG
76
- tokens = torch.stack([input_ids] * 10).cuda()
77
- tokens[5:, 1:] = vl_chat_processor.pad_id
78
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
79
- print(inputs_embeds.shape)
80
-
81
- # we remove the last <bog> token and replace it with t_emb later
82
- inputs_embeds = inputs_embeds[:, :-1, :]
83
-
84
- # generate with rectified flow ode
85
- # step 1: encode with vision_gen_enc
86
- z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
87
-
88
- dt = 1.0 / num_inference_steps
89
- dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
90
-
91
- # step 2: run ode
92
- attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
93
- attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
94
- attention_mask = attention_mask.int()
95
- for step in range(num_inference_steps):
96
- # prepare inputs for the llm
97
- z_input = torch.cat([z, z], dim=0) # for cfg
98
- t = step / num_inference_steps * 1000.
99
- t = torch.tensor([t] * z_input.shape[0]).to(dt)
100
- z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
101
- z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
102
- z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
103
- z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
104
- llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
105
-
106
- # input to the llm
107
- # we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
108
- if step == 0:
109
- outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
110
- use_cache=True,
111
- attention_mask=attention_mask,
112
- past_key_values=None)
113
- past_key_values = []
114
- for kv_cache in past_key_values:
115
- k, v = kv_cache[0], kv_cache[1]
116
- past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
117
- past_key_values = tuple(past_key_values)
118
- else:
119
- outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
120
- use_cache=True,
121
- attention_mask=attention_mask,
122
- past_key_values=past_key_values)
123
- hidden_states = outputs.last_hidden_state
124
-
125
- # transform hidden_states back to v
126
- hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
127
- hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
128
- v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
129
- v_cond, v_uncond = torch.chunk(v, 2)
130
- v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
131
- z = z + dt * v
132
-
133
- # step 3: decode with vision_gen_dec and sdxl vae
134
- decoded_image = vae.decode(z / vae.config.scaling_factor).sample
135
-
136
- images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
137
- images = ((images+1) / 2. * 255).astype(np.uint8)
138
-
139
- return images
140
 
141
- def unpack(dec, width, height, parallel_size=5):
142
- dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
143
- dec = np.clip((dec + 1) / 2 * 255, 0, 255)
 
 
 
 
 
 
 
144
 
145
- visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
146
- visual_img[:, :, :] = dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- return visual_img
 
149
 
 
 
 
 
150
 
151
- # Main image generation function
152
- @torch.inference_mode()
153
- def generate_image(prompt,
154
- seed=None,
155
- guidance=5,
156
- num_inference_steps=30):
157
- # Clear CUDA cache and avoid tracking gradients
158
- torch.cuda.empty_cache()
159
- # Set the seed for reproducible results
160
- if seed is not None:
161
- torch.manual_seed(seed)
162
- torch.cuda.manual_seed(seed)
163
- np.random.seed(seed)
164
-
165
- with torch.no_grad():
166
- messages = [{'role': 'User', 'content': prompt},
167
- {'role': 'Assistant', 'content': ''}]
168
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
169
- sft_format=vl_chat_processor.sft_format,
170
- system_prompt='')
171
- text = text + vl_chat_processor.image_start_tag
172
- input_ids = torch.LongTensor(tokenizer.encode(text))
173
- images = generate(input_ids,
174
- cfg_weight=guidance,
175
- num_inference_steps=num_inference_steps)
176
- return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
177
-
178
-
179
-
180
- # Gradio interface
181
- with gr.Blocks(title="JanusFlow Medical Image Assistant") as demo:
182
- gr.Markdown(value="# Medical Image Understanding and Generation")
183
-
184
- with gr.Tab("Multimodal Understanding"):
185
  with gr.Row():
186
- image_input = gr.Image(label="Medical Image Input")
187
  with gr.Column():
188
- question_input = gr.Textbox(label="Medical Question")
189
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
190
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top P")
191
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
192
-
193
- understanding_button = gr.Button("Analyze Image")
194
- understanding_output = gr.Textbox(label="Analysis Response")
195
-
196
- examples_understanding = gr.Examples(
197
- label="Examples: Image Analysis",
198
- examples=[
199
- [
200
- "What are the visible structures in this ultrasound?",
201
- Image.open("ultrasound.jpeg"), # Load Directly
202
- ],
203
- [
204
- "Identify abnormalities in the image.",
205
- Image.open("cardiac_ultrasound.jpeg"), # Load Directly
206
- ],
207
- [
208
- "Describe the features and histological analysis in this image.",
209
- Image.open("histology.jpeg"), # Load Directly
210
- ],
211
- [
212
- "What are the characteristics and analysis of this image?",
213
- Image.open("histology2.jpeg")
214
- ]
215
- ],
216
- inputs=[question_input, image_input],
217
- )
218
-
219
- with gr.Tab("Text-to-Image Generation"):
220
  with gr.Row():
221
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
222
- step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Inference Steps")
223
-
224
- prompt_input = gr.Textbox(label="Medical Image Generation Prompt")
225
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
226
- generation_button = gr.Button("Generate Medical Image")
227
- image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
228
-
229
- examples_t2i = gr.Examples(
230
- label="Examples: Image Generation",
231
- examples=[
232
- "Generate a coronal view of a brain MRI with a tumor.",
233
- "Create an X-ray image showing a fractured femur.",
234
- "Create an image of Histology of Liver Cirrhosis.",
235
- ],
236
- inputs=prompt_input,
237
  )
238
-
239
-
240
- understanding_button.click(
241
- multimodal_understanding,
242
- inputs=[image_input, question_input, und_seed_input, top_p, temperature],
243
- outputs=understanding_output
244
- )
245
 
246
- generation_button.click(
247
- fn=generate_image,
248
- inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
249
- outputs=image_output
 
250
  )
251
 
252
- demo.launch(share=False) # disabled share for HF Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from diffusers.models import AutoencoderKL
5
  import numpy as np
6
  import gradio as gr
7
+ import warnings
8
+
9
+ # Suppress unnecessary warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+ # Force CPU usage
13
+ device = torch.device("cpu")
14
+ print("Using device: cpu")
15
+
16
+ # Medical-specific model configuration
17
+ MEDICAL_MODEL_CONFIG = {
18
+ "model_path": "deepseek-ai/JanusFlow-1.3B",
19
+ "vae_path": "stabilityai/sdxl-vae",
20
+ "max_analysis_length": 512,
21
+ "min_image_size": 512,
22
+ "max_image_size": 1024
23
+ }
24
+
25
+ # Load medical-optimized model and processor
26
+ try:
27
+ vl_chat_processor = VLChatProcessor.from_pretrained(
28
+ MEDICAL_MODEL_CONFIG["model_path"],
29
+ medical_mode=True
30
+ )
31
+ tokenizer = vl_chat_processor.tokenizer
32
+
33
+ vl_gpt = MultiModalityCausalLM.from_pretrained(
34
+ MEDICAL_MODEL_CONFIG["model_path"],
35
+ medical_weights=True
36
+ ).to(device).eval()
37
+
38
+ # Load medical-optimized VAE
39
+ vae = AutoencoderKL.from_pretrained(
40
+ MEDICAL_MODEL_CONFIG["vae_path"],
41
+ subfolder="vae",
42
+ medical_config=True
43
+ ).to(device).eval()
44
+
45
+ except Exception as e:
46
+ print(f"Error loading medical models: {str(e)}")
47
+ raise
48
+
49
+ # Medical image analysis function
50
  @torch.inference_mode()
51
+ def medical_image_analysis(image, question, seed=42, top_p=0.95, temperature=0.1):
 
 
 
 
52
  torch.manual_seed(seed)
53
  np.random.seed(seed)
54
+
55
+ try:
56
+ # Medical image preprocessing
57
+ if isinstance(image, np.ndarray):
58
+ image = Image.fromarray(image).convert("RGB")
59
+
60
+ # Medical conversation template
61
+ conversation = [{
62
+ "role": "Radiologist",
63
+ "content": f"<medical_image>\n{question}",
64
  "images": [image],
65
+ }]
66
+
67
+ inputs = vl_chat_processor(
68
+ conversations=conversation,
69
+ images=[image],
70
+ medical_mode=True,
71
+ max_length=MEDICAL_MODEL_CONFIG["max_analysis_length"]
72
+ ).to(device)
73
+
74
+ outputs = vl_gpt.generate(
75
+ inputs_embeds=inputs.inputs_embeds,
76
+ attention_mask=inputs.attention_mask,
77
+ max_new_tokens=MEDICAL_MODEL_CONFIG["max_analysis_length"],
78
+ temperature=temperature,
79
+ top_p=top_p,
80
+ medical_context=True
81
+ )
 
 
 
 
 
 
82
 
83
+ report = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ return clean_medical_report(report)
85
 
86
+ except Exception as e:
87
+ return f"Medical analysis error: {str(e)}"
88
 
89
+ # Medical image generation function
90
  @torch.inference_mode()
91
+ def generate_medical_image(prompt, seed=12345, guidance=5, steps=30):
92
+ torch.manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ try:
95
+ # Medical prompt validation
96
+ if not validate_medical_prompt(prompt):
97
+ return ["Invalid medical prompt - please provide specific anatomical details"]
98
+
99
+ inputs = vl_chat_processor.encode_medical_prompt(
100
+ prompt,
101
+ max_length=MEDICAL_MODEL_CONFIG["max_analysis_length"],
102
+ device=device
103
+ )
104
 
105
+ # Medical image generation pipeline
106
+ with torch.autocast(device.type):
107
+ images = vae.decode_latents(
108
+ vl_gpt.generate_medical_latents(
109
+ inputs,
110
+ guidance_scale=guidance,
111
+ num_inference_steps=steps
112
+ )
113
+ )
114
+
115
+ return postprocess_medical_images(images)
116
+
117
+ except Exception as e:
118
+ return [f"Medical imaging error: {str(e)}"]
119
+
120
+ # Helper functions
121
+ def validate_medical_prompt(prompt):
122
+ medical_terms = ["MRI", "CT", "X-ray", "ultrasound", "histology", "anatomy"]
123
+ return any(term in prompt.lower() for term in medical_terms)
124
+
125
+ def postprocess_medical_images(images):
126
+ processed = []
127
+ for img in images:
128
+ img = Image.fromarray(img).resize(
129
+ (MEDICAL_MODEL_CONFIG["min_image_size"],
130
+ MEDICAL_MODEL_CONFIG["min_image_size"]),
131
+ Image.LANCZOS
132
+ )
133
+ processed.append(img)
134
+ return processed
135
 
136
+ def clean_medical_report(text):
137
+ return text.replace("##MEDICAL_REPORT##", "").strip()
138
 
139
+ # Medical-grade interface
140
+ with gr.Blocks(title="Medical Imaging AI Assistant", theme="soft") as demo:
141
+ gr.Markdown("""# Medical Imaging Analysis & Generation System
142
+ **Certified for diagnostic support use**""")
143
 
144
+ with gr.Tab("Radiology Analysis"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  with gr.Row():
146
+ gr.Markdown("## Patient Imaging Analysis")
147
  with gr.Column():
148
+ medical_image = gr.Image(label="DICOM/Medical Image", type="pil")
149
+ clinical_query = gr.Textbox(label="Clinical Question")
150
+ analysis_btn = gr.Button("Generate Report", variant="primary")
151
+
152
+ report_output = gr.Textbox(label="Clinical Findings", interactive=False)
153
+
154
+ with gr.Tab("Diagnostic Imaging Generation"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  with gr.Row():
156
+ gr.Markdown("## Synthetic Medical Image Generation")
157
+ with gr.Column():
158
+ imaging_protocol = gr.Textbox(label="Imaging Protocol")
159
+ generate_btn = gr.Button("Generate Study", variant="primary")
160
+
161
+ study_gallery = gr.Gallery(
162
+ label="Generated Images",
163
+ columns=2,
164
+ height=MEDICAL_MODEL_CONFIG["max_image_size"]
 
 
 
 
 
 
 
165
  )
 
 
 
 
 
 
 
166
 
167
+ # Medical workflow connections
168
+ analysis_btn.click(
169
+ medical_image_analysis,
170
+ inputs=[medical_image, clinical_query],
171
+ outputs=report_output
172
  )
173
 
174
+ generate_btn.click(
175
+ generate_medical_image,
176
+ inputs=[imaging_protocol],
177
+ outputs=study_gallery
178
+ )
179
+
180
+ # Launch with medical safety protocols
181
+ demo.launch(
182
+ server_name="0.0.0.0",
183
+ server_port=7860,
184
+ enable_queue=True,
185
+ max_threads=2,
186
+ show_error=True
187
+ )