File size: 5,676 Bytes
30f8a30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
from mmgp import offload as offload
import torch
from wan.utils.utils import calculate_new_dimensions
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
from flux.modules.layers import get_linear_split_map
from flux.util import (
    aspect_ratio_to_height_width,
    load_ae,
    load_clip,
    load_flow_model,
    load_t5,
    save_image,
)

from PIL import Image

def stitch_images(img1, img2):
    # Resize img2 to match img1's height
    width1, height1 = img1.size
    width2, height2 = img2.size
    new_width2 = int(width2 * height1 / height2)
    img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
    
    stitched = Image.new('RGB', (width1 + new_width2, height1))
    stitched.paste(img1, (0, 0))
    stitched.paste(img2_resized, (width1, 0))
    return stitched

class model_factory:
    def __init__(
        self,
        checkpoint_dir,
        model_filename = None,
        model_type = None, 
        model_def = None,
        base_model_type = None,
        text_encoder_filename = None,
        quantizeTransformer = False,
        save_quantized = False,
        dtype = torch.bfloat16,
        VAE_dtype = torch.float32,
        mixed_precision_transformer = False
    ):
        self.device = torch.device(f"cuda")
        self.VAE_dtype = VAE_dtype
        self.dtype = dtype
        torch_device = "cpu"
        # model_filename = ["c:/temp/flux1-schnell.safetensors"] 
        
        self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
        self.clip = load_clip(torch_device)
        self.name = model_def.get("flux-model", "flux-dev")
        # self.name= "flux-dev-kontext"
        # self.name= "flux-dev"
        # self.name= "flux-schnell"
        source =  model_def.get("source", None)
        self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)

        self.vae = load_ae(self.name, device=torch_device)

        # offload.change_dtype(self.model, dtype, True)
        # offload.save_model(self.model, "flux-dev.safetensors")

        if not source is None:
            from wgp import save_model
            save_model(self.model, model_type, dtype, None)

        if save_quantized:
            from wgp import save_quantized_model
            save_quantized_model(self.model, model_type, model_filename[0], dtype, None)

        split_linear_modules_map = get_linear_split_map()
        self.model.split_linear_modules_map = split_linear_modules_map
        offload.split_linear_modules(self.model, split_linear_modules_map )

    
    def generate(
            self,
            seed: int | None = None,
            input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
            sampling_steps: int = 20,
            input_ref_images = None,
            width= 832,
            height=480,
            embedded_guidance_scale: float = 2.5,
            fit_into_canvas = None,
            callback = None,
            loras_slists = None,
            batch_size = 1,
            video_prompt_type = "",
            **bbargs
    ):
            
            if self._interrupt:
                return None

            device="cuda"
            if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0: 
                if "K" in video_prompt_type and False :
                    # image latents tiling method
                    w, h = input_ref_images[0].size
                    height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
                else:
                    # image stiching method
                    stiched = input_ref_images[0]
                    if "K" in video_prompt_type :
                        w, h = input_ref_images[0].size
                        height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)

                    for new_img in input_ref_images[1:]:
                        stiched = stitch_images(stiched, new_img)
                    input_ref_images  = [stiched]
            else:
                input_ref_images = None

            inp, height, width = prepare_kontext(
                t5=self.t5,
                clip=self.clip,
                prompt=input_prompt,
                ae=self.vae,
                img_cond_list=input_ref_images,
                target_width=width,
                target_height=height,
                bs=batch_size,
                seed=seed,
                device=device,
            )

            timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
            def unpack_latent(x):
                return unpack(x.float(), height, width) 
            # denoise initial noise
            x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent)
            if x==None: return None
            # decode latents to pixel space
            x = unpack_latent(x)
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                x = self.vae.decode(x)

            x = x.clamp(-1, 1)
            x = x.transpose(0, 1)
            return x

def query_model_def(model_type, model_def):
    flux_model = model_def.get("flux-model", "flux-dev")
    flux_schnell = flux_model == "flux-schnell" 
    model_def_output = {
        "image_outputs" : True,
    }
    if flux_schnell:
        model_def_output["no_guidance"] = True

    return model_def_output