File size: 3,227 Bytes
dd9600d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import torch
import librosa
import pandas as pd
import soundfile as sf
from tqdm import tqdm
from torchdiffeq import odeint
from einops import rearrange
from capspeech.nar.utils import make_pad_mask

@torch.no_grad()
def sample(model, vocoder,
           x, cond, text, prompt, clap, prompt_mask,
           steps=25, cfg=2.0,
           sway_sampling_coef=-1.0, device='cuda'):

    model.eval()
    vocoder.eval()

    y0 = torch.randn_like(x)

    neg_text = torch.ones_like(text) * -1
    neg_clap = torch.zeros_like(clap)
    neg_prompt = torch.zeros_like(prompt)
    neg_prompt_mask = torch.zeros_like(prompt_mask)
    neg_prompt_mask[:, 0] = 1

    def fn(t, x):
        pred = model(x=x, cond=cond, text=text, time=t, 
                     prompt=prompt, clap=clap,
                     mask=None,
                     prompt_mask=prompt_mask)

        null_pred = model(x=x, cond=cond, text=neg_text, time=t, 
                          prompt=neg_prompt, clap=neg_clap,
                          mask=None,
                          prompt_mask=neg_prompt_mask)
        return pred + (pred - null_pred) * cfg

    t_start = 0
    t = torch.linspace(t_start, 1, steps, device=device)
    if sway_sampling_coef is not None:
        t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)

    trajectory = odeint(fn, y0, t, method="euler")
    out = trajectory[-1]
    out = rearrange(out, 'b n d -> b d n')

    with torch.inference_mode():
        wav_gen = vocoder(out)
    wav_gen_float = wav_gen.squeeze().cpu().numpy() # wav_gen is FloatTensor with shape [1, T_time]
    return wav_gen_float


def prepare_batch(batch, mel, latent_sr):
    x, x_lens, y, y_lens, c, c_lens, tag = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"], batch["c"], batch["c_lens"], batch["tag"]

    # add len for clap embedding
    x_lens = x_lens + 1

    with torch.no_grad():
        audio_clip = mel(y)
        audio_clip = rearrange(audio_clip, 'b d n -> b n d')
        y_lens = (y_lens * latent_sr).long()

    return x, x_lens, audio_clip, y_lens, c, c_lens, tag

# use ground truth duration for simple inference
@torch.no_grad()
def eval_model(model, vocos, mel, val_loader, params,
               steps=25, cfg=2.0,
               sway_sampling_coef=-1.0, device='cuda',
               epoch=0, save_path='logs/eval/', val_num=5):

    save_path = save_path + '/' + str(epoch) + '/'
    os.makedirs(save_path, exist_ok=True)

    latent_sr = params['mel']['target_sample_rate'] / params['mel']['hop_length']

    for step, batch in enumerate(tqdm(val_loader)):
        (text, text_lens, audio_clips, audio_lens, prompt, prompt_lens, clap) = prepare_batch(batch, mel, latent_sr)
        cond = None

        seq_len_prompt = prompt.shape[1]
        prompt_mask = make_pad_mask(prompt_lens, seq_len_prompt).to(prompt.device)

        gen = sample(model, vocos,
                     audio_clips, cond, text, prompt, clap, prompt_mask,
                     steps=steps, cfg=cfg,
                     sway_sampling_coef=sway_sampling_coef, device=device)

        sf.write(save_path + f'{step}.wav', gen, samplerate=params['mel']['target_sample_rate'])

        if step + 1 >= val_num:
            break