mgbam commited on
Commit
a85b9e3
·
verified ·
1 Parent(s): 46a11a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -204
app.py CHANGED
@@ -1,18 +1,7 @@
1
- import sys
2
- sys.path.append('./LLAUS')
3
-
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
5
  import torch
6
- from llava import LlavaLlamaForCausalLM
7
- from llava.conversation import conv_templates
8
- from llava.utils import disable_torch_init
9
- from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
10
  from PIL import Image
11
- from torch.cuda.amp import autocast
12
  import gradio as gr
13
  import spaces
14
- from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
15
- import os
16
  from transformers import AutoProcessor, AutoModel
17
  import torch.nn.functional as F
18
 
@@ -20,69 +9,6 @@ import torch.nn.functional as F
20
  #++++++++ Model ++++++++++
21
  #---------------------------------
22
 
23
- DEFAULT_IMAGE_TOKEN = "<image>"
24
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
25
- DEFAULT_IM_START_TOKEN = "<im_start>"
26
- DEFAULT_IM_END_TOKEN = "<im_end>"
27
-
28
- def patch_config(config_path):
29
- """Applies necessary patches to the model config."""
30
- patch_dict = {
31
- "use_mm_proj": True,
32
- "mm_vision_tower": "openai/clip-vit-large-patch14",
33
- "mm_hidden_size": 1024
34
- }
35
- cfg = AutoConfig.from_pretrained(config_path)
36
- if not hasattr(cfg, "mm_vision_tower"):
37
- print(f'`mm_vision_tower` not found in `{config_path}`, applying patch and save to disk.')
38
- for k, v in patch_dict.items():
39
- setattr(cfg, k, v)
40
- cfg.save_pretrained(config_path)
41
-
42
- def load_llava_model():
43
- """Loads and initializes the LLaVA model."""
44
- model_name = "Baron-GG/LLaVA-Med" # Change this to your model if you uploaded a new one
45
- disable_torch_init()
46
- tokenizer = AutoTokenizer.from_pretrained(model_name)
47
- patch_config(model_name)
48
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
49
- model.model.requires_grad_(False)
50
-
51
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
52
-
53
- model.config.use_cache = False
54
- model.config.tune_mm_mlp_adapter = False
55
- model.config.freeze_mm_mlp_adapter = False
56
- model.config.mm_use_im_start_end = True
57
-
58
- mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
59
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
60
- if mm_use_im_start_end:
61
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
62
-
63
- vision_tower = model.model.vision_tower[0]
64
- vision_tower.to(device='cuda', dtype=torch.float16)
65
- vision_config = vision_tower.config
66
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
67
- vision_config.use_im_start_end = mm_use_im_start_end
68
- if mm_use_im_start_end:
69
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
70
- image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
71
-
72
- model = prepare_model_for_int8_training(model)
73
- lora_config = LoraConfig(
74
- r=64,
75
- lora_alpha=16,
76
- target_modules=["q_proj", "v_proj","k_proj","o_proj"],
77
- lora_dropout=0.05,
78
- bias="none",
79
- task_type="CAUSAL_LM",
80
- )
81
- model = get_peft_model(model, lora_config).cuda()
82
-
83
- model.eval()
84
- return model, tokenizer, image_processor, image_token_len, mm_use_im_start_end
85
-
86
  def load_biomedclip_model():
87
  """Loads the BiomedCLIP model and tokenizer."""
88
  biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
@@ -90,25 +16,6 @@ def load_biomedclip_model():
90
  model = AutoModel.from_pretrained(biomedclip_model_name).cuda().eval()
91
  return model, processor
92
 
93
-
94
- class KeywordsStoppingCriteria(StoppingCriteria):
95
- """Custom stopping criteria for generation."""
96
- def __init__(self, keywords, tokenizer, input_ids):
97
- self.keywords = keywords
98
- self.tokenizer = tokenizer
99
- self.start_len = None
100
- self.input_ids = input_ids
101
-
102
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
103
- if self.start_len is None:
104
- self.start_len = self.input_ids.shape[1]
105
- else:
106
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
107
- for keyword in self.keywords:
108
- if keyword in outputs:
109
- return True
110
- return False
111
-
112
  def compute_similarity(image, text, biomedclip_model, biomedclip_processor):
113
  """Computes similarity scores using BiomedCLIP."""
114
  with torch.no_grad():
@@ -121,91 +28,24 @@ def compute_similarity(image, text, biomedclip_model, biomedclip_processor):
121
  similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
122
  return similarity
123
 
124
- @torch.no_grad()
125
- def eval_llava_model(llava_model, llava_tokenizer, llava_image_processor, image, question, image_token_len, mm_use_im_start_end, max_new_tokens, temperature):
126
- """Evaluates the LLaVA model for a given image and question."""
127
-
128
- image_list = []
129
- image_tensor = llava_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # 3, 224, 224
130
- image_list.append(image_tensor)
131
- image_idx = 1
132
-
133
- if mm_use_im_start_end:
134
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len * image_idx + DEFAULT_IM_END_TOKEN + question
135
- else:
136
- qs = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len * image_idx + '\n' + question
137
-
138
- conv = conv_templates["simple"].copy()
139
- conv.append_message(conv.roles[0], qs)
140
- prompt = conv.get_prompt()
141
- inputs = llava_tokenizer([prompt])
142
-
143
- image_tensor = torch.stack(image_list, dim=0).half().cuda()
144
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
145
-
146
- keywords = ['###']
147
- stopping_criteria = KeywordsStoppingCriteria(keywords, llava_tokenizer, input_ids)
148
-
149
- with autocast():
150
- output_ids = llava_model.generate(
151
- input_ids=input_ids,
152
- images=image_tensor,
153
- do_sample=True,
154
- temperature=temperature,
155
- max_new_tokens=max_new_tokens,
156
- stopping_criteria=[stopping_criteria]
157
- )
158
-
159
- input_token_len = input_ids.shape[1]
160
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
161
- if n_diff_input_output > 0:
162
- print(f'[Warning] Sample: {n_diff_input_output} output_ids are not the same as the input_ids')
163
- outputs = llava_tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
164
-
165
- while True:
166
- cur_len = len(outputs)
167
- outputs = outputs.strip()
168
- for pattern in ['###', 'Assistant:', 'Response:']:
169
- if outputs.startswith(pattern):
170
- outputs = outputs[len(pattern):].strip()
171
- if len(outputs) == cur_len:
172
- break
173
-
174
- try:
175
- index = outputs.index(conv.sep)
176
- except ValueError:
177
- outputs += conv.sep
178
- index = outputs.index(conv.sep)
179
-
180
- outputs = outputs[:index].strip()
181
- print(outputs)
182
- return outputs
183
-
184
  #---------------------------------
185
  #++++++++ Gradio ++++++++++
186
  #---------------------------------
187
 
188
- SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
189
- You can duplicate and use it with a paid private GPU.
190
- <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
191
- Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
192
- '''
193
-
194
- def gradio_reset(chat_state, img_list):
195
  """Resets the chat state and image list."""
196
  if chat_state is not None:
197
  chat_state.messages = []
198
  if img_list is not None:
199
  img_list = []
200
- return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your medical image first', interactive=False), gr.update(value="Upload & Start Analysis", interactive=True), chat_state, img_list
201
 
202
- def upload_img(gr_img, text_input, chat_state):
203
  """Handles image upload."""
204
  if gr_img is None:
205
- return None, None, gr.update(interactive=True), chat_state, None
206
  img_list = [gr_img]
207
- return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Analysis", interactive=False), chat_state, img_list
208
-
209
 
210
  def gradio_ask(user_message, chatbot, chat_state):
211
  """Handles user input."""
@@ -215,24 +55,21 @@ def gradio_ask(user_message, chatbot, chat_state):
215
  return '', chatbot, chat_state
216
 
217
  @spaces.GPU
218
- def gradio_answer(chatbot, chat_state, img_list, llava_model, llava_tokenizer, llava_image_processor, image_token_len, mm_use_im_start_end, max_new_token, temperature, biomedclip_model, biomedclip_processor):
219
- """Generates and adds the bot's response to the chatbot using LLaVA"""
220
  if not img_list:
221
- return chatbot, chat_state, img_list
222
-
223
- # compute similarity using biomedclip
224
- similarity_score = compute_similarity(img_list[0],chatbot[-1][0], biomedclip_model, biomedclip_processor)
225
  print(f'Similarity Score is: {similarity_score}')
226
-
227
- # prepare the input for LLAVA
228
- llava_input_text = f"Based on the image and query provided the similarity score is {similarity_score:.3f}. " + chatbot[-1][0]
229
- llm_message = eval_llava_model(llava_model, llava_tokenizer, llava_image_processor, img_list[0], llava_input_text, image_token_len, mm_use_im_start_end, max_new_token, temperature)
230
-
231
- chatbot[-1][1] = llm_message
232
- return chatbot, chat_state, img_list
233
 
234
  title = """<h1 align="center">Medical Image Analysis Tool</h1>"""
235
- description = """<h3>Upload medical images, ask questions, and receive analysis.</h3>"""
236
  examples_list=[
237
  ["./case1.png", "Analyze the X-ray for any abnormalities."],
238
  ["./case2.jpg", "What type of disease may be present?"],
@@ -240,13 +77,10 @@ examples_list=[
240
  ]
241
 
242
  # Load models and related resources outside of the Gradio block for loading on startup
243
- llava_model, llava_tokenizer, llava_image_processor, image_token_len, mm_use_im_start_end = load_llava_model()
244
  biomedclip_model, biomedclip_processor = load_biomedclip_model()
245
 
246
-
247
  with gr.Blocks() as demo:
248
  gr.Markdown(title)
249
- # gr.Markdown(SHARED_UI_WARNING)
250
  gr.Markdown(description)
251
 
252
  with gr.Row():
@@ -255,37 +89,19 @@ with gr.Blocks() as demo:
255
  upload_button = gr.Button(value="Upload & Start Analysis", interactive=True, variant="primary")
256
  clear = gr.Button("Restart")
257
 
258
- max_new_token = gr.Slider(
259
- minimum=1,
260
- maximum=512,
261
- value=128,
262
- step=1,
263
- interactive=True,
264
- label="Max new tokens"
265
- )
266
-
267
- temperature = gr.Slider(
268
- minimum=0.1,
269
- maximum=2.0,
270
- value=0.3,
271
- step=0.1,
272
- interactive=True,
273
- label="Temperature",
274
- )
275
-
276
  with gr.Column():
277
  chat_state = gr.State()
278
  img_list = gr.State()
279
  chatbot = gr.Chatbot(label='Medical Analysis')
280
  text_input = gr.Textbox(label='Analysis Query', placeholder='Please upload your medical image first', interactive=False)
 
281
  gr.Examples(examples=examples_list, inputs=[image, text_input])
282
 
283
- upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
284
 
285
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
286
- gradio_answer, [chatbot, chat_state, img_list, llava_model, llava_tokenizer, llava_image_processor, image_token_len, mm_use_im_start_end, max_new_token, temperature, biomedclip_model, biomedclip_processor], [chatbot, chat_state, img_list]
287
  )
288
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
289
-
290
 
291
  demo.launch()
 
 
 
 
 
1
  import torch
 
 
 
 
2
  from PIL import Image
 
3
  import gradio as gr
4
  import spaces
 
 
5
  from transformers import AutoProcessor, AutoModel
6
  import torch.nn.functional as F
7
 
 
9
  #++++++++ Model ++++++++++
10
  #---------------------------------
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def load_biomedclip_model():
13
  """Loads the BiomedCLIP model and tokenizer."""
14
  biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
 
16
  model = AutoModel.from_pretrained(biomedclip_model_name).cuda().eval()
17
  return model, processor
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def compute_similarity(image, text, biomedclip_model, biomedclip_processor):
20
  """Computes similarity scores using BiomedCLIP."""
21
  with torch.no_grad():
 
28
  similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
29
  return similarity
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  #---------------------------------
32
  #++++++++ Gradio ++++++++++
33
  #---------------------------------
34
 
35
+ def gradio_reset(chat_state, img_list, similarity_output):
 
 
 
 
 
 
36
  """Resets the chat state and image list."""
37
  if chat_state is not None:
38
  chat_state.messages = []
39
  if img_list is not None:
40
  img_list = []
41
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your medical image first', interactive=False), gr.update(value="Upload & Start Analysis", interactive=True), chat_state, img_list, gr.update(value="", visible=False)
42
 
43
+ def upload_img(gr_img, text_input, chat_state, similarity_output):
44
  """Handles image upload."""
45
  if gr_img is None:
46
+ return None, None, gr.update(interactive=True), chat_state, None, gr.update(visible=False)
47
  img_list = [gr_img]
48
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Analysis", interactive=False), chat_state, img_list, gr.update(visible=True)
 
49
 
50
  def gradio_ask(user_message, chatbot, chat_state):
51
  """Handles user input."""
 
55
  return '', chatbot, chat_state
56
 
57
  @spaces.GPU
58
+ def gradio_answer(chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output):
59
+ """Computes and displays similarity scores."""
60
  if not img_list:
61
+ return chatbot, chat_state, img_list, similarity_output
62
+
63
+ similarity_score = compute_similarity(img_list[0], chatbot[-1][0], biomedclip_model, biomedclip_processor)
 
64
  print(f'Similarity Score is: {similarity_score}')
65
+
66
+ similarity_text = f"Similarity Score: {similarity_score:.3f}"
67
+ chatbot[-1][1] = similarity_text
68
+ return chatbot, chat_state, img_list, gr.update(value=similarity_text, visible=True)
69
+
 
 
70
 
71
  title = """<h1 align="center">Medical Image Analysis Tool</h1>"""
72
+ description = """<h3>Upload medical images, ask questions, and receive a similarity score.</h3>"""
73
  examples_list=[
74
  ["./case1.png", "Analyze the X-ray for any abnormalities."],
75
  ["./case2.jpg", "What type of disease may be present?"],
 
77
  ]
78
 
79
  # Load models and related resources outside of the Gradio block for loading on startup
 
80
  biomedclip_model, biomedclip_processor = load_biomedclip_model()
81
 
 
82
  with gr.Blocks() as demo:
83
  gr.Markdown(title)
 
84
  gr.Markdown(description)
85
 
86
  with gr.Row():
 
89
  upload_button = gr.Button(value="Upload & Start Analysis", interactive=True, variant="primary")
90
  clear = gr.Button("Restart")
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  with gr.Column():
93
  chat_state = gr.State()
94
  img_list = gr.State()
95
  chatbot = gr.Chatbot(label='Medical Analysis')
96
  text_input = gr.Textbox(label='Analysis Query', placeholder='Please upload your medical image first', interactive=False)
97
+ similarity_output = gr.Textbox(label="Similarity Score", visible=False, interactive=False)
98
  gr.Examples(examples=examples_list, inputs=[image, text_input])
99
 
100
+ upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
101
 
102
  text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
103
+ gradio_answer, [chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
104
  )
105
+ clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
 
106
 
107
  demo.launch()