import logging import math import os from typing import Any, Dict, List, Optional, Tuple, Union #from diffusers.models.controlnet import ControlNetConditioningEmbedding from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding import torch from torch import nn import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from tqdm.auto import tqdm from src.configs.stage2_config import args import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, ) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from src.dataset.stage2_dataset import InpaintDataset, InpaintCollate_fn from transformers import CLIPVisionModelWithProjection from transformers import Dinov2Model from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel import glob import os import torch from torch import nn from PIL import Image, ImageOps import numpy as np from diffusers import UniPCMultistepScheduler from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel from torchvision import transforms #from diffusers.models.controlnet import ControlNetConditioningEmbedding from transformers import CLIPImageProcessor from transformers import Dinov2Model from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel,ControlNetModel,DDIMScheduler from src.pipelines.PCDMs_pipeline import PCDMsPipeline #from single_extract_pose import inference_pose import spaces from libs.easy_dwpose import DWposeDetector from libs.easy_dwpose.draw import draw_openpose from libs.film import Predictor from PIL import Image import cv2 import os import gradio as gr import rembg import uuid import gc from numba import cuda import requests import json from huggingface_hub import hf_hub_download, HfApi from numba import cuda from multiprocessing import Pool, Process, Queue import torch.multiprocessing as mp # Inputs =================================================================================================== input_img = "sm.png" train_imgs = ["target.png"] in_vid = "walk.mp4" out_vid = 'out.mp4' """ train_steps = 100 inference_steps = 10 fps = 12 """ debug = False save_model = True should_gen_vid = False max_batch_size = 8 max_frame_count = 200 def save_temp_imgs(imgs): os.makedirs('temp', exist_ok=True) results = [] api = HfApi() for i, img in enumerate(imgs): #img_name = 'temp/'+str(uuid.uuid4())+'.png' img_name = 'temp/'+str(i)+'.png' img.save(img_name) """ url = 'https://tmpfiles.org/api/v1/upload' try: response = requests.post(url, files={'file': open(img_name, 'rb')}) # Check for successful response (status code 200) response.raise_for_status() # Print the server's response print("Status Code:", response.status_code) data = response.json() print("Response JSON:", data) results.append(data['data']['url']) except requests.exceptions.RequestException as e: print(f"An error occurred: {e}") """ results.append('https://huggingface.co/datasets/acmyu/KeyframesAIFiles/resolve/main/'+img_name) api.upload_file( path_or_fileobj='temp', path_in_repo='temp', repo_id="acmyu/KeyframesAIFiles", repo_type="dataset", ) return results def getThumbnails(imgs): thumbs = [] thumb_size = (512, 512) for img in imgs: th = img.copy() th.thumbnail(thumb_size) thumbs.append(th) return thumbs # Pose detection ============================================================================================== def load_models(): dwpose = DWposeDetector(device="cuda") rembg_session = rembg.new_session("u2netp") pcdms_model = hf_hub_download(repo_id="acmyu/PCDMs", filename="pcdms_ckpt.pt") # Load scheduler noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") # Load model image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant') image_encoder_g = CLIPVisionModelWithProjection.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')#("openai/clip-vit-base-patch32") vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae") unet = Stage2_InapintUNet2DConditionModel.from_pretrained( "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16, subfolder="unet", in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True) return dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet #load_models() def img_pad(img, tw, th, transparent=False): #print('pad', tw, th) img.thumbnail((tw, th)) if transparent: new_img = Image.new('RGBA', (tw, th), (0, 0, 0, 0)) else: new_img = Image.new("RGB", (tw, th), (0, 0, 0)) left = (tw - img.width) // 2 top = (th - img.height) // 2 #print(left, top) new_img.paste(img, (left, top)) return new_img def resize_pad(img, tw, th, transparent): w, h = img.size orig_tw = tw orig_th = th if tw/th > w/h: tw = int(th * w/h) elif tw/th < w/h: th = int(tw * h/w) img = img.resize((tw, th), Image.BICUBIC) return img_pad(img, orig_tw, orig_th, True) def resize_and_pad(img, target_img): tw, th = target_img.size return resize_pad(img, tw, th, False) def remove_zero_pad(image): image = np.array(image) dummy = np.argwhere(image != 0) # assume blackground is zero max_y = dummy[:, 0].max() min_y = dummy[:, 0].min() min_x = dummy[:, 1].min() max_x = dummy[:, 1].max() crop_image = image[min_y:max_y, min_x:max_x] return Image.fromarray(crop_image) def get_pose(img, dwpose, outfile, crop=False): #pil_image = Image.open("imgs/"+img).convert("RGB") #skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False) img.thumbnail((512,512)) out_img, pose = dwpose(img, include_hands=True, include_face=False) #print(pose['bodies']) if crop: bbox = out_img.getbbox() out_img = out_img.crop(bbox) out_img = ImageOps.expand(out_img, border=int(out_img.width*0.2), fill=(0,0,0)) return out_img, pose def extract_frames(video_path, fps): video_capture = cv2.VideoCapture(video_path) frame_count = 0 frames = [] fps_in = video_capture.get(cv2.CAP_PROP_FPS) fps_out = fps index_in = -1 index_out = -1 while True: success = video_capture.grab() if not success: break index_in += 1 if frame_count > max_frame_count: break out_due = int(index_in / fps_in * fps_out) if out_due > index_out: success, frame = video_capture.retrieve() if not success: break index_out += 1 frame_count += 1 frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) video_capture.release() print(f"Extracted {frame_count} frames") return frames def removebg(img, rembg_session, transparent=False): if transparent: result = Image.new('RGBA', img.size, (0, 0, 0, 0)) else: result = Image.new("RGB", img.size, "#ffffff") out = rembg.remove(img, session=rembg_session) result.paste(out, mask=out) return result def prepare_inputs_train(images, bg_remove, dwpose, rembg_session): print("remove background", bg_remove) if bg_remove: images = [removebg(img, rembg_session) for img in images] in_img = images[0] in_pose, _ = get_pose(in_img, dwpose, "in_pose.png") train_poses = [] train_imgs = [resize_and_pad(img, in_img) for img in images[1:]] for i, img in enumerate(train_imgs): train_pose, _ = get_pose(img, dwpose, "tr_pose"+str(i)+".png") train_poses.append(train_pose) return in_img, in_pose, train_imgs, train_poses def prepare_inputs_inference(in_img, in_vid, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app=False, target_poses=None): progress=gr.Progress(track_tqdm=True) print("prepare_inputs_inference") in_pose, _ = get_pose(in_img, dwpose, "in_pose.png") print(in_vid) print(frames) if in_vid: frames = extract_frames(in_vid, fps) for f in frames: f.thumbnail((512,512)) print("remove background", bg_remove) if bg_remove: in_img = removebg(in_img, rembg_session) #frames = [removebg(img, rembg_session) for img in frames] if debug: for i, frame in enumerate(frames): frame.save("out/frame_"+str(i)+".png") print("vid: ", in_vid, fps) progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames") if not target_poses: target_poses = [] target_poses_coords = [] max_left = max_top = 999999 max_right = max_bottom = 0 it = frames if is_app: it = progress.tqdm(frames, desc="Pose Detection") for f in it: tpose, tpose_coords = get_pose(f, dwpose, "tar_pose"+str(len(target_poses))+".png") #print(tpose_coords) coords = {} for k in tpose_coords: if k == 'bodies_multi': coords['bodies'] = tpose_coords[k].tolist() elif k in ['hands']: coords[k] = tpose_coords[k].tolist() elif k in ['num_candidates']: coords[k] = tpose_coords[k] #print(coords) target_poses.append(tpose) target_poses_coords.append(json.dumps(coords)) progress_bar.update(1) target_poses_cropped = [] for tpose in target_poses: if resize_inputs: bbox = tpose.getbbox() left, top, right, bottom = bbox max_left = min(max_left, left) max_top = min(max_top, top) max_right = max(max_right, right) max_bottom = max(max_bottom, bottom) tpose = tpose.crop((max_left, max_top, max_right, max_bottom)) tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0)) tpose = resize_and_pad(tpose, in_img) if debug: tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png") target_poses_cropped.append(tpose) #target_poses_cropped[0].save("pose.png") return in_img, target_poses_cropped, in_pose, target_poses_coords, frames def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize_inputs, is_app=False): in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session) in_img, target_poses_cropped, _, _, _ = prepare_inputs_inference(in_img, in_vid, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app) return in_img, in_pose, train_imgs, train_poses, target_poses_cropped # Training =================================================================================================== # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.18.0.dev0") logger = get_logger(__name__) class ImageProjModel_p(torch.nn.Module): """SD model with image prompt""" def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, out_dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class ImageProjModel_g(torch.nn.Module): """SD model with image prompt""" def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, out_dim), nn.Dropout(dropout) ) def forward(self, x): # b, 257,1280 return self.net(x) class SDModel(torch.nn.Module): """SD model with image prompt""" def __init__(self, unet) -> None: super().__init__() self.image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024) self.unet = unet self.pose_proj = ControlNetConditioningEmbedding( conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256), conditioning_channels=3) def forward(self, noisy_latents, timesteps, simg_f_p, timg_f_g, pose_f): extra_image_embeddings_p = self.image_proj_model_p(simg_f_p) extra_image_embeddings_g = timg_f_g print(extra_image_embeddings_p.size()) print(extra_image_embeddings_g.size()) encoder_image_hidden_states = torch.cat([extra_image_embeddings_p ,extra_image_embeddings_g], dim=1) pose_cond = self.pose_proj(pose_f) pred_noise = self.unet(noisy_latents, timesteps, class_labels=timg_f_g, encoder_hidden_states=encoder_image_hidden_states,my_pose_cond=pose_cond).sample return pred_noise def load_training_checkpoint(model, pcdms_model, tag=None, **kwargs): #model_sd = torch.load(load_dir, map_location="cpu")["module"] model_sd = torch.load( pcdms_model, map_location="cpu" )["module"] image_proj_model_dict = {} pose_proj_dict = {} unet_dict = {} for k in model_sd.keys(): if k.startswith("pose_proj"): pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k] elif k.startswith("image_proj_model_p"): image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k] elif k.startswith("image_proj_model."): image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k] elif k.startswith("unet"): unet_dict[k.replace("unet.", "")] = model_sd[k] else: print(k) model.pose_proj.load_state_dict(pose_proj_dict) model.image_proj_model_p.load_state_dict(image_proj_model_dict) model.unet.load_state_dict(unet_dict) return model, 0, 0 def checkpoint_model(checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs): """Utility function for checkpointing model + optimizer dictionaries The main purpose for this is to be able to resume training from that instant again """ checkpoint_state_dict = { "epoch": epoch, "last_global_step": last_global_step, } # Add extra kwargs too checkpoint_state_dict.update(kwargs) success = model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict) status_msg = f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}" if success: logging.info(f"Success {status_msg}") else: logging.warning(f"Failure {status_msg}") return @spaces.GPU(duration=600) def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune=True, is_app=False): logging_dir = 'outputs/logging' print('start train') progress=gr.Progress(track_tqdm=True) accelerator = Accelerator( log_with=args.report_to, project_dir=logging_dir, mixed_precision=args.mixed_precision, gradient_accumulation_steps=args.gradient_accumulation_steps ) # Make one log on every process with the configuration for debugging. #logging.basicConfig( # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", # datefmt="%m/%d/%Y %H:%M:%S", # level=logging.INFO, ) print(accelerator.state) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. set_seed(42) # Handle the repository creation if accelerator.is_main_process: os.makedirs('outputs', exist_ok=True) """ unet = Stage2_InapintUNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", in_channels=9, class_embed_type="projection" ,projection_class_embeddings_input_dim=1024, low_cpu_mem_usage=False, ignore_mismatched_sizes=True) """ image_encoder_p.requires_grad_(False) image_encoder_g.requires_grad_(False) vae.requires_grad_(False) sd_model = SDModel(unet=unet) sd_model.train() if args.gradient_checkpointing: sd_model.enable_gradient_checkpointing() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True learning_rate = 1e-4 train_batch_size = min(len(train_images), max_batch_size) #len(train_images) % 16 # Optimizer creation params_to_optimize = sd_model.parameters() optimizer = torch.optim.AdamW( params_to_optimize, lr=learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) inputs = [{ "source_image": in_image, "source_pose": in_pose, "target_image": timg, "target_pose": tpose, } for timg, tpose in zip(train_images, train_poses)] """ inputs = {[ "source_image": Image.open('imgs/sm.png'), "source_pose": Image.open('imgs/sm_pose.jpg'), "target_image": Image.open('imgs/target.png'), "target_pose": Image.open('imgs/target_pose.jpg'), ]} """ #print(inputs) dataset = InpaintDataset( inputs, 'imgs/', size=(args.img_width, args.img_height), # w h imgp_drop_rate=0.1, imgg_drop_rate=0.1, ) """ dataset = InpaintDataset( args.json_path, args.image_root_path, size=(args.img_width, args.img_height), # w h imgp_drop_rate=0.1, imgg_drop_rate=0.1, ) """ train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True) train_dataloader = torch.utils.data.DataLoader( dataset, sampler=train_sampler, collate_fn=InpaintCollate_fn, batch_size=train_batch_size, num_workers=0,) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True args.max_train_steps = train_steps lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. sd_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(sd_model, optimizer, train_dataloader, lr_scheduler) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 """ if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 """ # Move vae, unet and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) sd_model.unet.to(accelerator.device, dtype=weight_dtype) image_encoder_p.to(accelerator.device, dtype=weight_dtype) image_encoder_g.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = train_steps # Train! total_batch_size = ( train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps ) print("***** Running training *****") print(f" Num batches each epoch = {len(train_dataloader)}") print(f" Num Epochs = {args.num_train_epochs}") print(f" Instantaneous batch size per device = {train_batch_size}") print( f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" ) print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") print(f" Total optimization steps = {args.max_train_steps}") if args.resume_from_checkpoint: # New Code # # Loads the DeepSpeed checkpoint from the specified path prior_model, last_epoch, last_global_step = load_training_checkpoint( sd_model, pcdms_model, **{"load_optimizer_states": True, "load_lr_scheduler_states": True}, ) print(f"Resumed from checkpoint: {args.resume_from_checkpoint}, global step: {last_global_step}") starting_epoch = last_epoch global_steps = last_global_step sd_model = sd_model else: global_steps = 0 starting_epoch = 0 sd_model = sd_model progress_bar = tqdm(range(global_steps, args.max_train_steps), initial=global_steps, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) bsz = train_batch_size if not finetune or train_steps == 0: accelerator.wait_for_everyone() accelerator.end_training() checkpoint_state_dict = { "epoch": 0, "module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(), } torch.save(checkpoint_state_dict, modelId+".pt") del sd_model gc.collect() torch.cuda.empty_cache() return #return {k: v.cpu() for k, v in sd_model.state_dict().items()} it = range(starting_epoch, args.num_train_epochs) if is_app: it = progress.tqdm(it, desc="Fine-tuning") for epoch in it: for step, batch in enumerate(train_dataloader): with accelerator.accumulate(sd_model): with torch.no_grad(): # Convert images to latent space latents = vae.encode(batch["source_target_image"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Get the masked image latents masked_latents = vae.encode(batch["vae_source_mask_image"].to(dtype=weight_dtype)).latent_dist.sample() masked_latents = masked_latents * vae.config.scaling_factor bsz = batch["target_image"].size(dim=0) # mask mask1 = torch.ones((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype) mask0 = torch.zeros((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype) mask = torch.cat([mask1, mask0], dim=3) # Get the image embedding for conditioning cond_image_feature_p = image_encoder_p(batch["source_image"].to(accelerator.device, dtype=weight_dtype)) cond_image_feature_p = (cond_image_feature_p.last_hidden_state) cond_image_feature_g = image_encoder_g(batch["target_image"].to(accelerator.device, dtype=weight_dtype), ).image_embeds cond_image_feature_g =cond_image_feature_g.unsqueeze(1) # Sample noise that we'll add to the latents noise = torch.randn_like(latents) if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) # Sample a random timestep for each image #timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (train_batch_size,),device=latents.device, ) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,),device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) #print(noisy_latents.size(), mask.size(), masked_latents.size()) noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1) # Get the text embedding for conditioning cond_pose = batch["source_target_pose"].to(dtype=weight_dtype) #print(noisy_latents.size()) #print(cond_image_feature_p.size()) #print(cond_image_feature_g.size()) #print(cond_pose.size()) # Predict the noise residual model_pred = sd_model(noisy_latents, timesteps, cond_image_feature_p,cond_image_feature_g, cond_pose, ) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = sd_model.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: global_steps += 1 if global_steps >= args.max_train_steps: break logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} print(logs) progress_bar.set_postfix(**logs) progress_bar.update(1) # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() accelerator.end_training() sd_model.unet.cpu() sd_model.cpu() del vae del image_encoder_p del image_encoder_g if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps: print('saving', modelId) checkpoint_state_dict = { "epoch": 0, "module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(), } print(list(sd_model.state_dict().keys())[:20]) torch.save(checkpoint_state_dict, modelId+".pt") del sd_model gc.collect() torch.cuda.empty_cache() print('done train') print(torch.cuda.memory_allocated()/1024**2) return del sd_model gc.collect() torch.cuda.empty_cache() return {k: v.cpu() for k, v in sd_model.state_dict().items()} # Pose-transfer =================================================================================================== device = "cuda" class ImageProjModel(torch.nn.Module): """SD model with image prompt""" def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, out_dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size print(w, h) grid = Image.new("RGB", size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def load_mydict(modelId, finetuned_model): if save_model: model_ckpt_path = modelId+'.pt' model_sd = torch.load(model_ckpt_path, map_location="cpu")["module"] else: model_sd = finetuned_model #torch.load(model_ckpt_path, map_location="cpu")["module"] image_proj_model_dict = {} pose_proj_dict = {} unet_dict = {} for k in model_sd.keys(): if k.startswith("pose_proj"): pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k] elif k.startswith("image_proj_model_p"): image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k] elif k.startswith("image_proj_model"): image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k] elif k.startswith("unet"): unet_dict[k.replace("unet.", "")] = model_sd[k] else: print(k) return image_proj_model_dict, pose_proj_dict, unet_dict @spaces.GPU(duration=600) def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder, is_app=False): print('start inference') progress=gr.Progress(track_tqdm=True) if not save_model: finetuned_model = {k: v.cuda() for k, v in finetuned_model.items()} device = "cuda" pretrained_model_name_or_path ="stabilityai/stable-diffusion-2-1-base" image_encoder_path = "facebook/dinov2-giant" #model_ckpt_path = "./pcdms_ckpt.pt" # ckpt path model_ckpt_path = modelId+'.pt' clip_image_processor = CLIPImageProcessor() img_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) generator = torch.Generator(device=device).manual_seed(42) """ unet = Stage2_InapintUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device) vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae").to(device, dtype=torch.float16) image_encoder = Dinov2Model.from_pretrained(image_encoder_path).to(device, dtype=torch.float16) """ noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) unet = unet.to(device, dtype=torch.float16) vae = vae.to(device, dtype=torch.float16) image_encoder = image_encoder.to(device, dtype=torch.float16) image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16) pose_proj_model = ControlNetConditioningEmbedding( conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256), conditioning_channels=3).to(device).to(dtype=torch.float16) # load weight print('loading', modelId) image_proj_model_dict, pose_proj_dict, unet_dict = load_mydict(modelId, finetuned_model) print('loaded', modelId) image_proj_model.load_state_dict(image_proj_model_dict) pose_proj_model.load_state_dict(pose_proj_dict) unet.load_state_dict(unet_dict) pipe = PCDMsPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", unet=unet, torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device) print('====================== model load finish ===================') results = [] progress_bar = tqdm(range(len(target_poses)), initial=0, desc="Frames") it = target_poses if is_app: it = progress.tqdm(it, desc="Pose Transfer") for pose in it: num_samples = 1 image_size = (512, 512) s_img_path = 'imgs/'+input_img # input image 1 #target_pose_img = 'imgs/pose_'+str(n)+'.png' # input image 2 #t_pose = inference_pose(target_pose_img, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC) #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC) t_pose = pose.convert("RGB").resize((image_size), Image.BICUBIC) #t_pose = resize_and_pad(pose.convert("RGB")) #s_img = Image.open(s_img_path) width_orig, height_orig = in_image.size s_img = in_image.convert("RGB").resize(image_size, Image.BICUBIC) #s_img = resize_and_pad(in_image.convert("RGB")) black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC) s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height)) s_img_t_mask.paste(s_img, (0, 0)) s_img_t_mask.paste(black_image, (s_img.width, 0)) #s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC) #s_pose = Image.open('imgs/sm_pose.jpg').convert("RGB").resize(image_size, Image.BICUBIC) s_pose = in_pose.convert("RGB").resize(image_size, Image.BICUBIC) #s_pose = resize_and_pad(in_pose.convert("RGB")) print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height)) #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC) st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height)) st_pose.paste(s_pose, (0, 0)) st_pose.paste(t_pose, (s_pose.width, 0)) clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0) cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0) mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16) mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16) mask = torch.cat([mask1, mask0], dim=3) with torch.inference_mode(): cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device)) simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample() simg_mask_latents = simg_mask_latents * 0.18215 images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state image_prompt_embeds = image_proj_model(images_embeds) uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds)) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) output, _ = pipe( simg_mask_latents= simg_mask_latents, mask = mask, cond_pose = cond_pose, prompt_embeds=image_prompt_embeds, negative_prompt_embeds=uncond_image_prompt_embeds, height=image_size[1], width=image_size[0]*2, num_images_per_prompt=num_samples, guidance_scale=2.0, generator=generator, num_inference_steps=inference_steps, ) output = output.images[-1] result = output.crop((image_size[0], 0, image_size[0] * 2, image_size[1])) result = result.resize((width_orig, height_orig), Image.BICUBIC) #result = remove_zero_pad(result) if debug: result.save('out/'+str(len(results))+'.png') results.append(result) progress_bar.update(1) del unet del vae del image_encoder del image_proj_model del pose_proj_model if not save_model: del finetuned_model gc.collect() torch.cuda.empty_cache() print(torch.cuda.memory_allocated()/1024**2) return results def gen_vid(frames, video_name, fps, codec): progress=gr.Progress(track_tqdm=True) frame = cv2.cvtColor(np.array(frames[0]), cv2.COLOR_RGB2BGR) height, width, layers = frame.shape #video = cv2.VideoWriter(video_name, 0, 1, (width,height)) if codec == 'mp4': video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) else: video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'VP90'), fps, (width, height)) for r in progress.tqdm(frames, desc="Creating video"): image = cv2.cvtColor(np.array(r), cv2.COLOR_RGB2BGR) video.write(image) #cv2.destroyAllWindows() #video.release() def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, finetune=True, is_app=False): print("==== Load Models ====") dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models() print("==== Pose Detection ====") in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize_inputs, is_app=is_app) if save_model: train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app) print('next') results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app) else: print("==== Finetuning ====") finetuned_model = train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app) print("==== Pose Transfer ====") results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder_p, is_app) return results def run_train_impl(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True, finetune=True): finetune=True is_app=True images = [img[0] for img in images] dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models() if resize_inputs: resize = 'target' else: resize = 'none' in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session) train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app) gc.collect() torch.cuda.empty_cache() def run_train(images, train_steps=100, modelId="fine_tuned_pcdms", bg_remove=True, resize_inputs=True): run_train_impl(images, train_steps, modelId, bg_remove, resize_inputs) """ mp.set_start_method('spawn', force=True) p = mp.Process(target=run_train_impl, args=(images, train_steps, modelId, bg_remove, resize_inputs)) p.start() p.join() """ def run_inference_impl(images, video_path, frames, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True): finetune=True is_app=True dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models() if not os.path.exists(modelId+".pt"): run_train(images, train_steps, modelId, bg_remove, resize_inputs) images = [img[0] for img in images] in_img = images[0] if frames: frames = [img[0] for img in frames] in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, video_path, frames, fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app) #target_poses[0].save('inf_pose.png') results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app) #urls = save_temp_imgs(results) if should_gen_vid: if debug: gen_vid(results, out_vid+'.mp4', fps, 'mp4') else: gen_vid(results, out_vid+'.webm', fps, 'webm') # postprocessing results = [removebg(img, rembg_session, True) for img in results] #results = [img_pad(img, img_width, img_height, True) for img in results] print("Done!") gc.collect() torch.cuda.empty_cache() return out_vid+'.webm', results, getThumbnails(results), target_poses_coords, orig_frames def run_inference(images, video_path, frames, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True): return run_inference_impl(images, video_path, frames, train_steps, inference_steps, fps, modelId, img_width, img_height, bg_remove, resize_inputs) def generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True): finetune=True is_app=True dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models() if not os.path.exists(modelId+".pt"): run_train(images, train_steps, modelId, bg_remove, resize_inputs) images = [img[0] for img in images] in_img = images[0] in_pose, _ = get_pose(in_img, dwpose, "in_pose.png") print(target_poses) target_poses = json.loads(target_poses) target_poses = [Image.fromarray(draw_openpose(pose, height=img_height, width=img_width, include_hands=True, include_face=False)) for pose in target_poses] in_img, target_poses, in_pose, target_poses_coords, orig_frames = prepare_inputs_inference(in_img, None, [], 12, dwpose, rembg_session, bg_remove, resize_inputs, is_app, target_poses) #target_poses[0].save('gen_pose.png') results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app) #urls = save_temp_imgs(results) # postprocessing results = [removebg(img, rembg_session, True) for img in results] #results = [img_pad(img, img_width, img_height, True) for img in results] print("Done!") gc.collect() torch.cuda.empty_cache() results[0].save('result.png') return results, getThumbnails(results) def run_generate_frame(images, target_poses, train_steps=100, inference_steps=10, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=True, resize_inputs=True): return generate_frame(images, target_poses, train_steps, inference_steps, modelId, img_width, img_height, bg_remove, resize_inputs) def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True): images = [img[0] for img in images] results = run(images, video_path, train_steps, inference_steps, fps, bg_remove, resize_inputs, finetune=True, is_app=True) print("==== Video generation ====") out_vid = f"out_{uuid.uuid4()}" if debug: gen_vid(results, out_vid+'.mp4', fps, 'mp4') else: gen_vid(results, out_vid+'.webm', fps, 'webm') print("Done!") return out_vid+'.webm', results def run_eval(images_orig, video_path, train_steps=100, inference_steps=10, fps=12, modelId="fine_tuned_pcdms", img_width=1920, img_height=1080, bg_remove=False, resize_inputs=False): is_app=False dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models() images = [img[0] for img in images_orig] in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session) in_img, target_poses, in_pose, _, _ = prepare_inputs_inference(in_img, video_path, [], fps, dwpose, rembg_session, bg_remove, resize_inputs, is_app) target_poses = target_poses[:max_frame_count] #train_steps = 3 finetune = False train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app) results_base = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app) finetune = True train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app) results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app) gc.collect() torch.cuda.empty_cache() return results, results_base @spaces.GPU(duration=30) def interpolate_frames(frame1, frame2, times_to_interp, remove_bg): film = Predictor() film.setup() thumb_size = (512, 512) width, height = frame1.size frame1.thumbnail(thumb_size) frame2.thumbnail(thumb_size) out_vid = film.predict(frame1, frame2, int(times_to_interp)) print(out_vid) if str(out_vid).endswith('.mp4'): results = extract_frames(out_vid, 30) results = results[1:-1] else: results = [Image.open(out_vid)] print(results) if remove_bg: rembg_session = rembg.new_session("u2netp") results = [removebg(img, rembg_session, True) for img in results] for r in results: r.thumbnail((width, height)) del film return results, getThumbnails(results) def run_interpolate_frames(frame1, frame2, times_to_interp, remove_bg): with Pool() as pool: results = pool.starmap(interpolate_frames, [(frame1, frame2, times_to_interp, remove_bg)]) return results[0] def resize_images(images, width, height): images = [img[0] for img in images] return [resize_pad(img, width, height, True) for img in images]