KeyframesAI / main.py
acmyu's picture
get metrics for user edits
c54f540
import logging
import math
import os
from typing import Any, Dict, List, Optional, Tuple, Union
#from diffusers.models.controlnet import ControlNetConditioningEmbedding
from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from tqdm.auto import tqdm
from src.configs.stage2_config import args
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from src.dataset.stage2_dataset import InpaintDataset, InpaintCollate_fn
from transformers import CLIPVisionModelWithProjection
from transformers import Dinov2Model
from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
import glob
import os
import torch
from torch import nn
from PIL import Image, ImageOps
import numpy as np
from diffusers import UniPCMultistepScheduler
from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
from torchvision import transforms
#from diffusers.models.controlnet import ControlNetConditioningEmbedding
from transformers import CLIPImageProcessor
from transformers import Dinov2Model
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel,ControlNetModel,DDIMScheduler
from src.pipelines.PCDMs_pipeline import PCDMsPipeline
#from single_extract_pose import inference_pose
import spaces
from libs.easy_dwpose import DWposeDetector
from libs.easy_dwpose.draw import draw_openpose
from libs.film import Predictor
from PIL import Image
import cv2
import os
import gradio as gr
import rembg
import uuid
import gc
from numba import cuda
import requests
import json
from huggingface_hub import hf_hub_download, HfApi
from numba import cuda
from multiprocessing import Pool, Process, Queue
import torch.multiprocessing as mp
# Inputs ===================================================================================================
input_img = "sm.png"
train_imgs = ["target.png"]
in_vid = "walk.mp4"
out_vid = 'out.mp4'
"""
train_steps = 100
inference_steps = 10
fps = 12
"""
debug = False
save_model = True
should_gen_vid = False
max_batch_size = 8
max_frame_count = 200
no_bg_final = True
def save_temp_imgs(imgs):
os.makedirs('temp', exist_ok=True)
results = []
api = HfApi()
for i, img in enumerate(imgs):
#img_name = 'temp/'+str(uuid.uuid4())+'.png'
img_name = 'temp/'+str(i)+'.png'
img.save(img_name)
"""
url = 'https://tmpfiles.org/api/v1/upload'
try:
response = requests.post(url, files={'file': open(img_name, 'rb')})
# Check for successful response (status code 200)
response.raise_for_status()
# Print the server's response
print("Status Code:", response.status_code)
data = response.json()
print("Response JSON:", data)
results.append(data['data']['url'])
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
"""
results.append('https://huggingface.co/datasets/acmyu/KeyframesAIFiles/resolve/main/'+img_name)
api.upload_file(
path_or_fileobj='temp',
path_in_repo='temp',
repo_id="acmyu/KeyframesAIFiles",
repo_type="dataset",
)
return results
def getThumbnails(imgs):
thumbs = []
thumb_size = (512, 512)
for img in imgs:
th = img.copy()
th.thumbnail(thumb_size)
thumbs.append(th)
return thumbs
# Pose detection ==============================================================================================
def load_models():
dwpose = DWposeDetector(device="cuda")
rembg_session = rembg.new_session("u2netp")
pcdms_model = hf_hub_download(repo_id="acmyu/PCDMs", filename="pcdms_ckpt.pt")
# Load scheduler
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
# Load model
image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant')
image_encoder_g = CLIPVisionModelWithProjection.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')#("openai/clip-vit-base-patch32")
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae")
unet = Stage2_InapintUNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
torch_dtype=torch.float16,
subfolder="unet",
in_channels=9,
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True)
return dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet
#load_models()
def img_pad(img, tw, th, transparent=False):
#print('pad', tw, th)
img.thumbnail((tw, th))
if transparent:
new_img = Image.new('RGBA', (tw, th), (0, 0, 0, 0))
else:
new_img = Image.new("RGB", (tw, th), (0, 0, 0))
left = (tw - img.width) // 2
top = (th - img.height) // 2
#print(left, top)
new_img.paste(img, (left, top))
return new_img
def resize_pad(img, tw, th, transparent):
w, h = img.size
orig_tw = tw
orig_th = th
if tw/th > w/h:
tw = int(th * w/h)
elif tw/th < w/h:
th = int(tw * h/w)
img = img.resize((tw, th), Image.BICUBIC)
return img_pad(img, orig_tw, orig_th, True)
def resize_and_pad(img, target_img):
tw, th = target_img.size
return resize_pad(img, tw, th, False)
def remove_zero_pad(image):
image = np.array(image)
dummy = np.argwhere(image != 0) # assume blackground is zero
max_y = dummy[:, 0].max()
min_y = dummy[:, 0].min()
min_x = dummy[:, 1].min()
max_x = dummy[:, 1].max()
crop_image = image[min_y:max_y, min_x:max_x]
return Image.fromarray(crop_image)
def get_pose(img, dwpose, outfile, crop=False):
#pil_image = Image.open("imgs/"+img).convert("RGB")
#skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
img.thumbnail((512,512))
out_img, pose = dwpose(img, include_hands=True, include_face=False)
#print(pose['bodies'])
if crop:
bbox = out_img.getbbox()
out_img = out_img.crop(bbox)
out_img = ImageOps.expand(out_img, border=int(out_img.width*0.2), fill=(0,0,0))
return out_img, pose
def extract_frames(video_path, fps):
video_capture = cv2.VideoCapture(video_path)
frame_count = 0
frames = []
fps_in = video_capture.get(cv2.CAP_PROP_FPS)
fps_out = fps
index_in = -1
index_out = -1
while True:
success = video_capture.grab()
if not success: break
index_in += 1
if frame_count > max_frame_count:
break
out_due = int(index_in / fps_in * fps_out)
if out_due > index_out:
success, frame = video_capture.retrieve()
if not success:
break
index_out += 1
frame_count += 1
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
video_capture.release()
print(f"Extracted {frame_count} frames")
return frames
def removebg(img, rembg_session, transparent=False):
if transparent:
result = Image.new('RGBA', img.size, (0, 0, 0, 0))
else:
result = Image.new("RGB", img.size, "#ffffff")
out = rembg.remove(img, session=rembg_session)
result.paste(out, mask=out)
return result
def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
print("remove background", bg_remove)
if bg_remove:
images = [removebg(img, rembg_session) for img in images]
in_img = images[0]
in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
train_poses = []
train_imgs = [resize_and_pad(img, in_img) for img in images[1:]]
for i, img in enumerate(train_imgs):
train_pose, _ = get_pose(img, dwpose, "tr_pose"+str(i)+".png")
train_poses.append(train_pose)
return in_img, in_pose, train_imgs, train_poses
def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app=False, target_poses=None):
progress=gr.Progress(track_tqdm=True)
print("prepare_inputs_inference")
in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
print(in_vid)
print(frames)
if in_vid:
frames = extract_frames(in_vid, fps)
for f in frames:
f.thumbnail((512,512))
print("remove background", bg_remove)
if bg_remove:
in_img = removebg(in_img, rembg_session)
#frames = [removebg(img, rembg_session) for img in frames]
if debug:
for i, frame in enumerate(frames):
frame.save("out/frame_"+str(i)+".png")
print("vid: ", in_vid, fps)
progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
if not target_poses:
target_poses = []
target_poses_coords = []
max_left = max_top = 999999
max_right = max_bottom = 0
it = frames
if is_app:
it = progress.tqdm(frames, desc="Pose Detection")
for f in it:
tpose, tpose_coords = get_pose(f, dwpose, "tar_pose"+str(len(target_poses))+".png")
#print(tpose_coords)
coords = {}
for k in tpose_coords:
if k == 'bodies_multi':
coords['bodies'] = tpose_coords[k].tolist()
elif k in ['hands']:
coords[k] = tpose_coords[k].tolist()
elif k in ['num_candidates']:
coords[k] = tpose_coords[k]
#print(coords)
target_poses.append(tpose)
target_poses_coords.append(json.dumps(coords))
progress_bar.update(1)
target_poses_cropped = []
for tpose in target_poses:
if resize_inputs:
bbox = tpose.getbbox()
left, top, right, bottom = bbox
max_left = min(max_left, left)
max_top = min(max_top, top)
max_right = max(max_right, right)
max_bottom = max(max_bottom, bottom)
tpose = tpose.crop((max_left, max_top, max_right, max_bottom))
tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0))
tpose = resize_and_pad(tpose, in_img)
if debug:
tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
target_poses_cropped.append(tpose)
#target_poses_cropped[0].save("pose.png")
return in_img, target_poses_cropped, in_pose, target_poses_coords, frames
def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize_inputs, is_app=False):
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
in_img, target_poses_cropped, _, _, _ = prepare_inputs_inference(in_img, in_vid, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
# Training ===================================================================================================
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)
class ImageProjModel_p(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class ImageProjModel_g(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
nn.Dropout(dropout)
)
def forward(self, x): # b, 257,1280
return self.net(x)
class SDModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, unet) -> None:
super().__init__()
self.image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024)
self.unet = unet
self.pose_proj = ControlNetConditioningEmbedding(
conditioning_embedding_channels=320,
block_out_channels=(16, 32, 96, 256),
conditioning_channels=3)
def forward(self, noisy_latents, timesteps, simg_f_p, timg_f_g, pose_f):
extra_image_embeddings_p = self.image_proj_model_p(simg_f_p)
extra_image_embeddings_g = timg_f_g
print(extra_image_embeddings_p.size())
print(extra_image_embeddings_g.size())
encoder_image_hidden_states = torch.cat([extra_image_embeddings_p ,extra_image_embeddings_g], dim=1)
pose_cond = self.pose_proj(pose_f)
pred_noise = self.unet(noisy_latents, timesteps, class_labels=timg_f_g, encoder_hidden_states=encoder_image_hidden_states,my_pose_cond=pose_cond).sample
return pred_noise
def load_training_checkpoint(model, pcdms_model, tag=None, **kwargs):
#model_sd = torch.load(load_dir, map_location="cpu")["module"]
model_sd = torch.load(
pcdms_model,
map_location="cpu"
)["module"]
image_proj_model_dict = {}
pose_proj_dict = {}
unet_dict = {}
for k in model_sd.keys():
if k.startswith("pose_proj"):
pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
elif k.startswith("image_proj_model_p"):
image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
elif k.startswith("image_proj_model."):
image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
elif k.startswith("unet"):
unet_dict[k.replace("unet.", "")] = model_sd[k]
else:
print(k)
model.pose_proj.load_state_dict(pose_proj_dict)
model.image_proj_model_p.load_state_dict(image_proj_model_dict)
model.unet.load_state_dict(unet_dict)
return model, 0, 0
def checkpoint_model(checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs):
"""Utility function for checkpointing model + optimizer dictionaries
The main purpose for this is to be able to resume training from that instant again
"""
checkpoint_state_dict = {
"epoch": epoch,
"last_global_step": last_global_step,
}
# Add extra kwargs too
checkpoint_state_dict.update(kwargs)
success = model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict)
status_msg = f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}"
if success:
logging.info(f"Success {status_msg}")
else:
logging.warning(f"Failure {status_msg}")
return
@spaces.GPU(duration=600)
def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune=True, is_app=False):
logging_dir = 'outputs/logging'
print('start train')
progress=gr.Progress(track_tqdm=True)
accelerator = Accelerator(
log_with=args.report_to,
project_dir=logging_dir,
mixed_precision=args.mixed_precision,
gradient_accumulation_steps=args.gradient_accumulation_steps
)
# Make one log on every process with the configuration for debugging.
#logging.basicConfig(
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
# datefmt="%m/%d/%Y %H:%M:%S",
# level=logging.INFO, )
print(accelerator.state)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
set_seed(42)
# Handle the repository creation
if accelerator.is_main_process:
os.makedirs('outputs', exist_ok=True)
"""
unet = Stage2_InapintUNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet",
in_channels=9, class_embed_type="projection" ,projection_class_embeddings_input_dim=1024,
low_cpu_mem_usage=False, ignore_mismatched_sizes=True)
"""
image_encoder_p.requires_grad_(False)
image_encoder_g.requires_grad_(False)
vae.requires_grad_(False)
sd_model = SDModel(unet=unet)
sd_model.train()
if args.gradient_checkpointing:
sd_model.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
learning_rate = 1e-4
train_batch_size = min(len(train_images), max_batch_size) #len(train_images) % 16
# Optimizer creation
params_to_optimize = sd_model.parameters()
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
inputs = [{
"source_image": in_image,
"source_pose": in_pose,
"target_image": timg,
"target_pose": tpose,
} for timg, tpose in zip(train_images, train_poses)]
"""
inputs = {[
"source_image": Image.open('imgs/sm.png'),
"source_pose": Image.open('imgs/sm_pose.jpg'),
"target_image": Image.open('imgs/target.png'),
"target_pose": Image.open('imgs/target_pose.jpg'),
]}
"""
#print(inputs)
dataset = InpaintDataset(
inputs,
'imgs/',
size=(args.img_width, args.img_height), # w h
imgp_drop_rate=0.1,
imgg_drop_rate=0.1,
)
"""
dataset = InpaintDataset(
args.json_path,
args.image_root_path,
size=(args.img_width, args.img_height), # w h
imgp_drop_rate=0.1,
imgg_drop_rate=0.1,
)
"""
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True)
train_dataloader = torch.utils.data.DataLoader(
dataset,
sampler=train_sampler,
collate_fn=InpaintCollate_fn,
batch_size=train_batch_size,
num_workers=0,)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
args.max_train_steps = train_steps
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
sd_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(sd_model, optimizer, train_dataloader, lr_scheduler)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
"""
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
"""
# Move vae, unet and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
sd_model.unet.to(accelerator.device, dtype=weight_dtype)
image_encoder_p.to(accelerator.device, dtype=weight_dtype)
image_encoder_g.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
args.num_train_epochs = train_steps
# Train!
total_batch_size = (
train_batch_size
* accelerator.num_processes
* args.gradient_accumulation_steps
)
print("***** Running training *****")
print(f" Num batches each epoch = {len(train_dataloader)}")
print(f" Num Epochs = {args.num_train_epochs}")
print(f" Instantaneous batch size per device = {train_batch_size}")
print(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
print(f" Total optimization steps = {args.max_train_steps}")
if args.resume_from_checkpoint:
# New Code #
# Loads the DeepSpeed checkpoint from the specified path
prior_model, last_epoch, last_global_step = load_training_checkpoint(
sd_model,
pcdms_model,
**{"load_optimizer_states": True, "load_lr_scheduler_states": True},
)
print(f"Resumed from checkpoint: {args.resume_from_checkpoint}, global step: {last_global_step}")
starting_epoch = last_epoch
global_steps = last_global_step
sd_model = sd_model
else:
global_steps = 0
starting_epoch = 0
sd_model = sd_model
progress_bar = tqdm(range(global_steps, args.max_train_steps), initial=global_steps, desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process, )
bsz = train_batch_size
if not finetune or train_steps == 0:
accelerator.wait_for_everyone()
accelerator.end_training()
checkpoint_state_dict = {
"epoch": 0,
"module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
}
torch.save(checkpoint_state_dict, modelId+".pt")
del sd_model
gc.collect()
torch.cuda.empty_cache()
return
#return {k: v.cpu() for k, v in sd_model.state_dict().items()}
it = range(starting_epoch, args.num_train_epochs)
if is_app:
it = progress.tqdm(it, desc="Fine-tuning")
for epoch in it:
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(sd_model):
with torch.no_grad():
# Convert images to latent space
latents = vae.encode(batch["source_target_image"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Get the masked image latents
masked_latents = vae.encode(batch["vae_source_mask_image"].to(dtype=weight_dtype)).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
bsz = batch["target_image"].size(dim=0)
# mask
mask1 = torch.ones((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
mask0 = torch.zeros((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
mask = torch.cat([mask1, mask0], dim=3)
# Get the image embedding for conditioning
cond_image_feature_p = image_encoder_p(batch["source_image"].to(accelerator.device, dtype=weight_dtype))
cond_image_feature_p = (cond_image_feature_p.last_hidden_state)
cond_image_feature_g = image_encoder_g(batch["target_image"].to(accelerator.device, dtype=weight_dtype), ).image_embeds
cond_image_feature_g =cond_image_feature_g.unsqueeze(1)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)
# Sample a random timestep for each image
#timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (train_batch_size,),device=latents.device, )
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,),device=latents.device, )
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
#print(noisy_latents.size(), mask.size(), masked_latents.size())
noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)
# Get the text embedding for conditioning
cond_pose = batch["source_target_pose"].to(dtype=weight_dtype)
#print(noisy_latents.size())
#print(cond_image_feature_p.size())
#print(cond_image_feature_g.size())
#print(cond_pose.size())
# Predict the noise residual
model_pred = sd_model(noisy_latents, timesteps, cond_image_feature_p,cond_image_feature_g, cond_pose, )
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = sd_model.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
global_steps += 1
if global_steps >= args.max_train_steps:
break
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
print(logs)
progress_bar.set_postfix(**logs)
progress_bar.update(1)
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
accelerator.end_training()
sd_model.unet.cpu()
sd_model.cpu()
del vae
del image_encoder_p
del image_encoder_g
if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
print('saving', modelId)
checkpoint_state_dict = {
"epoch": 0,
"module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
}
print(list(sd_model.state_dict().keys())[:20])
torch.save(checkpoint_state_dict, modelId+".pt")
del sd_model
gc.collect()
torch.cuda.empty_cache()
print('done train')
print(torch.cuda.memory_allocated()/1024**2)
return
del sd_model
gc.collect()
torch.cuda.empty_cache()
return {k: v.cpu() for k, v in sd_model.state_dict().items()}
# Pose-transfer ===================================================================================================
device = "cuda"
class ImageProjModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
print(w, h)
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def load_mydict(modelId, finetuned_model):
if save_model:
model_ckpt_path = modelId+'.pt'
model_sd = torch.load(model_ckpt_path, map_location="cpu")["module"]
else:
model_sd = finetuned_model #torch.load(model_ckpt_path, map_location="cpu")["module"]
image_proj_model_dict = {}
pose_proj_dict = {}
unet_dict = {}
for k in model_sd.keys():
if k.startswith("pose_proj"):
pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
elif k.startswith("image_proj_model_p"):
image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
elif k.startswith("image_proj_model"):
image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
elif k.startswith("unet"):
unet_dict[k.replace("unet.", "")] = model_sd[k]
else:
print(k)
return image_proj_model_dict, pose_proj_dict, unet_dict
@spaces.GPU(duration=600)
def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder, is_app=False):
print('start inference')
progress=gr.Progress(track_tqdm=True)
if not save_model:
finetuned_model = {k: v.cuda() for k, v in finetuned_model.items()}
device = "cuda"
pretrained_model_name_or_path ="stabilityai/stable-diffusion-2-1-base"
image_encoder_path = "facebook/dinov2-giant"
#model_ckpt_path = "./pcdms_ckpt.pt" # ckpt path
model_ckpt_path = modelId+'.pt'
clip_image_processor = CLIPImageProcessor()
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
generator = torch.Generator(device=device).manual_seed(42)
"""
unet = Stage2_InapintUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae").to(device, dtype=torch.float16)
image_encoder = Dinov2Model.from_pretrained(image_encoder_path).to(device, dtype=torch.float16)
"""
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
unet = unet.to(device, dtype=torch.float16)
vae = vae.to(device, dtype=torch.float16)
image_encoder = image_encoder.to(device, dtype=torch.float16)
image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16)
pose_proj_model = ControlNetConditioningEmbedding(
conditioning_embedding_channels=320,
block_out_channels=(16, 32, 96, 256),
conditioning_channels=3).to(device).to(dtype=torch.float16)
# load weight
print('loading', modelId)
image_proj_model_dict, pose_proj_dict, unet_dict = load_mydict(modelId, finetuned_model)
print('loaded', modelId)
image_proj_model.load_state_dict(image_proj_model_dict)
pose_proj_model.load_state_dict(pose_proj_dict)
unet.load_state_dict(unet_dict)
pipe = PCDMsPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", unet=unet, torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device)
print('====================== model load finish ===================')
results = []
progress_bar = tqdm(range(len(target_poses)), initial=0, desc="Frames")
it = target_poses
if is_app:
it = progress.tqdm(it, desc="Pose Transfer")
for pose in it:
num_samples = 1
image_size = (512, 512)
s_img_path = 'imgs/'+input_img # input image 1
#target_pose_img = 'imgs/pose_'+str(n)+'.png' # input image 2
#t_pose = inference_pose(target_pose_img, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
#t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
t_pose = pose.convert("RGB").resize((image_size), Image.BICUBIC)
#t_pose = resize_and_pad(pose.convert("RGB"))
#s_img = Image.open(s_img_path)
width_orig, height_orig = in_image.size
s_img = in_image.convert("RGB").resize(image_size, Image.BICUBIC)
#s_img = resize_and_pad(in_image.convert("RGB"))
black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC)
s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
s_img_t_mask.paste(s_img, (0, 0))
s_img_t_mask.paste(black_image, (s_img.width, 0))
#s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
#s_pose = Image.open('imgs/sm_pose.jpg').convert("RGB").resize(image_size, Image.BICUBIC)
s_pose = in_pose.convert("RGB").resize(image_size, Image.BICUBIC)
#s_pose = resize_and_pad(in_pose.convert("RGB"))
print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height))
#t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
st_pose.paste(s_pose, (0, 0))
st_pose.paste(t_pose, (s_pose.width, 0))
clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
mask = torch.cat([mask1, mask0], dim=3)
with torch.inference_mode():
cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device))
simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample()
simg_mask_latents = simg_mask_latents * 0.18215
images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state
image_prompt_embeds = image_proj_model(images_embeds)
uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds))
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
output, _ = pipe(
simg_mask_latents= simg_mask_latents,
mask = mask,
cond_pose = cond_pose,
prompt_embeds=image_prompt_embeds,
negative_prompt_embeds=uncond_image_prompt_embeds,
height=image_size[1],
width=image_size[0]*2,
num_images_per_prompt=num_samples,
guidance_scale=2.0,
generator=generator,
num_inference_steps=inference_steps,
)
output = output.images[-1]
result = output.crop((image_size[0], 0, image_size[0] * 2, image_size[1]))
result = result.resize((width_orig, height_orig), Image.BICUBIC)
#result = remove_zero_pad(result)
if debug:
result.save('out/'+str(len(results))+'.png')
results.append(result)
progress_bar.update(1)
del unet
del vae
del image_encoder
del image_proj_model
del pose_proj_model
if not save_model:
del finetuned_model
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated()/1024**2)
return results
def gen_vid(frames, video_name, fps, codec):
progress=gr.Progress(track_tqdm=True)
frame = cv2.cvtColor(np.array(frames[0]), cv2.COLOR_RGB2BGR)
height, width, layers = frame.shape
#video = cv2.VideoWriter(video_name, 0, 1, (width,height))
if codec == 'mp4':
video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
else:
video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'VP90'), fps, (width, height))
for r in progress.tqdm(frames, desc="Creating video"):
image = cv2.cvtColor(np.array(r), cv2.COLOR_RGB2BGR)
video.write(image)
#cv2.destroyAllWindows()
#video.release()
def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, finetune=True, is_app=False):
print("==== Load Models ====")
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
print("==== Pose Detection ====")
in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize_inputs, is_app=is_app)
if save_model:
train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
print('next')
results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
else:
print("==== Finetuning ====")
finetuned_model = train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
print("==== Pose Transfer ====")
results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder_p, is_app)
return results
def run_train_impl(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True, finetune=True):
finetune=True
is_app=True
images = [img[0] for img in images]
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
if resize_inputs:
resize = 'target'
else:
resize = 'none'
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
gc.collect()
torch.cuda.empty_cache()
def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True):
run_train_impl(images, train_steps, modelId, bg_remove, resize_inputs)
"""
mp.set_start_method('spawn', force=True)
p = mp.Process(target=run_train_impl, args=(images, train_steps, modelId, bg_remove, resize_inputs))
p.start()
p.join()
"""
def run_inference_impl(images, video_path, frames, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
finetune=True
is_app=True
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
if not os.path.exists(modelId+".pt"):
run_train(images, train_steps, modelId, bg_remove, resize_inputs)
images = [img[0] for img in images]
in_img = images[0]
if frames:
frames = [img[0] for img in frames]
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
#target_poses[0].save('inf_pose.png')
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
#urls = save_temp_imgs(results)
if should_gen_vid:
if debug:
gen_vid(results, out_vid+'.mp4', fps, 'mp4')
else:
gen_vid(results, out_vid+'.webm', fps, 'webm')
# postprocessing
if no_bg_final:
results = [removebg(img, rembg_session, True) for img in results]
#results = [img_pad(img, img_width, img_height, True) for img in results]
print("Done!")
gc.collect()
torch.cuda.empty_cache()
return out_vid+'.webm', results, getThumbnails(results), target_poses_coords, orig_frames
def run_inference(images, video_path, frames, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
return run_inference_impl(images, video_path, frames, train_steps, inference_steps, fps, modelId, img_width, img_height, bg_remove, resize_inputs)
def generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
finetune=True
is_app=True
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
if not os.path.exists(modelId+".pt"):
run_train(images, train_steps, modelId, bg_remove, resize_inputs)
images = [img[0] for img in images]
in_img = images[0]
in_pose, _ = get_pose(in_img, dwpose, "in_pose.png")
print(target_poses)
target_poses = json.loads(target_poses)
target_poses = [Image.fromarray(draw_openpose(pose, height=img_height, width=img_width, include_hands=True, include_face=False)) for pose in target_poses]
in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, None, [], 12, dwpose, rembg_session, bg_remove, resize_inputs, is_app, target_poses)
#target_poses[0].save('gen_pose.png')
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
#urls = save_temp_imgs(results)
# postprocessing
if no_bg_final:
results = [removebg(img, rembg_session, True) for img in results]
#results = [img_pad(img, img_width, img_height, True) for img in results]
print("Done!")
gc.collect()
torch.cuda.empty_cache()
results[0].save('result.png')
return results, getThumbnails(results)
def run_generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True):
return generate_frame(images, target_poses, train_steps, inference_steps, modelId, img_width, img_height, bg_remove, resize_inputs)
def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
images = [img[0] for img in images]
results = run(images, video_path, train_steps, inference_steps, fps, bg_remove, resize_inputs, finetune=True, is_app=True)
print("==== Video generation ====")
out_vid = f"out_{uuid.uuid4()}"
if debug:
gen_vid(results, out_vid+'.mp4', fps, 'mp4')
else:
gen_vid(results, out_vid+'.webm', fps, 'webm')
print("Done!")
return out_vid+'.webm', results
def run_eval(images_orig, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False):
is_app=False
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
images = [img[0] for img in images_orig]
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app)
target_poses = target_poses[:max_frame_count]
#train_steps = 3
finetune = False
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
finetune = True
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
gc.collect()
torch.cuda.empty_cache()
return results, results_base
@spaces.GPU(duration=30)
def interpolate_frames(frame1, frame2, times_to_interp, remove_bg):
film = Predictor()
film.setup()
thumb_size = (512, 512)
width, height = frame1.size
frame1.thumbnail(thumb_size)
frame2.thumbnail(thumb_size)
out_vid = film.predict(frame1, frame2, int(times_to_interp))
print(out_vid)
if str(out_vid).endswith('.mp4'):
results = extract_frames(out_vid, 30)
results = results[1:-1]
else:
results = [Image.open(out_vid)]
print(results)
if remove_bg:
rembg_session = rembg.new_session("u2netp")
results = [removebg(img, rembg_session, True) for img in results]
for r in results:
r.thumbnail((width, height))
del film
return results, getThumbnails(results)
def run_interpolate_frames(frame1, frame2, times_to_interp, remove_bg):
with Pool() as pool:
results = pool.starmap(interpolate_frames, [(frame1, frame2, times_to_interp, remove_bg)])
return results[0]
def resize_images(images, width, height):
images = [img[0] for img in images]
return [resize_pad(img, width, height, True) for img in images]