IndexTTS-2-Demo / tests /padding_test.py
kemuriririn's picture
init
fba9477
import torch
import torchaudio
from indextts.infer import IndexTTS
from indextts.utils.feature_extractors import MelSpectrogramFeatures
from torch.nn import functional as F
if __name__ == "__main__":
"""
Test the padding of text tokens in inference.
```
python tests/padding_test.py checkpoints
python tests/padding_test.py IndexTTS-1.5
```
"""
import transformers
transformers.set_seed(42)
import sys
sys.path.append("..")
if len(sys.argv) > 1:
model_dir = sys.argv[1]
else:
model_dir = "checkpoints"
audio_prompt="tests/sample_prompt.wav"
tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False)
text = "晕 XUAN4 是 一 种 not very good GAN3 觉"
text_tokens = tts.tokenizer.encode(text)
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L]
audio, sr = torchaudio.load(audio_prompt)
audio = torch.mean(audio, dim=0, keepdim=True)
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
auto_conditioning = MelSpectrogramFeatures()(audio).to(tts.device)
cond_mel_lengths = torch.tensor([auto_conditioning.shape[-1]]).to(tts.device)
with torch.no_grad():
kwargs = {
"cond_mel_lengths": cond_mel_lengths,
"do_sample": False,
"top_p": 0.8,
"top_k": None,
"temperature": 1.0,
"num_return_sequences": 1,
"length_penalty": 0.0,
"num_beams": 1,
"repetition_penalty": 10.0,
"max_generate_length": 100,
}
# baseline for non-pad
baseline = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs)
baseline = baseline.squeeze(0)
print("Inference padded text tokens...")
pad_text_tokens = [
F.pad(text_tokens, (8, 0), value=0), # left bos
F.pad(text_tokens, (0, 8), value=1), # right eos
F.pad(F.pad(text_tokens, (4, 0), value=0), (0, 4), value=1), # both side
F.pad(F.pad(text_tokens, (6, 0), value=0), (0, 2), value=1),
F.pad(F.pad(text_tokens, (0, 4), value=0), (0, 4), value=1),
]
output_for_padded = []
for t in pad_text_tokens:
# test for each padded text
out = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs)
output_for_padded.append(out.squeeze(0))
# batched inference
print("Inference padded text tokens as one batch...")
batched_text_tokens = torch.cat(pad_text_tokens, dim=0).to(tts.device)
assert len(pad_text_tokens) == batched_text_tokens.shape[0] and batched_text_tokens.ndim == 2
batch_output = tts.gpt.inference_speech(auto_conditioning, batched_text_tokens, **kwargs)
del pad_text_tokens
mismatch_idx = []
print("baseline:", baseline.shape, baseline)
print("--"*10)
print("baseline vs padded output:")
for i in range(len(output_for_padded)):
if not baseline.equal(output_for_padded[i]):
mismatch_idx.append(i)
if len(mismatch_idx) > 0:
print("mismatch:", mismatch_idx)
for i in mismatch_idx:
print(f"[{i}]: {output_for_padded[i]}")
else:
print("all matched")
del output_for_padded
print("--"*10)
print("baseline vs batched output:")
mismatch_idx = []
for i in range(batch_output.shape[0]):
if not baseline.equal(batch_output[i]):
mismatch_idx.append(i)
if len(mismatch_idx) > 0:
print("mismatch:", mismatch_idx)
for i in mismatch_idx:
print(f"[{i}]: {batch_output[i]}")
else:
print("all matched")
print("Test finished.")