Spaces:
Configuration error
Configuration error
import os | |
import torch | |
import torchaudio | |
import numpy as np | |
from pathlib import Path | |
from typing import Optional, Union, List, Tuple, Dict | |
from cached_path import cached_path | |
from hydra.utils import get_class | |
from omegaconf import OmegaConf | |
from importlib.resources import files | |
from pydub import AudioSegment, silence | |
from f5_tts.model import CFM | |
from f5_tts.model.utils import ( | |
get_tokenizer, | |
convert_char_to_pinyin, | |
) | |
from f5_tts.infer.utils_infer import ( | |
chunk_text, | |
load_vocoder, | |
transcribe, | |
initialize_asr_pipeline, | |
) | |
class F5TTSWrapper: | |
""" | |
A wrapper class for F5-TTS that preprocesses reference audio once | |
and allows for repeated TTS generation. | |
""" | |
def __init__( | |
self, | |
model_name: str = "F5TTS_v1_Base", | |
ckpt_path: Optional[str] = None, | |
vocab_file: Optional[str] = None, | |
vocoder_name: str = "vocos", | |
use_local_vocoder: bool = False, | |
vocoder_path: Optional[str] = None, | |
device: Optional[str] = None, | |
hf_cache_dir: Optional[str] = None, | |
target_sample_rate: int = 24000, | |
n_mel_channels: int = 100, | |
hop_length: int = 256, | |
win_length: int = 1024, | |
n_fft: int = 1024, | |
ode_method: str = "euler", | |
use_ema: bool = True, | |
): | |
""" | |
Initialize the F5-TTS wrapper with model configuration. | |
Args: | |
model_name: Name of the F5-TTS model variant (e.g., "F5TTS_v1_Base") | |
ckpt_path: Path to the model checkpoint file. If None, will use default path. | |
vocab_file: Path to the vocab file. If None, will use default. | |
vocoder_name: Name of the vocoder to use ("vocos" or "bigvgan") | |
use_local_vocoder: Whether to use a local vocoder or download from HF | |
vocoder_path: Path to the local vocoder. Only used if use_local_vocoder is True. | |
device: Device to run the model on. If None, will automatically determine. | |
hf_cache_dir: Directory to cache HuggingFace models | |
target_sample_rate: Target sample rate for audio | |
n_mel_channels: Number of mel channels | |
hop_length: Hop length for the mel spectrogram | |
win_length: Window length for the mel spectrogram | |
n_fft: FFT size for the mel spectrogram | |
ode_method: ODE method for sampling ("euler" or "midpoint") | |
use_ema: Whether to use EMA weights from the checkpoint | |
""" | |
# Set device | |
if device is None: | |
self.device = ( | |
"cuda" if torch.cuda.is_available() | |
else "xpu" if torch.xpu.is_available() | |
else "mps" if torch.backends.mps.is_available() | |
else "cpu" | |
) | |
else: | |
self.device = device | |
# Audio processing parameters | |
self.target_sample_rate = target_sample_rate | |
self.n_mel_channels = n_mel_channels | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.n_fft = n_fft | |
self.mel_spec_type = vocoder_name | |
# Sampling parameters | |
self.ode_method = ode_method | |
# Initialize ASR for transcription if needed | |
initialize_asr_pipeline(device=self.device) | |
# Load model configuration | |
if ckpt_path is None: | |
repo_name = "F5-TTS" | |
ckpt_step = 1250000 | |
ckpt_type = "safetensors" | |
# Adjust for previous models | |
if model_name == "F5TTS_Base": | |
if vocoder_name == "vocos": | |
ckpt_step = 1200000 | |
elif vocoder_name == "bigvgan": | |
model_name = "F5TTS_Base_bigvgan" | |
ckpt_type = "pt" | |
elif model_name == "E2TTS_Base": | |
repo_name = "E2-TTS" | |
ckpt_step = 1200000 | |
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{model_name}/model_{ckpt_step}.{ckpt_type}")) | |
# Load model configuration | |
config_path = str(files("f5_tts").joinpath(f"configs/{model_name}.yaml")) | |
model_cfg = OmegaConf.load(config_path) | |
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") | |
model_arc = model_cfg.model.arch | |
# Load tokenizer | |
if vocab_file is None: | |
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt")) | |
tokenizer_type = "custom" | |
self.vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer_type) | |
# Create model | |
self.model = CFM( | |
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), | |
mel_spec_kwargs=dict( | |
n_fft=n_fft, | |
hop_length=hop_length, | |
win_length=win_length, | |
n_mel_channels=n_mel_channels, | |
target_sample_rate=target_sample_rate, | |
mel_spec_type=vocoder_name, | |
), | |
odeint_kwargs=dict( | |
method=ode_method, | |
), | |
vocab_char_map=self.vocab_char_map, | |
).to(self.device) | |
# Load checkpoint | |
dtype = torch.float32 if vocoder_name == "bigvgan" else None | |
self._load_checkpoint(self.model, ckpt_path, dtype=dtype, use_ema=use_ema) | |
# Load vocoder | |
if vocoder_path is None: | |
if vocoder_name == "vocos": | |
vocoder_path = "../checkpoints/vocos-mel-24khz" | |
elif vocoder_name == "bigvgan": | |
vocoder_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" | |
self.vocoder = load_vocoder( | |
vocoder_name=vocoder_name, | |
is_local=use_local_vocoder, | |
local_path=vocoder_path, | |
device=self.device, | |
hf_cache_dir=hf_cache_dir | |
) | |
# Storage for reference data | |
self.ref_audio_processed = None | |
self.ref_text = None | |
self.ref_audio_len = None | |
# Default inference parameters | |
self.target_rms = 0.1 | |
self.cross_fade_duration = 0.15 | |
self.nfe_step = 32 | |
self.cfg_strength = 2.0 | |
self.sway_sampling_coef = -1.0 | |
self.speed = 1.0 | |
self.fix_duration = None | |
def _load_checkpoint(self, model, ckpt_path, dtype=None, use_ema=True): | |
""" | |
Load model checkpoint with proper handling of different checkpoint formats. | |
Args: | |
model: The model to load weights into | |
ckpt_path: Path to the checkpoint file | |
dtype: Data type for model weights | |
use_ema: Whether to use EMA weights from the checkpoint | |
Returns: | |
Loaded model | |
""" | |
if dtype is None: | |
dtype = ( | |
torch.float16 | |
if "cuda" in self.device | |
and torch.cuda.get_device_properties(self.device).major >= 7 | |
and not torch.cuda.get_device_name().endswith("[ZLUDA]") | |
else torch.float32 | |
) | |
model = model.to(dtype) | |
ckpt_type = ckpt_path.split(".")[-1] | |
if ckpt_type == "safetensors": | |
from safetensors.torch import load_file | |
checkpoint = load_file(ckpt_path, device=self.device) | |
else: | |
checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=True) | |
if use_ema: | |
if ckpt_type == "safetensors": | |
checkpoint = {"ema_model_state_dict": checkpoint} | |
checkpoint["model_state_dict"] = { | |
k.replace("ema_model.", ""): v | |
for k, v in checkpoint["ema_model_state_dict"].items() | |
if k not in ["initted", "step"] | |
} | |
# patch for backward compatibility | |
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: | |
if key in checkpoint["model_state_dict"]: | |
del checkpoint["model_state_dict"][key] | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
else: | |
if ckpt_type == "safetensors": | |
checkpoint = {"model_state_dict": checkpoint} | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
del checkpoint | |
torch.cuda.empty_cache() | |
return model.to(self.device) | |
def preprocess_reference(self, ref_audio_path: str, ref_text: str = "", clip_short: bool = True): | |
""" | |
Preprocess the reference audio and text, storing them for later use. | |
Args: | |
ref_audio_path: Path to the reference audio file | |
ref_text: Text transcript of reference audio. If empty, will auto-transcribe. | |
clip_short: Whether to clip long audio to shorter segments | |
Returns: | |
Tuple of processed audio and text | |
""" | |
print("Converting audio...") | |
# Load audio file | |
aseg = AudioSegment.from_file(ref_audio_path) | |
if clip_short: | |
# 1. try to find long silence for clipping | |
non_silent_segs = silence.split_on_silence( | |
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 | |
) | |
non_silent_wave = AudioSegment.silent(duration=0) | |
for non_silent_seg in non_silent_segs: | |
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: | |
print("Audio is over 12s, clipping short. (1)") | |
break | |
non_silent_wave += non_silent_seg | |
# 2. try to find short silence for clipping if 1. failed | |
if len(non_silent_wave) > 12000: | |
non_silent_segs = silence.split_on_silence( | |
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 | |
) | |
non_silent_wave = AudioSegment.silent(duration=0) | |
for non_silent_seg in non_silent_segs: | |
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: | |
print("Audio is over 12s, clipping short. (2)") | |
break | |
non_silent_wave += non_silent_seg | |
aseg = non_silent_wave | |
# 3. if no proper silence found for clipping | |
if len(aseg) > 12000: | |
aseg = aseg[:12000] | |
print("Audio is over 12s, clipping short. (3)") | |
# Remove silence edges | |
aseg = self._remove_silence_edges(aseg) + AudioSegment.silent(duration=50) | |
# Export to temporary file and load as tensor | |
import tempfile | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
aseg.export(tmp_file.name, format="wav") | |
processed_audio_path = tmp_file.name | |
# Transcribe if needed | |
if not ref_text.strip(): | |
print("No reference text provided, transcribing reference audio...") | |
ref_text = transcribe(processed_audio_path) | |
else: | |
print("Using custom reference text...") | |
# Ensure ref_text ends with proper punctuation | |
if not ref_text.endswith(". ") and not ref_text.endswith("。"): | |
if ref_text.endswith("."): | |
ref_text += " " | |
else: | |
ref_text += ". " | |
print("\nReference text:", ref_text) | |
# Load and process audio | |
audio, sr = torchaudio.load(processed_audio_path) | |
if audio.shape[0] > 1: # Convert stereo to mono | |
audio = torch.mean(audio, dim=0, keepdim=True) | |
# Normalize volume | |
rms = torch.sqrt(torch.mean(torch.square(audio))) | |
if rms < self.target_rms: | |
audio = audio * self.target_rms / rms | |
# Resample if needed | |
if sr != self.target_sample_rate: | |
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) | |
audio = resampler(audio) | |
# Move to device | |
audio = audio.to(self.device) | |
# Store reference data | |
self.ref_audio_processed = audio | |
self.ref_text = ref_text | |
self.ref_audio_len = audio.shape[-1] // self.hop_length | |
# Remove temporary file | |
os.unlink(processed_audio_path) | |
return audio, ref_text | |
def _remove_silence_edges(self, audio, silence_threshold=-42): | |
""" | |
Remove silence from the start and end of audio. | |
Args: | |
audio: AudioSegment to process | |
silence_threshold: dB threshold to consider as silence | |
Returns: | |
Processed AudioSegment | |
""" | |
# Remove silence from the start | |
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold) | |
audio = audio[non_silent_start_idx:] | |
# Remove silence from the end | |
non_silent_end_duration = audio.duration_seconds | |
for ms in reversed(audio): | |
if ms.dBFS > silence_threshold: | |
break | |
non_silent_end_duration -= 0.001 | |
trimmed_audio = audio[: int(non_silent_end_duration * 1000)] | |
return trimmed_audio | |
def generate( | |
self, | |
text: str, | |
output_path: Optional[str] = None, | |
nfe_step: Optional[int] = None, | |
cfg_strength: Optional[float] = None, | |
sway_sampling_coef: Optional[float] = None, | |
speed: Optional[float] = None, | |
fix_duration: Optional[float] = None, | |
cross_fade_duration: Optional[float] = None, | |
return_numpy: bool = False, | |
return_spectrogram: bool = False, | |
) -> Union[str, Tuple[np.ndarray, int], Tuple[np.ndarray, int, np.ndarray]]: | |
""" | |
Generate speech for the given text using the stored reference audio. | |
Args: | |
text: Text to synthesize | |
output_path: Path to save the generated audio. If None, won't save. | |
nfe_step: Number of function evaluation steps | |
cfg_strength: Classifier-free guidance strength | |
sway_sampling_coef: Sway sampling coefficient | |
speed: Speed of generated audio | |
fix_duration: Fixed duration in seconds | |
cross_fade_duration: Duration of cross-fade between segments | |
return_numpy: If True, returns the audio as a numpy array | |
return_spectrogram: If True, also returns the spectrogram | |
Returns: | |
If output_path provided: path to output file | |
If return_numpy=True: tuple of (audio_array, sample_rate) | |
If return_spectrogram=True: tuple of (audio_array, sample_rate, spectrogram) | |
""" | |
if self.ref_audio_processed is None or self.ref_text is None: | |
raise ValueError("Reference audio not preprocessed. Call preprocess_reference() first.") | |
# Use default values if not specified | |
nfe_step = nfe_step if nfe_step is not None else self.nfe_step | |
cfg_strength = cfg_strength if cfg_strength is not None else self.cfg_strength | |
sway_sampling_coef = sway_sampling_coef if sway_sampling_coef is not None else self.sway_sampling_coef | |
speed = speed if speed is not None else self.speed | |
fix_duration = fix_duration if fix_duration is not None else self.fix_duration | |
cross_fade_duration = cross_fade_duration if cross_fade_duration is not None else self.cross_fade_duration | |
# Split the input text into batches | |
audio_len = self.ref_audio_processed.shape[-1] / self.target_sample_rate | |
max_chars = int(len(self.ref_text.encode("utf-8")) / audio_len * (22 - audio_len)) | |
text_batches = chunk_text(text, max_chars=max_chars) | |
for i, text_batch in enumerate(text_batches): | |
print(f"Text batch {i}: {text_batch}") | |
print("\n") | |
# Generate audio for each batch | |
generated_waves = [] | |
spectrograms = [] | |
for text_batch in text_batches: | |
# Adjust speed for very short texts | |
local_speed = speed | |
if len(text_batch.encode("utf-8")) < 10: | |
local_speed = 0.3 | |
# Prepare the text | |
text_list = [self.ref_text + text_batch] | |
final_text_list = convert_char_to_pinyin(text_list) | |
# Calculate duration | |
if fix_duration is not None: | |
duration = int(fix_duration * self.target_sample_rate / self.hop_length) | |
else: | |
# Calculate duration based on text length | |
ref_text_len = len(self.ref_text.encode("utf-8")) | |
gen_text_len = len(text_batch.encode("utf-8")) | |
duration = self.ref_audio_len + int(self.ref_audio_len / ref_text_len * gen_text_len / local_speed) | |
# Generate audio | |
with torch.inference_mode(): | |
generated, _ = self.model.sample( | |
cond=self.ref_audio_processed, | |
text=final_text_list, | |
duration=duration, | |
steps=nfe_step, | |
cfg_strength=cfg_strength, | |
sway_sampling_coef=sway_sampling_coef, | |
) | |
# Process the generated mel spectrogram | |
generated = generated.to(torch.float32) | |
generated = generated[:, self.ref_audio_len:, :] | |
generated = generated.permute(0, 2, 1) | |
# Convert to audio | |
if self.mel_spec_type == "vocos": | |
generated_wave = self.vocoder.decode(generated) | |
elif self.mel_spec_type == "bigvgan": | |
generated_wave = self.vocoder(generated) | |
# Normalize volume if needed | |
rms = torch.sqrt(torch.mean(torch.square(self.ref_audio_processed))) | |
if rms < self.target_rms: | |
generated_wave = generated_wave * rms / self.target_rms | |
# Convert to numpy and append to list | |
generated_wave = generated_wave.squeeze().cpu().numpy() | |
generated_waves.append(generated_wave) | |
# Store spectrogram if needed | |
if return_spectrogram or output_path is not None: | |
spectrograms.append(generated.squeeze().cpu().numpy()) | |
# Combine all segments | |
if generated_waves: | |
if cross_fade_duration <= 0: | |
# Simply concatenate | |
final_wave = np.concatenate(generated_waves) | |
else: | |
# Cross-fade between segments | |
final_wave = generated_waves[0] | |
for i in range(1, len(generated_waves)): | |
prev_wave = final_wave | |
next_wave = generated_waves[i] | |
# Calculate cross-fade samples | |
cross_fade_samples = int(cross_fade_duration * self.target_sample_rate) | |
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) | |
if cross_fade_samples <= 0: | |
# No overlap possible, concatenate | |
final_wave = np.concatenate([prev_wave, next_wave]) | |
continue | |
# Create cross-fade | |
prev_overlap = prev_wave[-cross_fade_samples:] | |
next_overlap = next_wave[:cross_fade_samples] | |
fade_out = np.linspace(1, 0, cross_fade_samples) | |
fade_in = np.linspace(0, 1, cross_fade_samples) | |
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in | |
final_wave = np.concatenate([ | |
prev_wave[:-cross_fade_samples], | |
cross_faded_overlap, | |
next_wave[cross_fade_samples:] | |
]) | |
# Combine spectrograms if needed | |
if return_spectrogram or output_path is not None: | |
combined_spectrogram = np.concatenate(spectrograms, axis=1) | |
# Save to file if path provided | |
if output_path is not None: | |
output_dir = os.path.dirname(output_path) | |
if output_dir and not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# Save audio | |
torchaudio.save(output_path, | |
torch.tensor(final_wave).unsqueeze(0), | |
self.target_sample_rate) | |
# Save spectrogram if needed | |
if return_spectrogram: | |
spectrogram_path = os.path.splitext(output_path)[0] + '_spec.png' | |
self._save_spectrogram(combined_spectrogram, spectrogram_path) | |
if not return_numpy: | |
return output_path | |
# Return as requested | |
if return_spectrogram: | |
return final_wave, self.target_sample_rate, combined_spectrogram | |
else: | |
return final_wave, self.target_sample_rate | |
else: | |
raise RuntimeError("No audio generated") | |
def _save_spectrogram(self, spectrogram, path): | |
"""Save spectrogram as image""" | |
import matplotlib.pyplot as plt | |
plt.figure(figsize=(12, 4)) | |
plt.imshow(spectrogram, origin="lower", aspect="auto") | |
plt.colorbar() | |
plt.savefig(path) | |
plt.close() | |
def get_current_audio_length(self): | |
"""Get the length of the reference audio in seconds""" | |
if self.ref_audio_processed is None: | |
return 0 | |
return self.ref_audio_processed.shape[-1] / self.target_sample_rate | |