import os import subprocess import tempfile from pathlib import Path from typing import Union import shutil import cv2 import imageio import numpy as np import torch import torchvision # from decord import VideoReader, cpu from einops import rearrange, repeat from .iimage import IImage from PIL import Image, ImageDraw, ImageFont from torchvision.utils import save_image channel_first = 0 channel_last = -1 def video_naming(prompt, extension, batch_idx, idx): prompt_identifier = prompt.replace(" ", "_") prompt_identifier = prompt_identifier.replace("/", "_") if len(prompt_identifier) > 40: prompt_identifier = prompt_identifier[:40] filename = f"{batch_idx:04d}_{idx:04d}_{prompt_identifier}.{extension}" return filename def video_naming_chunk(prompt, extension, batch_idx, idx, chunk_idx): prompt_identifier = prompt.replace(" ", "_") prompt_identifier = prompt_identifier.replace("/", "_") if len(prompt_identifier) > 40: prompt_identifier = prompt_identifier[:40] filename = f"{batch_idx}_{idx}_{chunk_idx}_{prompt_identifier}.{extension}" return filename class ResultProcessor: def __init__(self, fps: int, n_frames: int, logger=None) -> None: self.fps = fps self.logger = logger self.n_frames = n_frames def set_logger(self, logger): self.logger = logger def _create_video( self, video, prompt, filename: Union[str, Path], append_video: torch.FloatTensor = None, input_flow=None, ): if video.ndim == 5: # can be batches if we provide list of filenames assert video.shape[0] == 1 video = video[0] if video.shape[0] == 3 and video.shape[1] == self.n_frames: video = rearrange(video, "C F W H -> F C W H") assert video.shape[1] == 3, f"Wrong video format. Got {video.shape}" if isinstance(filename, Path): filename = filename.as_posix() # assert video.max() <= 1 and video.min() >= 0 assert ( video.max() <= 1.1 and video.min() >= -0.1 ), f"video has unexpected range: [{video.min()}, {video.max()}]" vid_obj = IImage(video, vmin=0, vmax=1) if prompt is not None: vid_obj = vid_obj.append_text(prompt, padding=(0, 50, 0, 0)) if append_video is not None: if append_video.ndim == 5: assert append_video.shape[0] == 1 append_video = append_video[0] if append_video.shape[0] < video.shape[0]: append_video = torch.concat( [ append_video, repeat( append_video[-1, None], "F C W H -> (rep F) C W H", rep=video.shape[0] - append_video.shape[0], ), ], dim=0, ) if append_video.ndim == 3 and video.ndim == 4: append_video = repeat( append_video, "C W H -> F C W H", F=video.shape[0] ) append_video = IImage(append_video, vmin=-1, vmax=1) if prompt is not None: append_video = append_video.append_text( "input_frame", padding=(0, 50, 0, 0) ) vid_obj = vid_obj | append_video vid_obj = vid_obj.setFps(self.fps) vid_obj.save(filename) def _create_prompt_file(self, prompt, filename, video_path: str = None): filename = Path(filename) filename = filename.parent / (filename.stem + ".txt") with open(filename.as_posix(), "w") as file_writer: file_writer.write(prompt) file_writer.write("\n") if video_path is not None: file_writer.write(video_path) else: file_writer.write(" no_source") def log_video( self, video: torch.FloatTensor, prompt: str, video_id: str, log_folder: str, input_flow=None, video_path_input: str = None, extension: str = "gif", prompt_on_vid: bool = True, append_video: torch.FloatTensor = None, ): with tempfile.TemporaryDirectory() as tmpdirname: storage_fol = Path(tmpdirname) filename = f"{video_id}.{extension}".replace("/", "_") vid_filename = storage_fol / filename self._create_video( video, prompt if prompt_on_vid else None, vid_filename, append_video, input_flow=input_flow, ) prompt_file = storage_fol / f"{video_id}.txt" self._create_prompt_file(prompt, prompt_file, video_path_input) if self.logger.experiment.__class__.__name__ == "_DummyExperiment": run_fol = ( Path(self.logger.save_dir) / self.logger.experiment_id / self.logger.run_id / "artifacts" / log_folder ) if not run_fol.exists(): run_fol.mkdir(parents=True, exist_ok=True) shutil.copy( prompt_file.as_posix(), (run_fol / f"{video_id}.txt").as_posix() ) shutil.copy(vid_filename, (run_fol / filename).as_posix()) else: self.logger.experiment.log_artifact( self.logger.run_id, prompt_file.as_posix(), log_folder ) self.logger.experiment.log_artifact( self.logger.run_id, vid_filename, log_folder ) def save_to_file( self, video: torch.FloatTensor, prompt: str, video_filename: Union[str, Path], input_flow=None, conditional_video_path: str = None, prompt_on_vid: bool = True, conditional_video: torch.FloatTensor = None, ): self._create_video( video, prompt if prompt_on_vid else None, video_filename, conditional_video, input_flow=input_flow, ) self._create_prompt_file(prompt, video_filename, conditional_video_path) def add_text_to_image( image_array, text, position, font_size, text_color, font_path=None ): # Convert the NumPy array to PIL Image image_pil = Image.fromarray(image_array) # Create a drawing object draw = ImageDraw.Draw(image_pil) if font_path is not None: font = ImageFont.truetype(font_path, font_size) else: try: # Load the font font = ImageFont.truetype( "/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", font_size, ) except: font = ImageFont.load_default() # Draw the text on the image draw.text(position, text, font=font, fill=text_color) # Convert the PIL Image back to NumPy array modified_image_array = np.array(image_pil) return modified_image_array def add_text_to_video(video_path, prompt): outputs_with_overlay = [] with open(video_path, "rb") as f: vr = VideoReader(f, ctx=cpu(0)) for i in range(len(vr)): frame = vr[i] frame = add_text_to_image( frame, prompt, position=(10, 10), font_size=15, text_color=(255, 0, 0), ) outputs_with_overlay.append(frame) outputs = outputs_with_overlay video_path = video_path.replace("mp4", "gif") imageio.mimsave(video_path, outputs, duration=100, loop=0) def save_videos_grid( videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=30, prompt=None ): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) if prompt is not None: outputs_with_overlay = [] for frame in outputs: frame_out = add_text_to_image( frame, prompt, position=(10, 10), font_size=10, text_color=(255, 0, 0), ) outputs_with_overlay.append(frame_out) outputs = outputs_with_overlay imageio.mimsave(path, outputs, duration=round(1 / fps * 1000), loop=0) # iio.imwrite(path, outputs) # optimize(path) def set_channel_pos(data, shape_dict, channel_pos): assert data.ndim == 5 or data.ndim == 4 batch_dim = data.shape[0] frame_dim = shape_dict["frame_dim"] channel_dim = shape_dict["channel_dim"] width_dim = shape_dict["width_dim"] height_dim = shape_dict["height_dim"] assert batch_dim != frame_dim assert channel_dim != frame_dim assert channel_dim != batch_dim video_shape = list(data.shape) batch_pos = video_shape.index(batch_dim) channel_pos = video_shape.index(channel_dim) w_pos = video_shape.index(width_dim) h_pos = video_shape.index(height_dim) if w_pos == h_pos: video_shape[w_pos] = -1 h_pos = video_shape.index(height_dim) pattern_order = {} pattern_order[batch_pos] = "B" pattern_order[channel_pos] = "C" pattern_order[w_pos] = "W" pattern_order[h_pos] = "H" if data.ndim == 5: frame_pos = video_shape.index(frame_dim) pattern_order[frame_pos] = "F" if channel_pos == channel_first: pattern = " -> B F C W H" else: pattern = " -> B F W H C" else: if channel_pos == channel_first: pattern = " -> B C W H" else: pattern = " -> B W H C" pattern_input = [pattern_order[idx] for idx in range(data.ndim)] pattern_input = " ".join(pattern_input) pattern = pattern_input + pattern data = rearrange(data, pattern) def merge_first_two_dimensions(tensor): dims = tensor.ndim letters = [] for letter_idx in range(dims - 2): letters.append(chr(letter_idx + 67)) latters_pattern = " ".join(letters) tensor = rearrange( tensor, "A B " + latters_pattern + " -> (A B) " + latters_pattern ) # TODO merging first two dimensions might be easier with reshape so no need to create letters # should be 'tensor.view(*tensor.shape[:2], -1)' return tensor def apply_spatial_function_to_video_tensor(video, shape, func): # TODO detect batch, frame, channel, width, and height assert video.ndim == 5 batch_dim = shape["batch_dim"] frame_dim = shape["frame_dim"] channel_dim = shape["channel_dim"] width_dim = shape["width_dim"] height_dim = shape["height_dim"] assert batch_dim != frame_dim assert channel_dim != frame_dim assert channel_dim != batch_dim video_shape = list(video.shape) batch_pos = video_shape.index(batch_dim) frame_pos = video_shape.index(frame_dim) channel_pos = video_shape.index(channel_dim) w_pos = video_shape.index(width_dim) h_pos = video_shape.index(height_dim) if w_pos == h_pos: video_shape[w_pos] = -1 h_pos = video_shape.index(height_dim) pattern_order = {} pattern_order[batch_pos] = "B" pattern_order[channel_pos] = "C" pattern_order[frame_pos] = "F" pattern_order[w_pos] = "W" pattern_order[h_pos] = "H" pattern_order = sorted(pattern_order.items(), key=lambda x: x[1]) pattern_order = [x[0] for x in pattern_order] input_pattern = " ".join(pattern_order) video = rearrange(video, input_pattern + " -> (B F) C W H") video = func(video) video = rearrange(video, "(B F) C W H -> " + input_pattern, F=frame_dim) return video def dump_frames(videos, as_mosaik, storage_fol, save_image_kwargs): # assume videos is in format B F C H W, range [0,1] num_frames = videos.shape[1] num_videos = videos.shape[0] if videos.shape[2] != 3 and videos.shape[-1] == 3: videos = rearrange(videos, "B F W H C -> B F C W H") frame_counter = 0 if not isinstance(storage_fol, Path): storage_fol = Path(storage_fol) for frame_idx in range(num_frames): print(f" Creating frame {frame_idx}") batch_frame = videos[:, frame_idx, ...] if as_mosaik: filename = storage_fol / f"frame_{frame_counter:03d}.png" save_image(batch_frame, fp=filename.as_posix(), **save_image_kwargs) frame_counter += 1 else: for video_idx in range(num_videos): frame = batch_frame[video_idx] filename = storage_fol / f"frame_{frame_counter:03d}.png" save_image(frame, fp=filename.as_posix(), **save_image_kwargs) frame_counter += 1 def gif_from_videos(videos): assert videos.dim() == 5 assert videos.min() >= 0 assert videos.max() <= 1 gif_file = Path("tmp.gif").absolute() with tempfile.TemporaryDirectory() as tmpdirname: storage_fol = Path(tmpdirname) nrows = min(4, videos.shape[0]) dump_frames( videos=videos, storage_fol=storage_fol, as_mosaik=True, save_image_kwargs={"nrow": nrows}, ) cmd = f"ffmpeg -y -f image2 -framerate 4 -i {storage_fol / 'frame_%03d.png'} {gif_file.as_posix()}" subprocess.check_call( cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT ) return gif_file def add_margin(pil_img, top, right, bottom, left, color): width, height = pil_img.size new_width = width + right + left new_height = height + top + bottom result = Image.new(pil_img.mode, (new_width, new_height), color) result.paste(pil_img, (left, top)) return result def resize_to_fit(image, size): W, H = size w, h = image.size if H / h > W / w: H_ = int(h * W / w) W_ = W else: W_ = int(w * H / h) H_ = H return image.resize((W_, H_)) def pad_to_fit(image, size): W, H = size w, h = image.size pad_h = (H - h) // 2 pad_w = (W - w) // 2 return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0))