|
|
import logging |
|
|
from argparse import ArgumentParser |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
|
|
|
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, |
|
|
setup_eval_logging) |
|
|
from mmaudio.model.flow_matching import FlowMatching |
|
|
from mmaudio.model.networks import MMAudio, get_my_mmaudio |
|
|
from mmaudio.model.utils.features_utils import FeaturesUtils |
|
|
import os |
|
|
from mmaudio.ext.mel_converter import get_mel_converter |
|
|
from mmaudio.ext.autoencoder import AutoEncoderModule |
|
|
import time |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
import tqdm |
|
|
import glob |
|
|
log = logging.getLogger() |
|
|
|
|
|
class Audio: |
|
|
def __init__(self, audio_path, sample_rate): |
|
|
self.audio_paths = audio_path |
|
|
self.sample_rate = sample_rate |
|
|
self.num_timbre_sample = 89088 if sample_rate == 44100 else 32768 |
|
|
self.resampler = {} |
|
|
|
|
|
def load_audio(self): |
|
|
chunk_list=[] |
|
|
for audio_path in self.audio_paths: |
|
|
audio_chunk, sample_rate = torchaudio.load(audio_path) |
|
|
audio_chunk = audio_chunk.mean(dim=0) |
|
|
abs_max = audio_chunk.abs().max() |
|
|
audio_chunk = audio_chunk / abs_max * 0.95 |
|
|
|
|
|
|
|
|
if sample_rate == self.sample_rate: |
|
|
audio_chunk = audio_chunk |
|
|
else: |
|
|
if sample_rate not in self.resampler: |
|
|
|
|
|
self.resampler[sample_rate] = torchaudio.transforms.Resample( |
|
|
sample_rate, |
|
|
self.sample_rate, |
|
|
lowpass_filter_width=64, |
|
|
rolloff=0.9475937167399596, |
|
|
resampling_method='sinc_interp_kaiser', |
|
|
beta=14.769656459379492, |
|
|
) |
|
|
audio_chunk = self.resampler[sample_rate](audio_chunk) |
|
|
if audio_chunk.size(0) < self.num_timbre_sample: |
|
|
padding_length = self.num_timbre_sample - audio_chunk.size(0) |
|
|
audio_chunk = torch.cat([audio_chunk, torch.zeros(padding_length)], dim=0) |
|
|
else: |
|
|
audio_chunk = audio_chunk[:self.num_timbre_sample] |
|
|
|
|
|
chunk_list.append(audio_chunk) |
|
|
return chunk_list |
|
|
|
|
|
def process_video(video_path: Path, args, model: ModelConfig, net: MMAudio, fm: FlowMatching, feature_utils: FeaturesUtils, device: str, dtype: torch.dtype, audio: torch.Tensor, i): |
|
|
log.info(f'Processing video: {video_path}') |
|
|
t=time.time() |
|
|
audio_num_sample = 89088 |
|
|
if audio is not None: |
|
|
audio_num_sample = audio.shape[0] |
|
|
video_info = load_video(video_path, args.duration) |
|
|
clip_frames = video_info.clip_frames |
|
|
sync_frames = video_info.sync_frames |
|
|
duration = video_info.duration_sec |
|
|
if args.mask_away_clip: |
|
|
clip_frames = None |
|
|
else: |
|
|
clip_frames = clip_frames.unsqueeze(0) |
|
|
sync_frames = sync_frames.unsqueeze(0) |
|
|
|
|
|
model.seq_cfg.duration = duration |
|
|
model.seq_cfg.audio_num_sample = audio_num_sample |
|
|
net.update_seq_lengths(model.seq_cfg.latent_seq_len, model.seq_cfg.clip_seq_len, model.seq_cfg.sync_seq_len, model.seq_cfg.audio_seq_len) |
|
|
|
|
|
log.info(f'Prompt: {args.prompt}') |
|
|
log.info(f'Negative prompt: {args.negative_prompt}') |
|
|
audios = generate(clip_frames, |
|
|
sync_frames, [args.prompt], audio, |
|
|
negative_text=[args.negative_prompt], |
|
|
feature_utils=feature_utils, |
|
|
net=net, |
|
|
fm=fm, |
|
|
rng=torch.Generator(device=device).manual_seed(args.seed), |
|
|
cfg_strength=args.cfg_strength) |
|
|
audio = audios.float().cpu()[0] |
|
|
save_path = args.output / f'{video_path.stem}{i}.wav' |
|
|
torchaudio.save(save_path, audio, model.seq_cfg.sampling_rate) |
|
|
log.info(f'Audio saved to {save_path}') |
|
|
|
|
|
if not args.skip_video_composite: |
|
|
video_save_path = args.output / f'{video_path.stem}{i}.mp4' |
|
|
make_video(video_info, video_save_path, audio, sampling_rate=model.seq_cfg.sampling_rate) |
|
|
log.info(f'Video saved to {video_save_path}') |
|
|
|
|
|
@torch.inference_mode() |
|
|
def main(): |
|
|
setup_eval_logging() |
|
|
|
|
|
parser = ArgumentParser() |
|
|
parser.add_argument('--variant', |
|
|
type=str, |
|
|
default='large_44k',) |
|
|
parser.add_argument('--video_dir', type=Path, help='') |
|
|
parser.add_argument('--audio_path', type=str, default='') |
|
|
parser.add_argument('--prompt', type=str, help='Input prompt', default='') |
|
|
parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='') |
|
|
parser.add_argument('--duration', type=float, default=8.0) |
|
|
parser.add_argument('--cfg_strength', type=float, default=4.5) |
|
|
parser.add_argument('--num_steps', type=int, default=25) |
|
|
parser.add_argument('--mask_away_clip', action='store_true') |
|
|
parser.add_argument('--output', type=Path, help='Output directory', default='./') |
|
|
parser.add_argument('--seed', type=int, help='Random seed', default=42) |
|
|
parser.add_argument('--skip_video_composite', action='store_true') |
|
|
parser.add_argument('--full_precision', action='store_true') |
|
|
parser.add_argument('--model_path', type=str, default='weights/model.pth', help='Path to the model weights') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.variant not in all_model_cfg: |
|
|
raise ValueError(f'Unknown model variant: {args.variant}') |
|
|
model: ModelConfig = all_model_cfg[args.variant] |
|
|
model.download_if_needed() |
|
|
|
|
|
device = 'cpu' |
|
|
if torch.cuda.is_available(): |
|
|
device = 'cuda' |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = 'mps' |
|
|
else: |
|
|
log.warning('CUDA/MPS are not available, running on CPU') |
|
|
dtype = torch.float32 if args.full_precision else torch.bfloat16 |
|
|
|
|
|
args.output.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if args.audio_path != '': |
|
|
SAMPLE_RATE = 44100 |
|
|
audio = Audio([args.audio_path], SAMPLE_RATE) |
|
|
audio_list = audio.load_audio() |
|
|
else: |
|
|
audio_list = None |
|
|
|
|
|
model.model_path = Path(args.model_path) |
|
|
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() |
|
|
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)['weights']) |
|
|
log.info(f'Loaded weights from {model.model_path}') |
|
|
|
|
|
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=args.num_steps) |
|
|
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, |
|
|
synchformer_ckpt=model.synchformer_ckpt, |
|
|
enable_conditions=True, |
|
|
mode=model.mode, |
|
|
bigvgan_vocoder_ckpt=model.bigvgan_16k_path, |
|
|
need_vae_encoder=True) |
|
|
feature_utils = feature_utils.to(device, dtype).eval() |
|
|
|
|
|
if args.video_dir: |
|
|
video_dir: Path = args.video_dir.expanduser() |
|
|
video_files = sorted(list(video_dir.glob('*.mp4'))) |
|
|
if os.path.isfile(args.video_dir): |
|
|
video_files=[args.video_dir] |
|
|
if not video_files: |
|
|
log.warning(f'No video files found in {video_dir}') |
|
|
else: |
|
|
if audio_list is None: |
|
|
audio_list = [None] * len(video_files) |
|
|
if len(audio_list)==1: |
|
|
audio_list = audio_list * len(video_files) |
|
|
for i in range(1): |
|
|
for video_path, audio in tqdm.tqdm(zip(video_files,audio_list)): |
|
|
args.seed = torch.seed() |
|
|
process_video(video_path, args, model, net, fm, feature_utils, device, dtype, audio, i) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |