Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import librosa | |
import pandas as pd | |
import soundfile as sf | |
from tqdm import tqdm | |
from torchdiffeq import odeint | |
from einops import rearrange | |
from capspeech.nar.utils import make_pad_mask | |
def sample(model, vocoder, | |
x, cond, text, prompt, clap, prompt_mask, | |
steps=25, cfg=2.0, | |
sway_sampling_coef=-1.0, device='cuda'): | |
model.eval() | |
vocoder.eval() | |
y0 = torch.randn_like(x) | |
neg_text = torch.ones_like(text) * -1 | |
neg_clap = torch.zeros_like(clap) | |
neg_prompt = torch.zeros_like(prompt) | |
neg_prompt_mask = torch.zeros_like(prompt_mask) | |
neg_prompt_mask[:, 0] = 1 | |
def fn(t, x): | |
pred = model(x=x, cond=cond, text=text, time=t, | |
prompt=prompt, clap=clap, | |
mask=None, | |
prompt_mask=prompt_mask) | |
null_pred = model(x=x, cond=cond, text=neg_text, time=t, | |
prompt=neg_prompt, clap=neg_clap, | |
mask=None, | |
prompt_mask=neg_prompt_mask) | |
return pred + (pred - null_pred) * cfg | |
t_start = 0 | |
t = torch.linspace(t_start, 1, steps, device=device) | |
if sway_sampling_coef is not None: | |
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) | |
trajectory = odeint(fn, y0, t, method="euler") | |
out = trajectory[-1] | |
out = rearrange(out, 'b n d -> b d n') | |
with torch.inference_mode(): | |
wav_gen = vocoder(out) | |
wav_gen_float = wav_gen.squeeze().cpu().numpy() # wav_gen is FloatTensor with shape [1, T_time] | |
return wav_gen_float | |
def prepare_batch(batch, mel, latent_sr): | |
x, x_lens, y, y_lens, c, c_lens, tag = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"], batch["c"], batch["c_lens"], batch["tag"] | |
# add len for clap embedding | |
x_lens = x_lens + 1 | |
with torch.no_grad(): | |
audio_clip = mel(y) | |
audio_clip = rearrange(audio_clip, 'b d n -> b n d') | |
y_lens = (y_lens * latent_sr).long() | |
return x, x_lens, audio_clip, y_lens, c, c_lens, tag | |
# use ground truth duration for simple inference | |
def eval_model(model, vocos, mel, val_loader, params, | |
steps=25, cfg=2.0, | |
sway_sampling_coef=-1.0, device='cuda', | |
epoch=0, save_path='logs/eval/', val_num=5): | |
save_path = save_path + '/' + str(epoch) + '/' | |
os.makedirs(save_path, exist_ok=True) | |
latent_sr = params['mel']['target_sample_rate'] / params['mel']['hop_length'] | |
for step, batch in enumerate(tqdm(val_loader)): | |
(text, text_lens, audio_clips, audio_lens, prompt, prompt_lens, clap) = prepare_batch(batch, mel, latent_sr) | |
cond = None | |
seq_len_prompt = prompt.shape[1] | |
prompt_mask = make_pad_mask(prompt_lens, seq_len_prompt).to(prompt.device) | |
gen = sample(model, vocos, | |
audio_clips, cond, text, prompt, clap, prompt_mask, | |
steps=steps, cfg=cfg, | |
sway_sampling_coef=sway_sampling_coef, device=device) | |
sf.write(save_path + f'{step}.wav', gen, samplerate=params['mel']['target_sample_rate']) | |
if step + 1 >= val_num: | |
break |