import os import re import time from dataclasses import dataclass from glob import iglob from mmgp import offload as offload import torch from wan.utils.utils import calculate_new_dimensions from flux.sampling import denoise, get_schedule, prepare_kontext, unpack from flux.modules.layers import get_linear_split_map from flux.util import ( aspect_ratio_to_height_width, load_ae, load_clip, load_flow_model, load_t5, save_image, ) from PIL import Image def stitch_images(img1, img2): # Resize img2 to match img1's height width1, height1 = img1.size width2, height2 = img2.size new_width2 = int(width2 * height1 / height2) img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) stitched = Image.new('RGB', (width1 + new_width2, height1)) stitched.paste(img1, (0, 0)) stitched.paste(img2_resized, (width1, 0)) return stitched class model_factory: def __init__( self, checkpoint_dir, model_filename = None, model_type = None, model_def = None, base_model_type = None, text_encoder_filename = None, quantizeTransformer = False, save_quantized = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False ): self.device = torch.device(f"cuda") self.VAE_dtype = VAE_dtype self.dtype = dtype torch_device = "cpu" # model_filename = ["c:/temp/flux1-schnell.safetensors"] self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) self.clip = load_clip(torch_device) self.name = model_def.get("flux-model", "flux-dev") # self.name= "flux-dev-kontext" # self.name= "flux-dev" # self.name= "flux-schnell" source = model_def.get("source", None) self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) self.vae = load_ae(self.name, device=torch_device) # offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "flux-dev.safetensors") if not source is None: from wgp import save_model save_model(self.model, model_type, dtype, None) if save_quantized: from wgp import save_quantized_model save_quantized_model(self.model, model_type, model_filename[0], dtype, None) split_linear_modules_map = get_linear_split_map() self.model.split_linear_modules_map = split_linear_modules_map offload.split_linear_modules(self.model, split_linear_modules_map ) def generate( self, seed: int | None = None, input_prompt: str = "replace the logo with the text 'Black Forest Labs'", sampling_steps: int = 20, input_ref_images = None, width= 832, height=480, embedded_guidance_scale: float = 2.5, fit_into_canvas = None, callback = None, loras_slists = None, batch_size = 1, video_prompt_type = "", **bbargs ): if self._interrupt: return None device="cuda" if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0: if "K" in video_prompt_type and False : # image latents tiling method w, h = input_ref_images[0].size height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) else: # image stiching method stiched = input_ref_images[0] if "K" in video_prompt_type : w, h = input_ref_images[0].size height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] else: input_ref_images = None inp, height, width = prepare_kontext( t5=self.t5, clip=self.clip, prompt=input_prompt, ae=self.vae, img_cond_list=input_ref_images, target_width=width, target_height=height, bs=batch_size, seed=seed, device=device, ) timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) def unpack_latent(x): return unpack(x.float(), height, width) # denoise initial noise x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent) if x==None: return None # decode latents to pixel space x = unpack_latent(x) with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) x = x.clamp(-1, 1) x = x.transpose(0, 1) return x def query_model_def(model_type, model_def): flux_model = model_def.get("flux-model", "flux-dev") flux_schnell = flux_model == "flux-schnell" model_def_output = { "image_outputs" : True, } if flux_schnell: model_def_output["no_guidance"] = True return model_def_output