File size: 7,895 Bytes
e2bca25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import logging
from argparse import ArgumentParser
from pathlib import Path

import torch
import torchaudio

from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
                                setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
import os
from mmaudio.ext.mel_converter import get_mel_converter
from mmaudio.ext.autoencoder import AutoEncoderModule
import time
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import tqdm
import glob
log = logging.getLogger()

class Audio:
    def __init__(self, audio_path, sample_rate):
        self.audio_paths = audio_path
        self.sample_rate = sample_rate
        self.num_timbre_sample = 89088 if sample_rate == 44100 else 32768
        self.resampler = {}
    
    def load_audio(self):
        chunk_list=[]
        for audio_path in self.audio_paths:
            audio_chunk, sample_rate = torchaudio.load(audio_path)
            audio_chunk = audio_chunk.mean(dim=0)  # mono
            abs_max = audio_chunk.abs().max()
            audio_chunk = audio_chunk / abs_max * 0.95

            # resample
            if sample_rate == self.sample_rate:
                audio_chunk = audio_chunk
            else:
                if sample_rate not in self.resampler:
                    # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
                    self.resampler[sample_rate] = torchaudio.transforms.Resample(
                        sample_rate,
                        self.sample_rate,
                        lowpass_filter_width=64,
                        rolloff=0.9475937167399596,
                        resampling_method='sinc_interp_kaiser',
                        beta=14.769656459379492,
                    )
                audio_chunk = self.resampler[sample_rate](audio_chunk)
            if audio_chunk.size(0) < self.num_timbre_sample:
                padding_length = self.num_timbre_sample - audio_chunk.size(0)
                audio_chunk = torch.cat([audio_chunk, torch.zeros(padding_length)], dim=0)
            else:
                audio_chunk = audio_chunk[:self.num_timbre_sample]
            # audio_chunk = audio_chunk[:self.num_timbre_sample]
            chunk_list.append(audio_chunk)
        return chunk_list
    
def process_video(video_path: Path, args, model: ModelConfig, net: MMAudio, fm: FlowMatching, feature_utils: FeaturesUtils, device: str, dtype: torch.dtype, audio: torch.Tensor, i):
    log.info(f'Processing video: {video_path}')
    t=time.time()
    audio_num_sample = 89088
    if audio is not None:
        audio_num_sample = audio.shape[0]
    video_info = load_video(video_path, args.duration)
    clip_frames = video_info.clip_frames
    sync_frames = video_info.sync_frames
    duration = video_info.duration_sec
    if args.mask_away_clip:
        clip_frames = None
    else:
        clip_frames = clip_frames.unsqueeze(0)
    sync_frames = sync_frames.unsqueeze(0)

    model.seq_cfg.duration = duration
    model.seq_cfg.audio_num_sample = audio_num_sample
    net.update_seq_lengths(model.seq_cfg.latent_seq_len, model.seq_cfg.clip_seq_len, model.seq_cfg.sync_seq_len, model.seq_cfg.audio_seq_len)

    log.info(f'Prompt: {args.prompt}')
    log.info(f'Negative prompt: {args.negative_prompt}')
    audios = generate(clip_frames,
                      sync_frames, [args.prompt], audio,
                      negative_text=[args.negative_prompt],
                      feature_utils=feature_utils,
                      net=net,
                      fm=fm,
                      rng=torch.Generator(device=device).manual_seed(args.seed),
                      cfg_strength=args.cfg_strength)
    audio = audios.float().cpu()[0]
    save_path = args.output / f'{video_path.stem}{i}.wav'
    torchaudio.save(save_path, audio, model.seq_cfg.sampling_rate)
    log.info(f'Audio saved to {save_path}')

    if not args.skip_video_composite:
        video_save_path = args.output / f'{video_path.stem}{i}.mp4'
        make_video(video_info, video_save_path, audio, sampling_rate=model.seq_cfg.sampling_rate)
        log.info(f'Video saved to {video_save_path}')

@torch.inference_mode()
def main():
    setup_eval_logging()

    parser = ArgumentParser()
    parser.add_argument('--variant',
                        type=str,
                        default='large_44k',)
    parser.add_argument('--video_dir', type=Path, help='')
    parser.add_argument('--audio_path', type=str, default='')
    parser.add_argument('--prompt', type=str, help='Input prompt', default='')
    parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
    parser.add_argument('--duration', type=float, default=8.0)
    parser.add_argument('--cfg_strength', type=float, default=4.5)
    parser.add_argument('--num_steps', type=int, default=25)
    parser.add_argument('--mask_away_clip', action='store_true')
    parser.add_argument('--output', type=Path, help='Output directory', default='./')
    parser.add_argument('--seed', type=int, help='Random seed', default=42)
    parser.add_argument('--skip_video_composite', action='store_true')
    parser.add_argument('--full_precision', action='store_true')
    parser.add_argument('--model_path', type=str, default='weights/model.pth', help='Path to the model weights')

    args = parser.parse_args()

    if args.variant not in all_model_cfg:
        raise ValueError(f'Unknown model variant: {args.variant}')
    model: ModelConfig = all_model_cfg[args.variant]
    model.download_if_needed()

    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    elif torch.backends.mps.is_available():
        device = 'mps'
    else:
        log.warning('CUDA/MPS are not available, running on CPU')
    dtype = torch.float32 if args.full_precision else torch.bfloat16

    args.output.mkdir(parents=True, exist_ok=True)

    if  args.audio_path != '':
        SAMPLE_RATE = 44100
        audio = Audio([args.audio_path], SAMPLE_RATE)
        audio_list = audio.load_audio()
    else:
        audio_list = None

    model.model_path = Path(args.model_path)
    net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
    net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)['weights'])
    log.info(f'Loaded weights from {model.model_path}')

    fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=args.num_steps)
    feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
                                  synchformer_ckpt=model.synchformer_ckpt,
                                  enable_conditions=True,
                                  mode=model.mode,
                                  bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
                                  need_vae_encoder=True)
    feature_utils = feature_utils.to(device, dtype).eval()

    if args.video_dir:
        video_dir: Path = args.video_dir.expanduser()
        video_files = sorted(list(video_dir.glob('*.mp4')))
        if os.path.isfile(args.video_dir):
            video_files=[args.video_dir]
        if not video_files:
            log.warning(f'No video files found in {video_dir}')
        else:
            if audio_list is None:
                audio_list = [None] * len(video_files)
            if len(audio_list)==1:
                audio_list = audio_list * len(video_files)
            for i in range(1):
                for video_path, audio in tqdm.tqdm(zip(video_files,audio_list)):
                    args.seed = torch.seed()
                    process_video(video_path, args, model, net, fm, feature_utils, device, dtype, audio, i)

if __name__ == '__main__':
    main()