import random import numpy as np from functools import partial from torch.utils.data import Dataset, WeightedRandomSampler import torch.nn.functional as F import torch import math import decord from einops import rearrange from more_itertools import sliding_window from omegaconf import ListConfig import torchaudio import soundfile as sf from torchvision.transforms import RandomHorizontalFlip from audiomentations import Compose, AddGaussianNoise, PitchShift from safetensors.torch import load_file from tqdm import tqdm import cv2 from sgm.data.data_utils import ( create_masks_from_landmarks_full_size, create_face_mask_from_landmarks, create_masks_from_landmarks_box, create_masks_from_landmarks_mouth, ) from sgm.data.mask import face_mask_cheeks_batch torchaudio.set_audio_backend("sox_io") decord.bridge.set_bridge("torch") def exists(x): return x is not None def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None): len_file = audio.shape[-1] if max_len_sec or max_len_raw: max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr) if len_file < int(max_len): extened_wav = torch.nn.functional.pad( audio, (0, int(max_len) - len_file), "constant" ) else: extened_wav = audio[:, : int(max_len)] else: extened_wav = audio return extened_wav # Similar to regular video dataset but trades flexibility for speed class VideoDataset(Dataset): def __init__( self, filelist, resize_size=None, audio_folder="Audio", video_folder="CroppedVideos", emotions_folder="emotions", landmarks_folder=None, audio_emb_folder=None, video_extension=".avi", audio_extension=".wav", audio_rate=16000, latent_folder=None, audio_in_video=False, fps=25, num_frames=5, need_cond=True, step=1, mode="prediction", scale_audio=False, augment=False, augment_audio=False, use_latent=False, latent_type="stable", latent_scale=1, # For backwards compatibility from_audio_embedding=False, load_all_possible_indexes=False, audio_emb_type="wavlm", cond_noise=[-3.0, 0.5], motion_id=255.0, data_mean=None, data_std=None, use_latent_condition=False, skip_frames=0, get_separate_id=False, virtual_increase=1, filter_by_length=False, select_randomly=False, balance_datasets=True, use_emotions=False, get_original_frames=False, add_extra_audio_emb=False, expand_box=0.0, nose_index=28, what_mask="full", get_masks=False, ): self.audio_folder = audio_folder self.from_audio_embedding = from_audio_embedding self.audio_emb_type = audio_emb_type self.cond_noise = cond_noise self.latent_condition = use_latent_condition precomputed_latent = latent_type self.audio_emb_folder = ( audio_emb_folder if audio_emb_folder is not None else audio_folder ) self.skip_frames = skip_frames self.get_separate_id = get_separate_id self.fps = fps self.virtual_increase = virtual_increase self.select_randomly = select_randomly self.use_emotions = use_emotions self.emotions_folder = emotions_folder self.get_original_frames = get_original_frames self.add_extra_audio_emb = add_extra_audio_emb self.expand_box = expand_box self.nose_index = nose_index self.landmarks_folder = landmarks_folder self.what_mask = what_mask self.get_masks = get_masks assert not (exists(data_mean) ^ exists(data_std)), ( "Both data_mean and data_std should be provided" ) if data_mean is not None: data_mean = rearrange(torch.as_tensor(data_mean), "c -> c () () ()") data_std = rearrange(torch.as_tensor(data_std), "c -> c () () ()") self.data_mean = data_mean self.data_std = data_std self.motion_id = motion_id self.latent_folder = ( latent_folder if latent_folder is not None else video_folder ) self.audio_in_video = audio_in_video self.filelist = [] self.audio_filelist = [] self.landmark_filelist = [] if get_masks else None with open(filelist, "r") as files: for f in files.readlines(): f = f.rstrip() audio_path = f.replace(video_folder, audio_folder).replace( video_extension, audio_extension ) self.filelist += [f] self.audio_filelist += [audio_path] if self.get_masks: landmark_path = f.replace(video_folder, landmarks_folder).replace( video_extension, ".npy" ) self.landmark_filelist += [landmark_path] self.resize_size = resize_size if use_latent and not precomputed_latent: self.resize_size *= 4 if latent_type in ["stable", "ldm"] else 8 self.scale_audio = scale_audio self.step = step self.use_latent = use_latent self.precomputed_latent = precomputed_latent self.latent_type = latent_type self.latent_scale = latent_scale self.video_ext = video_extension self.video_folder = video_folder self.augment = augment self.maybe_augment = RandomHorizontalFlip(p=0.5) if augment else lambda x: x self.maybe_augment_audio = ( Compose( [ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.002, p=0.25), # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3), PitchShift(min_semitones=-1, max_semitones=1, p=0.25), # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.333), ] ) if augment_audio else lambda x, sample_rate: x ) self.maybe_augment_audio = partial( self.maybe_augment_audio, sample_rate=audio_rate ) self.mode = mode if mode == "interpolation": need_cond = False # Interpolation does not need condition as first and last frame becomes the condition self.need_cond = need_cond # If need cond will extract one more frame than the number of frames if get_separate_id: self.need_cond = True # It is used for the conditional model when the condition is not on the temporal dimension num_frames = num_frames if not self.need_cond else num_frames + 1 vr = decord.VideoReader(self.filelist[0]) self.video_rate = math.ceil(vr.get_avg_fps()) print(f"Video rate: {self.video_rate}") self.audio_rate = audio_rate a2v_ratio = fps / float(self.audio_rate) self.samples_per_frame = math.ceil(1 / a2v_ratio) if get_separate_id: assert mode == "prediction", ( "Separate identity frame is only supported for prediction mode" ) # No need for extra frame if we are getting a separate identity frame self.need_cond = True num_frames -= 1 self.num_frames = num_frames self.load_all_possible_indexes = load_all_possible_indexes if load_all_possible_indexes: self._indexes = self._get_indexes( self.filelist, self.audio_filelist, self.landmark_filelist ) else: if filter_by_length: self._indexes = self.filter_by_length( self.filelist, self.audio_filelist, self.landmark_filelist ) else: if self.get_masks: self._indexes = list( zip(self.filelist, self.audio_filelist, self.landmark_filelist) ) else: self._indexes = list( zip( self.filelist, self.audio_filelist, [None] * len(self.filelist), ) ) self.balance_datasets = balance_datasets if self.balance_datasets: self.weights = self._calculate_weights() self.sampler = WeightedRandomSampler( self.weights, num_samples=len(self._indexes), replacement=True ) def __len__(self): return len(self._indexes) * self.virtual_increase def _load_landmarks(self, filename, original_size, target_size, indexes): landmarks = np.load(filename, allow_pickle=True)[indexes, :] if self.what_mask == "full": mask = create_masks_from_landmarks_full_size( landmarks, original_size[0], original_size[1], offset=self.expand_box, nose_index=self.nose_index, ) elif self.what_mask == "box": mask = create_masks_from_landmarks_box( landmarks, (original_size[0], original_size[1]), box_expand=self.expand_box, nose_index=self.nose_index, ) elif self.what_mask == "heart": mask = face_mask_cheeks_batch( original_size, landmarks, box_expand=0.0, show_nose=True ) elif self.what_mask == "mouth": mask = create_masks_from_landmarks_mouth( landmarks, (original_size[0], original_size[1]), box_expand=0.01, nose_index=self.nose_index, ) else: mask = create_face_mask_from_landmarks( landmarks, original_size[0], original_size[1], mask_expand=0.05 ) # Interpolate the mask to the target size mask = F.interpolate( mask.unsqueeze(1).float(), size=target_size, mode="nearest" ) return mask, landmarks def get_emotions(self, video_file, video_indexes): emotions_path = video_file.replace( self.video_folder, self.emotions_folder ).replace(self.video_ext, ".pt") emotions = torch.load(emotions_path) return ( emotions["valence"][video_indexes], emotions["arousal"][video_indexes], emotions["labels"][video_indexes], ) def get_frame_indices(self, total_video_frames, select_randomly=False, start_idx=0): if select_randomly: # Randomly select self.num_frames indices from the available range available_indices = list(range(start_idx, total_video_frames)) if len(available_indices) < self.num_frames: raise ValueError( "Not enough frames in the video to sample with given parameters." ) indexes = random.sample(available_indices, self.num_frames) return sorted(indexes) # Sort to maintain temporal order else: # Calculate the maximum possible start index max_start_idx = total_video_frames - ( (self.num_frames - 1) * (self.skip_frames + 1) + 1 ) # Generate a random start index if max_start_idx > 0: start_idx = np.random.randint(start_idx, max_start_idx) else: raise ValueError( "Not enough frames in the video to sample with given parameters." ) # Generate the indices indexes = [ start_idx + i * (self.skip_frames + 1) for i in range(self.num_frames) ] return indexes def _load_audio(self, filename, max_len_sec, start=None, indexes=None): audio, sr = sf.read( filename, start=math.ceil(start * self.audio_rate), frames=math.ceil(self.audio_rate * max_len_sec), always_2d=True, ) # e.g (16000, 1) audio = audio.T # (1, 16000) assert sr == self.audio_rate, ( f"Audio rate is {sr} but should be {self.audio_rate}" ) audio = audio.mean(0, keepdims=True) audio = self.maybe_augment_audio(audio) audio = torch.from_numpy(audio).float() # audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=self.audio_rate) audio = trim_pad_audio(audio, self.audio_rate, max_len_sec=max_len_sec) return audio[0] def ensure_shape(self, tensors): target_length = self.samples_per_frame processed_tensors = [] for tensor in tensors: current_length = tensor.shape[1] diff = current_length - target_length assert abs(diff) <= 5, ( f"Expected shape {target_length}, but got {current_length}" ) if diff < 0: # Calculate how much padding is needed padding_needed = target_length - current_length # Pad the tensor padded_tensor = F.pad(tensor, (0, padding_needed)) processed_tensors.append(padded_tensor) elif diff > 0: # Trim the tensor trimmed_tensor = tensor[:, :target_length] processed_tensors.append(trimmed_tensor) else: # If it's already the correct size processed_tensors.append(tensor) return torch.cat(processed_tensors) def normalize_latents(self, latents): if self.data_mean is not None: # Normalize latents to 0 mean and 0.5 std latents = ((latents - self.data_mean) / self.data_std) * 0.5 return latents def convert_indexes(self, indexes_25fps, fps_from=25, fps_to=60): ratio = fps_to / fps_from indexes_60fps = [int(index * ratio) for index in indexes_25fps] return indexes_60fps def _get_frames_and_audio(self, idx): if self.load_all_possible_indexes: indexes, video_file, audio_file, land_file = self._indexes[idx] if self.audio_in_video: vr = decord.AVReader(video_file, sample_rate=self.audio_rate) else: vr = decord.VideoReader(video_file) len_video = len(vr) if "AA_processed" in video_file or "1000actors_nsv" in video_file: len_video *= 25 / 60 len_video = int(len_video) else: video_file, audio_file, land_file = self._indexes[idx] if self.audio_in_video: vr = decord.AVReader(video_file, sample_rate=self.audio_rate) else: vr = decord.VideoReader(video_file) len_video = len(vr) if "AA_processed" in video_file or "1000actors_nsv" in video_file: len_video *= 25 / 60 len_video = int(len_video) indexes = self.get_frame_indices( len_video, select_randomly=self.select_randomly, start_idx=120 if "1000actors_nsv" in video_file else 0, ) if self.get_separate_id: id_idx = np.random.randint(0, len_video) indexes.insert(0, id_idx) if "AA_processed" in video_file or "1000actors_nsv" in video_file: video_indexes = self.convert_indexes(indexes, fps_from=25, fps_to=60) audio_file = audio_file.replace("_output_output", "") if self.audio_emb_type == "wav2vec2" and "AA_processed" in video_file: audio_path_extra = ".safetensors" else: audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors" video_path_extra = f"_{self.latent_type}_512_latent.safetensors" audio_path_extra_extra = ( ".pt" if "AA_processed" in video_file else "_beats_emb.pt" ) else: video_indexes = indexes audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors" video_path_extra = f"_{self.latent_type}_512_latent.safetensors" audio_path_extra_extra = "_beats_emb.pt" emotions = None if self.use_emotions: emotions = self.get_emotions(video_file, video_indexes) if self.get_separate_id: emotions = (emotions[0][1:], emotions[1][1:], emotions[2][1:]) raw_audio = None if self.audio_in_video: raw_audio, frames_video = vr.get_batch(video_indexes) raw_audio = rearrange(self.ensure_shape(raw_audio), "f s -> (f s)") if self.use_latent and self.precomputed_latent: latent_file = video_file.replace(self.video_ext, video_path_extra).replace( self.video_folder, self.latent_folder ) frames = load_file(latent_file)["latents"][video_indexes, :, :, :] if frames.shape[-1] != 64: print(f"Frames shape: {frames.shape}, video file: {video_file}") frames = rearrange(frames, "t c h w -> c t h w") * self.latent_scale frames = self.normalize_latents(frames) else: if self.audio_in_video: frames = frames_video.permute(3, 0, 1, 2).float() else: frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float() if raw_audio is None: # Audio is not in video raw_audio = self._load_audio( audio_file, max_len_sec=frames.shape[1] / self.fps, start=indexes[0] / self.fps, # indexes=indexes, ) if not self.from_audio_embedding: audio = raw_audio audio_frames = rearrange(audio, "(f s) -> f s", s=self.samples_per_frame) else: audio = load_file( audio_file.replace(self.audio_folder, self.audio_emb_folder).split(".")[ 0 ] + audio_path_extra )["audio"] audio_frames = audio[indexes, :] if self.add_extra_audio_emb: audio_extra = torch.load( audio_file.replace(self.audio_folder, self.audio_emb_folder).split( "." )[0] + audio_path_extra_extra ) audio_extra = audio_extra[indexes, :] audio_frames = torch.cat([audio_frames, audio_extra], dim=-1) audio_frames = ( audio_frames[1:] if self.need_cond else audio_frames ) # Remove audio of first frame if self.get_original_frames: original_frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float() original_frames = self.scale_and_crop((original_frames / 255.0) * 2 - 1) original_frames = ( original_frames[:, 1:] if self.need_cond else original_frames ) else: original_frames = None if not self.use_latent or (self.use_latent and not self.precomputed_latent): frames = self.scale_and_crop((frames / 255.0) * 2 - 1) target = frames[:, 1:] if self.need_cond else frames if self.mode == "prediction": if self.use_latent: if self.audio_in_video: clean_cond = ( frames_video[0].unsqueeze(0).permute(3, 0, 1, 2).float() ) else: clean_cond = ( vr[video_indexes[0]].unsqueeze(0).permute(3, 0, 1, 2).float() ) original_size = clean_cond.shape[-2:] clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1).squeeze( 0 ) if self.latent_condition: noisy_cond = frames[:, 0] else: noisy_cond = clean_cond else: clean_cond = frames[:, 0] noisy_cond = clean_cond elif self.mode == "interpolation": if self.use_latent: if self.audio_in_video: clean_cond = frames_video[[0, -1]].permute(3, 0, 1, 2).float() else: clean_cond = ( vr.get_batch([video_indexes[0], video_indexes[-1]]) .permute(3, 0, 1, 2) .float() ) original_size = clean_cond.shape[-2:] clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1) if self.latent_condition: noisy_cond = torch.stack([target[:, 0], target[:, -1]], dim=1) else: noisy_cond = clean_cond else: clean_cond = torch.stack([target[:, 0], target[:, -1]], dim=1) noisy_cond = clean_cond # Add noise to conditional frame if self.cond_noise and isinstance(self.cond_noise, ListConfig): cond_noise = ( self.cond_noise[0] + self.cond_noise[1] * torch.randn((1,)) ).exp() noisy_cond = noisy_cond + cond_noise * torch.randn_like(noisy_cond) else: noisy_cond = noisy_cond + self.cond_noise * torch.randn_like(noisy_cond) cond_noise = self.cond_noise if self.get_masks: target_size = ( (self.resize_size, self.resize_size) if not self.use_latent else (self.resize_size // 8, self.resize_size // 8) ) masks, landmarks = self._load_landmarks( land_file, original_size, target_size, video_indexes ) landmarks = None masks = ( masks.permute(1, 0, 2, 3)[:, 1:] if self.need_cond else masks.permute(1, 0, 2, 3) ) else: masks = None landmarks = None return ( original_frames, clean_cond, noisy_cond, target, audio_frames, raw_audio, cond_noise, emotions, masks, landmarks, ) def filter_by_length(self, video_filelist, audio_filelist): def with_opencv(filename): video = cv2.VideoCapture(filename) frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT) return int(frame_count) filtered_video = [] filtered_audio = [] min_length = (self.num_frames - 1) * (self.skip_frames + 1) + 1 for vid_file, audio_file in tqdm( zip(video_filelist, audio_filelist), total=len(video_filelist), desc="Filtering", ): # vr = decord.VideoReader(vid_file) len_video = with_opencv(vid_file) # Short videos if len_video < min_length: continue filtered_video.append(vid_file) filtered_audio.append(audio_file) print(f"New number of files: {len(filtered_video)}") return filtered_video, filtered_audio def _get_indexes(self, video_filelist, audio_filelist): indexes = [] self.og_shape = None for vid_file, audio_file in zip(video_filelist, audio_filelist): vr = decord.VideoReader(vid_file) if self.og_shape is None: self.og_shape = vr[0].shape[-2] len_video = len(vr) # Short videos if len_video < self.num_frames: continue else: possible_indexes = list( sliding_window(range(len_video), self.num_frames) )[:: self.step] possible_indexes = list( map(lambda x: (x, vid_file, audio_file), possible_indexes) ) indexes.extend(possible_indexes) print("Indexes", len(indexes), "\n") return indexes def scale_and_crop(self, video): h, w = video.shape[-2], video.shape[-1] # scale shorter side to resolution if self.resize_size is not None: scale = self.resize_size / min(h, w) if h < w: target_size = (self.resize_size, math.ceil(w * scale)) else: target_size = (math.ceil(h * scale), self.resize_size) video = F.interpolate( video, size=target_size, mode="bilinear", align_corners=False, antialias=True, ) # center crop h, w = video.shape[-2], video.shape[-1] w_start = (w - self.resize_size) // 2 h_start = (h - self.resize_size) // 2 video = video[ :, :, h_start : h_start + self.resize_size, w_start : w_start + self.resize_size, ] return self.maybe_augment(video) def _calculate_weights(self): aa_processed_count = sum( 1 for item in self._indexes if "AA_processed" in (item[1] if len(item) == 3 else item[0]) ) nsv_processed_count = sum( 1 for item in self._indexes if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0]) ) other_count = len(self._indexes) - aa_processed_count - nsv_processed_count aa_processed_weight = 1 / aa_processed_count if aa_processed_count > 0 else 0 nsv_processed_weight = 1 / nsv_processed_count if nsv_processed_count > 0 else 0 other_weight = 1 / other_count if other_count > 0 else 0 print( f"AA processed count: {aa_processed_count}, NSV processed count: {nsv_processed_count}, other count: {other_count}" ) print(f"AA processed weight: {aa_processed_weight}") print(f"NSV processed weight: {nsv_processed_weight}") print(f"Other weight: {other_weight}") weights = [ aa_processed_weight if "AA_processed" in (item[1] if len(item) == 3 else item[0]) else nsv_processed_weight if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0]) else other_weight for item in self._indexes ] return weights def __getitem__(self, idx): if self.balance_datasets: idx = self.sampler.__iter__().__next__() try: ( original_frames, clean_cond, noisy_cond, target, audio, raw_audio, cond_noise, emotions, masks, landmarks, ) = self._get_frames_and_audio(idx % len(self._indexes)) except Exception as e: print(f"Error with index {idx}: {e}") return self.__getitem__(np.random.randint(0, len(self))) out_data = {} if original_frames is not None: out_data["original_frames"] = original_frames if audio is not None: out_data["audio_emb"] = audio out_data["raw_audio"] = raw_audio if self.use_emotions: out_data["valence"] = emotions[0] out_data["arousal"] = emotions[1] out_data["emo_labels"] = emotions[2] if self.use_latent: input_key = "latents" else: input_key = "frames" out_data[input_key] = target if noisy_cond is not None: out_data["cond_frames"] = noisy_cond out_data["cond_frames_without_noise"] = clean_cond if cond_noise is not None: out_data["cond_aug"] = cond_noise if masks is not None: out_data["masks"] = masks out_data["gt"] = target if landmarks is not None: out_data["landmarks"] = landmarks out_data["motion_bucket_id"] = torch.tensor([self.motion_id]) out_data["fps_id"] = torch.tensor([self.fps - 1]) out_data["num_video_frames"] = self.num_frames out_data["image_only_indicator"] = torch.zeros(self.num_frames) return out_data if __name__ == "__main__": import torchvision.transforms as transforms import cv2 transform = transforms.Compose(transforms=[transforms.Resize((256, 256))]) dataset = VideoDataset( "/vol/paramonos2/projects/antoni/datasets/mahnob/filelist_videos_val.txt", transform=transform, num_frames=25, ) print(len(dataset)) idx = np.random.randint(0, len(dataset)) for i in range(10): print(dataset[i][0].shape, dataset[i][1].shape) image_identity = (dataset[idx][0].permute(1, 2, 0).numpy() + 1) / 2 * 255 image_other = (dataset[idx][1][:, -1].permute(1, 2, 0).numpy() + 1) / 2 * 255 cv2.imwrite("image_identity.png", image_identity[:, :, ::-1]) for i in range(25): image = (dataset[idx][1][:, i].permute(1, 2, 0).numpy() + 1) / 2 * 255 cv2.imwrite(f"tmp_vid_dataset/image_{i}.png", image[:, :, ::-1])