Spaces:
Paused
Paused
import os | |
from einops import rearrange | |
import torch | |
import torch.nn as nn | |
from xfuser.core.distributed import ( | |
get_sequence_parallel_rank, | |
get_sequence_parallel_world_size, | |
get_sp_group, | |
) | |
from einops import rearrange, repeat | |
from functools import lru_cache | |
import imageio | |
import uuid | |
from tqdm import tqdm | |
import numpy as np | |
import subprocess | |
import soundfile as sf | |
import torchvision | |
import binascii | |
import os.path as osp | |
from skimage import color | |
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") | |
ASPECT_RATIO_627 = { | |
'0.26': ([320, 1216], 1), '0.38': ([384, 1024], 1), '0.50': ([448, 896], 1), '0.67': ([512, 768], 1), | |
'0.82': ([576, 704], 1), '1.00': ([640, 640], 1), '1.22': ([704, 576], 1), '1.50': ([768, 512], 1), | |
'1.86': ([832, 448], 1), '2.00': ([896, 448], 1), '2.50': ([960, 384], 1), '2.83': ([1088, 384], 1), | |
'3.60': ([1152, 320], 1), '3.80': ([1216, 320], 1), '4.00': ([1280, 320], 1)} | |
ASPECT_RATIO_960 = { | |
'0.22': ([448, 2048], 1), '0.29': ([512, 1792], 1), '0.36': ([576, 1600], 1), '0.45': ([640, 1408], 1), | |
'0.55': ([704, 1280], 1), '0.63': ([768, 1216], 1), '0.76': ([832, 1088], 1), '0.88': ([896, 1024], 1), | |
'1.00': ([960, 960], 1), '1.14': ([1024, 896], 1), '1.31': ([1088, 832], 1), '1.50': ([1152, 768], 1), | |
'1.58': ([1216, 768], 1), '1.82': ([1280, 704], 1), '1.91': ([1344, 704], 1), '2.20': ([1408, 640], 1), | |
'2.30': ([1472, 640], 1), '2.67': ([1536, 576], 1), '2.89': ([1664, 576], 1), '3.62': ([1856, 512], 1), | |
'3.75': ([1920, 512], 1)} | |
def torch_gc(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
def split_token_counts_and_frame_ids(T, token_frame, world_size, rank): | |
S = T * token_frame | |
split_sizes = [S // world_size + (1 if i < S % world_size else 0) for i in range(world_size)] | |
start = sum(split_sizes[:rank]) | |
end = start + split_sizes[rank] | |
counts = [0] * T | |
for idx in range(start, end): | |
t = idx // token_frame | |
counts[t] += 1 | |
counts_filtered = [] | |
frame_ids = [] | |
for t, c in enumerate(counts): | |
if c > 0: | |
counts_filtered.append(c) | |
frame_ids.append(t) | |
return counts_filtered, frame_ids | |
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): | |
source_min, source_max = source_range | |
new_min, new_max = target_range | |
normalized = (column - source_min) / (source_max - source_min + epsilon) | |
scaled = normalized * (new_max - new_min) + new_min | |
return scaled | |
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, mode='mean', attn_bias=None): | |
ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) | |
scale = 1.0 / visual_q.shape[-1] ** 0.5 | |
visual_q = visual_q * scale | |
visual_q = visual_q.transpose(1, 2) | |
ref_k = ref_k.transpose(1, 2) | |
attn = visual_q @ ref_k.transpose(-2, -1) | |
if attn_bias is not None: | |
attn = attn + attn_bias | |
x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens | |
x_ref_attn_maps = [] | |
ref_target_masks = ref_target_masks.to(visual_q.dtype) | |
x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) | |
for class_idx, ref_target_mask in enumerate(ref_target_masks): | |
torch_gc() | |
ref_target_mask = ref_target_mask[None, None, None, ...] | |
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask | |
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens | |
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H | |
if mode == 'mean': | |
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens | |
elif mode == 'max': | |
x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens | |
x_ref_attn_maps.append(x_ref_attnmap) | |
del attn | |
del x_ref_attn_map_source | |
torch_gc() | |
return torch.concat(x_ref_attn_maps, dim=0) | |
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2, enable_sp=False): | |
"""Args: | |
query (torch.tensor): B M H K | |
key (torch.tensor): B M H K | |
shape (tuple): (N_t, N_h, N_w) | |
ref_target_masks: [B, N_h * N_w] | |
""" | |
N_t, N_h, N_w = shape | |
if enable_sp: | |
ref_k = get_sp_group().all_gather(ref_k, dim=1) | |
x_seqlens = N_h * N_w | |
ref_k = ref_k[:, :x_seqlens] | |
_, seq_lens, heads, _ = visual_q.shape | |
class_num, _ = ref_target_masks.shape | |
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype) | |
split_chunk = heads // split_num | |
for i in range(split_num): | |
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks) | |
x_ref_attn_maps += x_ref_attn_maps_perhead | |
return x_ref_attn_maps / split_num | |
def rotate_half(x): | |
x = rearrange(x, "... (d r) -> ... d r", r=2) | |
x1, x2 = x.unbind(dim=-1) | |
x = torch.stack((-x2, x1), dim=-1) | |
return rearrange(x, "... d r -> ... (d r)") | |
class RotaryPositionalEmbedding1D(nn.Module): | |
def __init__(self, | |
head_dim, | |
): | |
super().__init__() | |
self.head_dim = head_dim | |
self.base = 10000 | |
def precompute_freqs_cis_1d(self, pos_indices): | |
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) | |
freqs = freqs.to(pos_indices.device) | |
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) | |
freqs = repeat(freqs, "... n -> ... (n r)", r=2) | |
return freqs | |
def forward(self, x, pos_indices): | |
"""1D RoPE. | |
Args: | |
query (torch.tensor): [B, head, seq, head_dim] | |
pos_indices (torch.tensor): [seq,] | |
Returns: | |
query with the same shape as input. | |
""" | |
freqs_cis = self.precompute_freqs_cis_1d(pos_indices) | |
x_ = x.float() | |
freqs_cis = freqs_cis.float().to(x.device) | |
cos, sin = freqs_cis.cos(), freqs_cis.sin() | |
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') | |
x_ = (x_ * cos) + (rotate_half(x_) * sin) | |
return x_.type_as(x) | |
def rand_name(length=8, suffix=''): | |
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') | |
if suffix: | |
if not suffix.startswith('.'): | |
suffix = '.' + suffix | |
name += suffix | |
return name | |
def cache_video(tensor, | |
save_file=None, | |
fps=30, | |
suffix='.mp4', | |
nrow=8, | |
normalize=True, | |
value_range=(-1, 1), | |
retry=5): | |
# cache file | |
cache_file = osp.join('/tmp', rand_name( | |
suffix=suffix)) if save_file is None else save_file | |
# save to cache | |
error = None | |
for _ in range(retry): | |
# preprocess | |
tensor = tensor.clamp(min(value_range), max(value_range)) | |
tensor = torch.stack([ | |
torchvision.utils.make_grid( | |
u, nrow=nrow, normalize=normalize, value_range=value_range) | |
for u in tensor.unbind(2) | |
], | |
dim=1).permute(1, 2, 3, 0) | |
tensor = (tensor * 255).type(torch.uint8).cpu() | |
# write video | |
writer = imageio.get_writer(cache_file, fps=fps, codec='libx264', quality=10, ffmpeg_params=["-crf", "10"]) | |
for frame in tensor.numpy(): | |
writer.append_data(frame) | |
writer.close() | |
return cache_file | |
def save_video_ffmpeg(gen_video_samples, save_path, vocal_audio_list, fps=25, quality=5, high_quality_save=False): | |
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): | |
writer = imageio.get_writer( | |
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params | |
) | |
for frame in tqdm(frames, desc="Saving video"): | |
frame = np.array(frame) | |
writer.append_data(frame) | |
writer.close() | |
save_path_tmp = save_path + "-temp.mp4" | |
if high_quality_save: | |
cache_video( | |
tensor=gen_video_samples.unsqueeze(0), | |
save_file=save_path_tmp, | |
fps=fps, | |
nrow=1, | |
normalize=True, | |
value_range=(-1, 1) | |
) | |
else: | |
video_audio = (gen_video_samples+1)/2 # C T H W | |
video_audio = video_audio.permute(1, 2, 3, 0).cpu().numpy() | |
video_audio = np.clip(video_audio * 255, 0, 255).astype(np.uint8) # to [0, 255] | |
save_video(video_audio, save_path_tmp, fps=fps, quality=quality) | |
# crop audio according to video length | |
_, T, _, _ = gen_video_samples.shape | |
duration = T / fps | |
save_path_crop_audio = save_path + "-cropaudio.wav" | |
final_command = [ | |
"ffmpeg", | |
"-i", | |
vocal_audio_list[0], | |
"-t", | |
f'{duration}', | |
save_path_crop_audio, | |
] | |
subprocess.run(final_command, check=True) | |
save_path = save_path + ".mp4" | |
if high_quality_save: | |
final_command = [ | |
"ffmpeg", | |
"-y", | |
"-i", save_path_tmp, | |
"-i", save_path_crop_audio, | |
"-c:v", "libx264", | |
"-crf", "0", | |
"-preset", "veryslow", | |
"-c:a", "aac", | |
"-shortest", | |
save_path, | |
] | |
subprocess.run(final_command, check=True) | |
os.remove(save_path_tmp) | |
os.remove(save_path_crop_audio) | |
else: | |
final_command = [ | |
"ffmpeg", | |
"-y", | |
"-i", | |
save_path_tmp, | |
"-i", | |
save_path_crop_audio, | |
"-c:v", | |
"libx264", | |
"-c:a", | |
"aac", | |
"-shortest", | |
save_path, | |
] | |
subprocess.run(final_command, check=True) | |
os.remove(save_path_tmp) | |
os.remove(save_path_crop_audio) | |
class MomentumBuffer: | |
def __init__(self, momentum: float): | |
self.momentum = momentum | |
self.running_average = 0 | |
def update(self, update_value: torch.Tensor): | |
new_average = self.momentum * self.running_average | |
self.running_average = update_value + new_average | |
def project( | |
v0: torch.Tensor, # [B, C, T, H, W] | |
v1: torch.Tensor, # [B, C, T, H, W] | |
): | |
dtype = v0.dtype | |
v0, v1 = v0.double(), v1.double() | |
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3, -4]) | |
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3, -4], keepdim=True) * v1 | |
v0_orthogonal = v0 - v0_parallel | |
return v0_parallel.to(dtype), v0_orthogonal.to(dtype) | |
def adaptive_projected_guidance( | |
diff: torch.Tensor, # [B, C, T, H, W] | |
pred_cond: torch.Tensor, # [B, C, T, H, W] | |
momentum_buffer: MomentumBuffer = None, | |
eta: float = 0.0, | |
norm_threshold: float = 55, | |
): | |
if momentum_buffer is not None: | |
momentum_buffer.update(diff) | |
diff = momentum_buffer.running_average | |
if norm_threshold > 0: | |
ones = torch.ones_like(diff) | |
diff_norm = diff.norm(p=2, dim=[-1, -2, -3, -4], keepdim=True) | |
print(f"diff_norm: {diff_norm}") | |
scale_factor = torch.minimum(ones, norm_threshold / diff_norm) | |
diff = diff * scale_factor | |
diff_parallel, diff_orthogonal = project(diff, pred_cond) | |
normalized_update = diff_orthogonal + eta * diff_parallel | |
return normalized_update | |
def match_and_blend_colors(source_chunk: torch.Tensor, reference_image: torch.Tensor, strength: float) -> torch.Tensor: | |
""" | |
Matches the color of a source video chunk to a reference image and blends with the original. | |
Args: | |
source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1]. | |
Assumes B=1 (batch size of 1). | |
reference_image (torch.Tensor): The reference image (B, C, 1, H, W) in range [-1, 1]. | |
Assumes B=1 and T=1 (single reference frame). | |
strength (float): The strength of the color correction (0.0 to 1.0). | |
0.0 means no correction, 1.0 means full correction. | |
Returns: | |
torch.Tensor: The color-corrected and blended video chunk. | |
""" | |
# print(f"[match_and_blend_colors] Input source_chunk shape: {source_chunk.shape}, reference_image shape: {reference_image.shape}, strength: {strength}") | |
if strength == 0.0: | |
# print(f"[match_and_blend_colors] Strength is 0, returning original source_chunk.") | |
return source_chunk | |
if not 0.0 <= strength <= 1.0: | |
raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}") | |
device = source_chunk.device | |
dtype = source_chunk.dtype | |
# Squeeze batch dimension, permute to T, H, W, C for skimage | |
# Source: (1, C, T, H, W) -> (T, H, W, C) | |
source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() | |
# Reference: (1, C, 1, H, W) -> (H, W, C) | |
ref_np = reference_image.squeeze(0).squeeze(1).permute(1, 2, 0).cpu().numpy() # Squeeze T dimension as well | |
# Normalize from [-1, 1] to [0, 1] for skimage | |
source_np_01 = (source_np + 1.0) / 2.0 | |
ref_np_01 = (ref_np + 1.0) / 2.0 | |
# Clip to ensure values are strictly in [0, 1] after potential float precision issues | |
source_np_01 = np.clip(source_np_01, 0.0, 1.0) | |
ref_np_01 = np.clip(ref_np_01, 0.0, 1.0) | |
# Convert reference to Lab | |
try: | |
ref_lab = color.rgb2lab(ref_np_01) | |
except ValueError as e: | |
# Handle potential errors if image data is not valid for conversion | |
print(f"Warning: Could not convert reference image to Lab: {e}. Skipping color correction for this chunk.") | |
return source_chunk | |
corrected_frames_np_01 = [] | |
for i in range(source_np_01.shape[0]): # Iterate over time (T) | |
source_frame_rgb_01 = source_np_01[i] | |
try: | |
source_lab = color.rgb2lab(source_frame_rgb_01) | |
except ValueError as e: | |
print(f"Warning: Could not convert source frame {i} to Lab: {e}. Using original frame.") | |
corrected_frames_np_01.append(source_frame_rgb_01) | |
continue | |
corrected_lab_frame = source_lab.copy() | |
# Perform color transfer for L, a, b channels | |
for j in range(3): # L, a, b | |
mean_src, std_src = source_lab[:, :, j].mean(), source_lab[:, :, j].std() | |
mean_ref, std_ref = ref_lab[:, :, j].mean(), ref_lab[:, :, j].std() | |
# Avoid division by zero if std_src is 0 | |
if std_src == 0: | |
# If source channel has no variation, keep it as is, but shift by reference mean | |
# This case is debatable, could also just copy source or target mean. | |
# Shifting by target mean helps if source is flat but target isn't. | |
corrected_lab_frame[:, :, j] = mean_ref | |
else: | |
corrected_lab_frame[:, :, j] = (corrected_lab_frame[:, :, j] - mean_src) * (std_ref / std_src) + mean_ref | |
try: | |
fully_corrected_frame_rgb_01 = color.lab2rgb(corrected_lab_frame) | |
except ValueError as e: | |
print(f"Warning: Could not convert corrected frame {i} back to RGB: {e}. Using original frame.") | |
corrected_frames_np_01.append(source_frame_rgb_01) | |
continue | |
# Clip again after lab2rgb as it can go slightly out of [0,1] | |
fully_corrected_frame_rgb_01 = np.clip(fully_corrected_frame_rgb_01, 0.0, 1.0) | |
# Blend with original source frame (in [0,1] RGB) | |
blended_frame_rgb_01 = (1 - strength) * source_frame_rgb_01 + strength * fully_corrected_frame_rgb_01 | |
corrected_frames_np_01.append(blended_frame_rgb_01) | |
corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0) | |
# Convert back to [-1, 1] | |
corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0 | |
# Permute back to (C, T, H, W), add batch dim, and convert to original torch.Tensor type and device | |
# (T, H, W, C) -> (C, T, H, W) | |
corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0) | |
corrected_chunk_tensor = corrected_chunk_tensor.contiguous() # Ensure contiguous memory layout | |
output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype) | |
# print(f"[match_and_blend_colors] Output tensor shape: {output_tensor.shape}") | |
return output_tensor | |