|
import dataclasses |
|
import logging |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
from colorlog import ColoredFormatter |
|
from PIL import Image |
|
from torchvision.transforms import v2 |
|
|
|
from meanaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio |
|
from meanaudio.model.flow_matching import FlowMatching |
|
from meanaudio.model.mean_flow import MeanFlow |
|
from meanaudio.model.networks import MeanAudio, FluxAudio |
|
from meanaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig |
|
from meanaudio.model.utils.features_utils import FeaturesUtils |
|
from meanaudio.utils.download_utils import download_model_if_needed |
|
|
|
log = logging.getLogger() |
|
|
|
|
|
@dataclasses.dataclass |
|
class ModelConfig: |
|
model_name: str |
|
model_path: Path |
|
vae_path: Path |
|
bigvgan_16k_path: Optional[Path] |
|
mode: str |
|
|
|
@property |
|
def seq_cfg(self) -> SequenceConfig: |
|
if self.mode == '16k': |
|
return CONFIG_16K |
|
elif self.mode == '44k': |
|
return CONFIG_44K |
|
|
|
def download_if_needed(self): |
|
raise NotImplementedError("Downloading models is not supported") |
|
download_model_if_needed(self.model_path) |
|
download_model_if_needed(self.vae_path) |
|
if self.bigvgan_16k_path is not None: |
|
download_model_if_needed(self.bigvgan_16k_path) |
|
|
|
|
|
fluxaudio_s_full = ModelConfig(model_name='fluxaudio_s_full', |
|
model_path=Path('./weights/fluxaudio_s_full.pth'), |
|
vae_path=Path('./weights/v1-16.pth'), |
|
bigvgan_16k_path=Path('./weights/best_netG.pt'), |
|
mode='16k') |
|
meanaudio_s_full = ModelConfig(model_name='meanaudio_s_full', |
|
model_path=Path('./weights/meanaudio_s_full.pth'), |
|
vae_path=Path('./weights/v1-16.pth'), |
|
bigvgan_16k_path=Path('./weights/best_netG.pt'), |
|
mode='16k') |
|
meanaudio_s_ac = ModelConfig(model_name='meanaudio_s_ac', |
|
model_path=Path('./weights/meanaudio_s_ac.pth'), |
|
vae_path=Path('./weights/v1-16.pth'), |
|
bigvgan_16k_path=Path('./weights/best_netG.pt'), |
|
mode='16k') |
|
meanaudio_l_full = ModelConfig(model_name='meanaudio_l_full', |
|
model_path=Path('./weights/meanaudio_l_full.pth'), |
|
vae_path=Path('./weights/v1-16.pth'), |
|
bigvgan_16k_path=Path('./weights/best_netG.pt'), |
|
mode='16k') |
|
|
|
all_model_cfg: dict[str, ModelConfig] = { |
|
'meanaudio_l_full': meanaudio_l_full, |
|
'meanaudio_s_full': meanaudio_s_full, |
|
'meanaudio_s_ac': meanaudio_s_ac, |
|
'fluxaudio_s_full': fluxaudio_s_full, |
|
} |
|
|
|
|
|
def generate_fm( |
|
text: Optional[list[str]], |
|
*, |
|
negative_text: Optional[list[str]] = None, |
|
feature_utils: FeaturesUtils, |
|
net: FluxAudio, |
|
fm: FlowMatching, |
|
rng: torch.Generator, |
|
cfg_strength: float, |
|
) -> torch.Tensor: |
|
|
|
|
|
device = feature_utils.device |
|
dtype = feature_utils.dtype |
|
|
|
bs = len(text) |
|
|
|
if text is not None: |
|
text_features, text_features_c = feature_utils.encode_text(text) |
|
else: |
|
text_features, text_features_c = net.get_empty_string_sequence(bs) |
|
|
|
if negative_text is not None: |
|
assert len(negative_text) == bs |
|
negative_text_features = feature_utils.encode_text(negative_text) |
|
else: |
|
negative_text_features = net.get_empty_string_sequence(bs) |
|
|
|
x0 = torch.randn(bs, |
|
net.latent_seq_len, |
|
net.latent_dim, |
|
device=device, |
|
dtype=dtype, |
|
generator=rng) |
|
preprocessed_conditions = net.preprocess_conditions(text_features, text_features_c) |
|
empty_conditions = net.get_empty_conditions( |
|
bs, negative_text_features=negative_text_features if negative_text is not None else None) |
|
|
|
cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions, |
|
cfg_strength) |
|
x1 = fm.to_data(cfg_ode_wrapper, x0) |
|
x1 = net.unnormalize(x1) |
|
spec = feature_utils.decode(x1) |
|
audio = feature_utils.vocode(spec) |
|
return audio |
|
|
|
|
|
def generate_mf( |
|
text: Optional[list[str]], |
|
*, |
|
negative_text: Optional[list[str]] = None, |
|
feature_utils: FeaturesUtils, |
|
net: MeanAudio, |
|
mf: MeanFlow, |
|
rng: torch.Generator, |
|
cfg_strength: float, |
|
) -> torch.Tensor: |
|
|
|
device = feature_utils.device |
|
dtype = feature_utils.dtype |
|
|
|
bs = len(text) |
|
|
|
if text is not None: |
|
text_features, text_features_c = feature_utils.encode_text(text) |
|
else: |
|
text_features, text_features_c = net.get_empty_string_sequence(bs) |
|
|
|
if negative_text is not None: |
|
assert len(negative_text) == bs |
|
negative_text_features = feature_utils.encode_text(negative_text) |
|
else: |
|
negative_text_features = net.get_empty_string_sequence(bs) |
|
|
|
x0 = torch.randn(bs, |
|
net.latent_seq_len, |
|
net.latent_dim, |
|
device=device, |
|
dtype=dtype, |
|
generator=rng) |
|
preprocessed_conditions = net.preprocess_conditions(text_features, text_features_c) |
|
empty_conditions = net.get_empty_conditions( |
|
bs, negative_text_features=negative_text_features if negative_text is not None else None) |
|
|
|
cfg_ode_wrapper = lambda t, r, x: net.ode_wrapper(t, r, x, preprocessed_conditions, empty_conditions, |
|
cfg_strength) |
|
x1 = mf.to_data(cfg_ode_wrapper, x0) |
|
x1 = net.unnormalize(x1) |
|
spec = feature_utils.decode(x1) |
|
audio = feature_utils.vocode(spec) |
|
return audio |
|
|
|
|
|
LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s" |
|
|
|
|
|
def setup_eval_logging(log_level: int = logging.INFO): |
|
logging.root.setLevel(log_level) |
|
formatter = ColoredFormatter(LOGFORMAT) |
|
stream = logging.StreamHandler() |
|
stream.setLevel(log_level) |
|
stream.setFormatter(formatter) |
|
log = logging.getLogger() |
|
log.setLevel(log_level) |
|
log.addHandler(stream) |