import imageio, os, torch, warnings, torchvision, argparse, json from peft import LoraConfig, inject_adapter_in_model from PIL import Image import pandas as pd from tqdm import tqdm from accelerate import Accelerator class ImageDataset(torch.utils.data.Dataset): def __init__( self, base_path=None, metadata_path=None, max_pixels=1920 * 1080, height=None, width=None, height_division_factor=16, width_division_factor=16, data_file_keys=("image",), image_file_extension=("jpg", "jpeg", "png", "webp"), repeat=1, args=None, ): if args is not None: base_path = args.dataset_base_path metadata_path = args.dataset_metadata_path height = args.height width = args.width max_pixels = args.max_pixels data_file_keys = args.data_file_keys.split(",") repeat = args.dataset_repeat self.base_path = base_path self.max_pixels = max_pixels self.height = height self.width = width self.height_division_factor = height_division_factor self.width_division_factor = width_division_factor self.data_file_keys = data_file_keys self.image_file_extension = image_file_extension self.repeat = repeat if height is not None and width is not None: print("Height and width are fixed. Setting `dynamic_resolution` to False.") self.dynamic_resolution = False elif height is None and width is None: print("Height and width are none. Setting `dynamic_resolution` to True.") self.dynamic_resolution = True if metadata_path is None: print("No metadata. Trying to generate it.") metadata = self.generate_metadata(base_path) print(f"{len(metadata)} lines in metadata.") self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] elif metadata_path.endswith(".json"): with open(metadata_path, "r") as f: metadata = json.load(f) self.data = metadata else: metadata = pd.read_csv(metadata_path) self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] def generate_metadata(self, folder): image_list, prompt_list = [], [] file_set = set(os.listdir(folder)) for file_name in file_set: if "." not in file_name: continue file_ext_name = file_name.split(".")[-1].lower() file_base_name = file_name[: -len(file_ext_name) - 1] if file_ext_name not in self.image_file_extension: continue prompt_file_name = file_base_name + ".txt" if prompt_file_name not in file_set: continue with open( os.path.join(folder, prompt_file_name), "r", encoding="utf-8" ) as f: prompt = f.read().strip() image_list.append(file_name) prompt_list.append(prompt) metadata = pd.DataFrame() metadata["image"] = image_list metadata["prompt"] = prompt_list return metadata def crop_and_resize(self, image, target_height, target_width): width, height = image.size scale = max(target_width / width, target_height / height) image = torchvision.transforms.functional.resize( image, (round(height * scale), round(width * scale)), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) image = torchvision.transforms.functional.center_crop( image, (target_height, target_width) ) return image def get_height_width(self, image): if self.dynamic_resolution: width, height = image.size if width * height > self.max_pixels: scale = (width * height / self.max_pixels) ** 0.5 height, width = int(height / scale), int(width / scale) height = height // self.height_division_factor * self.height_division_factor width = width // self.width_division_factor * self.width_division_factor else: height, width = self.height, self.width return height, width def load_image(self, file_path): image = Image.open(file_path).convert("RGB") image = self.crop_and_resize(image, *self.get_height_width(image)) return image def load_data(self, file_path): return self.load_image(file_path) def __getitem__(self, data_id): data = self.data[data_id % len(self.data)].copy() for key in self.data_file_keys: if key in data: path = os.path.join(self.base_path, data[key]) data[key] = self.load_data(path) if data[key] is None: warnings.warn(f"cannot load file {data[key]}.") return None return data def __len__(self): return len(self.data) * self.repeat class VideoDataset(torch.utils.data.Dataset): def __init__( self, base_path=None, metadata_path=None, num_frames=81, time_division_factor=4, time_division_remainder=1, max_pixels=1920 * 1080, height=None, width=None, height_division_factor=16, width_division_factor=16, data_file_keys=("video",), image_file_extension=("jpg", "jpeg", "png", "webp"), video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), repeat=1, args=None, ): if args is not None: base_path = args.dataset_base_path metadata_path = args.dataset_metadata_path height = args.height width = args.width max_pixels = args.max_pixels num_frames = args.num_frames data_file_keys = args.data_file_keys.split(",") repeat = args.dataset_repeat self.base_path = base_path self.num_frames = num_frames self.time_division_factor = time_division_factor self.time_division_remainder = time_division_remainder self.max_pixels = max_pixels self.height = height self.width = width self.height_division_factor = height_division_factor self.width_division_factor = width_division_factor self.data_file_keys = data_file_keys self.image_file_extension = image_file_extension self.video_file_extension = video_file_extension self.repeat = repeat if height is not None and width is not None: print("Height and width are fixed. Setting `dynamic_resolution` to False.") self.dynamic_resolution = False elif height is None and width is None: print("Height and width are none. Setting `dynamic_resolution` to True.") self.dynamic_resolution = True if metadata_path is None: print("No metadata. Trying to generate it.") metadata = self.generate_metadata(base_path) print(f"{len(metadata)} lines in metadata.") self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] elif metadata_path.endswith(".json"): with open(metadata_path, "r") as f: metadata = json.load(f) self.data = metadata else: metadata = pd.read_csv(metadata_path) self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] def generate_metadata(self, folder): video_list, prompt_list = [], [] file_set = set(os.listdir(folder)) for file_name in file_set: if "." not in file_name: continue file_ext_name = file_name.split(".")[-1].lower() file_base_name = file_name[: -len(file_ext_name) - 1] if ( file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension ): continue prompt_file_name = file_base_name + ".txt" if prompt_file_name not in file_set: continue with open( os.path.join(folder, prompt_file_name), "r", encoding="utf-8" ) as f: prompt = f.read().strip() video_list.append(file_name) prompt_list.append(prompt) metadata = pd.DataFrame() metadata["video"] = video_list metadata["prompt"] = prompt_list return metadata def crop_and_resize(self, image, target_height, target_width): width, height = image.size scale = max(target_width / width, target_height / height) image = torchvision.transforms.functional.resize( image, (round(height * scale), round(width * scale)), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) image = torchvision.transforms.functional.center_crop( image, (target_height, target_width) ) return image def get_height_width(self, image): if self.dynamic_resolution: width, height = image.size if width * height > self.max_pixels: scale = (width * height / self.max_pixels) ** 0.5 height, width = int(height / scale), int(width / scale) height = height // self.height_division_factor * self.height_division_factor width = width // self.width_division_factor * self.width_division_factor else: height, width = self.height, self.width return height, width def get_num_frames(self, reader): num_frames = self.num_frames if int(reader.count_frames()) < num_frames: num_frames = int(reader.count_frames()) while ( num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder ): num_frames -= 1 return num_frames def load_video(self, file_path): reader = imageio.get_reader(file_path) num_frames = self.get_num_frames(reader) frames = [] for frame_id in range(num_frames): frame = reader.get_data(frame_id) frame = Image.fromarray(frame) frame = self.crop_and_resize(frame, *self.get_height_width(frame)) frames.append(frame) reader.close() return frames def load_image(self, file_path): image = Image.open(file_path).convert("RGB") image = self.crop_and_resize(image, *self.get_height_width(image)) frames = [image] return frames def is_image(self, file_path): file_ext_name = file_path.split(".")[-1] return file_ext_name.lower() in self.image_file_extension def is_video(self, file_path): file_ext_name = file_path.split(".")[-1] return file_ext_name.lower() in self.video_file_extension def load_data(self, file_path): if self.is_image(file_path): return self.load_image(file_path) elif self.is_video(file_path): return self.load_video(file_path) else: return None def __getitem__(self, data_id): data = self.data[data_id % len(self.data)].copy() for key in self.data_file_keys: if key in data: path = os.path.join(self.base_path, data[key]) data[key] = self.load_data(path) if data[key] is None: warnings.warn(f"cannot load file {data[key]}.") return None return data def __len__(self): return len(self.data) * self.repeat class DiffusionTrainingModule(torch.nn.Module): def __init__(self): super().__init__() def to(self, *args, **kwargs): for name, model in self.named_children(): model.to(*args, **kwargs) return self def trainable_modules(self): trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) return trainable_modules def trainable_param_names(self): trainable_param_names = list( filter( lambda named_param: named_param[1].requires_grad, self.named_parameters(), ) ) trainable_param_names = set( [named_param[0] for named_param in trainable_param_names] ) return trainable_param_names def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): if lora_alpha is None: lora_alpha = lora_rank lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules ) model = inject_adapter_in_model(lora_config, model) return model def export_trainable_state_dict(self, state_dict, remove_prefix=None): trainable_param_names = self.trainable_param_names() state_dict = { name: param for name, param in state_dict.items() if name in trainable_param_names } if remove_prefix is not None: state_dict_ = {} for name, param in state_dict.items(): if name.startswith(remove_prefix): name = name[len(remove_prefix) :] state_dict_[name] = param state_dict = state_dict_ return state_dict class ModelLogger: def __init__( self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x: x ): self.output_path = output_path self.remove_prefix_in_ckpt = remove_prefix_in_ckpt self.state_dict_converter = state_dict_converter def on_step_end(self, loss): pass def on_epoch_end(self, accelerator, model, epoch_id): accelerator.wait_for_everyone() if accelerator.is_main_process: state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict( state_dict, remove_prefix=self.remove_prefix_in_ckpt ) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") accelerator.save(state_dict, path, safe_serialization=True) def launch_training_task( dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, num_epochs: int = 1, gradient_accumulation_steps: int = 1, ): dataloader = torch.utils.data.DataLoader( dataset, shuffle=True, collate_fn=lambda x: x[0] ) accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) model, optimizer, dataloader, scheduler = accelerator.prepare( model, optimizer, dataloader, scheduler ) for epoch_id in range(num_epochs): for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() loss = model(data) accelerator.backward(loss) optimizer.step() model_logger.on_step_end(loss) scheduler.step() model_logger.on_epoch_end(accelerator, model, epoch_id) def launch_data_process_task( model: DiffusionTrainingModule, dataset, output_path="./models" ): dataloader = torch.utils.data.DataLoader( dataset, shuffle=False, collate_fn=lambda x: x[0] ) accelerator = Accelerator() model, dataloader = accelerator.prepare(model, dataloader) os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) for data_id, data in enumerate(tqdm(dataloader)): with torch.no_grad(): inputs = model.forward_preprocess(data) inputs = { key: inputs[key] for key in model.model_input_keys if key in inputs } torch.save( inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth") ) def wan_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.", ) parser.add_argument( "--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.", ) parser.add_argument( "--max_pixels", type=int, default=1280 * 720, help="Maximum number of pixels per frame, used for dynamic resolution..", ) parser.add_argument( "--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.", ) parser.add_argument( "--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.", ) parser.add_argument( "--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.", ) parser.add_argument( "--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.", ) parser.add_argument( "--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.", ) parser.add_argument( "--model_paths", type=str, default=None, help="Paths to load models. In JSON format.", ) parser.add_argument( "--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Learning rate." ) parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") parser.add_argument( "--output_path", type=str, default="./models", help="Output save path." ) parser.add_argument( "--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.", ) parser.add_argument( "--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.", ) parser.add_argument( "--lora_base_model", type=str, default=None, help="Which model LoRA is added to.", ) parser.add_argument( "--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.", ) parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") parser.add_argument( "--extra_inputs", default=None, help="Additional model inputs, comma-separated." ) parser.add_argument( "--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.", ) parser.add_argument( "--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).", ) parser.add_argument( "--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).", ) return parser def flux_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.", ) parser.add_argument( "--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.", ) parser.add_argument( "--max_pixels", type=int, default=1024 * 1024, help="Maximum number of pixels per frame, used for dynamic resolution..", ) parser.add_argument( "--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.", ) parser.add_argument( "--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.", ) parser.add_argument( "--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.", ) parser.add_argument( "--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.", ) parser.add_argument( "--model_paths", type=str, default=None, help="Paths to load models. In JSON format.", ) parser.add_argument( "--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Learning rate." ) parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") parser.add_argument( "--output_path", type=str, default="./models", help="Output save path." ) parser.add_argument( "--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.", ) parser.add_argument( "--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.", ) parser.add_argument( "--lora_base_model", type=str, default=None, help="Which model LoRA is added to.", ) parser.add_argument( "--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.", ) parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") parser.add_argument( "--extra_inputs", default=None, help="Additional model inputs, comma-separated." ) parser.add_argument( "--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.", ) parser.add_argument( "--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.", ) parser.add_argument( "--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.", ) return parser