Spaces:
Running
on
Zero
Running
on
Zero
import uuid | |
import base64 | |
import re | |
import regex | |
from typing import AsyncGenerator, Union | |
import io | |
from pydub import AudioSegment | |
import torch | |
import numpy as np | |
from functools import lru_cache | |
from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer | |
def random_uuid() -> str: | |
return str(uuid.uuid4().hex) | |
async def async_generator_wrap(first_element, gen: AsyncGenerator): | |
"""Wrap an async generator with the first element.""" | |
yield first_element | |
async for item in gen: | |
yield item | |
def encode_base64_content_from_file(file_path: str) -> str: | |
"""Encode a content from a local file to base64 format.""" | |
# Read the MP3 file as binary and encode it directly to Base64 | |
with open(file_path, "rb") as audio_file: | |
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") | |
return audio_base64 | |
def pcm16_to_target_format( | |
np_audio: np.ndarray, | |
sample_rate: int, | |
bit_depth: int, | |
channels: int, | |
format: str, | |
target_rate: int, | |
): | |
wav_audio = AudioSegment( | |
np_audio.tobytes(), | |
frame_rate=sample_rate, | |
sample_width=bit_depth // 8, | |
channels=channels, | |
) | |
if target_rate is not None and target_rate != sample_rate: | |
wav_audio = wav_audio.set_frame_rate(target_rate) | |
# Convert WAV to MP3 | |
target_io = io.BytesIO() | |
wav_audio.export(target_io, format=format) | |
target_io.seek(0) | |
return target_io | |
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+") | |
def contains_chinese(text: str): | |
return bool(chinese_char_pattern.search(text)) | |
# remove blank between chinese character | |
def replace_blank(text: str): | |
out_str = [] | |
for i, c in enumerate(text): | |
if c == " ": | |
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "): | |
out_str.append(c) | |
else: | |
out_str.append(c) | |
return "".join(out_str) | |
def replace_corner_mark(text: str): | |
text = text.replace("²", "平方") | |
text = text.replace("³", "立方") | |
return text | |
# remove meaningless symbol | |
def remove_bracket(text: str): | |
text = text.replace("(", "").replace(")", "") | |
text = text.replace("【", "").replace("】", "") | |
text = text.replace("`", "").replace("`", "") | |
text = text.replace("——", " ") | |
return text | |
# split paragrah logic: | |
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len | |
# 2. cal sentence len according to lang | |
# 3. split sentence according to puncatation | |
def split_paragraph( | |
text: str, | |
tokenize, | |
lang="zh", | |
token_max_n=80, | |
token_min_n=60, | |
merge_len=20, | |
comma_split=False, | |
): | |
def calc_utt_length(_text: str): | |
if lang == "zh": | |
return len(_text) | |
else: | |
return len(tokenize(_text)) | |
def should_merge(_text: str): | |
if lang == "zh": | |
return len(_text) < merge_len | |
else: | |
return len(tokenize(_text)) < merge_len | |
if lang == "zh": | |
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"] | |
else: | |
pounc = [".", "?", "!", ";", ":"] | |
if comma_split: | |
pounc.extend([",", ","]) | |
if text[-1] not in pounc: | |
if lang == "zh": | |
text += "。" | |
else: | |
text += "." | |
st = 0 | |
utts = [] | |
for i, c in enumerate(text): | |
if c in pounc: | |
if len(text[st:i]) > 0: | |
utts.append(text[st:i] + c) | |
if i + 1 < len(text) and text[i + 1] in ['"', "”"]: | |
tmp = utts.pop(-1) | |
utts.append(tmp + text[i + 1]) | |
st = i + 2 | |
else: | |
st = i + 1 | |
final_utts = [] | |
cur_utt = "" | |
for utt in utts: | |
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: | |
final_utts.append(cur_utt) | |
cur_utt = "" | |
cur_utt = cur_utt + utt | |
if len(cur_utt) > 0: | |
if should_merge(cur_utt) and len(final_utts) != 0: | |
final_utts[-1] = final_utts[-1] + cur_utt | |
else: | |
final_utts.append(cur_utt) | |
return final_utts | |
def is_only_punctuation(text: str): | |
# Regular expression: Match strings that consist only of punctuation marks or are empty. | |
punctuation_pattern = r"^[\p{P}\p{S}]*$" | |
return bool(regex.fullmatch(punctuation_pattern, text)) | |
# spell Arabic numerals | |
def spell_out_number(text: str, inflect_parser): | |
new_text = [] | |
st = None | |
for i, c in enumerate(text): | |
if not c.isdigit(): | |
if st is not None: | |
num_str = inflect_parser.number_to_words(text[st:i]) | |
new_text.append(num_str) | |
st = None | |
new_text.append(c) | |
else: | |
if st is None: | |
st = i | |
if st is not None and st < len(text): | |
num_str = inflect_parser.number_to_words(text[st:]) | |
new_text.append(num_str) | |
return "".join(new_text) | |
def remove_emoji(text: str): | |
# Pattern to match emojis and their modifiers | |
# - Standard emoji range | |
# - Zero-width joiners (U+200D) | |
# - Variation selectors (U+FE0F, U+FE0E) | |
# - Skin tone modifiers (U+1F3FB to U+1F3FF) | |
emoji_pattern = re.compile( | |
r"[" | |
r"\U00010000-\U0010FFFF" # Standard emoji range | |
r"\u200D" # Zero-width joiner | |
r"\uFE0F\uFE0E" # Variation selectors | |
r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers | |
r"]+", | |
flags=re.UNICODE, | |
) | |
return emoji_pattern.sub(r"", text) | |
def remove_repeated_punctuations(text, punctuations): | |
if len(punctuations) == 0: | |
return text | |
pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations | |
return re.sub(rf"({pattern})\1+", r"\1", text) | |
def full_to_half_width(text: str) -> str: | |
"""Convert full-width punctuation to half-width in a given string.""" | |
full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~" | |
half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | |
trans_table = str.maketrans(full_width, half_width) | |
return text.translate(trans_table) | |
def split_interleaved_delayed_audios( | |
audio_data: Union[list[list[int]], torch.Tensor], | |
audio_tokenizer: HiggsAudioTokenizer, | |
audio_stream_eos_id: int, | |
) -> list[tuple[list[list[int]], torch.Tensor]]: | |
separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks | |
# Convert separator to numpy array if audio_data is numpy array | |
if isinstance(audio_data, torch.Tensor): | |
audio_data = audio_data.transpose(1, 0) | |
separator = torch.tensor(separator) | |
# Find the indices where the rows equal the separator | |
split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0] | |
start = 0 | |
groups = [] | |
for idx in split_indices: | |
groups.append(audio_data[start:idx].transpose(1, 0)) | |
start = idx + 1 | |
if start < len(audio_data): | |
groups.append(audio_data[start:].transpose(1, 0)) | |
else: | |
groups = [] | |
current = [] | |
for row in audio_data: | |
current.append(row) | |
if row == separator: | |
groups.append(current) | |
current = [] | |
# Don't forget the last group if there's no trailing separator | |
if current: | |
groups.append(current) | |
return groups | |