File size: 16,722 Bytes
12edc27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
import os
import sys
import base64
from io import BytesIO
from typing import Optional

from PIL import Image

import numpy as np
import torch

from segment_anything import SamPredictor, sam_model_registry
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from huggingface_hub import hf_hub_download

sys.path.append(os.getcwd())
import BEN2


## Ordinary function
def resize(image: Image.Image, 
            target_width: int, 
            target_height: int,
            interpolate: Image.Resampling = Image.Resampling.LANCZOS,
            return_type: str = "pil") -> Image.Image | np.ndarray:
    """
    Crops and resizes an image while preserving the aspect ratio.

    Args:
        image (Image.Image): Input PIL image to be cropped and resized.
        target_width (int): Target width of the output image.
        target_height (int): Target height of the output image.
        interpolate (Image.Resampling): The interpolation method.
        return_type (str): The type of the output image.

    Returns:
        Image.Image: Cropped and resized image.
    """
    # Original dimensions
    resized_image = image.resize((target_width, target_height), interpolate)
    if return_type == "pil":
        return resized_image
    elif return_type == "np":
        return np.array(resized_image)
    else:
        raise ValueError(f"Invalid return type: {return_type}")


def resize_long_edge(
    image: Image.Image, 
    long_edge_size: int, 
    interpolate: Image.Resampling = Image.Resampling.LANCZOS,
    return_type: str = "pil"
    ) -> np.ndarray | Image.Image:
    """
    Resize the long edge of the image to the long_edge_size.

    Args:
        image (Image.Image): The image to resize.
        long_edge_size (int): The size of the long edge.
        interpolate (Image.Resampling): The interpolation method.

    Returns:
        np.ndarray: The resized image.
    """
    w, h = image.size
    scale_ratio = long_edge_size / max(h, w)
    output_w = int(w * scale_ratio)
    output_h = int(h * scale_ratio)
    image = resize(image, target_width=int(output_w), target_height=int(output_h), interpolate=interpolate, return_type=return_type)
    return image


def ensure_divisible_by_value(
    image: Image.Image | np.ndarray, 
    value: int = 8, 
    interpolate: Image.Resampling = Image.Resampling.NEAREST,
    return_type: str = "np"
    ) -> np.ndarray | Image.Image:
    """
    Ensure the image dimensions are divisible by value.

    Args:
        image_pil (Image.Image): The image to ensure divisible by value.
        value (int): The value to ensure divisible by.
        interpolate (Image.Resampling): The interpolation method.
        return_type (str): The type of the output image.

    Returns:
        np.ndarray | Image.Image: The resized image.
    """

    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    w, h = image.size

    w = (w // value) * value
    h = (h // value) * value
    image = resize(image, w, h, interpolate=interpolate, return_type=return_type)
    return image


def resize_paired_image(
    image_reference: np.ndarray, 
    image_target: np.ndarray, 
    mask_target: np.ndarray,
    force_resize_long_edge: bool = False,
    return_type: str = "np"
    ) -> tuple[np.ndarray | Image.Image, np.ndarray | Image.Image, np.ndarray | Image.Image]:

    if isinstance(image_reference, np.ndarray):
        image_reference = Image.fromarray(image_reference)
    if isinstance(image_target, np.ndarray):
        image_target = Image.fromarray(image_target)
    if isinstance(mask_target, np.ndarray):
        mask_target = Image.fromarray(mask_target)

    if force_resize_long_edge:
        image_reference = resize_long_edge(image_reference, 1024, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
        image_target = resize_long_edge(image_target, 1024, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
        mask_target = resize_long_edge(mask_target, 1024, interpolate=Image.Resampling.NEAREST, return_type=return_type)

    if isinstance(image_reference, Image.Image):
        ref_width, ref_height = image_reference.size
        target_width, target_height = image_target.size
    else:
        ref_height, ref_width = image_reference.shape[:2]
        target_width, target_height = image_target.shape[:2]

    # resize the ref image to the same height as the target image and ensure the ratio remains the same
    if ref_height != target_height:
        scale_ratio = target_height / ref_height
        image_reference = resize(image_reference, int(ref_width * scale_ratio), target_height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)

    if return_type == "pil": 
        image_reference = Image.fromarray(image_reference) if isinstance(image_reference, np.ndarray) else image_reference
        image_target = Image.fromarray(image_target) if isinstance(image_target, np.ndarray) else image_target
        mask_target = Image.fromarray(mask_target) if isinstance(mask_target, np.ndarray) else mask_target
        return image_reference, image_target, mask_target
    else:
        image_reference = np.array(image_reference) if isinstance(image_reference, Image.Image) else image_reference
        image_target = np.array(image_target) if isinstance(image_target, Image.Image) else image_target
        mask_target = np.array(mask_target) if isinstance(mask_target, Image.Image) else mask_target
        return image_reference, image_target, mask_target


def prepare_input_images(
    img_ref: np.ndarray,
    custmization_mode: str,
    img_target: Optional[np.ndarray] = None,
    mask_target: Optional[np.ndarray] = None,
    width: Optional[int] = None,
    height: Optional[int] = None,
    force_resize_long_edge: bool = False,
    return_type: str = "np"
    ) -> tuple[np.ndarray | Image.Image, np.ndarray | Image.Image, np.ndarray | Image.Image]:


    if custmization_mode.lower() == "position-free":
        img_target = np.ones_like(img_ref) * 255
        mask_target = np.zeros_like(img_ref)

    if isinstance(width, int) and isinstance(height, int):
        img_ref = resize(Image.fromarray(img_ref), width, height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
        img_target = resize(Image.fromarray(img_target), width, height, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
        mask_target = resize(Image.fromarray(mask_target), width, height, interpolate=Image.Resampling.NEAREST, return_type=return_type)
    else:
        img_ref, img_target, mask_target = resize_paired_image(img_ref, img_target, mask_target, force_resize_long_edge, return_type=return_type)

    img_ref = ensure_divisible_by_value(img_ref, value=16, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
    img_target = ensure_divisible_by_value(img_target, value=16, interpolate=Image.Resampling.LANCZOS, return_type=return_type)
    mask_target = ensure_divisible_by_value(mask_target, value=16, interpolate=Image.Resampling.NEAREST, return_type=return_type)

    return img_ref, img_target, mask_target


def get_mask_type_ids(custmization_mode: str, input_mask_mode: str) -> int:
    if custmization_mode.lower() == "position-free":
        return torch.tensor([0])
    elif custmization_mode.lower() == "position-aware":
        if "precise" in input_mask_mode.lower():
            return torch.tensor([1])
        else:
            return torch.tensor([2])
    else: 
        raise ValueError(f"Invalid custmization mode: {custmization_mode}")


def scale_image(image_np, is_mask: bool = False):
    """
    Scale the image to the range of [-1, 1] if not a mask, otherwise scale to [0, 1].

    Args:
        image_np (np.ndarray): Input image.
        is_mask (bool): Whether the image is a mask.
    Returns:
        np.ndarray: Scaled image.
    """
    if is_mask:
        image_np = image_np / 255.0
    else:
        image_np = image_np / 255.0
        image_np = image_np * 2 - 1
    return image_np


def get_sam_predictor(sam_ckpt_path, device):
    """
    Get the SAM predictor.
    Args:
        sam_ckpt_path (str): The path to the SAM checkpoint.
        device (str): The device to load the model on.
    Returns:
        SamPredictor: The SAM predictor.
    """
    if not os.path.exists(sam_ckpt_path):
        sam_ckpt_path = hf_hub_download(repo_id="HCMUE-Research/SAM-vit-h", filename="sam_vit_h_4b8939.pth")

    sam = sam_model_registry['vit_h'](checkpoint=sam_ckpt_path).to(device)
    sam.eval()
    predictor = SamPredictor(sam)

    return predictor


def image_to_base64(img):
    """
    Convert an image to a base64 string.
    Args:
        img (PIL.Image.Image): The image to convert.
    Returns:
        str: The base64 string.
    """
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    img_bytes = buffered.getvalue()
    return base64.b64encode(img_bytes).decode('utf-8')


def get_vlm(vlm_ckpt_path, device, torch_dtype):
    """
    Get the VLM pipeline.
    Args:
        vlm_ckpt_path (str): The path to the VLM checkpoint.
        device (str): The device to load the model on.
        torch_dtype (torch.dtype): The data type of the model.
    Returns:
        tuple: The processor and model.
    """
    if not os.path.exists(vlm_ckpt_path):
        vlm_ckpt_path = "Qwen/Qwen2.5-VL-7B-Instruct"

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    vlm_ckpt_path, torch_dtype=torch_dtype).to(device)
    processor = AutoProcessor.from_pretrained(vlm_ckpt_path)


    return processor, model


def construct_vlm_gen_prompt(image_target, image_reference, target_mask, custmization_mode):
    """
    Construct the VLM generation prompt.
    Args:
        image_target (np.ndarray): The target image.
        image_reference (np.ndarray): The reference image.
        target_mask (np.ndarray): The target mask.
        custmization_mode (str): The customization mode.
    Returns:
        list: The messages.
    """
    if custmization_mode.lower() == "position-free":
        image_reference_pil = Image.fromarray(image_reference.astype(np.uint8))
        image_reference_base_64 = image_to_base64(image_reference_pil)
        messages = [
            {
                "role": "system", 
                "content": "I will input a reference image. Please identify the main subject/object in this image and generate a new description by placing this subject in a completely different scene or context. For example, if the reference image shows a rabbit sitting in a garden surrounded by green leaves and roses, you could generate a description like 'The rabbit is sitting on a rocky cliff overlooking a serene ocean, with the sun setting behind it, casting a warm glow over the scene'. Please directly output the new description without explaining your thought process. The description should not exceed 256 tokens."
            },
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image",
                        "image": f"data:image;base64,{image_reference_base_64}"
                    },
                ],
            }
        ]
        return messages
    else:
        image_reference_pil = Image.fromarray(image_reference.astype(np.uint8))
        image_reference_base_64 = image_to_base64(image_reference_pil)

        target_mask_binary = target_mask > 127.5
        masked_image_target = image_target * target_mask_binary
        masked_image_target_pil = Image.fromarray(masked_image_target.astype(np.uint8))
        masked_image_target_base_64 = image_to_base64(masked_image_target_pil)
        

        messages = [
            {
                "role": "system", 
                "content": "I will input a reference image and a target image with its main subject area masked (in black). Please directly describe the scene where the main subject/object from the reference image is placed into the masked area of the target image. Focus on describing the final combined scene, making sure to clearly describe both the object from the reference image and the background/environment from the target image. For example, if the reference shows a white cat with orange stripes on a beach and the target shows a masked area in a garden with blooming roses and tulips, directly describe 'A white cat with orange stripes sits elegantly among the vibrant red roses and yellow tulips in the lush garden, surrounded by green foliage.' The description should not exceed 256 tokens."
            },
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image",
                        "image": f"data:image;base64,{image_reference_base_64}"
                    },
                    {
                        "type": "image",
                        "image": f"data:image;base64,{masked_image_target_base_64}"
                    }
                ],
            }
        ]
        return messages


def construct_vlm_polish_prompt(prompt):
    """
    Construct the VLM polish prompt.
    Args:
        prompt (str): The prompt to polish.
    Returns:
        list: The messages.
    """
    messages = [
        {
            "role": "system",
            "content": "You are a helpful assistant that can polish the text prompt to make it more specific, detailed, and complete. Please directly output the polished prompt without explaining your thought process. The prompt should not exceed 256 tokens."
        },
        {
            "role": "user",
            "content": prompt
        }
    ]
    return messages


def run_vlm(vlm_processor, vlm_model, messages, device):
    """
    Run the VLM.
    Args:
        vlm_processor (torch.nn.Module): The VLM processor.
        vlm_model (torch.nn.Module): The VLM model.
        messages (list): The messages.
        device (str): The device to run the model on.
    Returns:
        str: The output text.
    """
    text = vlm_processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True)
    
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = vlm_processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)
    # Inference
    generated_ids = vlm_model.generate(**inputs, do_sample=True, num_beams=4, temperature=1.5, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = vlm_processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    return output_text


def get_ben2_model(ben2_model_path, device):
    """
    Get the BEN2 model.
    Args:
        ben2_model_path (str): The path to the BEN2 model.
        device (str): The device to load the model on.
    Returns:
        BEN2: The BEN2 model.
    """
    if not os.path.exists(ben2_model_path):
        ben2_model_path = hf_hub_download(repo_id="PramaLLC/BEN2", filename="BEN2_Base.pth")

    ben2_model = BEN2.BEN_Base().to(device)
    ben2_model.loadcheckpoints(model_path=ben2_model_path)
    return ben2_model


def make_dict_img_mask(img_path, mask_path):
    """
    Make a dictionary of the image and mask for gr.ImageEditor.
    Keep interface, not used in the gradio app.
    Args:
        img_path (str): The path to the image.
        mask_path (str): The path to the mask.
    Returns:
        dict: The dictionary of the image and mask.
    """
    from PIL import ImageOps
    background = Image.open(img_path).convert("RGBA")
    layers = [
        Image.merge("RGBA", (
            Image.new("L", Image.open(mask_path).size, 255),  # R channel
            Image.new("L", Image.open(mask_path).size, 255),  # G channel
            Image.new("L", Image.open(mask_path).size, 255),  # B channel
            ImageOps.invert(Image.open(mask_path).convert("L"))  # Inverted alpha channel
        ))
    ]
    # Combine layers with background by replacing the alpha channel
    background = np.array(background.convert("RGB"))
    _, _, _, layer_alpha = layers[0].split()
    layer_alpha = np.array(layer_alpha)[:,:,np.newaxis]
    composite = background * (1 - (layer_alpha > 0)) + np.ones_like(background) * (layer_alpha > 0) * 255

    
    composite = Image.fromarray(composite.astype("uint8")).convert("RGBA")
    return {
        'background': background,
        'layers': layers,
        'composite': composite
    }