Stand-In / trainers /utils.py
fffiloni's picture
Migrated from GitHub
26557da verified
import imageio, os, torch, warnings, torchvision, argparse, json
from peft import LoraConfig, inject_adapter_in_model
from PIL import Image
import pandas as pd
from tqdm import tqdm
from accelerate import Accelerator
class ImageDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None,
metadata_path=None,
max_pixels=1920 * 1080,
height=None,
width=None,
height_division_factor=16,
width_division_factor=16,
data_file_keys=("image",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
repeat=1,
args=None,
):
if args is not None:
base_path = args.dataset_base_path
metadata_path = args.dataset_metadata_path
height = args.height
width = args.width
max_pixels = args.max_pixels
data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat
self.base_path = base_path
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.repeat = repeat
if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
image_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[: -len(file_ext_name) - 1]
if file_ext_name not in self.image_file_extension:
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(
os.path.join(folder, prompt_file_name), "r", encoding="utf-8"
) as f:
prompt = f.read().strip()
image_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["image"] = image_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height * scale), round(width * scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
)
image = torchvision.transforms.functional.center_crop(
image, (target_height, target_width)
)
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
return image
def load_data(self, file_path):
return self.load_image(file_path)
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class VideoDataset(torch.utils.data.Dataset):
def __init__(
self,
base_path=None,
metadata_path=None,
num_frames=81,
time_division_factor=4,
time_division_remainder=1,
max_pixels=1920 * 1080,
height=None,
width=None,
height_division_factor=16,
width_division_factor=16,
data_file_keys=("video",),
image_file_extension=("jpg", "jpeg", "png", "webp"),
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"),
repeat=1,
args=None,
):
if args is not None:
base_path = args.dataset_base_path
metadata_path = args.dataset_metadata_path
height = args.height
width = args.width
max_pixels = args.max_pixels
num_frames = args.num_frames
data_file_keys = args.data_file_keys.split(",")
repeat = args.dataset_repeat
self.base_path = base_path
self.num_frames = num_frames
self.time_division_factor = time_division_factor
self.time_division_remainder = time_division_remainder
self.max_pixels = max_pixels
self.height = height
self.width = width
self.height_division_factor = height_division_factor
self.width_division_factor = width_division_factor
self.data_file_keys = data_file_keys
self.image_file_extension = image_file_extension
self.video_file_extension = video_file_extension
self.repeat = repeat
if height is not None and width is not None:
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
self.dynamic_resolution = False
elif height is None and width is None:
print("Height and width are none. Setting `dynamic_resolution` to True.")
self.dynamic_resolution = True
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
print(f"{len(metadata)} lines in metadata.")
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
elif metadata_path.endswith(".json"):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.data = metadata
else:
metadata = pd.read_csv(metadata_path)
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
def generate_metadata(self, folder):
video_list, prompt_list = [], []
file_set = set(os.listdir(folder))
for file_name in file_set:
if "." not in file_name:
continue
file_ext_name = file_name.split(".")[-1].lower()
file_base_name = file_name[: -len(file_ext_name) - 1]
if (
file_ext_name not in self.image_file_extension
and file_ext_name not in self.video_file_extension
):
continue
prompt_file_name = file_base_name + ".txt"
if prompt_file_name not in file_set:
continue
with open(
os.path.join(folder, prompt_file_name), "r", encoding="utf-8"
) as f:
prompt = f.read().strip()
video_list.append(file_name)
prompt_list.append(prompt)
metadata = pd.DataFrame()
metadata["video"] = video_list
metadata["prompt"] = prompt_list
return metadata
def crop_and_resize(self, image, target_height, target_width):
width, height = image.size
scale = max(target_width / width, target_height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height * scale), round(width * scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
)
image = torchvision.transforms.functional.center_crop(
image, (target_height, target_width)
)
return image
def get_height_width(self, image):
if self.dynamic_resolution:
width, height = image.size
if width * height > self.max_pixels:
scale = (width * height / self.max_pixels) ** 0.5
height, width = int(height / scale), int(width / scale)
height = height // self.height_division_factor * self.height_division_factor
width = width // self.width_division_factor * self.width_division_factor
else:
height, width = self.height, self.width
return height, width
def get_num_frames(self, reader):
num_frames = self.num_frames
if int(reader.count_frames()) < num_frames:
num_frames = int(reader.count_frames())
while (
num_frames > 1
and num_frames % self.time_division_factor
!= self.time_division_remainder
):
num_frames -= 1
return num_frames
def load_video(self, file_path):
reader = imageio.get_reader(file_path)
num_frames = self.get_num_frames(reader)
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
frames.append(frame)
reader.close()
return frames
def load_image(self, file_path):
image = Image.open(file_path).convert("RGB")
image = self.crop_and_resize(image, *self.get_height_width(image))
frames = [image]
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
return file_ext_name.lower() in self.image_file_extension
def is_video(self, file_path):
file_ext_name = file_path.split(".")[-1]
return file_ext_name.lower() in self.video_file_extension
def load_data(self, file_path):
if self.is_image(file_path):
return self.load_image(file_path)
elif self.is_video(file_path):
return self.load_video(file_path)
else:
return None
def __getitem__(self, data_id):
data = self.data[data_id % len(self.data)].copy()
for key in self.data_file_keys:
if key in data:
path = os.path.join(self.base_path, data[key])
data[key] = self.load_data(path)
if data[key] is None:
warnings.warn(f"cannot load file {data[key]}.")
return None
return data
def __len__(self):
return len(self.data) * self.repeat
class DiffusionTrainingModule(torch.nn.Module):
def __init__(self):
super().__init__()
def to(self, *args, **kwargs):
for name, model in self.named_children():
model.to(*args, **kwargs)
return self
def trainable_modules(self):
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
return trainable_modules
def trainable_param_names(self):
trainable_param_names = list(
filter(
lambda named_param: named_param[1].requires_grad,
self.named_parameters(),
)
)
trainable_param_names = set(
[named_param[0] for named_param in trainable_param_names]
)
return trainable_param_names
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None):
if lora_alpha is None:
lora_alpha = lora_rank
lora_config = LoraConfig(
r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules
)
model = inject_adapter_in_model(lora_config, model)
return model
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
trainable_param_names = self.trainable_param_names()
state_dict = {
name: param
for name, param in state_dict.items()
if name in trainable_param_names
}
if remove_prefix is not None:
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith(remove_prefix):
name = name[len(remove_prefix) :]
state_dict_[name] = param
state_dict = state_dict_
return state_dict
class ModelLogger:
def __init__(
self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x: x
):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(
state_dict, remove_prefix=self.remove_prefix_in_ckpt
)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task(
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
):
dataloader = torch.utils.data.DataLoader(
dataset, shuffle=True, collate_fn=lambda x: x[0]
)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(loss)
scheduler.step()
model_logger.on_epoch_end(accelerator, model, epoch_id)
def launch_data_process_task(
model: DiffusionTrainingModule, dataset, output_path="./models"
):
dataloader = torch.utils.data.DataLoader(
dataset, shuffle=False, collate_fn=lambda x: x[0]
)
accelerator = Accelerator()
model, dataloader = accelerator.prepare(model, dataloader)
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True)
for data_id, data in enumerate(tqdm(dataloader)):
with torch.no_grad():
inputs = model.forward_preprocess(data)
inputs = {
key: inputs[key] for key in model.model_input_keys if key in inputs
}
torch.save(
inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")
)
def wan_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--dataset_base_path",
type=str,
default="",
required=True,
help="Base path of the dataset.",
)
parser.add_argument(
"--dataset_metadata_path",
type=str,
default=None,
help="Path to the metadata file of the dataset.",
)
parser.add_argument(
"--max_pixels",
type=int,
default=1280 * 720,
help="Maximum number of pixels per frame, used for dynamic resolution..",
)
parser.add_argument(
"--height",
type=int,
default=None,
help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.",
)
parser.add_argument(
"--width",
type=int,
default=None,
help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.",
)
parser.add_argument(
"--num_frames",
type=int,
default=81,
help="Number of frames per video. Frames are sampled from the video prefix.",
)
parser.add_argument(
"--data_file_keys",
type=str,
default="image,video",
help="Data file keys in the metadata. Comma-separated.",
)
parser.add_argument(
"--dataset_repeat",
type=int,
default=1,
help="Number of times to repeat the dataset per epoch.",
)
parser.add_argument(
"--model_paths",
type=str,
default=None,
help="Paths to load models. In JSON format.",
)
parser.add_argument(
"--model_id_with_origin_paths",
type=str,
default=None,
help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.",
)
parser.add_argument(
"--learning_rate", type=float, default=1e-4, help="Learning rate."
)
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--output_path", type=str, default="./models", help="Output save path."
)
parser.add_argument(
"--remove_prefix_in_ckpt",
type=str,
default="pipe.dit.",
help="Remove prefix in ckpt.",
)
parser.add_argument(
"--trainable_models",
type=str,
default=None,
help="Models to train, e.g., dit, vae, text_encoder.",
)
parser.add_argument(
"--lora_base_model",
type=str,
default=None,
help="Which model LoRA is added to.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Which layers LoRA is added to.",
)
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument(
"--extra_inputs", default=None, help="Additional model inputs, comma-separated."
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to offload gradient checkpointing to CPU memory.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Gradient accumulation steps.",
)
parser.add_argument(
"--max_timestep_boundary",
type=float,
default=1.0,
help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).",
)
parser.add_argument(
"--min_timestep_boundary",
type=float,
default=0.0,
help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).",
)
return parser
def flux_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--dataset_base_path",
type=str,
default="",
required=True,
help="Base path of the dataset.",
)
parser.add_argument(
"--dataset_metadata_path",
type=str,
default=None,
help="Path to the metadata file of the dataset.",
)
parser.add_argument(
"--max_pixels",
type=int,
default=1024 * 1024,
help="Maximum number of pixels per frame, used for dynamic resolution..",
)
parser.add_argument(
"--height",
type=int,
default=None,
help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.",
)
parser.add_argument(
"--width",
type=int,
default=None,
help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.",
)
parser.add_argument(
"--data_file_keys",
type=str,
default="image",
help="Data file keys in the metadata. Comma-separated.",
)
parser.add_argument(
"--dataset_repeat",
type=int,
default=1,
help="Number of times to repeat the dataset per epoch.",
)
parser.add_argument(
"--model_paths",
type=str,
default=None,
help="Paths to load models. In JSON format.",
)
parser.add_argument(
"--model_id_with_origin_paths",
type=str,
default=None,
help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.",
)
parser.add_argument(
"--learning_rate", type=float, default=1e-4, help="Learning rate."
)
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--output_path", type=str, default="./models", help="Output save path."
)
parser.add_argument(
"--remove_prefix_in_ckpt",
type=str,
default="pipe.dit.",
help="Remove prefix in ckpt.",
)
parser.add_argument(
"--trainable_models",
type=str,
default=None,
help="Models to train, e.g., dit, vae, text_encoder.",
)
parser.add_argument(
"--lora_base_model",
type=str,
default=None,
help="Which model LoRA is added to.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Which layers LoRA is added to.",
)
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
parser.add_argument(
"--extra_inputs", default=None, help="Additional model inputs, comma-separated."
)
parser.add_argument(
"--align_to_opensource_format",
default=False,
action="store_true",
help="Whether to align the lora format to opensource format. Only for DiT's LoRA.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--use_gradient_checkpointing_offload",
default=False,
action="store_true",
help="Whether to offload gradient checkpointing to CPU memory.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Gradient accumulation steps.",
)
return parser