Spaces:
Running
on
Zero
Running
on
Zero
import imageio, os, torch, warnings, torchvision, argparse, json | |
from peft import LoraConfig, inject_adapter_in_model | |
from PIL import Image | |
import pandas as pd | |
from tqdm import tqdm | |
from accelerate import Accelerator | |
class ImageDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
base_path=None, | |
metadata_path=None, | |
max_pixels=1920 * 1080, | |
height=None, | |
width=None, | |
height_division_factor=16, | |
width_division_factor=16, | |
data_file_keys=("image",), | |
image_file_extension=("jpg", "jpeg", "png", "webp"), | |
repeat=1, | |
args=None, | |
): | |
if args is not None: | |
base_path = args.dataset_base_path | |
metadata_path = args.dataset_metadata_path | |
height = args.height | |
width = args.width | |
max_pixels = args.max_pixels | |
data_file_keys = args.data_file_keys.split(",") | |
repeat = args.dataset_repeat | |
self.base_path = base_path | |
self.max_pixels = max_pixels | |
self.height = height | |
self.width = width | |
self.height_division_factor = height_division_factor | |
self.width_division_factor = width_division_factor | |
self.data_file_keys = data_file_keys | |
self.image_file_extension = image_file_extension | |
self.repeat = repeat | |
if height is not None and width is not None: | |
print("Height and width are fixed. Setting `dynamic_resolution` to False.") | |
self.dynamic_resolution = False | |
elif height is None and width is None: | |
print("Height and width are none. Setting `dynamic_resolution` to True.") | |
self.dynamic_resolution = True | |
if metadata_path is None: | |
print("No metadata. Trying to generate it.") | |
metadata = self.generate_metadata(base_path) | |
print(f"{len(metadata)} lines in metadata.") | |
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
elif metadata_path.endswith(".json"): | |
with open(metadata_path, "r") as f: | |
metadata = json.load(f) | |
self.data = metadata | |
else: | |
metadata = pd.read_csv(metadata_path) | |
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
def generate_metadata(self, folder): | |
image_list, prompt_list = [], [] | |
file_set = set(os.listdir(folder)) | |
for file_name in file_set: | |
if "." not in file_name: | |
continue | |
file_ext_name = file_name.split(".")[-1].lower() | |
file_base_name = file_name[: -len(file_ext_name) - 1] | |
if file_ext_name not in self.image_file_extension: | |
continue | |
prompt_file_name = file_base_name + ".txt" | |
if prompt_file_name not in file_set: | |
continue | |
with open( | |
os.path.join(folder, prompt_file_name), "r", encoding="utf-8" | |
) as f: | |
prompt = f.read().strip() | |
image_list.append(file_name) | |
prompt_list.append(prompt) | |
metadata = pd.DataFrame() | |
metadata["image"] = image_list | |
metadata["prompt"] = prompt_list | |
return metadata | |
def crop_and_resize(self, image, target_height, target_width): | |
width, height = image.size | |
scale = max(target_width / width, target_height / height) | |
image = torchvision.transforms.functional.resize( | |
image, | |
(round(height * scale), round(width * scale)), | |
interpolation=torchvision.transforms.InterpolationMode.BILINEAR, | |
) | |
image = torchvision.transforms.functional.center_crop( | |
image, (target_height, target_width) | |
) | |
return image | |
def get_height_width(self, image): | |
if self.dynamic_resolution: | |
width, height = image.size | |
if width * height > self.max_pixels: | |
scale = (width * height / self.max_pixels) ** 0.5 | |
height, width = int(height / scale), int(width / scale) | |
height = height // self.height_division_factor * self.height_division_factor | |
width = width // self.width_division_factor * self.width_division_factor | |
else: | |
height, width = self.height, self.width | |
return height, width | |
def load_image(self, file_path): | |
image = Image.open(file_path).convert("RGB") | |
image = self.crop_and_resize(image, *self.get_height_width(image)) | |
return image | |
def load_data(self, file_path): | |
return self.load_image(file_path) | |
def __getitem__(self, data_id): | |
data = self.data[data_id % len(self.data)].copy() | |
for key in self.data_file_keys: | |
if key in data: | |
path = os.path.join(self.base_path, data[key]) | |
data[key] = self.load_data(path) | |
if data[key] is None: | |
warnings.warn(f"cannot load file {data[key]}.") | |
return None | |
return data | |
def __len__(self): | |
return len(self.data) * self.repeat | |
class VideoDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
base_path=None, | |
metadata_path=None, | |
num_frames=81, | |
time_division_factor=4, | |
time_division_remainder=1, | |
max_pixels=1920 * 1080, | |
height=None, | |
width=None, | |
height_division_factor=16, | |
width_division_factor=16, | |
data_file_keys=("video",), | |
image_file_extension=("jpg", "jpeg", "png", "webp"), | |
video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), | |
repeat=1, | |
args=None, | |
): | |
if args is not None: | |
base_path = args.dataset_base_path | |
metadata_path = args.dataset_metadata_path | |
height = args.height | |
width = args.width | |
max_pixels = args.max_pixels | |
num_frames = args.num_frames | |
data_file_keys = args.data_file_keys.split(",") | |
repeat = args.dataset_repeat | |
self.base_path = base_path | |
self.num_frames = num_frames | |
self.time_division_factor = time_division_factor | |
self.time_division_remainder = time_division_remainder | |
self.max_pixels = max_pixels | |
self.height = height | |
self.width = width | |
self.height_division_factor = height_division_factor | |
self.width_division_factor = width_division_factor | |
self.data_file_keys = data_file_keys | |
self.image_file_extension = image_file_extension | |
self.video_file_extension = video_file_extension | |
self.repeat = repeat | |
if height is not None and width is not None: | |
print("Height and width are fixed. Setting `dynamic_resolution` to False.") | |
self.dynamic_resolution = False | |
elif height is None and width is None: | |
print("Height and width are none. Setting `dynamic_resolution` to True.") | |
self.dynamic_resolution = True | |
if metadata_path is None: | |
print("No metadata. Trying to generate it.") | |
metadata = self.generate_metadata(base_path) | |
print(f"{len(metadata)} lines in metadata.") | |
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
elif metadata_path.endswith(".json"): | |
with open(metadata_path, "r") as f: | |
metadata = json.load(f) | |
self.data = metadata | |
else: | |
metadata = pd.read_csv(metadata_path) | |
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] | |
def generate_metadata(self, folder): | |
video_list, prompt_list = [], [] | |
file_set = set(os.listdir(folder)) | |
for file_name in file_set: | |
if "." not in file_name: | |
continue | |
file_ext_name = file_name.split(".")[-1].lower() | |
file_base_name = file_name[: -len(file_ext_name) - 1] | |
if ( | |
file_ext_name not in self.image_file_extension | |
and file_ext_name not in self.video_file_extension | |
): | |
continue | |
prompt_file_name = file_base_name + ".txt" | |
if prompt_file_name not in file_set: | |
continue | |
with open( | |
os.path.join(folder, prompt_file_name), "r", encoding="utf-8" | |
) as f: | |
prompt = f.read().strip() | |
video_list.append(file_name) | |
prompt_list.append(prompt) | |
metadata = pd.DataFrame() | |
metadata["video"] = video_list | |
metadata["prompt"] = prompt_list | |
return metadata | |
def crop_and_resize(self, image, target_height, target_width): | |
width, height = image.size | |
scale = max(target_width / width, target_height / height) | |
image = torchvision.transforms.functional.resize( | |
image, | |
(round(height * scale), round(width * scale)), | |
interpolation=torchvision.transforms.InterpolationMode.BILINEAR, | |
) | |
image = torchvision.transforms.functional.center_crop( | |
image, (target_height, target_width) | |
) | |
return image | |
def get_height_width(self, image): | |
if self.dynamic_resolution: | |
width, height = image.size | |
if width * height > self.max_pixels: | |
scale = (width * height / self.max_pixels) ** 0.5 | |
height, width = int(height / scale), int(width / scale) | |
height = height // self.height_division_factor * self.height_division_factor | |
width = width // self.width_division_factor * self.width_division_factor | |
else: | |
height, width = self.height, self.width | |
return height, width | |
def get_num_frames(self, reader): | |
num_frames = self.num_frames | |
if int(reader.count_frames()) < num_frames: | |
num_frames = int(reader.count_frames()) | |
while ( | |
num_frames > 1 | |
and num_frames % self.time_division_factor | |
!= self.time_division_remainder | |
): | |
num_frames -= 1 | |
return num_frames | |
def load_video(self, file_path): | |
reader = imageio.get_reader(file_path) | |
num_frames = self.get_num_frames(reader) | |
frames = [] | |
for frame_id in range(num_frames): | |
frame = reader.get_data(frame_id) | |
frame = Image.fromarray(frame) | |
frame = self.crop_and_resize(frame, *self.get_height_width(frame)) | |
frames.append(frame) | |
reader.close() | |
return frames | |
def load_image(self, file_path): | |
image = Image.open(file_path).convert("RGB") | |
image = self.crop_and_resize(image, *self.get_height_width(image)) | |
frames = [image] | |
return frames | |
def is_image(self, file_path): | |
file_ext_name = file_path.split(".")[-1] | |
return file_ext_name.lower() in self.image_file_extension | |
def is_video(self, file_path): | |
file_ext_name = file_path.split(".")[-1] | |
return file_ext_name.lower() in self.video_file_extension | |
def load_data(self, file_path): | |
if self.is_image(file_path): | |
return self.load_image(file_path) | |
elif self.is_video(file_path): | |
return self.load_video(file_path) | |
else: | |
return None | |
def __getitem__(self, data_id): | |
data = self.data[data_id % len(self.data)].copy() | |
for key in self.data_file_keys: | |
if key in data: | |
path = os.path.join(self.base_path, data[key]) | |
data[key] = self.load_data(path) | |
if data[key] is None: | |
warnings.warn(f"cannot load file {data[key]}.") | |
return None | |
return data | |
def __len__(self): | |
return len(self.data) * self.repeat | |
class DiffusionTrainingModule(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def to(self, *args, **kwargs): | |
for name, model in self.named_children(): | |
model.to(*args, **kwargs) | |
return self | |
def trainable_modules(self): | |
trainable_modules = filter(lambda p: p.requires_grad, self.parameters()) | |
return trainable_modules | |
def trainable_param_names(self): | |
trainable_param_names = list( | |
filter( | |
lambda named_param: named_param[1].requires_grad, | |
self.named_parameters(), | |
) | |
) | |
trainable_param_names = set( | |
[named_param[0] for named_param in trainable_param_names] | |
) | |
return trainable_param_names | |
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): | |
if lora_alpha is None: | |
lora_alpha = lora_rank | |
lora_config = LoraConfig( | |
r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules | |
) | |
model = inject_adapter_in_model(lora_config, model) | |
return model | |
def export_trainable_state_dict(self, state_dict, remove_prefix=None): | |
trainable_param_names = self.trainable_param_names() | |
state_dict = { | |
name: param | |
for name, param in state_dict.items() | |
if name in trainable_param_names | |
} | |
if remove_prefix is not None: | |
state_dict_ = {} | |
for name, param in state_dict.items(): | |
if name.startswith(remove_prefix): | |
name = name[len(remove_prefix) :] | |
state_dict_[name] = param | |
state_dict = state_dict_ | |
return state_dict | |
class ModelLogger: | |
def __init__( | |
self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x: x | |
): | |
self.output_path = output_path | |
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt | |
self.state_dict_converter = state_dict_converter | |
def on_step_end(self, loss): | |
pass | |
def on_epoch_end(self, accelerator, model, epoch_id): | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
state_dict = accelerator.get_state_dict(model) | |
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict( | |
state_dict, remove_prefix=self.remove_prefix_in_ckpt | |
) | |
state_dict = self.state_dict_converter(state_dict) | |
os.makedirs(self.output_path, exist_ok=True) | |
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") | |
accelerator.save(state_dict, path, safe_serialization=True) | |
def launch_training_task( | |
dataset: torch.utils.data.Dataset, | |
model: DiffusionTrainingModule, | |
model_logger: ModelLogger, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
num_epochs: int = 1, | |
gradient_accumulation_steps: int = 1, | |
): | |
dataloader = torch.utils.data.DataLoader( | |
dataset, shuffle=True, collate_fn=lambda x: x[0] | |
) | |
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) | |
model, optimizer, dataloader, scheduler = accelerator.prepare( | |
model, optimizer, dataloader, scheduler | |
) | |
for epoch_id in range(num_epochs): | |
for data in tqdm(dataloader): | |
with accelerator.accumulate(model): | |
optimizer.zero_grad() | |
loss = model(data) | |
accelerator.backward(loss) | |
optimizer.step() | |
model_logger.on_step_end(loss) | |
scheduler.step() | |
model_logger.on_epoch_end(accelerator, model, epoch_id) | |
def launch_data_process_task( | |
model: DiffusionTrainingModule, dataset, output_path="./models" | |
): | |
dataloader = torch.utils.data.DataLoader( | |
dataset, shuffle=False, collate_fn=lambda x: x[0] | |
) | |
accelerator = Accelerator() | |
model, dataloader = accelerator.prepare(model, dataloader) | |
os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) | |
for data_id, data in enumerate(tqdm(dataloader)): | |
with torch.no_grad(): | |
inputs = model.forward_preprocess(data) | |
inputs = { | |
key: inputs[key] for key in model.model_input_keys if key in inputs | |
} | |
torch.save( | |
inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth") | |
) | |
def wan_parser(): | |
parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
parser.add_argument( | |
"--dataset_base_path", | |
type=str, | |
default="", | |
required=True, | |
help="Base path of the dataset.", | |
) | |
parser.add_argument( | |
"--dataset_metadata_path", | |
type=str, | |
default=None, | |
help="Path to the metadata file of the dataset.", | |
) | |
parser.add_argument( | |
"--max_pixels", | |
type=int, | |
default=1280 * 720, | |
help="Maximum number of pixels per frame, used for dynamic resolution..", | |
) | |
parser.add_argument( | |
"--height", | |
type=int, | |
default=None, | |
help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.", | |
) | |
parser.add_argument( | |
"--width", | |
type=int, | |
default=None, | |
help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.", | |
) | |
parser.add_argument( | |
"--num_frames", | |
type=int, | |
default=81, | |
help="Number of frames per video. Frames are sampled from the video prefix.", | |
) | |
parser.add_argument( | |
"--data_file_keys", | |
type=str, | |
default="image,video", | |
help="Data file keys in the metadata. Comma-separated.", | |
) | |
parser.add_argument( | |
"--dataset_repeat", | |
type=int, | |
default=1, | |
help="Number of times to repeat the dataset per epoch.", | |
) | |
parser.add_argument( | |
"--model_paths", | |
type=str, | |
default=None, | |
help="Paths to load models. In JSON format.", | |
) | |
parser.add_argument( | |
"--model_id_with_origin_paths", | |
type=str, | |
default=None, | |
help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.", | |
) | |
parser.add_argument( | |
"--learning_rate", type=float, default=1e-4, help="Learning rate." | |
) | |
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") | |
parser.add_argument( | |
"--output_path", type=str, default="./models", help="Output save path." | |
) | |
parser.add_argument( | |
"--remove_prefix_in_ckpt", | |
type=str, | |
default="pipe.dit.", | |
help="Remove prefix in ckpt.", | |
) | |
parser.add_argument( | |
"--trainable_models", | |
type=str, | |
default=None, | |
help="Models to train, e.g., dit, vae, text_encoder.", | |
) | |
parser.add_argument( | |
"--lora_base_model", | |
type=str, | |
default=None, | |
help="Which model LoRA is added to.", | |
) | |
parser.add_argument( | |
"--lora_target_modules", | |
type=str, | |
default="q,k,v,o,ffn.0,ffn.2", | |
help="Which layers LoRA is added to.", | |
) | |
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") | |
parser.add_argument( | |
"--extra_inputs", default=None, help="Additional model inputs, comma-separated." | |
) | |
parser.add_argument( | |
"--use_gradient_checkpointing_offload", | |
default=False, | |
action="store_true", | |
help="Whether to offload gradient checkpointing to CPU memory.", | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Gradient accumulation steps.", | |
) | |
parser.add_argument( | |
"--max_timestep_boundary", | |
type=float, | |
default=1.0, | |
help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).", | |
) | |
parser.add_argument( | |
"--min_timestep_boundary", | |
type=float, | |
default=0.0, | |
help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).", | |
) | |
return parser | |
def flux_parser(): | |
parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
parser.add_argument( | |
"--dataset_base_path", | |
type=str, | |
default="", | |
required=True, | |
help="Base path of the dataset.", | |
) | |
parser.add_argument( | |
"--dataset_metadata_path", | |
type=str, | |
default=None, | |
help="Path to the metadata file of the dataset.", | |
) | |
parser.add_argument( | |
"--max_pixels", | |
type=int, | |
default=1024 * 1024, | |
help="Maximum number of pixels per frame, used for dynamic resolution..", | |
) | |
parser.add_argument( | |
"--height", | |
type=int, | |
default=None, | |
help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.", | |
) | |
parser.add_argument( | |
"--width", | |
type=int, | |
default=None, | |
help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.", | |
) | |
parser.add_argument( | |
"--data_file_keys", | |
type=str, | |
default="image", | |
help="Data file keys in the metadata. Comma-separated.", | |
) | |
parser.add_argument( | |
"--dataset_repeat", | |
type=int, | |
default=1, | |
help="Number of times to repeat the dataset per epoch.", | |
) | |
parser.add_argument( | |
"--model_paths", | |
type=str, | |
default=None, | |
help="Paths to load models. In JSON format.", | |
) | |
parser.add_argument( | |
"--model_id_with_origin_paths", | |
type=str, | |
default=None, | |
help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.", | |
) | |
parser.add_argument( | |
"--learning_rate", type=float, default=1e-4, help="Learning rate." | |
) | |
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") | |
parser.add_argument( | |
"--output_path", type=str, default="./models", help="Output save path." | |
) | |
parser.add_argument( | |
"--remove_prefix_in_ckpt", | |
type=str, | |
default="pipe.dit.", | |
help="Remove prefix in ckpt.", | |
) | |
parser.add_argument( | |
"--trainable_models", | |
type=str, | |
default=None, | |
help="Models to train, e.g., dit, vae, text_encoder.", | |
) | |
parser.add_argument( | |
"--lora_base_model", | |
type=str, | |
default=None, | |
help="Which model LoRA is added to.", | |
) | |
parser.add_argument( | |
"--lora_target_modules", | |
type=str, | |
default="q,k,v,o,ffn.0,ffn.2", | |
help="Which layers LoRA is added to.", | |
) | |
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") | |
parser.add_argument( | |
"--extra_inputs", default=None, help="Additional model inputs, comma-separated." | |
) | |
parser.add_argument( | |
"--align_to_opensource_format", | |
default=False, | |
action="store_true", | |
help="Whether to align the lora format to opensource format. Only for DiT's LoRA.", | |
) | |
parser.add_argument( | |
"--use_gradient_checkpointing", | |
default=False, | |
action="store_true", | |
help="Whether to use gradient checkpointing.", | |
) | |
parser.add_argument( | |
"--use_gradient_checkpointing_offload", | |
default=False, | |
action="store_true", | |
help="Whether to offload gradient checkpointing to CPU memory.", | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Gradient accumulation steps.", | |
) | |
return parser | |