File size: 4,739 Bytes
30f8a30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gc
import logging

import torch

from .eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
                                load_video, make_video, setup_eval_logging)
from .model.flow_matching import FlowMatching
from .model.networks import MMAudio, get_my_mmaudio
from .model.sequence_config import SequenceConfig
from .model.utils.features_utils import FeaturesUtils

persistent_offloadobj = None

def get_model(persistent_models = False, verboseLevel = 1) -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    global device, persistent_offloadobj, persistent_net, persistent_features_utils, persistent_seq_cfg

    log = logging.getLogger()

    device =  'cpu' #"cuda"
    # if torch.cuda.is_available():
    #     device = 'cuda'
    # elif torch.backends.mps.is_available():
    #     device = 'mps'
    # else:
    #     log.warning('CUDA/MPS are not available, running on CPU')
    dtype = torch.bfloat16

    model: ModelConfig = all_model_cfg['large_44k_v2']
    # model.download_if_needed()

    setup_eval_logging()

    seq_cfg = model.seq_cfg
    if persistent_offloadobj == None:
        from accelerate import init_empty_weights
        # with init_empty_weights():
        net: MMAudio = get_my_mmaudio(model.model_name)
        net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
        net.to(device, dtype).eval()
        log.info(f'Loaded weights from {model.model_path}')
        feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
                                    synchformer_ckpt=model.synchformer_ckpt,
                                    enable_conditions=True,
                                    mode=model.mode,
                                    bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
                                    need_vae_encoder=False)
        feature_utils = feature_utils.to(device, dtype).eval()
        feature_utils.device = "cuda"

        pipe = { "net" : net, "clip" : feature_utils.clip_model, "syncformer" : feature_utils.synchformer, "vocode" : feature_utils.tod.vocoder, "vae" : feature_utils.tod.vae }
        from mmgp import offload
        offloadobj = offload.profile(pipe, profile_no=4, verboseLevel=2)
        if persistent_models:
            persistent_offloadobj = offloadobj
            persistent_net = net
            persistent_features_utils = feature_utils
            persistent_seq_cfg = seq_cfg

    else:
        offloadobj = persistent_offloadobj  
        net = persistent_net 
        feature_utils = persistent_features_utils
        seq_cfg = persistent_seq_cfg

    if not persistent_models:
        persistent_offloadobj = None
        persistent_net = None
        persistent_features_utils = None
        persistent_seq_cfg = None

    return net, feature_utils, seq_cfg, offloadobj

@torch.inference_mode()
def video_to_audio(video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
                   cfg_strength: float, duration: float, save_path , persistent_models = False, audio_file_only = False, verboseLevel = 1):

    global device

    net, feature_utils, seq_cfg, offloadobj = get_model(persistent_models, verboseLevel )

    rng = torch.Generator(device="cuda")
    if seed >= 0:
        rng.manual_seed(seed)
    else:
        rng.seed()
    fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)

    video_info = load_video(video, duration)
    clip_frames = video_info.clip_frames
    sync_frames = video_info.sync_frames
    duration = video_info.duration_sec
    clip_frames = clip_frames.unsqueeze(0)
    sync_frames = sync_frames.unsqueeze(0)
    seq_cfg.duration = duration
    net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)

    audios = generate(clip_frames,
                      sync_frames, [prompt],
                      negative_text=[negative_prompt],
                      feature_utils=feature_utils,
                      net=net,
                      fm=fm,
                      rng=rng,
                      cfg_strength=cfg_strength,
                      offloadobj = offloadobj
                      )
    audio = audios.float().cpu()[0]


    if audio_file_only:
        import torchaudio
        torchaudio.save(save_path, audio.unsqueeze(0) if audio.dim() == 1 else audio, seq_cfg.sampling_rate)
    else:
        make_video(video, video_info, save_path, audio, sampling_rate=seq_cfg.sampling_rate)

    offloadobj.unload_all()
    if not persistent_models:
        offloadobj.release()

    torch.cuda.empty_cache()
    gc.collect()
    return save_path