DMOSpeech2 / infer.py
mrfakename's picture
pt 1
597cecf
import os
import torch
import torch.nn.functional as F
import torchaudio
from safetensors.torch import load_file
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from duration_predictor import SpeechLengthPredictor
from f5_tts.infer.utils_infer import (chunk_text, convert_char_to_pinyin,
hop_length, load_vocoder,
preprocess_ref_audio_text, speed,
target_rms, target_sample_rate,
transcribe)
# Import F5-TTS modules
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (default, exists, get_tokenizer, lens_to_mask,
list_str_to_idx, list_str_to_tensor,
mask_from_frac_lengths)
# Import custom modules
from unimodel import UniModel
class DMOInference:
"""F5-TTS Inference wrapper class for easy text-to-speech generation."""
def __init__(
self,
student_checkpoint_path="",
duration_predictor_path="",
device="cuda",
model_type="F5TTS_Base", # "F5TTS_Base" or "E2TTS_Base"
tokenizer="pinyin",
dataset_name="Emilia_ZH_EN",
):
"""
Initialize F5-TTS inference model.
Args:
student_checkpoint_path: Path to student model checkpoint
duration_predictor_path: Path to duration predictor checkpoint
device: Device to run inference on
model_type: Model architecture type
tokenizer: Tokenizer type ("pinyin", "char", or "custom")
dataset_name: Dataset name for tokenizer
cuda_device_id: CUDA device ID to use
"""
self.device = device
self.model_type = model_type
self.tokenizer = tokenizer
self.dataset_name = dataset_name
# Model parameters
self.target_sample_rate = 24000
self.n_mel_channels = 100
self.hop_length = 256
self.real_guidance_scale = 2
self.fake_guidance_scale = 0
self.gen_cls_loss = False
self.num_student_step = 4
# Initialize components
self._setup_tokenizer()
self._setup_models(student_checkpoint_path)
self._setup_mel_spec()
self._setup_vocoder()
self._setup_duration_predictor(duration_predictor_path)
def _setup_tokenizer(self):
"""Setup tokenizer and vocabulary."""
if self.tokenizer == "custom":
tokenizer_path = self.tokenizer_path
else:
tokenizer_path = self.dataset_name
self.vocab_char_map, self.vocab_size = get_tokenizer(
tokenizer_path, self.tokenizer
)
def _setup_models(self, student_checkpoint_path):
"""Initialize teacher and student models."""
# Model configuration
if self.model_type == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
elif self.model_type == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
else:
raise ValueError(f"Unknown model type: {self.model_type}")
# Initialize UniModel (student)
self.model = UniModel(
model_cls(
**model_cfg,
text_num_embeds=self.vocab_size,
mel_dim=self.n_mel_channels,
second_time=self.num_student_step > 1,
),
checkpoint_path="",
vocab_char_map=self.vocab_char_map,
frac_lengths_mask=(0.5, 0.9),
real_guidance_scale=self.real_guidance_scale,
fake_guidance_scale=self.fake_guidance_scale,
gen_cls_loss=self.gen_cls_loss,
sway_coeff=0,
)
# Load student checkpoint
checkpoint = torch.load(student_checkpoint_path, map_location="cpu")
self.model.load_state_dict(checkpoint["model_state_dict"], strict=False)
# Setup generator and teacher
self.generator = self.model.feedforward_model.to(self.device)
self.teacher = self.model.guidance_model.real_unet.to(self.device)
self.scale = checkpoint["scale"]
def _setup_mel_spec(self):
"""Initialize mel spectrogram module."""
mel_spec_kwargs = dict(
target_sample_rate=self.target_sample_rate,
n_mel_channels=self.n_mel_channels,
hop_length=self.hop_length,
)
self.mel_spec = MelSpec(**mel_spec_kwargs)
def _setup_vocoder(self):
"""Initialize vocoder."""
self.vocos = load_vocoder(is_local=False, local_path="")
self.vocos = self.vocos.to(self.device)
def _setup_duration_predictor(self, checkpoint_path):
"""Initialize duration predictor."""
self.wav2mel = MelSpec(
target_sample_rate=24000,
n_mel_channels=100,
hop_length=256,
win_length=1024,
n_fft=1024,
mel_spec_type="vocos",
).to(self.device)
self.SLP = SpeechLengthPredictor(
vocab_size=2545,
n_mel=100,
hidden_dim=512,
n_text_layer=4,
n_cross_layer=4,
n_head=8,
output_dim=301,
).to(self.device)
self.SLP.eval()
self.SLP.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
)
def predict_duration(
self, pmt_wav_path, tar_text, pmt_text, dp_softmax_range=0.7, temperature=0
):
"""
Predict duration for target text based on prompt audio.
Args:
pmt_wav_path: Path to prompt audio
tar_text: Target text to generate
pmt_text: Prompt text
dp_softmax_range: softmax annliation range from rate-based duration
temperature: temperature for softmax sampling (if 0, will use argmax)
Returns:
Estimated duration in frames
"""
pmt_wav, sr = torchaudio.load(pmt_wav_path)
if sr != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
pmt_wav = resampler(pmt_wav)
if pmt_wav.size(0) > 1:
pmt_wav = pmt_wav[0].unsqueeze(0)
pmt_wav = pmt_wav.to(self.device)
pmt_mel = self.wav2mel(pmt_wav).permute(0, 2, 1)
tar_tokens = self._convert_to_pinyin(list(tar_text))
pmt_tokens = self._convert_to_pinyin(list(pmt_text))
# Calculate duration
ref_text_len = len(pmt_tokens)
gen_text_len = len(tar_tokens)
ref_audio_len = pmt_mel.size(1)
duration = int(ref_audio_len / ref_text_len * gen_text_len / speed)
duration = duration // 10
min_duration = max(int(duration * dp_softmax_range), 0)
max_duration = min(int(duration * (1 + dp_softmax_range)), 301)
all_tokens = pmt_tokens + [" "] + tar_tokens
text_ids = list_str_to_idx([all_tokens], self.vocab_char_map).to(self.device)
text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size)
with torch.no_grad():
predictions = self.SLP(text_ids=text_ids, mel=pmt_mel)
predictions = predictions[:, -1, :]
predictions[:, :min_duration] = float("-inf")
predictions[:, max_duration:] = float("-inf")
if temperature == 0:
est_label = predictions.argmax(-1)[..., -1].item() * 10
else:
probs = torch.softmax(predictions / temperature, dim=-1)
sampled_idx = torch.multinomial(
probs.squeeze(0), num_samples=1
) # Remove the -1 index
est_label = sampled_idx.item() * 10
return est_label
def _convert_to_pinyin(self, char_list):
"""Convert character list to pinyin."""
result = []
for x in convert_char_to_pinyin(char_list):
result = result + x
while result[0] == " " and len(result) > 1:
result = result[1:]
return result
def generate(
self,
gen_text,
audio_path,
prompt_text=None,
teacher_steps=16,
teacher_stopping_time=0.07,
student_start_step=1,
duration=None,
dp_softmax_range=0.7,
temperature=0,
eta=1.0,
cfg_strength=2.0,
sway_coefficient=-1.0,
verbose=False,
):
"""
Generate speech from text using teacher-student distillation.
Args:
gen_text: Text to generate
audio_path: Path to prompt audio
prompt_text: Prompt text (if None, will use ASR)
teacher_steps: Number of teacher guidance steps
teacher_stopping_time: When to stop teacher sampling
student_start_step: When to start student sampling
duration: Total duration (if None, will predict)
dp_softmax_range: Duration predictor softmax range allowed around rate based duration
temperature: Temperature for duration predictor sampling (0 means use argmax)
eta: Stochasticity control (0=DDIM, 1=DDPM)
cfg_strength: Classifier-free guidance strength
sway_coefficient: Sway sampling coefficient
verbose: Output sampling steps
Returns:
Generated audio waveform
"""
if prompt_text is None:
prompt_text = transcribe(audio_path)
# Predict duration if not provided
if duration is None:
duration = self.predict_duration(
audio_path, gen_text, prompt_text, dp_softmax_range, temperature
)
# Preprocess audio and text
ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text)
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
# Normalize audio
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
audio = resampler(audio)
audio = audio.to(self.device)
# Prepare text
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
# Calculate durations
ref_audio_len = audio.shape[-1] // self.hop_length
if duration is None:
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(
ref_audio_len / ref_text_len * gen_text_len / speed
)
else:
duration = ref_audio_len + duration
if verbose:
print("audio:", audio.shape)
print("text:", final_text_list)
print("duration:", duration)
print("eta (stochasticity):", eta) # Print eta value for debugging
# Run inference
with torch.inference_mode():
cond, text, step_cond, cond_mask, max_duration, duration_tensor = (
self._prepare_inputs(audio, final_text_list, duration)
)
# Teacher-student sampling
if teacher_steps > 0 and student_start_step > 0:
if verbose:
print(
"Start teacher sampling with hybrid DDIM/DDPM (eta={})....".format(
eta
)
)
x1 = self._teacher_sampling(
step_cond,
text,
cond_mask,
max_duration,
duration_tensor, # Use duration_tensor
teacher_steps,
teacher_stopping_time,
eta,
cfg_strength,
verbose,
sway_coefficient,
)
else:
x1 = step_cond
if verbose:
print("Start student sampling...")
# Student sampling
x1 = self._student_sampling(
x1, cond, text, student_start_step, verbose, sway_coefficient
)
# Decode to audio
mel = x1.permute(0, 2, 1) * self.scale
generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :])
return generated_wave.cpu().numpy().squeeze()
def generate_teacher_only(
self,
gen_text,
audio_path,
prompt_text=None,
teacher_steps=32,
duration=None,
eta=1.0,
cfg_strength=2.0,
sway_coefficient=-1.0,
):
"""
Generate speech using teacher model only (no student distillation).
Args:
gen_text: Text to generate
audio_path: Path to prompt audio
prompt_text: Prompt text (if None, will use ASR)
teacher_steps: Number of sampling steps
duration: Total duration (if None, will predict)
eta: Stochasticity control (0=DDIM, 1=DDPM)
cfg_strength: Classifier-free guidance strength
sway_coefficient: Sway sampling coefficient
Returns:
Generated audio waveform
"""
if prompt_text is None:
prompt_text = transcribe(audio_path)
# Predict duration if not provided
if duration is None:
duration = self.predict_duration(audio_path, gen_text, prompt_text)
# Preprocess audio and text
ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text)
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
# Normalize audio
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
audio = resampler(audio)
audio = audio.to(self.device)
# Prepare text
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
# Calculate durations
ref_audio_len = audio.shape[-1] // self.hop_length
if duration is None:
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(
ref_audio_len / ref_text_len * gen_text_len / speed
)
else:
duration = ref_audio_len + duration
# Run inference
with torch.inference_mode():
cond, text, step_cond, cond_mask, max_duration = self._prepare_inputs(
audio, final_text_list, duration
)
# Teacher-only sampling
x1 = self._teacher_sampling(
step_cond,
text,
cond_mask,
max_duration,
duration,
teacher_steps,
1.0,
eta,
cfg_strength,
sway_coefficient, # stopping_time=1.0 for full sampling
)
# Decode to audio
mel = x1.permute(0, 2, 1) * self.scale
generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :])
return generated_wave
def _prepare_inputs(self, audio, text_list, duration):
"""Prepare inputs for generation."""
lens = None
max_duration_limit = 4096
cond = audio
text = text_list
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = cond.permute(0, 2, 1)
assert cond.shape[-1] == 100
cond = cond / self.scale
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
# Process text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(dim=-1)
lens = torch.maximum(text_lens, lens)
# Process duration
cond_mask = lens_to_mask(lens)
if isinstance(duration, int):
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(lens + 1, duration)
duration = duration.clamp(max=max_duration_limit)
max_duration = duration.amax()
# Pad conditioning
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
cond_mask = F.pad(
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
)
cond_mask = cond_mask.unsqueeze(-1)
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
return cond, text, step_cond, cond_mask, max_duration, duration
def _teacher_sampling(
self,
step_cond,
text,
cond_mask,
max_duration,
duration,
teacher_steps,
teacher_stopping_time,
eta,
cfg_strength,
verbose,
sway_sampling_coef=-1,
):
"""Perform teacher model sampling."""
device = step_cond.device
# Pre-generate noise sequence for stochastic sampling
noise_seq = None
if eta > 0:
noise_seq = [
torch.randn(1, max_duration, 100, device=device)
for _ in range(teacher_steps)
]
def fn(t, x):
with torch.inference_mode():
with torch.autocast(device_type="cuda", dtype=torch.float16):
if verbose:
print(f"current t: {t}")
step_frac = 1.0 - t.item()
step_idx = (
min(int(step_frac * len(noise_seq)), len(noise_seq) - 1)
if noise_seq
else 0
)
# Predict flow
pred = self.teacher(
x=x,
cond=step_cond,
text=text,
time=t,
mask=None,
drop_audio_cond=False,
drop_text=False,
)
if cfg_strength > 1e-5:
null_pred = self.teacher(
x=x,
cond=step_cond,
text=text,
time=t,
mask=None,
drop_audio_cond=True,
drop_text=True,
)
pred = pred + (pred - null_pred) * cfg_strength
# Add stochasticity if eta > 0
if eta > 0 and noise_seq is not None:
alpha_t = 1.0 - t.item()
sigma_t = t.item()
noise_scale = torch.sqrt(
torch.tensor(
(sigma_t**2) / (alpha_t**2 + sigma_t**2) * eta,
device=device,
)
)
return pred + noise_scale * noise_seq[step_idx]
else:
return pred
# Initialize noise
y0 = []
for dur in duration:
y0.append(torch.randn(dur, 100, device=device, dtype=step_cond.dtype))
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
# Setup time steps
t = torch.linspace(
0, 1, teacher_steps + 1, device=device, dtype=step_cond.dtype
)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
t = t[: (t > teacher_stopping_time).float().argmax() + 2]
t = t[:-1]
# Solve ODE
trajectory = odeint(fn, y0, t, method="euler")
if teacher_stopping_time < 1.0:
# If early stopping, compute final step
pred = fn(t[-1], trajectory[-1])
test_out = trajectory[-1] + (1 - t[-1]) * pred
return test_out
else:
return trajectory[-1]
def _student_sampling(
self, x1, cond, text, student_start_step, verbose, sway_coeff=-1
):
"""Perform student model sampling."""
steps = torch.Tensor([0, 0.25, 0.5, 0.75])
steps = steps + sway_coeff * (torch.cos(torch.pi / 2 * steps) - 1 + steps)
steps = steps[student_start_step:]
for step in steps:
time = torch.Tensor([step]).to(x1.device)
x0 = torch.randn_like(x1)
t = time.unsqueeze(-1).unsqueeze(-1)
phi = (1 - t) * x0 + t * x1
if verbose:
print(f"current step: {step}")
with torch.no_grad():
pred = self.generator(
x=phi,
cond=cond,
text=text,
time=time,
drop_audio_cond=False,
drop_text=False,
)
# Predicted mel spectrogram
output = phi + (1 - t) * pred
x1 = output
return x1