|
import argparse |
|
import copy |
|
import math |
|
import random |
|
from typing import Any |
|
import pdb |
|
import os |
|
|
|
import time |
|
from PIL import Image, ImageOps |
|
|
|
import torch |
|
from accelerate import Accelerator |
|
from library.device_utils import clean_memory_on_device |
|
from safetensors.torch import load_file |
|
from networks import lora_flux |
|
|
|
from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, \ |
|
strategy_base, strategy_flux, train_util |
|
from torchvision import transforms |
|
import train_network |
|
from library.utils import setup_logging |
|
from diffusers.utils import load_image |
|
import numpy as np |
|
|
|
setup_logging() |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_target_model( |
|
fp8_base: bool, |
|
pretrained_model_name_or_path: str, |
|
disable_mmap_load_safetensors: bool, |
|
clip_l_path: str, |
|
fp8_base_unet: bool, |
|
t5xxl_path: str, |
|
ae_path: str, |
|
weight_dtype: torch.dtype, |
|
accelerator: Accelerator |
|
): |
|
|
|
loading_dtype = None if fp8_base else weight_dtype |
|
|
|
|
|
_, model = flux_utils.load_flow_model( |
|
pretrained_model_name_or_path, |
|
|
|
torch.float8_e4m3fn, |
|
|
|
"cpu", |
|
disable_mmap=disable_mmap_load_safetensors |
|
) |
|
|
|
if fp8_base: |
|
|
|
if model.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: |
|
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") |
|
elif model.dtype == torch.float8_e4m3fn: |
|
logger.info("Loaded fp8 FLUX model") |
|
|
|
|
|
clip_l = flux_utils.load_clip_l( |
|
clip_l_path, |
|
weight_dtype, |
|
|
|
"cpu", |
|
disable_mmap=disable_mmap_load_safetensors |
|
) |
|
clip_l.eval() |
|
|
|
|
|
if fp8_base and not fp8_base_unet: |
|
loading_dtype_t5xxl = None |
|
else: |
|
loading_dtype_t5xxl = weight_dtype |
|
|
|
|
|
t5xxl = flux_utils.load_t5xxl( |
|
t5xxl_path, |
|
loading_dtype_t5xxl, |
|
|
|
"cpu", |
|
disable_mmap=disable_mmap_load_safetensors |
|
) |
|
t5xxl.eval() |
|
|
|
if fp8_base and not fp8_base_unet: |
|
|
|
if t5xxl.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: |
|
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") |
|
elif t5xxl.dtype == torch.float8_e4m3fn: |
|
logger.info("Loaded fp8 T5XXL model") |
|
|
|
|
|
ae = flux_utils.load_ae( |
|
ae_path, |
|
weight_dtype, |
|
|
|
"cpu", |
|
disable_mmap=disable_mmap_load_safetensors |
|
) |
|
|
|
|
|
|
|
|
|
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model |
|
|
|
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
class ResizeWithPadding: |
|
def __init__(self, size, fill=255): |
|
self.size = size |
|
self.fill = fill |
|
|
|
def __call__(self, img): |
|
if isinstance(img, np.ndarray): |
|
img = Image.fromarray(img) |
|
elif not isinstance(img, Image.Image): |
|
raise TypeError("Input must be a PIL Image or a NumPy array") |
|
|
|
width, height = img.size |
|
|
|
if width == height: |
|
img = img.resize((self.size, self.size), Image.LANCZOS) |
|
else: |
|
max_dim = max(width, height) |
|
|
|
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill)) |
|
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2)) |
|
|
|
img = new_img.resize((self.size, self.size), Image.LANCZOS) |
|
|
|
return img |
|
|
|
|
|
def sample(args, accelerator, vae, text_encoder, flux, output_dir, sample_images, sample_prompts): |
|
def encode_images_to_latents(vae, images): |
|
|
|
b, c, h, w = images.shape |
|
num_split = 2 if args.frame_num == 4 else 3 |
|
|
|
img_parts = [images[:, :, :, i * w // num_split:(i + 1) * w // num_split] for i in range(num_split)] |
|
|
|
latents = [vae.encode(img) for img in img_parts] |
|
|
|
latents = torch.cat(latents, dim=-1) |
|
return latents |
|
|
|
def encode_images_to_latents2(vae, images): |
|
latents = vae.encode(images) |
|
return latents |
|
|
|
|
|
conditions = {} |
|
with torch.no_grad(): |
|
for image_path, prompt_dict in zip(sample_images, sample_prompts): |
|
prompt = prompt_dict.get("prompt", "") |
|
if prompt not in conditions: |
|
logger.info(f"Cache conditions for image: {image_path} with prompt: {prompt}") |
|
resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255) |
|
img_transforms = transforms.Compose([ |
|
resize_transform, |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
]) |
|
|
|
image = img_transforms(np.array(load_image(image_path), dtype=np.uint8)).unsqueeze(0).to( |
|
|
|
vae.device, |
|
dtype=vae.dtype |
|
) |
|
latents = encode_images_to_latents2(vae, image) |
|
|
|
|
|
logger.debug(f"Encoded latents shape for prompt '{prompt}': {latents.shape}") |
|
|
|
|
|
conditions[prompt] = latents.to("cpu") |
|
|
|
sample_conditions = conditions |
|
|
|
if sample_conditions is not None: |
|
conditions = {k: v for k, v in sample_conditions.items()} |
|
|
|
sample_prompts_te_outputs = {} |
|
text_encoder[0].to(accelerator.device) |
|
text_encoder[1].to(accelerator.device) |
|
|
|
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) |
|
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True) |
|
|
|
with accelerator.autocast(), torch.no_grad(): |
|
for prompt_dict in sample_prompts: |
|
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: |
|
if p not in sample_prompts_te_outputs: |
|
logger.info(f"Cache Text Encoder outputs for prompt: {p}") |
|
tokens_and_masks = tokenize_strategy.tokenize(p) |
|
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( |
|
tokenize_strategy, text_encoder, tokens_and_masks, True |
|
) |
|
|
|
logger.info(f"Generating image") |
|
save_dir = output_dir |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
with torch.no_grad(), accelerator.autocast(): |
|
for prompt_dict in sample_prompts: |
|
sample_image_inference( |
|
args, |
|
accelerator, |
|
flux, |
|
text_encoder, |
|
vae, |
|
save_dir, |
|
prompt_dict, |
|
sample_prompts_te_outputs, |
|
None, |
|
conditions |
|
) |
|
|
|
clean_memory_on_device(accelerator.device) |
|
|
|
|
|
def sample_image_inference( |
|
args, |
|
accelerator: Accelerator, |
|
flux: flux_models.Flux, |
|
text_encoder, |
|
ae: flux_models.AutoEncoder, |
|
save_dir, |
|
prompt_dict, |
|
sample_prompts_te_outputs, |
|
prompt_replacement, |
|
sample_images_ae_outputs |
|
): |
|
|
|
sample_steps = prompt_dict.get("sample_steps", 20) |
|
width = prompt_dict.get("width", 1024) if args.frame_num == 4 else prompt_dict.get("width", 1056) |
|
height = prompt_dict.get("height", 1024) if args.frame_num == 4 else prompt_dict.get("height", 1056) |
|
scale = prompt_dict.get("scale", 1.0) |
|
seed = prompt_dict.get("seed") |
|
prompt: str = prompt_dict.get("prompt", "") |
|
|
|
if prompt_replacement is not None: |
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
else: |
|
|
|
torch.seed() |
|
torch.cuda.seed() |
|
|
|
|
|
height = max(64, height - height % 16) |
|
width = max(64, width - width % 16) |
|
logger.info(f"prompt: {prompt}") |
|
logger.info(f"height: {height}") |
|
logger.info(f"width: {width}") |
|
logger.info(f"sample_steps: {sample_steps}") |
|
logger.info(f"scale: {scale}") |
|
if seed is not None: |
|
logger.info(f"seed: {seed}") |
|
|
|
|
|
|
|
text_encoder_conds = [] |
|
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: |
|
text_encoder_conds = sample_prompts_te_outputs[prompt] |
|
logger.info(f"Using cached text encoder outputs for prompt: {prompt}") |
|
|
|
if sample_images_ae_outputs and prompt in sample_images_ae_outputs: |
|
ae_outputs = sample_images_ae_outputs[prompt] |
|
else: |
|
ae_outputs = None |
|
|
|
|
|
|
|
|
|
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds |
|
|
|
|
|
logger.debug( |
|
f"l_pooled shape: {l_pooled.shape}, t5_out shape: {t5_out.shape}, txt_ids shape: {txt_ids.shape}, t5_attn_mask shape: {t5_attn_mask.shape}") |
|
|
|
|
|
weight_dtype = ae.dtype |
|
packed_latent_height = height // 16 |
|
packed_latent_width = width // 16 |
|
|
|
|
|
logger.debug(f"packed_latent_height: {packed_latent_height}, packed_latent_width: {packed_latent_width}") |
|
|
|
|
|
noise = torch.randn( |
|
1, |
|
packed_latent_height * packed_latent_width, |
|
16 * 2 * 2, |
|
device=accelerator.device, |
|
dtype=weight_dtype, |
|
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, |
|
) |
|
|
|
timesteps = flux_train_utils.get_schedule(sample_steps, noise.shape[1], shift=True) |
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to( |
|
accelerator.device, dtype=weight_dtype |
|
) |
|
t5_attn_mask = t5_attn_mask.to(accelerator.device) |
|
|
|
clip_l, t5xxl = text_encoder |
|
|
|
clip_l.to("cpu") |
|
t5xxl.to("cpu") |
|
|
|
clean_memory_on_device(accelerator.device) |
|
flux.to("cuda") |
|
|
|
for param in flux.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
with accelerator.autocast(), torch.no_grad(): |
|
x = flux_train_utils.denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, |
|
guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs) |
|
|
|
|
|
logger.debug(f"x shape after denoise: {x.shape}") |
|
|
|
x = x.float() |
|
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) |
|
|
|
|
|
|
|
ae.to(accelerator.device) |
|
with accelerator.autocast(), torch.no_grad(): |
|
x = ae.decode(x) |
|
ae.to("cpu") |
|
clean_memory_on_device(accelerator.device) |
|
|
|
x = x.clamp(-1, 1) |
|
x = x.permute(0, 2, 3, 1) |
|
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) |
|
|
|
|
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) |
|
seed_suffix = "" if seed is None else f"_{seed}" |
|
i: int = prompt_dict.get("enum", 0) |
|
img_filename = f"{ts_str}{seed_suffix}_{i}.png" |
|
image.save(os.path.join(save_dir, img_filename)) |
|
|
|
|
|
def setup_argparse(): |
|
parser = argparse.ArgumentParser(description="FLUX-Controlnet-Inpainting Inference Script") |
|
|
|
|
|
parser.add_argument('--base_flux_checkpoint', type=str, required=True, |
|
help='Path to BASE_FLUX_CHECKPOINT') |
|
parser.add_argument('--lora_weights_path', type=str, required=True, |
|
help='Path to LORA_WEIGHTS_PATH') |
|
parser.add_argument('--clip_l_path', type=str, required=True, |
|
help='Path to CLIP_L_PATH') |
|
parser.add_argument('--t5xxl_path', type=str, required=True, |
|
help='Path to T5XXL_PATH') |
|
parser.add_argument('--ae_path', type=str, required=True, |
|
help='Path to AE_PATH') |
|
parser.add_argument('--sample_images_file', type=str, required=True, |
|
help='Path to SAMPLE_IMAGES_FILE') |
|
parser.add_argument('--sample_prompts_file', type=str, required=True, |
|
help='Path to SAMPLE_PROMPTS_FILE') |
|
parser.add_argument('--output_dir', type=str, required=True, |
|
help='Directory to save OUTPUT_DIR') |
|
parser.add_argument('--frame_num', type=int, choices=[4, 9], required=True, |
|
help="The number of steps in the generated step diagram (choose 4 or 9)") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(args): |
|
accelerator = Accelerator(mixed_precision='bf16', device_placement=True) |
|
|
|
BASE_FLUX_CHECKPOINT = args.base_flux_checkpoint |
|
LORA_WEIGHTS_PATH = args.lora_weights_path |
|
CLIP_L_PATH = args.clip_l_path |
|
T5XXL_PATH = args.t5xxl_path |
|
AE_PATH = args.ae_path |
|
|
|
SAMPLE_IMAGES_FILE = args.sample_images_file |
|
SAMPLE_PROMPTS_FILE = args.sample_prompts_file |
|
OUTPUT_DIR = args.output_dir |
|
|
|
with open(SAMPLE_IMAGES_FILE, "r", encoding="utf-8") as f: |
|
image_lines = f.readlines() |
|
sample_images = [line.strip() for line in image_lines if line.strip() and not line.strip().startswith("#")] |
|
|
|
sample_prompts = train_util.load_prompts(SAMPLE_PROMPTS_FILE) |
|
|
|
|
|
_, [clip_l, t5xxl], ae, model = load_target_model( |
|
fp8_base=True, |
|
pretrained_model_name_or_path=BASE_FLUX_CHECKPOINT, |
|
disable_mmap_load_safetensors=False, |
|
clip_l_path=CLIP_L_PATH, |
|
fp8_base_unet=False, |
|
t5xxl_path=T5XXL_PATH, |
|
ae_path=AE_PATH, |
|
weight_dtype=torch.bfloat16, |
|
accelerator=accelerator |
|
) |
|
|
|
model.eval() |
|
clip_l.eval() |
|
t5xxl.eval() |
|
ae.eval() |
|
|
|
|
|
multiplier = 1.0 |
|
weights_sd = load_file(LORA_WEIGHTS_PATH) |
|
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, |
|
True) |
|
|
|
lora_model.apply_to([clip_l, t5xxl], model) |
|
info = lora_model.load_state_dict(weights_sd, strict=True) |
|
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}") |
|
lora_model.eval() |
|
lora_model.to("cuda") |
|
|
|
|
|
text_encoder = [clip_l, t5xxl] |
|
|
|
sample(args, accelerator, vae=ae, text_encoder=text_encoder, flux=model, output_dir=OUTPUT_DIR, |
|
sample_images=sample_images, sample_prompts=sample_prompts) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = setup_argparse() |
|
|
|
main(args) |
|
|