File size: 4,375 Bytes
f08d17a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import argparse
from PIL import Image
import openai
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
from utils import encode_image_to_datauri, cot_with_gpt, extract_instructions, infer_with_DiT, roi_localization, fusion
from src.flux.generate import generate, seed_everything

def main():
    parser = argparse.ArgumentParser(description="Evaluate single image + instruction using GPT-4o")
    parser.add_argument("image_path", help="Path to input image")
    parser.add_argument("prompt", help="Original instruction")
    parser.add_argument("--seed", type=int, default=3407, help="Random seed for reproducibility")
    args = parser.parse_args()

    seed_everything(args.seed)

    openai.api_key = "YOUR_API_KEY"

    if not openai.api_key:
        raise ValueError("OPENAI_API_KEY environment variable not set.")
    
    os.makedirs("results", exist_ok=True)

    
    ###########################################
    ###         CoT -> instructions         ###
    ###########################################

    uri = encode_image_to_datauri(args.image_path)
    categories, instructions = cot_with_gpt(uri, args.prompt)
    print(categories)
    print(instructions)

    # categories = ['Move', 'Resize']
    # instructions = ['Move the woman to the right', 'Minify the woman']

    ###########################################
    ###      Neural Program Interpreter     ###
    ###########################################
    for i in range(len(categories)):
        if i == 0:
            image = args.image_path
        else:
            image = f"results/{i-1}.png"
        category = categories[i]
        instruction = instructions[i]
        if category in ('Add', 'Remove', 'Replace', 'Action Change', 'Move', 'Resize'):
            if category in ('Add', 'Remove', 'Replace'):
                if category == 'Add':
                    edited_image = infer_with_DiT('RoI Editing', image, instruction, category)
                else:
                    ### RoI Localization
                    mask_image = roi_localization(image, instruction, category)
                    # mask_image.save("mask.png")
                    ### RoI Inpainting
                    edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, category)
            elif category == 'Action Change':
                ### RoI Localization
                mask_image = roi_localization(image, instruction, category)
                ### RoI Inpainting
                edited_image = infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') # inpainted bg
                ### RoI Editing
                changed_instance, x0, y1, scale = infer_with_DiT('RoI Editing', image, instruction, category) # action change
                fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
                ### RoI Compositioning
                edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
            elif category in ('Move', 'Resize'):
                ### RoI Localization
                mask_image, changed_instance, x0, y1, scale  = roi_localization(image, instruction, category)
                ### RoI Inpainting
                edited_image= infer_with_DiT('RoI Inpainting', mask_image, instruction, 'Remove') # inpainted bg
                # changed_instance, bottom_left, scale = layout_change(image, instruction) # move/resize
                fusion_image = fusion(edited_image, changed_instance, x0, y1, scale)
                fusion_image.save("fusion.png")
                ### RoI Compositioning
                edited_image = infer_with_DiT('RoI Compositioning', fusion_image, instruction, None)
      
        elif category in ('Appearance Change', 'Background Change', 'Color Change', 'Material Change', 'Expression Change'):
            ### RoI Editing
            edited_image = infer_with_DiT('RoI Editing', image, instruction, category)

        elif category in ('Tone Transfer', 'Style Change'):
            ### Global Transformation
            edited_image = infer_with_DiT('Global Transformation', image, instruction, category)
        
        else:
            raise ValueError(f"Invalid category: '{category}'")
        
        image = edited_image
        image.save(f"results/{i}.png")


if __name__ == "__main__":
    main()