File size: 13,440 Bytes
6213d31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import os, sys, time
import math
import torch
import spaces
import numpy as np
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps

from my_utils.solvers import gurobi_solver


def get_next_size(curr_size, final_size, keep_ratio):
    """Calculate next size for progressive pruning during denoising.
    
    Args:
        curr_size: Current number of candidates
        final_size: Target final size
        keep_ratio: Fraction of candidates to keep at each step
    """
    if curr_size < final_size:
        raise ValueError("Current size is less than the final size!")
    elif curr_size == final_size:
        return curr_size
    else:
        next_size = math.ceil(curr_size * keep_ratio)
        return max(next_size, final_size)


@torch.no_grad()
def decode_latent(z, pipe, height, width):
    """Decode latent tensor to image using VAE decoder.
    
    Args:
        z: Latent tensor to decode
        pipe: Diffusion pipeline with VAE
        height: Image height
        width: Image width
    """
    z = pipe._unpack_latents(z, height, width, pipe.vae_scale_factor)
    z = (z / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    z = pipe.vae.decode(z, return_dict=False)[0].clamp(-1,1)
    return z


@torch.no_grad()
@spaces.GPU(duration=300)
def run_group_inference(pipe, model_name=None, prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, 
        true_cfg_scale=1.0, height=None, width=None, num_inference_steps=28, sigmas=None, guidance_scale=3.5, 
        l_generator=None, max_sequence_length=512,
        # group inference arguments
        unary_score_fn=None, binary_score_fn=None,
        starting_candidates=None, output_group_size=None, pruning_ratio=None, lambda_score=None,
        # control arguments
        control_image=None,
        # input image for flux-kontext
        input_image=None,
        skip_first_cfg=True
):
    """Run group inference with progressive pruning for diverse, high-quality image generation.
    
    Args:
        pipe: Diffusion pipeline
        model_name: Model type (flux-schnell, flux-dev, flux-depth, flux-canny, flux-kontext)
        prompt: Text prompt for generation
        unary_score_fn: Function to compute image quality scores
        binary_score_fn: Function to compute pairwise diversity scores  
        starting_candidates: Initial number of noise samples
        output_group_size: Final number of images to generate
        pruning_ratio: Fraction to prune at each denoising step
        lambda_score: Weight between quality and diversity terms
        control_image: Control image for depth/canny models
        input_image: Input image for flux-kontext editing
    """
    if l_generator is None:
        l_generator = [torch.Generator("cpu").manual_seed(42+_seed) for _seed in range(starting_candidates)]

    # use the default height and width if not provided
    height = height or pipe.default_sample_size * pipe.vae_scale_factor
    width = width or pipe.default_sample_size * pipe.vae_scale_factor

    pipe._guidance_scale = guidance_scale
    pipe._current_timestep = None
    pipe._interrupt = False
    pipe._joint_attention_kwargs = {}

    device = pipe._execution_device

    lora_scale = None
    has_neg_prompt = negative_prompt is not None
    do_true_cfg = true_cfg_scale > 1 and has_neg_prompt

    # 3. Encode prompts
    prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(prompt=prompt, prompt_2=prompt_2, prompt_embeds=None, pooled_prompt_embeds=None, device=device, max_sequence_length=max_sequence_length, lora_scale=lora_scale)
    
    if do_true_cfg:
        negative_prompt_embeds, negative_pooled_prompt_embeds, _ = pipe.encode_prompt(prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_embeds=None, pooled_prompt_embeds=None, device=device, max_sequence_length=max_sequence_length, lora_scale=lora_scale)

    # 4. Prepare latent variables
    if model_name in ["flux-depth", "flux-canny"]:
        # for control models, the pipe.transformer.config.in_channels is doubled
        num_channels_latents = pipe.transformer.config.in_channels // 8
    else:
        num_channels_latents = pipe.transformer.config.in_channels // 4
    
    # Handle different model types
    image_latents = None
    image_ids = None
    if model_name == "flux-kontext":
        processed_image = pipe.image_processor.preprocess(input_image, height=height, width=width)
        l_latents = []
        for _gen in l_generator:
            latents, img_latents, latent_ids, img_ids = pipe.prepare_latents(
                processed_image, 1, num_channels_latents, height, width, 
                prompt_embeds.dtype, device, _gen
            )
            l_latents.append(latents)
        # Use the image_latents and image_ids from the first generator
        _, image_latents, latent_image_ids, image_ids = pipe.prepare_latents(
            processed_image, 1, num_channels_latents, height, width, 
            prompt_embeds.dtype, device, l_generator[0]
        )
        # Combine latent_ids with image_ids
        if image_ids is not None:
            latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0)
    else:
        # For other models (flux-schnell, flux-dev, flux-depth, flux-canny)
        l_latents = [pipe.prepare_latents(1, num_channels_latents, height, width, prompt_embeds.dtype, device, _gen)[0] for _gen in l_generator]
        _, latent_image_ids = pipe.prepare_latents(1, num_channels_latents, height, width, prompt_embeds.dtype, device, l_generator[0])
    
    # 4.5. Prepare control image if provided
    control_latents = None
    if model_name in ["flux-depth", "flux-canny"]:
        control_image_processed = pipe.prepare_image(image=control_image, width=width, height=height, batch_size=1, num_images_per_prompt=1, device=device, dtype=pipe.vae.dtype,)
        if control_image_processed.ndim == 4:
            control_latents = pipe.vae.encode(control_image_processed).latents
            control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
            height_control_image, width_control_image = control_latents.shape[2:]
            control_latents = pipe._pack_latents(control_latents, 1, num_channels_latents, height_control_image, width_control_image)

    # 5. Prepare timesteps
    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
    image_seq_len = latent_image_ids.shape[0]
    mu = calculate_shift(image_seq_len, pipe.scheduler.config.get("base_image_seq_len", 256), pipe.scheduler.config.get("max_image_seq_len", 4096), pipe.scheduler.config.get("base_shift", 0.5), pipe.scheduler.config.get("max_shift", 1.15))
    timesteps, num_inference_steps = retrieve_timesteps(pipe.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
    num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
    pipe._num_timesteps = len(timesteps)
    _dtype = l_latents[0].dtype

    # handle guidance
    if pipe.transformer.config.guidance_embeds:
        guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(1)
    else:
        guidance = None
    guidance_1 = torch.full([1], 1.0, device=device, dtype=torch.float32).expand(1)

    # 6. Denoising loop
    with pipe.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if pipe.interrupt:
                continue
            if guidance is not None and skip_first_cfg and i == 0:
                curr_guidance = guidance_1
            else:
                curr_guidance = guidance

            pipe._current_timestep = t
            timestep = t.expand(1).to(_dtype)
            # ipdb.set_trace()
            next_latents = []
            x0_preds = []
            # do 1 denoising step
            for _latent in l_latents:
                # prepare input for transformer based on model type
                if model_name in ["flux-depth", "flux-canny"]:
                    # Control models: concatenate control latents along dim=2
                    latent_model_input = torch.cat([_latent, control_latents], dim=2)
                elif model_name == "flux-kontext":
                    # Kontext model: concatenate image latents along dim=1
                    latent_model_input = torch.cat([_latent, image_latents], dim=1)
                else:
                    # Standard models (flux-schnell, flux-dev): use latents as is
                    latent_model_input = _latent
                    
                noise_pred = pipe.transformer(hidden_states=latent_model_input, timestep=timestep / 1000, guidance=curr_guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=pipe.joint_attention_kwargs, return_dict=False)[0]
                
                # For flux-kontext, we need to slice the noise_pred to match the latents size
                if model_name == "flux-kontext":
                    noise_pred = noise_pred[:, : _latent.size(1)]
                
                if do_true_cfg:
                    neg_noise_pred = pipe.transformer(hidden_states=latent_model_input, timestep=timestep / 1000, guidance=curr_guidance, pooled_projections=negative_pooled_prompt_embeds, encoder_hidden_states=negative_prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=pipe.joint_attention_kwargs, return_dict=False)[0]
                    if model_name == "flux-kontext":
                        neg_noise_pred = neg_noise_pred[:, : _latent.size(1)]
                    noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
                # compute the previous noisy sample x_t -> x_t-1
                _latent = pipe.scheduler.step(noise_pred, t, _latent, return_dict=False)[0]
                # the scheduler is not state-less, it maintains a step index that is incremented by one after each step, 
                # so we need to decrease it here
                if hasattr(pipe.scheduler, "_step_index"):
                    pipe.scheduler._step_index -= 1
                
                if type(pipe.scheduler) == FlowMatchEulerDiscreteScheduler:
                    dt = 0.0 - pipe.scheduler.sigmas[i]
                    x0_pred = _latent + dt * noise_pred
                else:
                    raise NotImplementedError("Only Flow Scheduler is supported for now! For other schedulers, you need to manually implement the x0 prediction step.")
                
                x0_preds.append(x0_pred)
                next_latents.append(_latent)
            
            if hasattr(pipe.scheduler, "_step_index"):
                pipe.scheduler._step_index += 1
            
            # if the size of next_latents > output_group_size, prune the latents
            if len(next_latents) > output_group_size:
                next_size = get_next_size(len(next_latents), output_group_size, 1 - pruning_ratio)
                print(f"Pruning from {len(next_latents)} to {next_size}")
                # decode the latents to pixels with tiny-vae
                l_x0_decoded = [decode_latent(_latent, pipe, height, width) for _latent in x0_preds]
                # compute the unary and binary scores
                l_unary_scores = unary_score_fn(l_x0_decoded, target_caption=prompt)
                M_binary_scores = binary_score_fn(l_x0_decoded) # upper triangular matrix
                # run with Quadratic Integer Programming sover
                t_start = time.time()
                selected_indices = gurobi_solver(l_unary_scores, M_binary_scores, next_size, lam=lambda_score)
                t_end = time.time()
                print(f"Time taken for QIP: {t_end - t_start} seconds")
                l_latents = [next_latents[_i] for _i in selected_indices]
            else:
                l_latents = next_latents
            
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
                progress_bar.update()

    pipe._current_timestep = None

    l_images = [pipe._unpack_latents(_latent, height, width, pipe.vae_scale_factor) for _latent in l_latents]
    l_images = [(latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor for latents in l_images]
    l_images = [pipe.vae.decode(_image, return_dict=False)[0] for _image in l_images]
    l_images_tensor = [image.clamp(-1, 1) for image in l_images]  # Keep tensor version for scoring
    l_images = [pipe.image_processor.postprocess(image, output_type="pil")[0] for image in l_images]
    
    # Compute and print final scores
    print(f"\n=== Final Scores for {len(l_images)} generated images ===")
    
    # Compute unary scores
    final_unary_scores = unary_score_fn(l_images_tensor, target_caption=prompt)
    print(f"Unary scores (quality): {final_unary_scores}")
    print(f"Mean unary score: {np.mean(final_unary_scores):.4f}")
    
    # Compute binary scores
    final_binary_scores = binary_score_fn(l_images_tensor)
    print(f"Binary score matrix (diversity):")
    print(final_binary_scores)
    
    print("=" * 50)
    
    pipe.maybe_free_model_hooks()
    return l_images