import os from dataclasses import dataclass from typing import Union, Optional import torch from huggingface_hub import hf_hub_download from accelerate.logging import get_logger from accelerate import state from safetensors import safe_open from safetensors.torch import load_file as load_sft from safetensors.torch import save_file as save_sft from ..models.model import Flux, FluxParams, IC_Custom from ..modules.autoencoder import AutoEncoder, AutoEncoderParams from ..modules.conditioner import HFEmbedder from ..modules.image_embedders import ReduxImageEncoder from .process_util import print_load_warning # Initialize logger with a try-except to handle cases where accelerate state isn't initialized if state.is_initialized(): logger = get_logger(__name__, log_level="INFO") else: import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # ------------------------------------------------------------------------- # 1) model definition # ------------------------------------------------------------------------- DIT_PARAMS = FluxParams( in_channels=384, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ) AE_PARAMS = AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ) @dataclass class HFModelSpec: repo_id: str filename: Optional[str] = None ckpt_path: Optional[str] = None configs = { "flux-fill-dev-dit": HFModelSpec( repo_id="black-forest-labs/FLUX.1-Fill-dev", filename="flux1-fill-dev.safetensors", ckpt_path=os.getenv("FLUX_DEV_FILL"), ), "flux-fill-dev-ae": HFModelSpec( repo_id="black-forest-labs/FLUX.1-Fill-dev", filename="ae.safetensors", ckpt_path=os.getenv("AE"), ), "t5-v1_1-xxl": HFModelSpec( repo_id="DeepFloyd/t5-v1_1-xxl", ckpt_path=os.getenv("T5_XXL"), ), "clip-vit-large-patch14": HFModelSpec( repo_id="openai/clip-vit-large-patch14", ckpt_path=os.getenv("CLIP_VIT_LARGE_PATCH14"), ), "siglip-so400m-patch14-384": HFModelSpec( repo_id="google/siglip-so400m-patch14-384", ckpt_path=os.getenv("SIGLIP_SO400M_PATCH14_384"), ), "flux1-redux-dev": HFModelSpec( repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", ckpt_path=os.getenv("FLUX1_REDUX_DEV"), ), "dit_lora_0x1561": HFModelSpec( repo_id="TencentARC/IC-Custom", filename="dit_lora_0x1561.safetensors", ckpt_path=os.getenv("DIT_LORA"), ), "dit_txt_img_in_0x1561": HFModelSpec( repo_id="TencentARC/IC-Custom", filename="dit_txt_img_in_0x1561.safetensors", ckpt_path=os.getenv("DIT_TXT_IMG_IN"), ), "dit_boundary_embeddings_0x1561": HFModelSpec( repo_id="TencentARC/IC-Custom", filename="dit_boundary_embeddings_0x1561.safetensors", ckpt_path=os.getenv("DIT_BOUNDARY_EMBEDDINGS"), ), "dit_task_register_embeddings_0x1561": HFModelSpec( repo_id="TencentARC/IC-Custom", filename="dit_task_register_embeddings_0x1561.safetensors", ckpt_path=os.getenv("DIT_TASK_REGISTER_EMBEDDINGS"), ) } # ------------------------------------------------------------------------- # 2) load model func # ------------------------------------------------------------------------- def resolve_model_path( name: str, repo_id_field: str = "repo_id", filename_field: str = "filename", ckpt_path_field: str = "ckpt_path", hf_download: bool = True, ) -> str: """ Resolve a model path from name, handling local paths, config paths, and HF downloads. Args: name: Model name or path repo_id_field: Field name in configs for repo_id filename_field: Field name in configs for filename (if download needed) ckpt_path_field: Field name in configs for checkpoint path hf_download: Whether to download from HF if not found locally replace_suffix: Whether to replace suffix in filename suffix_map: Mapping of suffixes to replace Explanation: 1) Resolve from CLI 2) Resolve from ENV 3) Resolve from online HF Returns: Resolved path to the model """ # If it's a direct path, return it if os.path.exists(name): return name # Try to get from configs if name in configs: # Get local path from configs path = getattr(configs[name], ckpt_path_field) # If local path exists, use it if path is not None and os.path.exists(path): return path # If download is allowed and we have repo info if (hf_download and hasattr(configs[name], repo_id_field) and getattr(configs[name], repo_id_field) is not None): # If we need a specific file (not just the repo) if filename_field and hasattr(configs[name], filename_field): filename = getattr(configs[name], filename_field) # Download the file logger.info(f"Downloading {getattr(configs[name], repo_id_field)}/{filename}") return hf_hub_download( getattr(configs[name], repo_id_field), filename, ) # If we just need the repo ID return getattr(configs[name], repo_id_field) # If all else fails, assume name is the path/repo_id return name def load_dit( name: str, device: Union[str, torch.device] = "cuda", dtype: torch.dtype = torch.bfloat16, ): """ Load a Flux model. Args: name: Model name or path hf_download: Whether to download from HF if not found locally device: Device to load model on dtype: Data type for model Returns: model: Loaded Flux model """ # Loading Flux if not os.path.exists(name): name = "flux-fill-dev-dit" logger.info("Initializing Flux model") # Resolve checkpoint path ckpt_path = resolve_model_path( name=name, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) # Convert device string to torch.device if needed if isinstance(device, str): device = torch.device(device) # Initialize model with device: model = Flux(DIT_PARAMS).to(dtype=dtype) # Load weights model = load_model_weights(model, ckpt_path, device=device) return model def load_ic_custom( name: str, device: Union[str, torch.device] = "cuda", dtype: torch.dtype = torch.bfloat16, ): """ Function to load the IC-Custom (FLUX.1-Fill-dev + LoRA weights) model. Args: name: Model config name or path hf_download: Whether to download from HF if not found locally device: Device to load model on dtype: Data type for model Returns: model: Loaded IC_Custom model """ logger.info("Initializing IC-Custom model") # Resolve checkpoint path if not os.path.exists(name): name = "flux-fill-dev-dit" ckpt_path = resolve_model_path( name=name, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) # Convert device string to torch.device if needed if isinstance(device, str): device = torch.device(device) # Initialize model on the specified device with device: model = IC_Custom(DIT_PARAMS).to(dtype=dtype) # Load weights model = load_model_weights(model, ckpt_path, device=device) return model def load_embedder( name: str, is_clip: bool, device: Union[str, torch.device], max_length: int, dtype: torch.dtype, ) -> HFEmbedder: """ Generic function to load an embedder model (T5 or CLIP). Args: name: Model name or path is_clip: Whether this is a CLIP model device: Device to load model on max_length: Maximum sequence length dtype: Data type for model Returns: model: Loaded embedder model """ # Convert device string to torch.device if needed if isinstance(device, str): device = torch.device(device) # Resolve model path - for embedders we don't need to download specific files, # just need the repo_id or local path path = resolve_model_path( name=name, repo_id_field="repo_id", filename_field=None, # No specific file needed ckpt_path_field="ckpt_path", hf_download=True, # HFEmbedder handles downloads itself ) # Initialize and return the model model = HFEmbedder( path, max_length=max_length, is_clip=is_clip, ).to(device).to(dtype) return model def load_t5( name: str = "t5-v1_1-xxl", device: Union[str, torch.device] = "cuda", max_length: int = 512, dtype: torch.dtype = torch.bfloat16, ) -> HFEmbedder: """ Load a T5 text encoder model. Args: name: Model name or path device: Device to load model on max_length: Maximum sequence length dtype: Data type for model Returns: model: Loaded T5 model """ if not os.path.exists(name): name = "t5-v1_1-xxl" logger.info(f"Loading T5 model: {name}") return load_embedder( name=name, is_clip=False, device=device, max_length=max_length, dtype=dtype, ) def load_clip( name: str = "clip-vit-large-patch14", device: Union[str, torch.device] = "cuda", dtype: torch.dtype = torch.bfloat16, ) -> HFEmbedder: """ Load a CLIP text encoder model. Args: name: Model name or path device: Device to load model on dtype: Data type for model Returns: model: Loaded CLIP model """ if not os.path.exists(name): name = "clip-vit-large-patch14" logger.info(f"Loading CLIP model: {name}") return load_embedder( name=name, is_clip=True, device=device, max_length=77, # Standard for CLIP dtype=dtype, ) def load_ae( name: str, device: Union[str, torch.device] = "cuda", ) -> AutoEncoder: """ Load an AutoEncoder model. Args: name: Model name or path pretrained_ckpt_path: Path to checkpoint (overrides name) device: Device to load model on Returns: model: Loaded AutoEncoder model """ if not os.path.exists(name): name = "flux-fill-dev-ae" logger.info(f"Loading AutoEncoder model: {name}") # Convert device string to torch.device if needed if isinstance(device, str): device = torch.device(device) # Resolve checkpoint path ckpt_path = resolve_model_path( name=name, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) # Initialize model with device: ae = AutoEncoder(AE_PARAMS) # Load weights model = load_model_weights(ae, ckpt_path, device=device, strict=False) return model def load_redux( redux_name: str = "flux1-redux-dev", siglip_name: str = "siglip-so400m-patch14-384", device: Union[str, torch.device] = "cuda", dtype: torch.dtype = torch.bfloat16, ) -> ReduxImageEncoder: """ Load a Redux Image Encoder model. Args: redux_name: Redux model name or path siglip_name: SigLIP model name or path device: Device to load model on dtype: Data type for model Returns: model: Loaded Redux Image Encoder model """ if not os.path.exists(redux_name): redux_name = "flux1-redux-dev" if not os.path.exists(siglip_name): siglip_name = "siglip-so400m-patch14-384" logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}") # Convert device string to torch.device if needed if isinstance(device, str): device = torch.device(device) # Resolve Redux path redux_path = resolve_model_path( name=redux_name, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) # Resolve SigLIP path - for SigLIP we don't need to download specific files, # just need the repo_id or local path siglip_path = resolve_model_path( name=siglip_name, repo_id_field="repo_id", filename_field=None, # No specific file needed ckpt_path_field="ckpt_path", hf_download=True, # ReduxImageEncoder handles SigLIP downloads itself ) # Initialize and return the model with device: image_encoder = ReduxImageEncoder( redux_path=redux_path, siglip_path=siglip_path, device=device, ).to(dtype=dtype) return image_encoder # ------------------------------------------------------------------------- # 3) load and save weights func # ------------------------------------------------------------------------- def save_lora_weights(model, save_path): """ Extracts LoRA weights from the given model and saves them as a safetensors file. Args: model (torch.nn.Module): The model containing LoRA weights. save_path (str): The path to save the safetensors file. """ # Collect LoRA weights (commonly containing '_lora' in their names) lora_state_dict = {} for name, param in model.state_dict().items(): if '_lora' in name: lora_state_dict[name] = param.cpu() if not lora_state_dict: logger.warning("No LoRA weights found in the model to save.") save_sft(lora_state_dict, save_path) logger.info(f"LoRA weights saved to {save_path}") def save_txt_img_in_weights(model, save_path): """ Save the weights and biases of 'txt_in' and 'img_in' layers from the model. This function extracts parameters whose names are: - 'txt_in.weight' - 'txt_in.bias' - 'img_in.weight' - 'img_in.bias' and saves them to a safetensors file. Args: model (torch.nn.Module): The model containing the parameters. save_path (str): The file path to save the extracted weights. """ target_keys = ['txt_in.weight', 'txt_in.bias', 'img_in.weight', 'img_in.bias'] selected_state_dict = {} for name, param in model.state_dict().items(): if name in target_keys: selected_state_dict[name] = param.cpu() if not selected_state_dict: logger.warning("No txt_in/img_in weights or biases found in the model to save.") save_sft(selected_state_dict, save_path) logger.info(f"txt_in/img_in weights and biases saved to {save_path}") def save_task_rigister_embeddings(weights, save_path): """ Save the weights and biases of 'mask_type_embedding' layer from the model. """ target_keys = ['task_register_embeddings.weight'] selected_state_dict = {} for name, param in weights.items(): if name in target_keys: selected_state_dict[name] = param.cpu() if not selected_state_dict: logger.warning("No task_register_embeddings weights found in the model to save.") save_sft(selected_state_dict, save_path) logger.info(f"task_register_embeddings weights saved to {save_path}") def save_boundary_embeddings(weights, save_path): """ Save the weights and biases of 'boundary_embedding' layer from the model. """ target_keys = ['cond_embedding.weight', 'target_embedding.weight', 'idx_embedding.weight'] selected_state_dict = {} for name, param in weights.items(): if name in target_keys: selected_state_dict[name] = param.cpu() if not selected_state_dict: logger.warning("No boundary_embedding weights found in the model to save.") save_sft(selected_state_dict, save_path) logger.info(f"boundary_embedding weights saved to {save_path}") def load_model_weights( model, weights_path, device=None, strict=False, assign=False, filter_keys=False ): """ Unified function to load weights into a model from a safetensors file. Args: model (torch.nn.Module): The model to update with weights. weights_path (str): Path to the safetensors file containing weights. device (str or torch.device, optional): Device to load weights on. If None, uses CPU. strict (bool): Whether to strictly enforce that the keys match. assign (bool): Whether to assign weights (used by some models). filter_keys (bool): If True, only loads keys that exist in the model. Returns: model: The model with weights loaded """ if weights_path is None: logger.info("No weights path provided, skipping weight loading") return model logger.info(f"Loading weights from {weights_path}") # Load the state dict if device is not None: # load_sft doesn't support torch.device objects device_str = str(device) if not isinstance(device, str) else device state_dict = load_sft(weights_path, device=device_str) else: state_dict = load_sft(weights_path) # Handle different loading strategies if filter_keys: # Filter keys to only those in the model model_state_dict = model.state_dict() update_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} missing_keys = [k for k in state_dict if k not in model_state_dict] if missing_keys: logger.warning(f"Some keys in the file are not found in the model: {missing_keys}") missing, unexpected = [], [] model.load_state_dict(update_dict, strict=strict) else: # Standard loading missing, unexpected = model.load_state_dict(state_dict, strict=strict, assign=assign) # Report any issues with loading if len(unexpected) > 0: print_load_warning(unexpected=unexpected) return model def load_safetensors(path): tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors