import random import os import torch import torch.distributed as dist from PIL import Image import subprocess import torchvision.transforms as transforms import torch.nn.functional as F import torch.nn as nn import wan from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.utils import cache_image, cache_video, str2bool # from wan.utils.multitalk_utils import save_video_ffmpeg # from .kokoro import KPipeline from transformers import Wav2Vec2FeatureExtractor from .wav2vec2 import Wav2Vec2Model import librosa import pyloudnorm as pyln import numpy as np from einops import rearrange import soundfile as sf import re import math def custom_init(device, wav2vec): audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device) audio_encoder.feature_extractor._freeze_parameters() wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True) return wav2vec_feature_extractor, audio_encoder def loudness_norm(audio_array, sr=16000, lufs=-23): meter = pyln.Meter(sr) loudness = meter.integrated_loudness(audio_array) if abs(loudness) > 100: return audio_array normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) return normalized_audio def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25): audio_duration = len(speech_array) / sr video_length = audio_duration * fps # wav2vec_feature_extractor audio_feature = np.squeeze( wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values ) audio_feature = torch.from_numpy(audio_feature).float().to(device=device) audio_feature = audio_feature.unsqueeze(0) # audio encoder with torch.no_grad(): embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True) if len(embeddings) == 0: print("Fail to extract audio embedding") return None audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) audio_emb = rearrange(audio_emb, "b s d -> s b d") audio_emb = audio_emb.cpu().detach() return audio_emb def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): ext = os.path.splitext(audio_path)[1].lower() if ext in ['.mp4', '.mov', '.avi', '.mkv']: human_speech_array = extract_audio_from_video(audio_path, sample_rate) return human_speech_array else: human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate) human_speech_array = loudness_norm(human_speech_array, sr) return human_speech_array def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0, pad = 0): if not (left_path==None or right_path==None): human_speech_array1 = audio_prepare_single(left_path, duration = duration) human_speech_array2 = audio_prepare_single(right_path, duration = duration) elif left_path==None: human_speech_array2 = audio_prepare_single(right_path, duration = duration) human_speech_array1 = np.zeros(human_speech_array2.shape[0]) elif right_path==None: human_speech_array1 = audio_prepare_single(left_path, duration = duration) human_speech_array2 = np.zeros(human_speech_array1.shape[0]) if audio_type=='para': new_human_speech1 = human_speech_array1 new_human_speech2 = human_speech_array2 elif audio_type=='add': new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])]) new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]]) #dont include the padding on the summed audio which is used to build the output audio track sum_human_speechs = new_human_speech1 + new_human_speech2 if pad > 0: new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1]) new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2]) return new_human_speech1, new_human_speech2, sum_human_speechs def process_tts_single(text, save_dir, voice1): s1_sentences = [] pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M') voice_tensor = torch.load(voice1, weights_only=True) generator = pipeline( text, voice=voice_tensor, # <= change voice here speed=1, split_pattern=r'\n+' ) audios = [] for i, (gs, ps, audio) in enumerate(generator): audios.append(audio) audios = torch.concat(audios, dim=0) s1_sentences.append(audios) s1_sentences = torch.concat(s1_sentences, dim=0) save_path1 =f'{save_dir}/s1.wav' sf.write(save_path1, s1_sentences, 24000) # save each audio file s1, _ = librosa.load(save_path1, sr=16000) return s1, save_path1 def process_tts_multi(text, save_dir, voice1, voice2): pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)' matches = re.findall(pattern, text, re.DOTALL) s1_sentences = [] s2_sentences = [] pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M') for idx, (speaker, content) in enumerate(matches): if speaker == '1': voice_tensor = torch.load(voice1, weights_only=True) generator = pipeline( content, voice=voice_tensor, # <= change voice here speed=1, split_pattern=r'\n+' ) audios = [] for i, (gs, ps, audio) in enumerate(generator): audios.append(audio) audios = torch.concat(audios, dim=0) s1_sentences.append(audios) s2_sentences.append(torch.zeros_like(audios)) elif speaker == '2': voice_tensor = torch.load(voice2, weights_only=True) generator = pipeline( content, voice=voice_tensor, # <= change voice here speed=1, split_pattern=r'\n+' ) audios = [] for i, (gs, ps, audio) in enumerate(generator): audios.append(audio) audios = torch.concat(audios, dim=0) s2_sentences.append(audios) s1_sentences.append(torch.zeros_like(audios)) s1_sentences = torch.concat(s1_sentences, dim=0) s2_sentences = torch.concat(s2_sentences, dim=0) sum_sentences = s1_sentences + s2_sentences save_path1 =f'{save_dir}/s1.wav' save_path2 =f'{save_dir}/s2.wav' save_path_sum = f'{save_dir}/sum.wav' sf.write(save_path1, s1_sentences, 24000) # save each audio file sf.write(save_path2, s2_sentences, 24000) sf.write(save_path_sum, sum_sentences, 24000) s1, _ = librosa.load(save_path1, sr=16000) s2, _ = librosa.load(save_path2, sr=16000) # sum, _ = librosa.load(save_path_sum, sr=16000) return s1, s2, save_path_sum def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0): wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") pad = int(padded_frames_for_embeddings/ fps * sr) new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad) audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) full_audio_embs = [] if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) if audio_guide2 == None: sum_human_speechs = None return full_audio_embs, sum_human_speechs def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5): if full_audio_embs == None: return None HUMAN_NUMBER = len(full_audio_embs) audio_end_idx = audio_start_idx + clip_length indices = (torch.arange(2 * 2 + 1) - 2) * 1 audio_embs = [] # split audio with window size for human_idx in range(HUMAN_NUMBER): center_indices = torch.arange( audio_start_idx, audio_end_idx, 1 ).unsqueeze( 1 ) + indices.unsqueeze(0) center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device) audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device) audio_embs.append(audio_emb) audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype) # audio_cond = audio.to(device=x.device, dtype=x.dtype) audio_cond = audio_embs first_frame_audio_emb_s = audio_cond[:, :1, ...] latter_frame_audio_emb = audio_cond[:, 1:, ...] latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale) middle_index = audio_window // 2 latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) return [first_frame_audio_emb_s, latter_frame_audio_emb_s] def resize_and_centercrop(cond_image, target_size): """ Resize image or tensor to the target size without padding. """ # Get the original size if isinstance(cond_image, torch.Tensor): _, orig_h, orig_w = cond_image.shape else: orig_h, orig_w = cond_image.height, cond_image.width target_h, target_w = target_size # Calculate the scaling factor for resizing scale_h = target_h / orig_h scale_w = target_w / orig_w # Compute the final size scale = max(scale_h, scale_w) final_h = math.ceil(scale * orig_h) final_w = math.ceil(scale * orig_w) # Resize if isinstance(cond_image, torch.Tensor): if len(cond_image.shape) == 3: cond_image = cond_image[None] resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() # crop cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) cropped_tensor = cropped_tensor.squeeze(0) else: resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR) resized_image = np.array(resized_image) # tensor and crop resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous() cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) cropped_tensor = cropped_tensor[:, :, None, :, :] return cropped_tensor def timestep_transform( t, shift=5.0, num_timesteps=1000, ): t = t / num_timesteps # shift the timestep based on ratio new_t = shift * t / (1 + (shift - 1) * t) new_t = new_t * num_timesteps return new_t def parse_speakers_locations(speakers_locations): bbox = {} if speakers_locations is None or len(speakers_locations) == 0: return None, "" speakers = speakers_locations.split(" ") if len(speakers) !=2: error= "Two speakers locations should be defined" return "", error for i, speaker in enumerate(speakers): location = speaker.strip().split(":") if len(location) not in (2,4): error = f"Invalid Speaker Location '{location}'. A Speaker Location should be defined in the format Left:Right or usuing a BBox Left:Top:Right:Bottom" return "", error try: good = False location_float = [ float(val) for val in location] good = all( 0 <= val <= 100 for val in location_float) except: pass if not good: error = f"Invalid Speaker Location '{location}'. Each number should be between 0 and 100." return "", error if len(location_float) == 2: location_float = [location_float[0], 0, location_float[1], 100] bbox[f"human{i}"] = location_float return bbox, "" # construct human mask def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None): human_masks = [] if HUMAN_NUMBER==1: background_mask = torch.ones([src_h, src_w]) human_mask1 = torch.ones([src_h, src_w]) human_mask2 = torch.ones([src_h, src_w]) human_masks = [human_mask1, human_mask2, background_mask] elif HUMAN_NUMBER==2: if bbox != None: assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio" background_mask = torch.zeros([src_h, src_w]) for _, person_bbox in bbox.items(): y_min, x_min, y_max, x_max = person_bbox x_min, y_min, x_max, y_max = max(x_min,5), max(y_min, 5), min(x_max,95), min(y_max,95) x_min, y_min, x_max, y_max = int(src_h * x_min / 100), int(src_w * y_min / 100), int(src_h * x_max / 100), int(src_w * y_max / 100) human_mask = torch.zeros([src_h, src_w]) human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1 background_mask += human_mask human_masks.append(human_mask) else: x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale)) background_mask = torch.zeros([src_h, src_w]) background_mask = torch.zeros([src_h, src_w]) human_mask1 = torch.zeros([src_h, src_w]) human_mask2 = torch.zeros([src_h, src_w]) lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale)) righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2)) human_mask1[x_min:x_max, lefty_min:lefty_max] = 1 human_mask2[x_min:x_max, righty_min:righty_max] = 1 background_mask += human_mask1 background_mask += human_mask2 human_masks = [human_mask1, human_mask2] background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1)) human_masks.append(background_mask) # toto = Image.fromarray(human_masks[2].mul_(255).unsqueeze(-1).repeat(1,1,3).to(torch.uint8).cpu().numpy()) ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device) # resize and centercrop for ref_target_masks # ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w)) N_h, N_w = lat_h // 2, lat_w // 2 token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze() token_ref_target_masks = (token_ref_target_masks > 0) token_ref_target_masks = token_ref_target_masks.float() #.to(self.device) token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) return token_ref_target_masks