WenetSpeech-Yue-TTS / cosyvoice /dataset /custom_processor.py
gz412's picture
add base and zjg
a3f711d
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
import json
from typing import Iterable, Dict, Any
def json_line_opener(data: Iterable[Dict[str, Any]], mode='train'):
"""
data: Iterable[dict] 来自 DataList,里面只有 {'src': 'xxx.txt'}
逐行读取 json,yield 出 {'key', 'text', 'text_token', 'speech_token'}
"""
for sample in data:
txt_path = sample['src']
with open(txt_path, 'r', encoding='utf-8') as f:
for line in f:
if not line.strip():
continue
js = json.loads(line)
yield {
'key': js['key'],
'text': js['txt'],
'text_token': js['txt'], # 先保留原文本,tokenize 阶段再转 id
'speech_token': js['code'], # 已经是 list[int]
}
# def parquet_opener(data, mode='train', tts_data={}):
# """ Give url or local file, return file descriptor
# Inplace operation.
# Args:
# data(Iterable[str]): url or local file list
# Returns:
# Iterable[{src, stream}]
# """
# for sample in data:
# assert 'src' in sample
# url = sample['src']
# try:
# for df in pq.ParquetFile(url).iter_batches(batch_size=64):
# df = df.to_pandas()
# for i in range(len(df)):
# sample.update(dict(df.loc[i]))
# if mode == 'train':
# # NOTE do not return sample directly, must initialize a new dict
# yield {**sample}
# else:
# for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
# yield {**sample, 'tts_index': index, 'tts_text': text}
# except Exception as ex:
# logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['text_token']) < token_min_length:
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['text_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['speech']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
sample['speech'] = waveform
yield sample
def compute_fbank(data,
feat_extractor,
token_mel_ratio=0,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
if token_mel_ratio != 0:
# trim to align speech_token and speech_feat
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
feat = feat[:token_mel_ratio * token_len]
sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_feat'] = feat
yield sample
def compute_f0(data, sample_rate, hop_size, mode='train'):
""" Extract f0
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
frame_period = hop_size * 1000 / sample_rate
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
sample['pitch_feat'] = f0
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: len(x['speech_token']))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: len(x['speech_token']))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
# assert 'speech_token' in sample
# assert isinstance(sample['speech_token'], torch.Tensor)
new_sample_frames = len(sample['speech_token'])
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
import torch.distributed as dist
def padding(data, **kw):
"""
padding 同时也承担“空 rank 补偿”职责:
如果本 rank 没有产出任何 batch,就 yield 一条 dummy,
保证所有 rank 的 DataLoader 迭代次数一致。
"""
real_yield = False
for batch in data:
real_yield = True
keys = [x['key'] for x in batch]
text_token = [torch.tensor(x['text_token'], dtype=torch.long) for x in batch]
speech_token = [torch.tensor(x['speech_token'], dtype=torch.long) for x in batch]
text_token_len = torch.tensor([t.size(0) for t in text_token], dtype=torch.long)
speech_token_len = torch.tensor([s.size(0) for s in speech_token], dtype=torch.long)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
speech_token = pad_sequence(speech_token, batch_first=True, padding_value=0)
yield {
'key': keys,
'text_token': text_token,
'text_token_len': text_token_len,
'speech_token': speech_token,
'speech_token_len': speech_token_len,
}
# 如果本 rank 没产出任何 batch
if dist.is_initialized() and not real_yield:
dummy = {
'key': ['dummy'],
'text_token': torch.zeros(1, 1, dtype=torch.long),
'text_token_len': torch.tensor([1]),
'speech_token': torch.zeros(1, 1, dtype=torch.long),
'speech_token_len': torch.tensor([1]),
}
yield dummy
# def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
# """ Padding the data into training data
# Args:
# data: Iterable[List[{key, feat, label}]]
# Returns:
# Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
# """
# for sample in data:
# assert isinstance(sample, list)
# speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
# dtype=torch.int32)
# order = torch.argsort(speech_feat_len, descending=True)
# utts = [sample[i]['utt'] for i in order]
# speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
# speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
# speech = pad_sequence(speech, batch_first=True, padding_value=0)
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
# speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
# speech_token = pad_sequence(speech_token,
# batch_first=True,
# padding_value=0)
# speech_feat = [sample[i]['speech_feat'] for i in order]
# speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
# speech_feat = pad_sequence(speech_feat,
# batch_first=True,
# padding_value=0)
# text = [sample[i]['text'] for i in order]
# text_token = [torch.tensor(sample[i]['text_token']) for i in order]
# text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
# text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
# utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
# spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
# batch = {
# "utts": utts,
# "speech": speech,
# "speech_len": speech_len,
# "speech_token": speech_token,
# "speech_token_len": speech_token_len,
# "speech_feat": speech_feat,
# "speech_feat_len": speech_feat_len,
# "text": text,
# "text_token": text_token,
# "text_token_len": text_token_len,
# "utt_embedding": utt_embedding,
# "spk_embedding": spk_embedding,
# }
# if gan is True:
# # in gan train, we need pitch_feat
# pitch_feat = [sample[i]['pitch_feat'] for i in order]
# pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
# pitch_feat = pad_sequence(pitch_feat,
# batch_first=True,
# padding_value=0)
# batch["pitch_feat"] = pitch_feat
# batch["pitch_feat_len"] = pitch_feat_len
# else:
# # only gan train needs speech, delete it to save memory
# del batch["speech"]
# del batch["speech_len"]
# if dpo is True:
# reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
# reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
# reject_speech_token = pad_sequence(reject_speech_token,
# batch_first=True,
# padding_value=0)
# batch['reject_speech_token'] = reject_speech_token
# batch['reject_speech_token_len'] = reject_speech_token_len
# if use_spk_embedding is True:
# batch["embedding"] = batch["spk_embedding"]
# else:
# batch["embedding"] = batch["utt_embedding"]
# yield batch