Spaces:
Running
on
Zero
Running
on
Zero
Update utils.py
Browse files
utils.py
CHANGED
@@ -5,11 +5,12 @@ from src.flux.condition import Condition
|
|
5 |
from PIL import Image
|
6 |
import argparse
|
7 |
import os
|
|
|
8 |
import json
|
9 |
import base64
|
10 |
import io
|
11 |
import re
|
12 |
-
from PIL import
|
13 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
14 |
from scipy.ndimage import binary_dilation
|
15 |
import cv2
|
@@ -27,6 +28,31 @@ except ImportError:
|
|
27 |
|
28 |
import re
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def encode_image_to_datauri(path, size=(512, 512)):
|
31 |
with Image.open(path).convert('RGB') as img:
|
32 |
img = img.resize(size, Image.LANCZOS)
|
@@ -34,8 +60,6 @@ def encode_image_to_datauri(path, size=(512, 512)):
|
|
34 |
img.save(buffer, format='PNG')
|
35 |
b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
36 |
return b64
|
37 |
-
# return f"data:image/png;base64,{b64}"
|
38 |
-
|
39 |
|
40 |
@retry(
|
41 |
reraise=True,
|
@@ -93,7 +117,6 @@ def cot_with_gpt(image_uri, instruction):
|
|
93 |
categories, instructions = extract_instructions(text)
|
94 |
return categories, instructions
|
95 |
|
96 |
-
|
97 |
def extract_instructions(text):
|
98 |
categories = []
|
99 |
instructions = []
|
@@ -134,9 +157,9 @@ def extract_last_bbox(result):
|
|
134 |
x0, y0, x1, y1 = map(int, last_match[1:])
|
135 |
return x0, y0, x1, y1
|
136 |
|
137 |
-
|
138 |
def infer_with_DiT(task, image, instruction, category):
|
139 |
-
|
140 |
|
141 |
if task == 'RoI Inpainting':
|
142 |
if category == 'Add' or category == 'Replace':
|
@@ -180,18 +203,14 @@ def infer_with_DiT(task, image, instruction, category):
|
|
180 |
condition = Condition("scene", image, position_delta=(0, -32))
|
181 |
else:
|
182 |
raise ValueError(f"Invalid task: '{task}'")
|
183 |
-
|
184 |
-
|
185 |
-
torch_dtype=torch.bfloat16
|
186 |
-
)
|
187 |
-
|
188 |
-
pipe = pipe.to("cuda")
|
189 |
-
|
190 |
pipe.load_lora_weights(
|
191 |
"Cicici1109/IEAP",
|
192 |
weight_name=lora_path,
|
193 |
adapter_name="scene",
|
194 |
)
|
|
|
195 |
result_img = generate(
|
196 |
pipe,
|
197 |
prompt=instruction_dit,
|
@@ -201,15 +220,13 @@ def infer_with_DiT(task, image, instruction, category):
|
|
201 |
height=512,
|
202 |
width=512,
|
203 |
).images[0]
|
204 |
-
|
205 |
if task == 'RoI Editing' and category == 'Action Change':
|
206 |
text_roi = extract_object_with_gpt(instruction)
|
207 |
instruction_loc = f"<image>Please segment {text_roi}."
|
208 |
-
# (model, tokenizer, image_path, instruction, work_dir, dilate):
|
209 |
img = result_img
|
210 |
-
# print(f"Instruction: {instruction_loc}")
|
211 |
|
212 |
-
model, tokenizer =
|
213 |
|
214 |
result = model.predict_forward(
|
215 |
image=img,
|
@@ -218,13 +235,11 @@ def infer_with_DiT(task, image, instruction, category):
|
|
218 |
)
|
219 |
|
220 |
prediction = result['prediction']
|
221 |
-
# print(f"Model Output: {prediction}")
|
222 |
|
223 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
224 |
pred_mask = result['prediction_masks'][0]
|
225 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
226 |
|
227 |
-
## obtain region bbox
|
228 |
rows = np.any(pred_mask_np, axis=1)
|
229 |
cols = np.any(pred_mask_np, axis=0)
|
230 |
if not np.any(rows) or not np.any(cols):
|
@@ -238,18 +253,10 @@ def infer_with_DiT(task, image, instruction, category):
|
|
238 |
|
239 |
return changed_instance, x0, y1, 1
|
240 |
|
241 |
-
|
242 |
return result_img
|
243 |
|
244 |
def load_model(model_path):
|
245 |
-
|
246 |
-
model_path,
|
247 |
-
torch_dtype="auto",
|
248 |
-
device_map="auto",
|
249 |
-
trust_remote_code=True
|
250 |
-
).eval()
|
251 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
252 |
-
return model, tokenizer
|
253 |
|
254 |
def extract_object_with_gpt(instruction):
|
255 |
system_prompt = (
|
@@ -304,7 +311,6 @@ def extract_region_with_gpt(instruction):
|
|
304 |
max_tokens=20,
|
305 |
)
|
306 |
object_phrase = response.choices[0].message['content'].strip().strip('"')
|
307 |
-
# print(f"Identified object: {object_phrase}")
|
308 |
return object_phrase
|
309 |
except Exception as e:
|
310 |
print(f"GPT extraction failed: {e}")
|
@@ -372,8 +378,9 @@ def crop_masked_region(image, pred_mask_np):
|
|
372 |
|
373 |
return Image.fromarray(cropped_image, mode='RGBA')
|
374 |
|
375 |
-
|
376 |
-
|
|
|
377 |
if category == 'Add':
|
378 |
text_roi = extract_region_with_gpt(instruction)
|
379 |
else:
|
@@ -389,13 +396,11 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
|
|
389 |
)
|
390 |
|
391 |
prediction = result['prediction']
|
392 |
-
# print(f"Model Output: {prediction}")
|
393 |
|
394 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
395 |
pred_mask = result['prediction_masks'][0]
|
396 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
397 |
if category == 'Add':
|
398 |
-
## obtain region bbox
|
399 |
rows = np.any(pred_mask_np, axis=1)
|
400 |
cols = np.any(pred_mask_np, axis=0)
|
401 |
if not np.any(rows) or not np.any(cols):
|
@@ -405,17 +410,14 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
|
|
405 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
406 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
407 |
|
408 |
-
|
409 |
-
bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
|
410 |
-
# print(bbox)
|
411 |
x0, y0, x1, y1 = layout_add(bbox, instruction)
|
412 |
mask = bbox_to_mask(x0, y0, x1, y1)
|
413 |
-
## make it black
|
414 |
masked_img = get_masked(mask, img)
|
415 |
elif category == 'Move' or category == 'Resize':
|
416 |
dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
|
417 |
masked_img = get_masked(dilated_original_mask, img)
|
418 |
-
|
419 |
rows = np.any(pred_mask_np, axis=1)
|
420 |
cols = np.any(pred_mask_np, axis=0)
|
421 |
if not np.any(rows) or not np.any(cols):
|
@@ -425,12 +427,10 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
|
|
425 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
426 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
427 |
|
428 |
-
|
429 |
-
bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
|
430 |
-
# print(bbox)
|
431 |
x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
|
432 |
scale = (y1_new - y0_new) / (y1 - y0)
|
433 |
-
|
434 |
changed_instance = crop_masked_region(img, pred_mask_np)
|
435 |
|
436 |
return masked_img, changed_instance, x0_new, y1_new, scale
|
@@ -588,4 +588,7 @@ def layout_change(bbox, instruction):
|
|
588 |
result = response.choices[0].message.content.strip()
|
589 |
|
590 |
bbox = extract_last_bbox(result)
|
591 |
-
return bbox
|
|
|
|
|
|
|
|
5 |
from PIL import Image
|
6 |
import argparse
|
7 |
import os
|
8 |
+
import spaces
|
9 |
import json
|
10 |
import base64
|
11 |
import io
|
12 |
import re
|
13 |
+
from PIL import ImageFilter
|
14 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
15 |
from scipy.ndimage import binary_dilation
|
16 |
import cv2
|
|
|
28 |
|
29 |
import re
|
30 |
|
31 |
+
pipe = None
|
32 |
+
model_dict = {}
|
33 |
+
|
34 |
+
def init_flux_pipeline():
|
35 |
+
global pipe
|
36 |
+
if pipe is None:
|
37 |
+
pipe = FluxPipeline.from_pretrained(
|
38 |
+
"black-forest-labs/FLUX.1-dev",
|
39 |
+
torch_dtype=torch.bfloat16
|
40 |
+
)
|
41 |
+
pipe = pipe.to("cuda")
|
42 |
+
|
43 |
+
def get_model(model_path):
|
44 |
+
global model_dict
|
45 |
+
if model_path not in model_dict:
|
46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
47 |
+
model_path,
|
48 |
+
torch_dtype="auto",
|
49 |
+
device_map="auto",
|
50 |
+
trust_remote_code=True
|
51 |
+
).eval()
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
53 |
+
model_dict[model_path] = (model, tokenizer)
|
54 |
+
return model_dict[model_path]
|
55 |
+
|
56 |
def encode_image_to_datauri(path, size=(512, 512)):
|
57 |
with Image.open(path).convert('RGB') as img:
|
58 |
img = img.resize(size, Image.LANCZOS)
|
|
|
60 |
img.save(buffer, format='PNG')
|
61 |
b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
62 |
return b64
|
|
|
|
|
63 |
|
64 |
@retry(
|
65 |
reraise=True,
|
|
|
117 |
categories, instructions = extract_instructions(text)
|
118 |
return categories, instructions
|
119 |
|
|
|
120 |
def extract_instructions(text):
|
121 |
categories = []
|
122 |
instructions = []
|
|
|
157 |
x0, y0, x1, y1 = map(int, last_match[1:])
|
158 |
return x0, y0, x1, y1
|
159 |
|
160 |
+
@spaces.GPU
|
161 |
def infer_with_DiT(task, image, instruction, category):
|
162 |
+
init_flux_pipeline()
|
163 |
|
164 |
if task == 'RoI Inpainting':
|
165 |
if category == 'Add' or category == 'Replace':
|
|
|
203 |
condition = Condition("scene", image, position_delta=(0, -32))
|
204 |
else:
|
205 |
raise ValueError(f"Invalid task: '{task}'")
|
206 |
+
|
207 |
+
pipe.unload_lora_weights()
|
|
|
|
|
|
|
|
|
|
|
208 |
pipe.load_lora_weights(
|
209 |
"Cicici1109/IEAP",
|
210 |
weight_name=lora_path,
|
211 |
adapter_name="scene",
|
212 |
)
|
213 |
+
|
214 |
result_img = generate(
|
215 |
pipe,
|
216 |
prompt=instruction_dit,
|
|
|
220 |
height=512,
|
221 |
width=512,
|
222 |
).images[0]
|
223 |
+
|
224 |
if task == 'RoI Editing' and category == 'Action Change':
|
225 |
text_roi = extract_object_with_gpt(instruction)
|
226 |
instruction_loc = f"<image>Please segment {text_roi}."
|
|
|
227 |
img = result_img
|
|
|
228 |
|
229 |
+
model, tokenizer = get_model("ByteDance/Sa2VA-8B")
|
230 |
|
231 |
result = model.predict_forward(
|
232 |
image=img,
|
|
|
235 |
)
|
236 |
|
237 |
prediction = result['prediction']
|
|
|
238 |
|
239 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
240 |
pred_mask = result['prediction_masks'][0]
|
241 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
242 |
|
|
|
243 |
rows = np.any(pred_mask_np, axis=1)
|
244 |
cols = np.any(pred_mask_np, axis=0)
|
245 |
if not np.any(rows) or not np.any(cols):
|
|
|
253 |
|
254 |
return changed_instance, x0, y1, 1
|
255 |
|
|
|
256 |
return result_img
|
257 |
|
258 |
def load_model(model_path):
|
259 |
+
return get_model(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
def extract_object_with_gpt(instruction):
|
262 |
system_prompt = (
|
|
|
311 |
max_tokens=20,
|
312 |
)
|
313 |
object_phrase = response.choices[0].message['content'].strip().strip('"')
|
|
|
314 |
return object_phrase
|
315 |
except Exception as e:
|
316 |
print(f"GPT extraction failed: {e}")
|
|
|
378 |
|
379 |
return Image.fromarray(cropped_image, mode='RGBA')
|
380 |
|
381 |
+
@spaces.GPU
|
382 |
+
def roi_localization(image, instruction, category):
|
383 |
+
model, tokenizer = get_model("ByteDance/Sa2VA-8B")
|
384 |
if category == 'Add':
|
385 |
text_roi = extract_region_with_gpt(instruction)
|
386 |
else:
|
|
|
396 |
)
|
397 |
|
398 |
prediction = result['prediction']
|
|
|
399 |
|
400 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
401 |
pred_mask = result['prediction_masks'][0]
|
402 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
403 |
if category == 'Add':
|
|
|
404 |
rows = np.any(pred_mask_np, axis=1)
|
405 |
cols = np.any(pred_mask_np, axis=0)
|
406 |
if not np.any(rows) or not np.any(cols):
|
|
|
410 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
411 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
412 |
|
413 |
+
bbox = combine_bbox(text_roi, x0, y0, x1, y1)
|
|
|
|
|
414 |
x0, y0, x1, y1 = layout_add(bbox, instruction)
|
415 |
mask = bbox_to_mask(x0, y0, x1, y1)
|
|
|
416 |
masked_img = get_masked(mask, img)
|
417 |
elif category == 'Move' or category == 'Resize':
|
418 |
dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
|
419 |
masked_img = get_masked(dilated_original_mask, img)
|
420 |
+
|
421 |
rows = np.any(pred_mask_np, axis=1)
|
422 |
cols = np.any(pred_mask_np, axis=0)
|
423 |
if not np.any(rows) or not np.any(cols):
|
|
|
427 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
428 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
429 |
|
430 |
+
bbox = combine_bbox(text_roi, x0, y0, x1, y1)
|
|
|
|
|
431 |
x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
|
432 |
scale = (y1_new - y0_new) / (y1 - y0)
|
433 |
+
|
434 |
changed_instance = crop_masked_region(img, pred_mask_np)
|
435 |
|
436 |
return masked_img, changed_instance, x0_new, y1_new, scale
|
|
|
588 |
result = response.choices[0].message.content.strip()
|
589 |
|
590 |
bbox = extract_last_bbox(result)
|
591 |
+
return bbox
|
592 |
+
|
593 |
+
if __name__ == "__main__":
|
594 |
+
init_flux_pipeline()
|