Spaces:
Runtime error
Runtime error
import os | |
import re | |
import torch | |
import torchaudio | |
import numpy as np | |
from transformers import AutoTokenizer | |
from modeling_asteroid import AsteroidTTSInstruct | |
from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer | |
MAX_CHANNELS = 8 | |
SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds | |
def load_model(model_path, spt_config_path, spt_checkpoint_path): | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa") | |
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path) | |
model.eval() | |
spt.eval() | |
return tokenizer, model, spt | |
def process_jsonl_item(item): | |
"""Process JSONL data items and extract audio and text information according to the new format""" | |
base_path = item.get("base_path", "") | |
text = item.get("text", "") | |
# Process prompt audio and text | |
if "prompt_audio" in item and "prompt_text" in item: | |
print("Using prompt_audio and prompt_text directly from item.") | |
# If prompt_audio and prompt_text exist, use them directly | |
prompt_audio = item["prompt_audio"] | |
prompt_text = item["prompt_text"] | |
# Only perform path joining when prompt_audio is a string path | |
if isinstance(prompt_audio, str) and base_path and prompt_audio: | |
prompt_audio = os.path.join(base_path, prompt_audio) | |
else: | |
print("Using speaker1 and speaker2 information for prompt audio and text.") | |
# Otherwise, merge speaker1 and speaker2 information | |
prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "") | |
prompt_text_speaker1 = item.get("prompt_text_speaker1", "") | |
prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "") | |
prompt_text_speaker2 = item.get("prompt_text_speaker2", "") | |
# Process audio: if it's a string path, perform path joining; if it's a tuple, use directly | |
if isinstance(prompt_audio_speaker1, str): | |
speaker1_audio = os.path.join(base_path, prompt_audio_speaker1) if base_path and prompt_audio_speaker1 else prompt_audio_speaker1 | |
else: | |
speaker1_audio = prompt_audio_speaker1 # Use tuple directly | |
if isinstance(prompt_audio_speaker2, str): | |
speaker2_audio = os.path.join(base_path, prompt_audio_speaker2) if base_path and prompt_audio_speaker2 else prompt_audio_speaker2 | |
else: | |
speaker2_audio = prompt_audio_speaker2 # Use tuple directly | |
prompt_audio = { | |
"speaker1": speaker1_audio, | |
"speaker2": speaker2_audio | |
} | |
# Merge text | |
prompt_text = "" | |
if prompt_text_speaker1: | |
prompt_text += f"[S1]{prompt_text_speaker1}" | |
if prompt_text_speaker2: | |
prompt_text += f"[S2]{prompt_text_speaker2}" | |
prompt_text = prompt_text.strip() | |
return { | |
"text": text, | |
"prompt_text": prompt_text, | |
"prompt_audio": prompt_audio | |
} | |
def load_audio_data(prompt_audio, target_sample_rate=16000): | |
"""Load audio data and return processed audio tensor | |
Args: | |
prompt_audio: Can be in the following formats: | |
- String: audio file path | |
- Tuple: (wav, sr) result from torchaudio.load | |
- Dict: {"speaker1": path_or_tuple, "speaker2": path_or_tuple} | |
""" | |
if prompt_audio is None: | |
return None | |
try: | |
# Check if prompt_audio is a dictionary (containing speaker1 and speaker2) | |
if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio: | |
# Process audio from both speakers separately | |
wav1, sr1 = _load_single_audio(prompt_audio["speaker1"]) | |
wav2, sr2 = _load_single_audio(prompt_audio["speaker2"]) | |
# Merge audio from both speakers | |
wav = merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate) | |
if wav is None: | |
return None | |
else: | |
# Single audio | |
wav, sr = _load_single_audio(prompt_audio) | |
# Resample to 16k | |
if sr != target_sample_rate: | |
wav = torchaudio.functional.resample(wav, sr, target_sample_rate) | |
# Ensure mono channel | |
if wav.shape[0] > 1: | |
wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono | |
if len(wav.shape) == 1: | |
wav = wav.unsqueeze(0) | |
return wav | |
except Exception as e: | |
print(f"Error loading audio data: {e}") | |
raise | |
def _load_single_audio(audio_input): | |
"""Load single audio, supports file path or (wav, sr) tuple | |
Args: | |
audio_input: String (file path) or tuple (wav, sr) | |
Returns: | |
tuple: (wav, sr) | |
""" | |
if isinstance(audio_input, tuple) and len(audio_input) == 2: | |
# Already a (wav, sr) tuple | |
wav, sr = audio_input | |
return wav, sr | |
elif isinstance(audio_input, str): | |
# Is a file path, needs to be loaded | |
wav, sr = torchaudio.load(audio_input) | |
return wav, sr | |
else: | |
raise ValueError(f"Unsupported audio input format: {type(audio_input)}") | |
def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000): | |
"""Merge audio data from two speakers""" | |
try: | |
# Process first audio | |
if sr1 != target_sample_rate: | |
wav1 = torchaudio.functional.resample(wav1, sr1, target_sample_rate) | |
# Ensure mono channel | |
if wav1.shape[0] > 1: | |
wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono | |
if len(wav1.shape) == 1: | |
wav1 = wav1.unsqueeze(0) | |
# Process second audio | |
if sr2 != target_sample_rate: | |
wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate) | |
# Ensure mono channel | |
if wav2.shape[0] > 1: | |
wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono | |
if len(wav2.shape) == 1: | |
wav2 = wav2.unsqueeze(0) | |
# Concatenate audio | |
merged_wav = torch.cat([wav1, wav2], dim=1) | |
return merged_wav | |
except Exception as e: | |
print(f"Error merging audio: {e}") | |
raise | |
def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, max_channels=8, pad_token=1024): | |
seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>" | |
inputs1 = np.array(tokenizer.encode(seq)) | |
input_ids = np.full((inputs1.shape[0], max_channels), pad_token) | |
input_ids[:, 0] = inputs1 | |
if audio_data is not None: | |
try: | |
# audio_data should now be a processed audio tensor | |
wav = audio_data | |
# Add fixed 5-second silence at the end of audio (using 16k sample rate) | |
silence_samples = int(SILENCE_DURATION * 16000) | |
silence = torch.zeros(wav.shape[0], silence_samples) | |
wav = torch.cat([wav, silence], dim=1) | |
with torch.no_grad(): | |
# Use SPT encoding | |
encode_result = spt.encode([wav.squeeze().to(device)]) | |
audio_token = encode_result["codes_list"][0].permute(1, 0).cpu().numpy() # Adjust dimension order | |
# similar to DAC encoding adjustment | |
audio_token[:, 0] = audio_token[:, 0] + 151665 # Keep this line if offset is needed, otherwise delete | |
input_ids = np.concatenate([input_ids, audio_token]) | |
except Exception as e: | |
print(f"Error processing audio data: {e}") | |
raise | |
return input_ids | |
def shifting_inputs(input_ids, tokenizer, pad_token=1024, max_channels=8): | |
seq_len = input_ids.shape[0] | |
new_seq_len = seq_len + max_channels - 1 | |
shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64) | |
shifted_input_ids[:, 0] = np.full(new_seq_len, tokenizer.pad_token_id, dtype=np.int64) | |
for i in range(max_channels): | |
shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i] | |
return shifted_input_ids | |
def rpadding(input_ids, channels, tokenizer): | |
attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids] | |
max_length = max(ids.shape[0] for ids in input_ids) | |
padded_input_ids, padded_attns = [], [] | |
for ids, attn in zip(input_ids, attention_masks): | |
pad_len = max_length - ids.shape[0] | |
input_pad = np.full((pad_len, channels), 1024) | |
input_pad[:, 0] = tokenizer.pad_token_id | |
padded_input_ids.append(np.concatenate([input_pad, ids])) | |
attn_pad = np.zeros(pad_len) | |
padded_attns.append(np.concatenate([attn_pad, attn])) | |
input_ids = torch.tensor(np.stack(padded_input_ids)) | |
attention_mask = torch.tensor(np.stack(padded_attns)) | |
return input_ids, attention_mask | |
def find_max_valid_positions(C: torch.Tensor, invalid_value=1024) -> torch.Tensor: | |
values = C[:, :, 1] | |
mask = (values != invalid_value) | |
reversed_mask = mask.flip(dims=[1]) | |
reversed_indices = torch.argmax(reversed_mask.int(), dim=1) | |
seq_len = C.size(1) | |
original_indices = seq_len - 1 - reversed_indices | |
has_valid = mask.any(dim=1) | |
original_indices = torch.where(has_valid, original_indices, -1) | |
return original_indices | |
def normalize_text(text: str) -> str: | |
""" | |
Normalize multi-speaker script. | |
1. Don't preserve line breaks. | |
2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves). | |
3. Remove decorative symbols: 【】《》()『』「」"-“” . | |
4. Internal punctuation !;:、 → ,;only allow ? and ,。 | |
5. Multiple 。 keep only the last one, others → ,。 | |
6. Replace consecutive "哈" (>=2) with "(笑)". | |
7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment. | |
""" | |
# Replace [1], [2] etc. format with [S1], [S2] etc. format | |
text = re.sub(r'\[(\d+)\]', r'[S\1]', text) | |
# Remove decorative characters | |
remove_chars = "【】《》()『』「」""\"-“”" | |
# Remove brackets for non-speaker tags (keep content, only remove brackets themselves) | |
text = re.sub(r'\[(?!S\d+\])([^\]]*)\]', r'\1', text) | |
# Use positive lookahead to split text by speaker tags (tags themselves are still preserved) | |
segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " ")) | |
normalized_lines = [] | |
for seg in segments: | |
seg = seg.strip() | |
if not seg: | |
continue | |
# Extract tags | |
m = re.match(r'^(\[S\d+\])\s*(.*)', seg) | |
tag, content = m.groups() if m else ('', seg) | |
# Remove irrelevant symbols | |
content = re.sub(f"[{re.escape(remove_chars)}]", "", content) | |
# Handle consecutive "哈" characters: replace 2 or more with "(笑)" | |
content = re.sub(r'哈{2,}', '(笑)', content) | |
# First handle multi-character punctuation marks | |
content = content.replace('——', ',') | |
content = content.replace('……', ',') | |
# Handle single-character internal punctuation marks | |
internal_punct_map = str.maketrans({ | |
'!': ',', '!': ',', | |
';': ',', ';': ',', | |
':': ',', ':': ',', | |
'、': ',', | |
'?': ',', '?': ',' | |
}) | |
content = content.translate(internal_punct_map) | |
content = content.strip() | |
# Keep only the final period | |
if len(content) > 1: | |
last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) | |
body = content[:-1].replace('。', ',') | |
content = body + last_ch | |
normalized_lines.append(f"{tag}{content}".strip()) | |
return "".join(normalized_lines) | |
def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, start_idx, use_normalize=False): | |
"""Process a batch of data items and generate audio, return audio data and metadata""" | |
try: | |
# Prepare batch data | |
batch_size = len(batch_items) | |
texts = [] | |
prompts = [system_prompt] * batch_size | |
prompt_audios = [] | |
actual_texts_data = [] # Store actual text data used | |
print(f"Processing {batch_size} samples starting from index {start_idx}...") | |
# Extract text and audio from each sample | |
for i, item in enumerate(batch_items): | |
# Use new processing function | |
processed_item = process_jsonl_item(item) | |
text = processed_item["text"] | |
prompt_text = processed_item["prompt_text"] | |
# Merge text | |
full_text = prompt_text + text | |
original_full_text = full_text # Save original text | |
# Apply text normalization based on parameter | |
if use_normalize: | |
full_text = normalize_text(full_text) | |
# Replace speaker tags | |
final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>") | |
texts.append(final_text) | |
# Save actual text information used | |
actual_texts_data.append({ | |
"index": start_idx + i, | |
"original_text": original_full_text, | |
"normalized_text": normalize_text(original_full_text) if use_normalize else None, | |
"final_text": final_text, | |
"use_normalize": use_normalize | |
}) | |
# Get reference audio | |
prompt_audios.append(processed_item["prompt_audio"]) | |
# Process inputs | |
input_ids_list = [] | |
for i, (text, prompt, audio_path) in enumerate(zip(texts, prompts, prompt_audios)): | |
# Load audio data here | |
audio_data = load_audio_data(audio_path) if audio_path else None | |
inputs = process_inputs(tokenizer, spt, prompt, text, device, audio_data) | |
inputs = shifting_inputs(inputs, tokenizer) | |
input_ids_list.append(inputs) | |
# Pad batch inputs | |
input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer) | |
# Batch generation | |
print(f"Starting batch audio generation...") | |
start = input_ids.shape[1] - MAX_CHANNELS + 1 | |
# Move inputs to GPU | |
input_ids = input_ids.to(device) | |
attention_mask = attention_mask.to(device) | |
# Generate model outputs | |
outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
) | |
print(f"Original outputs shape: {outputs.shape}") | |
print(f"Start value: {start}") | |
print(f"Shape after slicing: {outputs[:, start:].shape}") | |
print(f"MAX_CHANNELS: {MAX_CHANNELS}") | |
print(f"Calculated seq_len: {outputs.shape[1] - MAX_CHANNELS + 1}") | |
# Process outputs | |
outputs = outputs[:, start:] | |
seq_len = outputs.shape[1] - MAX_CHANNELS + 1 | |
speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device) | |
# Adjust output format | |
for j in range(MAX_CHANNELS): | |
speech_ids[..., j] = outputs[:, j : seq_len + j, j] | |
if j == 0: | |
speech_ids[..., j] = speech_ids[..., j] - 151665 | |
# Find valid positions for each sample | |
li = find_max_valid_positions(speech_ids) | |
# Store audio result data | |
audio_results = [] | |
# Process batch sample results individually | |
for i in range(batch_size): | |
try: | |
# Extract valid speech tokens | |
end_idx = li[i] + 1 | |
if end_idx <= 0: | |
print(f"Sample {start_idx + i} has no valid speech tokens") | |
audio_results.append(None) | |
continue | |
this_speech_id = speech_ids[i, :end_idx] | |
print(f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}") | |
# Decode generated audio | |
with torch.no_grad(): | |
codes_list = [this_speech_id.permute(1, 0)] # Convert to SPT expected format | |
decode_result = spt.decode(codes_list, overlap_seconds=10) | |
audio_result = decode_result["syn_wav_list"][0].cpu().detach() | |
if audio_result.ndim == 1: # If 1D [samples] | |
audio_result = audio_result.unsqueeze(0) # Convert to 2D [1, samples] | |
# Save audio data instead of file path | |
audio_results.append({ | |
"audio_data": audio_result, | |
"sample_rate": spt.output_sample_rate, | |
"index": start_idx + i | |
}) | |
print(f"Audio generation completed: sample {start_idx + i}") | |
except Exception as e: | |
print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...") | |
import traceback | |
traceback.print_exc() | |
audio_results.append(None) | |
# Clean up GPU memory | |
torch.cuda.empty_cache() | |
# Return text data and audio data | |
return actual_texts_data, audio_results | |
except Exception as e: | |
print(f"Error during batch processing: {str(e)}") | |
raise | |