LoveHandles's picture
Upload folder using huggingface_hub
d4a8a38 verified
# Copyright 2024 The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
from glob import glob
import math
import os
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from typing import Any, Dict, Tuple, List
import torch
import wandb
from pipeline_mochi_rgba import *
from diffusers import FlowMatchEulerDiscreteScheduler, MochiTransformer3DModel
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.training_utils import cast_training_params
from diffusers.utils import export_to_video
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from huggingface_hub import create_repo, upload_folder
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from args import get_args # isort:skip
from dataset_simple import LatentEmbedDataset
from utils import print_memory, reset_memory # isort:skip
from rgba_utils import *
# Taken from
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139
def get_cosine_annealing_lr_scheduler(
optimizer: torch.optim.Optimizer,
warmup_steps: int,
total_steps: int,
):
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
else:
return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def save_model_card(
repo_id: str,
videos=None,
base_model: str = None,
validation_prompt=None,
repo_folder=None,
fps=30,
):
widget_dict = []
if videos is not None and len(videos) > 0:
for i, video in enumerate(videos):
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
widget_dict.append(
{
"text": validation_prompt if validation_prompt else " ",
"output": {"url": f"final_video_{i}.mp4"},
}
)
model_description = f"""
# Mochi-1 Preview LoRA Finetune
<Gallery />
## Model description
This is a lora finetune of the Mochi-1 preview model `{base_model}`.
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
## Download model
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
## Usage
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
```py
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
import torch
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
pipe.load_lora_weights("CHANGE_ME")
pipe.enable_model_cpu_offload()
with torch.autocast("cuda", torch.bfloat16):
video = pipe(
prompt="CHANGE_ME",
guidance_scale=6.0,
num_inference_steps=64,
height=480,
width=848,
max_sequence_length=256,
output_type="np"
).frames[0]
export_to_video(video)
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="apache-2.0",
base_model=base_model,
prompt=validation_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-video",
"diffusers-training",
"diffusers",
"lora",
"mochi-1-preview",
"mochi-1-preview-diffusers",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipe: MochiPipeline,
args: Dict[str, Any],
pipeline_args: Dict[str, Any],
step: int,
wandb_run: str = None,
is_final_validation: bool = False,
):
print(
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
)
phase_name = "test" if is_final_validation else "validation"
if not args.enable_model_cpu_offload:
pipe = pipe.to("cuda")
# run inference
generator = torch.manual_seed(args.seed) if args.seed else None
videos = []
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
video_filenames = []
for i, video in enumerate(videos):
prompt = (
pipeline_args["prompt"][:25]
.replace(" ", "_")
.replace(" ", "_")
.replace("'", "_")
.replace('"', "_")
.replace("/", "_")
)
filename = os.path.join(args.output_dir, f"{phase_name}_{str(step)}_video_{i}_{prompt}.mp4")
export_to_video(video, filename, fps=30)
video_filenames.append(filename)
if wandb_run:
wandb.log(
{
phase_name: [
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30)
for i, filename in enumerate(video_filenames)
]
}
)
return videos
# Adapted from the original code:
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578
def cast_dit(model, dtype):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
assert any(
n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"]
), f"Unexpected linear layer: {name}"
module.to(dtype=dtype)
elif isinstance(module, torch.nn.Conv2d):
module.to(dtype=dtype)
return model
def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path):
# lora_state_dict = get_peft_model_state_dict(model)
processor_state_dict = get_processor_state_dict(model)
torch.save(
{
"state_dict": processor_state_dict,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"global_step": global_step,
},
checkpoint_path,
)
class CollateFunction:
def __init__(self, caption_dropout: float = None) -> None:
self.caption_dropout = caption_dropout
def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]:
ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0)
z = DiagonalGaussianDistribution(ldists).sample()
assert torch.isfinite(z).all()
# Sample noise which we will add to the samples.
eps = torch.randn_like(z)
sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0)
prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0)
if self.caption_dropout and random.random() < self.caption_dropout:
prompt_embeds.zero_()
prompt_attention_mask = prompt_attention_mask.long()
prompt_attention_mask.zero_()
prompt_attention_mask = prompt_attention_mask.bool()
return dict(
z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask
)
def main(args):
if not torch.cuda.is_available():
raise ValueError("Not supported without CUDA.")
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
# Handle the repository creation
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Prepare models and scheduler
transformer = MochiTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
revision=args.revision,
variant=args.variant,
)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
transformer.requires_grad_(False)
transformer.to("cuda")
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.cast_dit:
transformer = cast_dit(transformer, torch.bfloat16)
if args.compile_dit:
transformer.compile()
prepare_for_rgba_inference(
model=transformer,
device=torch.device("cuda"),
dtype=torch.bfloat16,
# seq_length=seq_length,
)
processor_params = get_all_processor_params(transformer)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = args.learning_rate * args.train_batch_size
# only upcast trainable parameters (LoRA) into fp32
if not isinstance(processor_params, list):
processor_params = [processor_params]
for m in processor_params:
for param in m:
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# Prepare optimizer
transformer_lora_parameters = processor_params # list(filter(lambda p: p.requires_grad, transformer.parameters()))
num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters)
optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
# Dataset and DataLoader
train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
print(f"Found {len(train_vids)} training videos in {args.data_root}")
assert len(train_vids) > 0, f"No training data found in {args.data_root}"
collate_fn = CollateFunction(caption_dropout=args.caption_dropout)
train_dataset = LatentEmbedDataset(train_vids, repeat=1)
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
pin_memory=args.pin_memory,
)
# LR scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = len(train_dataloader)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_cosine_annealing_lr_scheduler(
optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = len(train_dataloader)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
wandb_run = None
if args.report_to == "wandb":
tracker_name = args.tracker_name or "mochi-1-rgba-lora"
wandb_run = wandb.init(project=tracker_name, config=vars(args))
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
if "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
if "lr_scheduler" in checkpoint:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
# set_peft_model_state_dict(transformer, checkpoint["state_dict"]) # Luozhou: modify this line
processor_state_dict = checkpoint["state_dict"]
load_processor_state_dict(transformer, processor_state_dict)
print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
print(f"Resuming from global step: {global_step}")
else:
global_step = 0
print("===== Memory before training =====")
reset_memory("cuda")
print_memory("cuda")
# Train!
total_batch_size = args.train_batch_size
print("***** Running training *****")
print(f" Num trainable parameters = {num_trainable_parameters}")
print(f" Num examples = {len(train_dataset)}")
print(f" Num batches each epoch = {len(train_dataloader)}")
print(f" Num epochs = {args.num_train_epochs}")
print(f" Instantaneous batch size per device = {args.train_batch_size}")
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f" Total optimization steps = {args.max_train_steps}")
first_epoch = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
)
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
for step, batch in enumerate(train_dataloader):
with torch.no_grad():
z = batch["z"].to("cuda")
eps = batch["eps"].to("cuda")
sigma = batch["sigma"].to("cuda")
prompt_embeds = batch["prompt_embeds"].to("cuda")
prompt_attention_mask = batch["prompt_attention_mask"].to("cuda")
all_attention_mask = prepare_attention_mask(
prompt_attention_mask=prompt_attention_mask,
latents=z
)
sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
ut = z - eps
# (1 - sigma) because of
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
# Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation.
timesteps = (1 - sigma) * scheduler.config.num_train_timesteps
with torch.autocast("cuda", torch.bfloat16):
model_pred = transformer(
hidden_states=z_sigma,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=all_attention_mask,
timestep=timesteps,
return_dict=False,
)[0]
assert model_pred.shape == z.shape
loss = F.mse_loss(model_pred.float(), ut.float())
loss.backward()
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
progress_bar.update(1)
global_step += 1
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
logs = {"loss": loss.detach().item(), "lr": last_lr}
progress_bar.set_postfix(**logs)
if wandb_run:
wandb_run.log(logs, step=global_step)
if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0:
print(f"Saving checkpoint at step {global_step}")
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
save_checkpoint(
transformer,
optimizer,
lr_scheduler,
global_step,
checkpoint_path,
)
# if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
print("===== Memory before validation =====")
print_memory("cuda")
transformer.eval()
pipe = MochiPipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=transformer,
scheduler=scheduler,
revision=args.revision,
variant=args.variant,
)
if args.enable_slicing:
pipe.vae.enable_slicing()
if args.enable_tiling:
pipe.vae.enable_tiling()
if args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload()
# validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
validation_prompts = [
"A boy in a white shirt and shorts is seen bouncing a ball, isolated background",
]
for validation_prompt in validation_prompts:
pipeline_args = {
"prompt": validation_prompt,
"guidance_scale": 6.0,
"num_frames": 37,
"num_inference_steps": 64,
"height": args.height,
"width": args.width,
"max_sequence_length": 256,
}
log_validation(
pipe=pipe,
args=args,
pipeline_args=pipeline_args,
step=global_step,
wandb_run=wandb_run,
)
print("===== Memory after validation =====")
print_memory("cuda")
reset_memory("cuda")
del pipe.text_encoder
del pipe.vae
del pipe
gc.collect()
torch.cuda.empty_cache()
transformer.train()
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
transformer.eval()
# saving lora weights
# transformer_lora_layers = get_peft_model_state_dict(transformer)
# MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers)
# Cleanup trained models to save memory
del transformer
gc.collect()
torch.cuda.empty_cache()
# Final test inference
# validation_outputs = []
# if args.validation_prompt and args.num_validation_videos > 0:
# print("===== Memory before testing =====")
# print_memory("cuda")
# reset_memory("cuda")
# pipe = MochiPipeline.from_pretrained(
# args.pretrained_model_name_or_path,
# revision=args.revision,
# variant=args.variant,
# )
# if args.enable_slicing:
# pipe.vae.enable_slicing()
# if args.enable_tiling:
# pipe.vae.enable_tiling()
# if args.enable_model_cpu_offload:
# pipe.enable_model_cpu_offload()
# # Load LoRA weights
# # lora_scaling = args.lora_alpha / args.rank
# # pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora")
# # pipe.set_adapters(["mochi-lora"], [lora_scaling])
# # Run inference
# validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
# for validation_prompt in validation_prompts:
# pipeline_args = {
# "prompt": validation_prompt,
# "guidance_scale": 6.0,
# "num_inference_steps": 64,
# "height": args.height,
# "width": args.width,
# "max_sequence_length": 256,
# }
# video = log_validation(
# pipe=pipe,
# args=args,
# pipeline_args=pipeline_args,
# epoch=epoch,
# wandb_run=wandb_run,
# is_final_validation=True,
# )
# validation_outputs.extend(video)
# print("===== Memory after testing =====")
# print_memory("cuda")
# reset_memory("cuda")
# torch.cuda.synchronize("cuda")
if __name__ == "__main__":
args = get_args()
main(args)