Cicici1109 commited on
Commit
b7130f0
·
verified ·
1 Parent(s): 5fb32e4

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +46 -43
utils.py CHANGED
@@ -5,11 +5,12 @@ from src.flux.condition import Condition
5
  from PIL import Image
6
  import argparse
7
  import os
 
8
  import json
9
  import base64
10
  import io
11
  import re
12
- from PIL import Image, ImageFilter
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
  from scipy.ndimage import binary_dilation
15
  import cv2
@@ -27,6 +28,31 @@ except ImportError:
27
 
28
  import re
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def encode_image_to_datauri(path, size=(512, 512)):
31
  with Image.open(path).convert('RGB') as img:
32
  img = img.resize(size, Image.LANCZOS)
@@ -34,8 +60,6 @@ def encode_image_to_datauri(path, size=(512, 512)):
34
  img.save(buffer, format='PNG')
35
  b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
36
  return b64
37
- # return f"data:image/png;base64,{b64}"
38
-
39
 
40
  @retry(
41
  reraise=True,
@@ -93,7 +117,6 @@ def cot_with_gpt(image_uri, instruction):
93
  categories, instructions = extract_instructions(text)
94
  return categories, instructions
95
 
96
-
97
  def extract_instructions(text):
98
  categories = []
99
  instructions = []
@@ -134,9 +157,9 @@ def extract_last_bbox(result):
134
  x0, y0, x1, y1 = map(int, last_match[1:])
135
  return x0, y0, x1, y1
136
 
137
-
138
  def infer_with_DiT(task, image, instruction, category):
139
- # seed_everything(3407)
140
 
141
  if task == 'RoI Inpainting':
142
  if category == 'Add' or category == 'Replace':
@@ -180,18 +203,14 @@ def infer_with_DiT(task, image, instruction, category):
180
  condition = Condition("scene", image, position_delta=(0, -32))
181
  else:
182
  raise ValueError(f"Invalid task: '{task}'")
183
- pipe = FluxPipeline.from_pretrained(
184
- "black-forest-labs/FLUX.1-dev",
185
- torch_dtype=torch.bfloat16
186
- )
187
-
188
- pipe = pipe.to("cuda")
189
-
190
  pipe.load_lora_weights(
191
  "Cicici1109/IEAP",
192
  weight_name=lora_path,
193
  adapter_name="scene",
194
  )
 
195
  result_img = generate(
196
  pipe,
197
  prompt=instruction_dit,
@@ -201,15 +220,13 @@ def infer_with_DiT(task, image, instruction, category):
201
  height=512,
202
  width=512,
203
  ).images[0]
204
- # result_img
205
  if task == 'RoI Editing' and category == 'Action Change':
206
  text_roi = extract_object_with_gpt(instruction)
207
  instruction_loc = f"<image>Please segment {text_roi}."
208
- # (model, tokenizer, image_path, instruction, work_dir, dilate):
209
  img = result_img
210
- # print(f"Instruction: {instruction_loc}")
211
 
212
- model, tokenizer = load_model("ByteDance/Sa2VA-8B")
213
 
214
  result = model.predict_forward(
215
  image=img,
@@ -218,13 +235,11 @@ def infer_with_DiT(task, image, instruction, category):
218
  )
219
 
220
  prediction = result['prediction']
221
- # print(f"Model Output: {prediction}")
222
 
223
  if '[SEG]' in prediction and 'prediction_masks' in result:
224
  pred_mask = result['prediction_masks'][0]
225
  pred_mask_np = np.squeeze(np.array(pred_mask))
226
 
227
- ## obtain region bbox
228
  rows = np.any(pred_mask_np, axis=1)
229
  cols = np.any(pred_mask_np, axis=0)
230
  if not np.any(rows) or not np.any(cols):
@@ -238,18 +253,10 @@ def infer_with_DiT(task, image, instruction, category):
238
 
239
  return changed_instance, x0, y1, 1
240
 
241
-
242
  return result_img
243
 
244
  def load_model(model_path):
245
- model = AutoModelForCausalLM.from_pretrained(
246
- model_path,
247
- torch_dtype="auto",
248
- device_map="auto",
249
- trust_remote_code=True
250
- ).eval()
251
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
252
- return model, tokenizer
253
 
254
  def extract_object_with_gpt(instruction):
255
  system_prompt = (
@@ -304,7 +311,6 @@ def extract_region_with_gpt(instruction):
304
  max_tokens=20,
305
  )
306
  object_phrase = response.choices[0].message['content'].strip().strip('"')
307
- # print(f"Identified object: {object_phrase}")
308
  return object_phrase
309
  except Exception as e:
310
  print(f"GPT extraction failed: {e}")
@@ -372,8 +378,9 @@ def crop_masked_region(image, pred_mask_np):
372
 
373
  return Image.fromarray(cropped_image, mode='RGBA')
374
 
375
- def roi_localization(image, instruction, category): # add, remove, replace, action change, move, resize
376
- model, tokenizer = load_model("ByteDance/Sa2VA-8B")
 
377
  if category == 'Add':
378
  text_roi = extract_region_with_gpt(instruction)
379
  else:
@@ -389,13 +396,11 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
389
  )
390
 
391
  prediction = result['prediction']
392
- # print(f"Model Output: {prediction}")
393
 
394
  if '[SEG]' in prediction and 'prediction_masks' in result:
395
  pred_mask = result['prediction_masks'][0]
396
  pred_mask_np = np.squeeze(np.array(pred_mask))
397
  if category == 'Add':
398
- ## obtain region bbox
399
  rows = np.any(pred_mask_np, axis=1)
400
  cols = np.any(pred_mask_np, axis=0)
401
  if not np.any(rows) or not np.any(cols):
@@ -405,17 +410,14 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
405
  y0, y1 = np.where(rows)[0][[0, -1]]
406
  x0, x1 = np.where(cols)[0][[0, -1]]
407
 
408
- ## obtain inpainting bbox
409
- bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
410
- # print(bbox)
411
  x0, y0, x1, y1 = layout_add(bbox, instruction)
412
  mask = bbox_to_mask(x0, y0, x1, y1)
413
- ## make it black
414
  masked_img = get_masked(mask, img)
415
  elif category == 'Move' or category == 'Resize':
416
  dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
417
  masked_img = get_masked(dilated_original_mask, img)
418
- ## obtain region bbox
419
  rows = np.any(pred_mask_np, axis=1)
420
  cols = np.any(pred_mask_np, axis=0)
421
  if not np.any(rows) or not np.any(cols):
@@ -425,12 +427,10 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
425
  y0, y1 = np.where(rows)[0][[0, -1]]
426
  x0, x1 = np.where(cols)[0][[0, -1]]
427
 
428
- ## obtain inpainting bbox
429
- bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
430
- # print(bbox)
431
  x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
432
  scale = (y1_new - y0_new) / (y1 - y0)
433
- # print(scale)
434
  changed_instance = crop_masked_region(img, pred_mask_np)
435
 
436
  return masked_img, changed_instance, x0_new, y1_new, scale
@@ -588,4 +588,7 @@ def layout_change(bbox, instruction):
588
  result = response.choices[0].message.content.strip()
589
 
590
  bbox = extract_last_bbox(result)
591
- return bbox
 
 
 
 
5
  from PIL import Image
6
  import argparse
7
  import os
8
+ import spaces
9
  import json
10
  import base64
11
  import io
12
  import re
13
+ from PIL import ImageFilter
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
  from scipy.ndimage import binary_dilation
16
  import cv2
 
28
 
29
  import re
30
 
31
+ pipe = None
32
+ model_dict = {}
33
+
34
+ def init_flux_pipeline():
35
+ global pipe
36
+ if pipe is None:
37
+ pipe = FluxPipeline.from_pretrained(
38
+ "black-forest-labs/FLUX.1-dev",
39
+ torch_dtype=torch.bfloat16
40
+ )
41
+ pipe = pipe.to("cuda")
42
+
43
+ def get_model(model_path):
44
+ global model_dict
45
+ if model_path not in model_dict:
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_path,
48
+ torch_dtype="auto",
49
+ device_map="auto",
50
+ trust_remote_code=True
51
+ ).eval()
52
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
53
+ model_dict[model_path] = (model, tokenizer)
54
+ return model_dict[model_path]
55
+
56
  def encode_image_to_datauri(path, size=(512, 512)):
57
  with Image.open(path).convert('RGB') as img:
58
  img = img.resize(size, Image.LANCZOS)
 
60
  img.save(buffer, format='PNG')
61
  b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
62
  return b64
 
 
63
 
64
  @retry(
65
  reraise=True,
 
117
  categories, instructions = extract_instructions(text)
118
  return categories, instructions
119
 
 
120
  def extract_instructions(text):
121
  categories = []
122
  instructions = []
 
157
  x0, y0, x1, y1 = map(int, last_match[1:])
158
  return x0, y0, x1, y1
159
 
160
+ @spaces.GPU
161
  def infer_with_DiT(task, image, instruction, category):
162
+ init_flux_pipeline()
163
 
164
  if task == 'RoI Inpainting':
165
  if category == 'Add' or category == 'Replace':
 
203
  condition = Condition("scene", image, position_delta=(0, -32))
204
  else:
205
  raise ValueError(f"Invalid task: '{task}'")
206
+
207
+ pipe.unload_lora_weights()
 
 
 
 
 
208
  pipe.load_lora_weights(
209
  "Cicici1109/IEAP",
210
  weight_name=lora_path,
211
  adapter_name="scene",
212
  )
213
+
214
  result_img = generate(
215
  pipe,
216
  prompt=instruction_dit,
 
220
  height=512,
221
  width=512,
222
  ).images[0]
223
+
224
  if task == 'RoI Editing' and category == 'Action Change':
225
  text_roi = extract_object_with_gpt(instruction)
226
  instruction_loc = f"<image>Please segment {text_roi}."
 
227
  img = result_img
 
228
 
229
+ model, tokenizer = get_model("ByteDance/Sa2VA-8B")
230
 
231
  result = model.predict_forward(
232
  image=img,
 
235
  )
236
 
237
  prediction = result['prediction']
 
238
 
239
  if '[SEG]' in prediction and 'prediction_masks' in result:
240
  pred_mask = result['prediction_masks'][0]
241
  pred_mask_np = np.squeeze(np.array(pred_mask))
242
 
 
243
  rows = np.any(pred_mask_np, axis=1)
244
  cols = np.any(pred_mask_np, axis=0)
245
  if not np.any(rows) or not np.any(cols):
 
253
 
254
  return changed_instance, x0, y1, 1
255
 
 
256
  return result_img
257
 
258
  def load_model(model_path):
259
+ return get_model(model_path)
 
 
 
 
 
 
 
260
 
261
  def extract_object_with_gpt(instruction):
262
  system_prompt = (
 
311
  max_tokens=20,
312
  )
313
  object_phrase = response.choices[0].message['content'].strip().strip('"')
 
314
  return object_phrase
315
  except Exception as e:
316
  print(f"GPT extraction failed: {e}")
 
378
 
379
  return Image.fromarray(cropped_image, mode='RGBA')
380
 
381
+ @spaces.GPU
382
+ def roi_localization(image, instruction, category):
383
+ model, tokenizer = get_model("ByteDance/Sa2VA-8B")
384
  if category == 'Add':
385
  text_roi = extract_region_with_gpt(instruction)
386
  else:
 
396
  )
397
 
398
  prediction = result['prediction']
 
399
 
400
  if '[SEG]' in prediction and 'prediction_masks' in result:
401
  pred_mask = result['prediction_masks'][0]
402
  pred_mask_np = np.squeeze(np.array(pred_mask))
403
  if category == 'Add':
 
404
  rows = np.any(pred_mask_np, axis=1)
405
  cols = np.any(pred_mask_np, axis=0)
406
  if not np.any(rows) or not np.any(cols):
 
410
  y0, y1 = np.where(rows)[0][[0, -1]]
411
  x0, x1 = np.where(cols)[0][[0, -1]]
412
 
413
+ bbox = combine_bbox(text_roi, x0, y0, x1, y1)
 
 
414
  x0, y0, x1, y1 = layout_add(bbox, instruction)
415
  mask = bbox_to_mask(x0, y0, x1, y1)
 
416
  masked_img = get_masked(mask, img)
417
  elif category == 'Move' or category == 'Resize':
418
  dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
419
  masked_img = get_masked(dilated_original_mask, img)
420
+
421
  rows = np.any(pred_mask_np, axis=1)
422
  cols = np.any(pred_mask_np, axis=0)
423
  if not np.any(rows) or not np.any(cols):
 
427
  y0, y1 = np.where(rows)[0][[0, -1]]
428
  x0, x1 = np.where(cols)[0][[0, -1]]
429
 
430
+ bbox = combine_bbox(text_roi, x0, y0, x1, y1)
 
 
431
  x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
432
  scale = (y1_new - y0_new) / (y1 - y0)
433
+
434
  changed_instance = crop_masked_region(img, pred_mask_np)
435
 
436
  return masked_img, changed_instance, x0_new, y1_new, scale
 
588
  result = response.choices[0].message.content.strip()
589
 
590
  bbox = extract_last_bbox(result)
591
+ return bbox
592
+
593
+ if __name__ == "__main__":
594
+ init_flux_pipeline()