import math import os import random import warnings import librosa import numpy as np import torch from PIL import Image import cv2 from einops import rearrange import torchvision.transforms.functional as TF from torch.utils.data.dataset import Dataset import torch.nn.functional as F def get_random_mask(shape, image_start_only=False): f, c, h, w = shape mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) if not image_start_only: if f != 1: mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05]) else: mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) if mask_index == 0: center_x = torch.randint(0, w, (1,)).item() center_y = torch.randint(0, h, (1,)).item() block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 start_x = max(center_x - block_size_x // 2, 0) end_x = min(center_x + block_size_x // 2, w) start_y = max(center_y - block_size_y // 2, 0) end_y = min(center_y + block_size_y // 2, h) mask[:, :, start_y:end_y, start_x:end_x] = 1 elif mask_index == 1: mask[:, :, :, :] = 1 elif mask_index == 2: mask_frame_index = np.random.randint(1, 5) mask[mask_frame_index:, :, :, :] = 1 elif mask_index == 3: mask_frame_index = np.random.randint(1, 5) mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 elif mask_index == 4: center_x = torch.randint(0, w, (1,)).item() center_y = torch.randint(0, h, (1,)).item() block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 start_x = max(center_x - block_size_x // 2, 0) end_x = min(center_x + block_size_x // 2, w) start_y = max(center_y - block_size_y // 2, 0) end_y = min(center_y + block_size_y // 2, h) mask_frame_before = np.random.randint(0, f // 2) mask_frame_after = np.random.randint(f // 2, f) mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 elif mask_index == 5: mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8) elif mask_index == 6: num_frames_to_mask = random.randint(1, max(f // 2, 1)) frames_to_mask = random.sample(range(f), num_frames_to_mask) for i in frames_to_mask: block_height = random.randint(1, h // 4) block_width = random.randint(1, w // 4) top_left_y = random.randint(0, h - block_height) top_left_x = random.randint(0, w - block_width) mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1 elif mask_index == 7: center_x = torch.randint(0, w, (1,)).item() center_y = torch.randint(0, h, (1,)).item() a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴 b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴 for i in range(h): for j in range(w): if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1: mask[:, :, i, j] = 1 elif mask_index == 8: center_x = torch.randint(0, w, (1,)).item() center_y = torch.randint(0, h, (1,)).item() radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() for i in range(h): for j in range(w): if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2: mask[:, :, i, j] = 1 elif mask_index == 9: for idx in range(f): if np.random.rand() > 0.5: mask[idx, :, :, :] = 1 else: raise ValueError(f"The mask_index {mask_index} is not define") else: if f != 1: mask[1:, :, :, :] = 1 else: mask[:, :, :, :] = 1 return mask class LargeScaleTalkingFantasyVideos(Dataset): def __init__(self, txt_path, width, height, n_sample_frames, sample_frame_rate, only_last_features=False, vocal_encoder=None, audio_encoder=None, vocal_sample_rate=16000, audio_sample_rate=24000, enable_inpaint=True, audio_margin=2, vae_stride=None, patch_size=None, wav2vec_processor=None, wav2vec=None): self.txt_path = txt_path self.width = width self.height = height self.n_sample_frames = n_sample_frames self.sample_frame_rate = sample_frame_rate self.only_last_features = only_last_features self.vocal_encoder = vocal_encoder self.audio_encoder = audio_encoder self.vocal_sample_rate = vocal_sample_rate self.audio_sample_rate = audio_sample_rate self.enable_inpaint = enable_inpaint self.wav2vec_processor = wav2vec_processor self.audio_margin = audio_margin self.vae_stride = vae_stride self.patch_size = patch_size self.max_area = height * width self.aspect_ratio = height / width self.video_files = self._read_txt_file_images() self.lat_h = round( np.sqrt(self.max_area * self.aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) self.lat_w = round( np.sqrt(self.max_area / self.aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) def _read_txt_file_images(self): with open(self.txt_path, 'r') as file: lines = file.readlines() video_files = [] for line in lines: video_file = line.strip() video_files.append(video_file) return video_files def __len__(self): return len(self.video_files) def frame_count(self, frames_path): files = os.listdir(frames_path) png_files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')] png_files_count = len(png_files) return png_files_count def find_frames_list(self, frames_path): files = os.listdir(frames_path) image_files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')] if image_files[0].startswith('frame_'): image_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) else: image_files.sort(key=lambda x: int(x.split('.')[0])) return image_files def __getitem__(self, idx): warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=FutureWarning) video_path = os.path.join(self.video_files[idx], "sub_clip.mp4") cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) try: is_0_fps = 2 / fps except Exception as e: print(f"The fps of {video_path} is 0 !!!") vocal_audio_path = os.path.join(self.video_files[idx], "audio.wav") vocal_duration = librosa.get_duration(filename=vocal_audio_path) frames_path = os.path.join(self.video_files[idx], "images") total_frame_number = self.frame_count(frames_path) fps = total_frame_number / vocal_duration print(f"The calculated fps of {video_path} is {fps} !!!") # idx = random.randint(0, len(self.video_files) - 1) # video_path = os.path.join(self.video_files[idx], "sub_clip.mp4") # cap = cv2.VideoCapture(video_path) # fps = cap.get(cv2.CAP_PROP_FPS) frames_path = os.path.join(self.video_files[idx], "images") face_masks_path = os.path.join(self.video_files[idx], "face_masks") lip_masks_path = os.path.join(self.video_files[idx], "lip_masks") raw_audio_path = os.path.join(self.video_files[idx], "audio.wav") # vocal_audio_path = os.path.join(self.video_files[idx], "vocal.wav") vocal_audio_path = os.path.join(self.video_files[idx], "audio.wav") video_length = self.frame_count(frames_path) frames_list = self.find_frames_list(frames_path) clip_length = min(video_length, (self.n_sample_frames - 1) * self.sample_frame_rate + 1) start_idx = random.randint(0, video_length - clip_length) batch_index = np.linspace( start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int ).tolist() all_indices = list(range(0, video_length)) reference_frame_idx = random.choice(all_indices) tgt_pil_image_list = [] tgt_face_masks_list = [] tgt_lip_masks_list = [] # reference_frame_path = os.path.join(frames_path, frames_list[reference_frame_idx]) reference_frame_path = os.path.join(frames_path, frames_list[start_idx]) reference_pil_image = Image.open(reference_frame_path).convert('RGB') reference_pil_image = reference_pil_image.resize((self.width, self.height)) reference_pil_image = torch.from_numpy(np.array(reference_pil_image)).float() reference_pil_image = reference_pil_image / 127.5 - 1 for index in batch_index: tgt_img_path = os.path.join(frames_path, frames_list[index]) # file_name = os.path.splitext(os.path.basename(tgt_img_path))[0] file_name = os.path.basename(tgt_img_path) face_mask_path = os.path.join(face_masks_path, file_name) lip_mask_path = os.path.join(lip_masks_path, file_name) try: tgt_img_pil = Image.open(tgt_img_path).convert('RGB') except Exception as e: print(f"Fail loading the image: {tgt_img_path}") try: tgt_lip_mask = Image.open(lip_mask_path) # tgt_lip_mask = Image.open(lip_mask_path).convert('RGB') tgt_lip_mask = tgt_lip_mask.resize((self.width, self.height)) tgt_lip_mask = torch.from_numpy(np.array(tgt_lip_mask)).float() # tgt_lip_mask = tgt_lip_mask / 127.5 - 1 tgt_lip_mask = tgt_lip_mask / 255 except Exception as e: print(f"Fail loading the lip masks: {lip_mask_path}") tgt_lip_mask = torch.ones(self.height, self.width) # tgt_lip_mask = torch.ones(self.height, self.width, 3) tgt_lip_masks_list.append(tgt_lip_mask) try: tgt_face_mask = Image.open(face_mask_path) # tgt_face_mask = Image.open(face_mask_path).convert('RGB') tgt_face_mask = tgt_face_mask.resize((self.width, self.height)) tgt_face_mask = torch.from_numpy(np.array(tgt_face_mask)).float() tgt_face_mask = tgt_face_mask / 255 # tgt_face_mask = tgt_face_mask / 127.5 - 1 except Exception as e: print(f"Fail loading the face masks: {face_mask_path}") tgt_face_mask = torch.ones(self.height, self.width) # tgt_face_mask = torch.ones(self.height, self.width, 3) tgt_face_masks_list.append(tgt_face_mask) tgt_img_pil = tgt_img_pil.resize((self.width, self.height)) tgt_img_tensor = torch.from_numpy(np.array(tgt_img_pil)).float() tgt_img_normalized = tgt_img_tensor / 127.5 - 1 tgt_pil_image_list.append(tgt_img_normalized) sr = 16000 vocal_input, sample_rate = librosa.load(vocal_audio_path, sr=sr) vocal_duration = librosa.get_duration(filename=vocal_audio_path) start_time = batch_index[0] / fps end_time = (clip_length / fps) + start_time start_sample = int(start_time * sr) end_sample = int(end_time * sr) try: vocal_segment = vocal_input[start_sample:end_sample] except: print(f"The current vocal segment is too short: {vocal_audio_path}, [{batch_index[0]}, {batch_index[-1]}], fps={fps}, clip_length={clip_length}, vocal_duration={vocal_duration}], [{start_time}, {end_time}]") vocal_segment = vocal_input[start_sample:] vocal_input_values = self.wav2vec_processor( vocal_segment, sampling_rate=sample_rate, return_tensors="pt" ).input_values tgt_pil_image_list = torch.stack(tgt_pil_image_list, dim=0) tgt_pil_image_list = rearrange(tgt_pil_image_list, "f h w c -> f c h w") reference_pil_image = rearrange(reference_pil_image, "h w c -> c h w") tgt_face_masks_list = torch.stack(tgt_face_masks_list, dim=0) tgt_face_masks_list = torch.unsqueeze(tgt_face_masks_list, dim=-1) tgt_face_masks_list = rearrange(tgt_face_masks_list, "f h w c -> c f h w") tgt_lip_masks_list = torch.stack(tgt_lip_masks_list, dim=0) tgt_lip_masks_list = torch.unsqueeze(tgt_lip_masks_list, dim=-1) tgt_lip_masks_list = rearrange(tgt_lip_masks_list, "f h w c -> c f h w") clip_pixel_values = reference_pil_image.permute(1, 2, 0).contiguous() clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 cos_similarities = [] stride = 8 for i in range(0, tgt_pil_image_list.size()[0] - stride, stride): frame1 = tgt_pil_image_list[i] frame2 = tgt_pil_image_list[i + stride] frame1_flat = frame1.contiguous().view(-1) frame2_flat = frame2.contiguous().view(-1) cos_sim = F.cosine_similarity(frame1_flat, frame2_flat, dim=0) cos_sim = (cos_sim + 1) / 2 cos_similarities.append(cos_sim.item()) overall_cos_sim = F.cosine_similarity(tgt_pil_image_list[0].contiguous().view(-1), tgt_pil_image_list[-1].contiguous().view(-1), dim=0) overall_cos_sim = (overall_cos_sim + 1) / 2 cos_similarities.append(overall_cos_sim.item()) motion_id = (1.0 - sum(cos_similarities) / len(cos_similarities)) * 100 if "singing" in self.video_files[idx]: text_prompt = "The protagonist is singing" elif "speech" in self.video_files[idx]: text_prompt = "The protagonist is talking" elif "dancing" in self.video_files[idx]: text_prompt = "The protagonist is simultaneously dancing and singing" else: text_prompt = "" print(1 / 0) sample = dict( pixel_values=tgt_pil_image_list, reference_image=reference_pil_image, clip_pixel_values=clip_pixel_values, tgt_face_masks=tgt_face_masks_list, vocal_input_values=vocal_input_values, text_prompt=text_prompt, motion_id=motion_id, tgt_lip_masks=tgt_lip_masks_list, audio_path=raw_audio_path, ) if self.enable_inpaint: pixel_value_masks = get_random_mask(tgt_pil_image_list.size(), image_start_only=True) masked_pixel_values = tgt_pil_image_list * (1-pixel_value_masks) sample["masked_pixel_values"] = masked_pixel_values sample["pixel_value_masks"] = pixel_value_masks return sample