import os import argparse import torch from pathlib import Path from AdaIN import AdaINNet from PIL import Image from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range import cv2 import imageio import numpy as np from tqdm import tqdm parser = argparse.ArgumentParser() parser.add_argument('--content_video', type=str, required=True, help='Content video file path') parser.add_argument('--style_image', type=str, required=True, help='Style image file path') parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path') parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level') parser.add_argument('--cuda', action='store_true', help='Use CUDA') parser.add_argument('--color_control', action='store_true', help='Preserve content color') args = parser.parse_args() device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0): """ Given content image and style image, generate feature maps with encoder, apply neural style transfer with adaptive instance normalization, generate output image with decoder Args: content_tensor (torch.FloatTensor): Content image style_tensor (torch.FloatTensor): Style Image encoder: Encoder (vgg19) network decoder: Decoder network alpha (float, default=1.0): Weight of style image feature Return: output_tensor (torch.FloatTensor): Style Transfer output image """ content_enc = encoder(content_tensor) style_enc = encoder(style_tensor) transfer_enc = adaptive_instance_normalization(content_enc, style_enc) mix_enc = alpha * transfer_enc + (1-alpha) * content_enc return decoder(mix_enc) def main(): # Read video file content_video_pth = Path(args.content_video) content_video = cv2.VideoCapture(str(content_video_pth)) style_image_pth = Path(args.style_image) style_image = Image.open(style_image_pth) # Read video info fps = int(content_video.get(cv2.CAP_PROP_FPS)) frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT)) video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH)) # Prepare loop video_tqdm = tqdm(frame_count) # Prepare output video writer out_dir = './results_video/' os.makedirs(out_dir, exist_ok=True) out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem if args.color_control: out_pth += '_colorcontrol' out_pth += content_video_pth.suffix out_pth = Path(out_pth) writer = imageio.get_writer(out_pth, mode='I', fps=fps) # Load AdaIN model vgg = torch.load('vgg_normalized.pth') model = AdaINNet(vgg).to(device) model.decoder.load_state_dict(torch.load(args.decoder_weight)) model.eval() t = transform(512) style_tensor = t(style_image).unsqueeze(0).to(device) while content_video.isOpened(): ret, content_image = content_video.read() # Failed to read a frame if not ret: break content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device) # Linear Histogram Matching if needed if args.color_control: style_tensor = linear_histogram_matching(content_tensor,style_tensor) with torch.no_grad(): out_tensor = style_transfer(content_tensor, style_tensor, model.encoder , model.decoder, args.alpha).cpu().detach().numpy() # Convert output frame to original size and rgb range (0,255) out_tensor = np.squeeze(out_tensor, axis=0) out_tensor = np.transpose(out_tensor, (1, 2, 0)) out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U) out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC) # Write output frame to video writer.append_data(np.array(out_tensor)) video_tqdm.update(1) content_video.release() print("\nContent: " + content_video_pth.stem + ". Style: " + style_image_pth.stem +'\n') if __name__ == '__main__': main()