DMOSpeech2 / guidance_model.py
mrfakename's picture
pt 1
597cecf
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from random import random
from typing import Callable
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from ctcmodel import ConformerCTC
from discriminator_conformer import ConformerDiscirminator
from ecapa_tdnn import ECAPA_TDNN
from f5_tts.model import DiT
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
list_str_to_tensor, mask_from_frac_lengths)
class NoOpContext:
def __enter__(self):
pass
def __exit__(self, *args):
pass
def predict_flow(
transformer, # flow model
x, # noisy input
cond, # mask (prompt mask + length mask)
text, # text input
time, # time step
second_time=None,
cfg_strength=1.0,
):
pred = transformer(
x=x,
cond=cond,
text=text,
time=time,
second_time=second_time,
drop_audio_cond=False,
drop_text=False,
)
if cfg_strength < 1e-5:
return pred
null_pred = transformer(
x=x,
cond=cond,
text=text,
time=time,
second_time=second_time,
drop_audio_cond=True,
drop_text=True,
)
return pred + (pred - null_pred) * cfg_strength
def _kl_dist_func(x, y):
log_probs = F.log_softmax(x, dim=2)
target_probs = F.log_softmax(y, dim=2)
return torch.nn.functional.kl_div(
log_probs, target_probs, reduction="batchmean", log_target=True
)
class Guidance(nn.Module):
def __init__(
self,
real_unet: DiT, # teacher flow model
fake_unet: DiT, # student flow model
use_fp16: bool = True,
real_guidance_scale: float = 0.0,
fake_guidance_scale: float = 0.0,
gen_cls_loss: bool = False,
sv_path_en: str = "",
sv_path_zh: str = "",
ctc_path: str = "",
sway_coeff: float = 0.0,
scale: float = 1.0,
):
super().__init__()
self.vocab_size = real_unet.vocab_size
if ctc_path != "":
model = ConformerCTC(
vocab_size=real_unet.vocab_size,
mel_dim=real_unet.mel_dim,
num_heads=8,
d_hid=512,
nlayers=6,
)
self.ctc_model = model.eval()
self.ctc_model.requires_grad_(False)
self.ctc_model.load_state_dict(
torch.load(ctc_path, weights_only=True, map_location="cpu")[
"model_state_dict"
]
)
if sv_path_en != "":
model = ECAPA_TDNN()
self.sv_model_en = model.eval()
self.sv_model_en.requires_grad_(False)
self.sv_model_en.load_state_dict(
torch.load(sv_path, weights_only=True, map_location="cpu")[
"model_state_dict"
]
)
if sv_path_zh != "":
model = ECAPA_TDNN()
self.sv_model_zh = model.eval()
self.sv_model_zh.requires_grad_(False)
self.sv_model_zh.load_state_dict(
torch.load(sv_path_zh, weights_only=True, map_location="cpu")[
"model_state_dict"
]
)
self.scale = scale
self.real_unet = real_unet
self.real_unet.requires_grad_(False) # no update on the teacher model
self.fake_unet = fake_unet
self.fake_unet.requires_grad_(True) # update the student model
self.real_guidance_scale = real_guidance_scale
self.fake_guidance_scale = fake_guidance_scale
assert self.fake_guidance_scale == 0, "no guidance for fake"
self.use_fp16 = use_fp16
self.gen_cls_loss = gen_cls_loss
self.sway_coeff = sway_coeff
if self.gen_cls_loss:
self.cls_pred_branch = ConformerDiscirminator(
input_dim=(self.fake_unet.depth + 1) * self.fake_unet.dim
+ 3 * 512, # 3 is the number of layers from the CTC model
num_layers=3,
channels=self.fake_unet.dim // 2,
)
self.cls_pred_branch.requires_grad_(True)
self.network_context_manager = (
torch.autocast(device_type="cuda", dtype=torch.float16)
if self.use_fp16
else NoOpContext()
)
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from f5_tts.model.dataset import (DynamicBatchSampler, collate_fn,
load_dataset)
from f5_tts.model.utils import get_tokenizer
bsz = 16
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
dataset_name = "Emilia_ZH_EN"
if tokenizer == "custom":
tokenizer_path = tokenizer_path
else:
tokenizer_path = dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
self.vocab_char_map = vocab_char_map
def compute_distribution_matching_loss(
self,
inp: float["b n d"] | float["b nw"], # mel or raw wave, ground truth latent
text: int["b nt"] | list[str], # text input
*,
second_time: torch.Tensor | None = None, # second time step for flow prediction
rand_span_mask: (
bool["b n d"] | bool["b nw"] | None
) = None, # combined mask (prompt mask + padding mask)
):
"""
Compute DMD loss (L_DMD) between the student distribution and teacher distribution.
Following the DMDSpeech logic:
- Sample time t
- Construct noisy input phi = (1 - t)*x0 + t*x1, where x0 is noise and x1 is inp
- Predict flows with teacher (f_phi) and student (G_theta)
- Compute gradient that aligns student distribution with teacher distribution
The code is adapted from F5-TTS but conceptualized per DMD:
L_DMD encourages p_theta to match p_data via the difference between teacher and student predictions.
"""
original_inp = inp
with torch.no_grad():
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype=dtype, device=device)
# get flow
t = time.unsqueeze(-1).unsqueeze(-1)
# t = t + self.sway_coeff * (torch.cos(torch.pi / 2 * t) - 1 + t)
sigma_t, alpha_t = (1 - t), t
phi = (1 - t) * x0 + t * x1 # noisy x
flow = x1 - x0 # flow target
# only predict what is within the random mask span for infilling
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
# run at full precision as autocast and no_grad doesn't work well together
with self.network_context_manager:
pred_fake = predict_flow(
self.fake_unet,
phi,
cond, # mask (prompt mask + length mask)
text, # text input
time, # time step
second_time=second_time,
cfg_strength=self.fake_guidance_scale,
)
# pred = (x1 - x0), thus phi + (1-t) * pred = (1 - t) * x0 + t * x1 + (1 - t) * (x1 - x0) = (1 - t) * x1 + t * x1 = x1
pred_fake_image = phi + (1 - t) * pred_fake
pred_fake_image[~rand_span_mask] = inp[~rand_span_mask]
with self.network_context_manager:
pred_real = predict_flow(
self.real_unet,
phi,
cond,
text,
time,
cfg_strength=self.real_guidance_scale,
)
pred_real_image = phi + (1 - t) * pred_real
pred_real_image[~rand_span_mask] = inp[~rand_span_mask]
p_real = inp - pred_real_image
p_fake = inp - pred_fake_image
grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1, 2], keepdim=True)
grad = torch.nan_to_num(grad)
# grad = grad / sigma_t # pred_fake - pred_real
# grad = grad * (1 + sigma_t / alpha_t)
# grad = grad / (1 + sigma_t / alpha_t) # noise
# grad = grad / sigma_t # score difference
# grad = grad * alpha_t
# grad = grad * (sigma_t ** 2 / alpha_t)
# grad = grad * (alpha_t + sigma_t ** 2 / alpha_t)
# The DMD loss: MSE to move student distribution closer to teacher distribution
# Only optimize over the masked region
loss = (
0.5
* F.mse_loss(
original_inp.float(),
(original_inp - grad).detach().float(),
reduction="none",
)
* rand_span_mask.unsqueeze(-1)
)
loss = loss.sum() / (rand_span_mask.sum() * grad.size(-1))
loss_dict = {"loss_dm": loss}
dm_log_dict = {
"dmtrain_time": time.detach().float(),
"dmtrain_noisy_inp": phi.detach().float(),
"dmtrain_pred_real_image": pred_real_image.detach().float(),
"dmtrain_pred_fake_image": pred_fake_image.detach().float(),
"dmtrain_grad": grad.detach().float(),
"dmtrain_gradient_norm": torch.norm(grad).item(),
}
return loss_dict, dm_log_dict
def compute_ctc_sv_loss(
self,
real_inp: torch.Tensor, # real data latent
fake_inp: torch.Tensor, # student-generated data latent
text: torch.Tensor,
text_lens: torch.Tensor,
rand_span_mask: torch.Tensor,
second_time: torch.Tensor | None = None,
):
"""
Compute CTC + SV loss for direct metric optimization, as described in DMDSpeech.
- CTC loss reduces WER
- SV loss improves speaker similarity
Both CTC and SV models operate on latents.
"""
# compute CTC loss
out, layer, ctc_loss = self.ctc_model(
fake_inp * self.scale, text, text_lens
) # lengths from rand_span_mask or known
with torch.no_grad():
real_out, real_layers, ctc_loss_test = self.ctc_model(
real_inp * self.scale, text, text_lens
)
real_logits = real_out.log_softmax(dim=2)
# emb_real = self.sv_model(real_inp * self.scale) # snippet from prompt region
fake_logits = out.log_softmax(dim=2)
kl_loss = F.kl_div(fake_logits, real_logits, reduction="mean", log_target=True)
# For SV:
# Extract speaker embeddings from real (prompt) and fake:
# emb_fake = self.sv_model(fake_inp * self.scale)
# sv_loss = 1 - F.cosine_similarity(emb_real, emb_fake, dim=-1).mean()
input_lengths = rand_span_mask.sum(axis=-1).cpu().numpy()
prompt_lengths = real_inp.size(1) - rand_span_mask.sum(axis=-1).cpu().numpy()
chunks_real = []
chunks_fake = []
mel_len = min([int(input_lengths.min().item() - 1), 300])
for bib in range(len(input_lengths)):
prompt_length = int(prompt_lengths[bib].item())
mel_length = int(input_lengths[bib].item())
mask = rand_span_mask[bib]
mask = torch.where(mask)[0]
prompt_start = mask[0].cpu().numpy()
prompt_end = mask[-1].cpu().numpy()
if prompt_end - mel_len <= prompt_start:
random_start = np.random.randint(0, mel_length - mel_len)
else:
random_start = np.random.randint(prompt_start, prompt_end - mel_len)
chunks_fake.append(fake_inp[bib, random_start : random_start + mel_len, :])
chunks_real.append(real_inp[bib, :mel_len, :])
chunks_real = torch.stack(chunks_real, dim=0)
chunks_fake = torch.stack(chunks_fake, dim=0)
with torch.no_grad():
emb_real_en = self.sv_model_en(chunks_real * self.scale)
emb_fake_en = self.sv_model_en(chunks_fake * self.scale)
sv_loss_en = 1 - F.cosine_similarity(emb_real_en, emb_fake_en, dim=-1).mean()
with torch.no_grad():
emb_real_zh = self.sv_model_zh(chunks_real * self.scale)
emb_fake_zh = self.sv_model_zh(chunks_fake * self.scale)
sv_loss_zh = 1 - F.cosine_similarity(emb_real_zh, emb_fake_zh, dim=-1).mean()
sv_loss = (sv_loss_en + sv_loss_zh) / 2
return (
{"loss_ctc": ctc_loss, "loss_kl": kl_loss, "loss_sim": sv_loss},
layer,
real_layers,
)
def compute_loss_fake(
self,
inp: torch.Tensor, # student generator output
text: torch.Tensor | list[str],
rand_span_mask: torch.Tensor,
second_time: torch.Tensor | None = None,
):
"""
Compute flow loss for the fake flow model, which is trained to estimate the flow (score) of the student distribution.
This is the same as L_diff in the paper.
"""
# Similar to distribution matching, but only train fake to predict flow directly
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# Sample a time
time = torch.rand((batch,), dtype=dtype, device=device)
x1 = inp
x0 = torch.randn_like(x1)
t = time.unsqueeze(-1).unsqueeze(-1)
phi = (1 - t) * x0 + t * x1
flow = x1 - x0
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
with self.network_context_manager:
pred = self.fake_unet(
x=phi,
cond=cond,
text=text,
time=time,
second_time=second_time,
drop_audio_cond=False,
drop_text=False, # make sure the cfg=1
)
# Compute MSE between predicted flow and actual flow, masked by rand_span_mask
loss = F.mse_loss(pred, flow, reduction="none")
loss = loss[rand_span_mask].mean()
loss_dict = {"loss_fake_mean": loss}
log_dict = {
"faketrain_noisy_inp": phi.detach().float(),
"faketrain_x1": x1.detach().float(),
"faketrain_pred_flow": pred.detach().float(),
}
return loss_dict, log_dict
def compute_cls_logits(
self,
inp: torch.Tensor, # student generator output
layer: torch.Tensor,
text: torch.Tensor,
rand_span_mask: torch.Tensor,
second_time: torch.Tensor | None = None,
guidance: bool = False,
):
"""
Compute adversarial loss logits for the generator.
This is used to compute L_adv in the paper.
"""
context_no_grad = torch.no_grad if guidance else NoOpContext
with context_no_grad():
# If we are not doing generator classification loss, return zeros
if not self.gen_cls_loss:
return torch.zeros_like(inp[..., 0]) # shape (b, n)
# For classification, we need some representation:
# We'll mimic the logic from compute_loss_fake
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# Sample a time
time = torch.rand((batch,), dtype=dtype, device=device)
x1 = inp
x0 = torch.randn_like(x1)
t = time.unsqueeze(-1).unsqueeze(-1)
phi = (1 - t) * x0 + t * x1
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
with self.network_context_manager:
layers = self.fake_unet(
x=phi,
cond=cond,
text=text,
time=time,
second_time=second_time,
drop_audio_cond=False,
drop_text=False, # make sure the cfg=1
classify_mode=True,
)
# layers = torch.stack(layers, dim=0)
if guidance:
layers = [layer.detach() for layer in layers]
layer = layer[-3:] # only use the last 3 layers
layer = [l.transpose(-1, -2) for l in layer]
# layer = [F.interpolate(l, mode='nearest', scale_factor=4).transpose(-1, -2) for l in layer]
if layer[0].size(1) < layers[0].size(1):
layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer]
layers = layer + layers
# logits: (b, 1)
logits = self.cls_pred_branch(layers)
return logits, layers
def compute_generator_cls_loss(
self,
inp: torch.Tensor, # student generator output
layer: torch.Tensor,
real_layers: torch.Tensor,
text: torch.Tensor,
rand_span_mask: torch.Tensor,
second_time: torch.Tensor | None = None,
mse_loss: bool = False,
mse_inp: torch.Tensor | None = None,
):
"""
Compute the adversarial loss for the generator.
"""
# Compute classification loss for generator:
if not self.gen_cls_loss:
return {"gen_cls_loss": 0}
logits, fake_layers = self.compute_cls_logits(
inp, layer, text, rand_span_mask, second_time, guidance=False
)
loss = ((1 - logits) ** 2).mean()
return {"gen_cls_loss": loss, "loss_mse": 0}
def compute_guidance_cls_loss(
self,
fake_inp: torch.Tensor,
text: torch.Tensor,
rand_span_mask: torch.Tensor,
real_data: dict,
second_time: torch.Tensor | None = None,
):
"""
This function computes the adversarial loss for the discirminator.
The discriminator is trained to classify the generator output as real or fake.
"""
with torch.no_grad():
# get layers from CTC model
_, layer = self.ctc_model(fake_inp * self.scale)
logits_fake, _ = self.compute_cls_logits(
fake_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True
)
loss_fake = (logits_fake**2).mean()
real_inp = real_data["inp"]
with torch.no_grad():
# get layers from CTC model
_, layer = self.ctc_model(real_inp * self.scale)
logits_real, _ = self.compute_cls_logits(
real_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True
)
loss_real = ((1 - logits_real) ** 2).mean()
classification_loss = loss_real + loss_fake
loss_dict = {"guidance_cls_loss": classification_loss}
log_dict = {
"pred_realism_on_real": loss_real.detach().item(),
"pred_realism_on_fake": loss_fake.detach().item(),
}
return loss_dict, log_dict
def generator_forward(
self,
inp: torch.Tensor,
text: torch.Tensor,
text_lens: torch.Tensor,
text_normalized: torch.Tensor,
text_normalized_lens: torch.Tensor,
rand_span_mask: torch.Tensor,
real_data: (
dict | None
) = None, # ground truth data (primarily prompt) to compute SV loss
second_time: torch.Tensor | None = None,
mse_loss: bool = False,
):
"""
Forward pass for the generator.
This function computes the loss for the generator, which includes:
- Distribution matching loss (L_DMD)
- Adversarial generator loss (L_adv(G; D))
- CTC/SV loss (L_ctc + L_sv)
"""
# 1. Compute DM loss
dm_loss_dict, dm_log_dict = self.compute_distribution_matching_loss(
inp, text, rand_span_mask=rand_span_mask, second_time=second_time
)
ctc_sv_loss_dict = {}
cls_loss_dict = {}
# 2. Compute optional CTC/SV loss if real_data provided
if real_data is not None:
real_inp = real_data["inp"]
ctc_sv_loss_dict, layer, real_layers = self.compute_ctc_sv_loss(
real_inp,
inp,
text_normalized,
text_normalized_lens,
rand_span_mask,
second_time=second_time,
)
# 3. Compute optional classification loss
if self.gen_cls_loss:
cls_loss_dict = self.compute_generator_cls_loss(
inp,
layer,
real_layers,
text,
rand_span_mask=rand_span_mask,
second_time=second_time,
mse_inp=real_data["inp"] if real_data is not None else None,
mse_loss=mse_loss,
)
loss_dict = {**dm_loss_dict, **cls_loss_dict, **ctc_sv_loss_dict}
log_dict = {**dm_log_dict}
return loss_dict, log_dict
def guidance_forward(
self,
fake_inp: torch.Tensor,
text: torch.Tensor,
text_lens: torch.Tensor,
rand_span_mask: torch.Tensor,
real_data: dict | None = None,
second_time: torch.Tensor | None = None,
):
"""
Forward pass for the guidnce module (discriminator + fake flow function).
This function computes the loss for the guidance module, which includes:
- Flow matching loss (L_diff)
- Adversarial discrminator loss (L_adv(D; G))
"""
# Compute fake loss (like epsilon prediction loss in Guidance)
fake_loss_dict, fake_log_dict = self.compute_loss_fake(
fake_inp, text, rand_span_mask=rand_span_mask, second_time=second_time
)
# If gen_cls_loss, compute guidance cls loss
cls_loss_dict = {}
cls_log_dict = {}
if self.gen_cls_loss and real_data is not None:
cls_loss_dict, cls_log_dict = self.compute_guidance_cls_loss(
fake_inp, text, rand_span_mask, real_data, second_time=second_time
)
loss_dict = {**fake_loss_dict, **cls_loss_dict}
log_dict = {**fake_log_dict, **cls_log_dict}
return loss_dict, log_dict
def forward(
self,
generator_turn=False,
guidance_turn=False,
generator_data_dict=None,
guidance_data_dict=None,
):
if generator_turn:
loss_dict, log_dict = self.generator_forward(
inp=generator_data_dict["inp"],
text=generator_data_dict["text"],
text_lens=generator_data_dict["text_lens"],
text_normalized=generator_data_dict["text_normalized"],
text_normalized_lens=generator_data_dict["text_normalized_lens"],
rand_span_mask=generator_data_dict["rand_span_mask"],
real_data=generator_data_dict.get("real_data", None),
second_time=generator_data_dict.get("second_time", None),
mse_loss=generator_data_dict.get("mse_loss", False),
)
elif guidance_turn:
loss_dict, log_dict = self.guidance_forward(
fake_inp=guidance_data_dict["inp"],
text=guidance_data_dict["text"],
text_lens=guidance_data_dict["text_lens"],
rand_span_mask=guidance_data_dict["rand_span_mask"],
real_data=guidance_data_dict.get("real_data", None),
second_time=guidance_data_dict.get("second_time", None),
)
else:
raise NotImplementedError(
"Must specify either generator_turn or guidance_turn"
)
return loss_dict, log_dict
if __name__ == "__main__":
from f5_tts.model.utils import get_tokenizer
bsz = 16
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
dataset_name = "Emilia_ZH_EN"
if tokenizer == "custom":
tokenizer_path = tokenizer_path
else:
tokenizer_path = dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
real_unet = DiT(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_num_embeds=vocab_size,
mel_dim=100,
)
fake_unet = DiT(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_num_embeds=vocab_size,
mel_dim=100,
)
guidance = Guidance(
real_unet,
fake_unet,
real_guidance_scale=1.0,
fake_guidance_scale=0.0,
use_fp16=True,
gen_cls_loss=True,
).cuda()
text = ["hello world"] * bsz
lens = torch.randint(1, 1000, (bsz,)).cuda()
inp = torch.randn(bsz, lens.max(), 80).cuda()
batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device
# handle text as string
if isinstance(text, list):
if exists(vocab_char_map):
text = list_str_to_idx(text, vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(
lens, length=seq_len
) # useless here, as collate_fn will pad to max length in batch
frac_lengths_mask = (0.7, 1.0)
# get a random span to mask out for training conditionally
frac_lengths = (
torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask)
)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# Construct data dicts for generator and guidance phases
# For flow, `real_data` can just be the ground truth if available; here we simulate it
real_data_dict = {
"inp": torch.zeros_like(inp), # simulating real data
}
generator_data_dict = {
"inp": inp,
"text": text,
"rand_span_mask": rand_span_mask,
"real_data": real_data_dict,
}
guidance_data_dict = {
"inp": inp,
"text": text,
"rand_span_mask": rand_span_mask,
"real_data": real_data_dict,
}
# Generator forward pass
loss_dict, log_dict = guidance(
generator_turn=True, generator_data_dict=generator_data_dict
)
print("Generator turn losses:", loss_dict)
# Guidance forward pass
loss_dict, log_dict = guidance(
guidance_turn=True, guidance_data_dict=guidance_data_dict
)
print("Guidance turn losses:", loss_dict)