zachzzc's picture
Add more voice clone voices; Update model names; Update playground
b4da283
import asyncio
import base64
import torch
import numpy as np
from io import BytesIO
from dataclasses import dataclass
from typing import List, Optional, Union
from copy import deepcopy
from transformers import AutoTokenizer, AutoProcessor
from transformers.cache_utils import StaticCache
from transformers.generation.streamers import BaseStreamer
from transformers.generation.stopping_criteria import StoppingCriteria
from dataclasses import asdict
from loguru import logger
import threading
import librosa
from ..dataset.chatml_dataset import (
ChatMLSample,
ChatMLDatasetSample,
prepare_chatml_sample,
)
from ..model import HiggsAudioModel
from ..model.utils import revert_delay_pattern
from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
def normalize_chinese_punctuation(text):
"""
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
"""
# Mapping of Chinese punctuation to English punctuation
chinese_to_english_punct = {
",": ",", # comma
"。": ".", # period
":": ":", # colon
";": ";", # semicolon
"?": "?", # question mark
"!": "!", # exclamation mark
"(": "(", # left parenthesis
")": ")", # right parenthesis
"【": "[", # left square bracket
"】": "]", # right square bracket
"《": "<", # left angle quote
"》": ">", # right angle quote
"“": '"', # left double quotation
"”": '"', # right double quotation
"‘": "'", # left single quotation
"’": "'", # right single quotation
"、": ",", # enumeration comma
"—": "-", # em dash
"…": "...", # ellipsis
"·": ".", # middle dot
"「": '"', # left corner bracket
"」": '"', # right corner bracket
"『": '"', # left double corner bracket
"』": '"', # right double corner bracket
}
# Replace each Chinese punctuation with its English counterpart
for zh_punct, en_punct in chinese_to_english_punct.items():
text = text.replace(zh_punct, en_punct)
return text
@dataclass
class HiggsAudioStreamerDelta:
"""Represents a chunk of generated content, either text or audio tokens."""
text: Optional[str] = None
text_tokens: Optional[torch.Tensor] = None
audio_tokens: Optional[torch.Tensor] = None
finish_reason: Optional[str] = None
class AsyncHiggsAudioStreamer(BaseStreamer):
"""
Async streamer that handles both text and audio token generation from Higgs-Audio model.
Stores chunks in a queue to be consumed by downstream applications.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenizer used to decode text tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt tokens in generation.
timeout (`float`, *optional*):
The timeout for the queue. If `None`, the queue will block indefinitely.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```python
>>> from transformers import AutoTokenizer
>>> from threading import Thread
>>> import asyncio
>>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
>>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
>>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
>>> async def main():
... streamer = AsyncHiggsAudioStreamer(tokenizer)
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
... thread.start()
...
... async for delta in streamer:
... if delta.text is not None:
... print("Text:", delta.text)
... if delta.audio_tokens is not None:
... print("Audio tokens shape:", delta.audio_tokens.shape)
>>> asyncio.run(main())
```
"""
def __init__(
self,
tokenizer: "AutoTokenizer",
skip_prompt: bool = False,
timeout: Optional[float] = None,
audio_num_codebooks: int = 1,
**decode_kwargs,
):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.timeout = timeout
self.decode_kwargs = decode_kwargs
self.audio_num_codebooks = audio_num_codebooks
# Queue to store generated chunks
self.queue = asyncio.Queue()
self.stop_signal = None
# Get running event loop
self.loop = asyncio.get_running_loop()
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
# State tracking
self.next_tokens_are_prompt = True
def put(self, value: torch.Tensor):
"""
Receives tokens and processes them as either text or audio tokens.
For text tokens, decodes and caches them until complete words are formed.
For audio tokens, directly queues them.
"""
if value.shape[0] > 1 and not self.next_tokens_are_prompt:
# This is likely audio tokens (shape: [audio_num_codebooks])
assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
delta = HiggsAudioStreamerDelta(audio_tokens=value)
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
return
# Skip prompt tokens if configured
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Process as text tokens
if len(value.shape) > 1:
value = value[0]
text = self.tokenizer.decode(value, **self.decode_kwargs)
delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
def end(self):
"""Flushes any remaining text tokens and signals the end of generation."""
self.next_tokens_are_prompt = True
self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
def __aiter__(self):
return self
async def __anext__(self):
try:
if self.has_asyncio_timeout:
async with asyncio.timeout(self.timeout):
value = await self.queue.get()
else:
value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
except asyncio.TimeoutError:
raise TimeoutError()
else:
if value == self.stop_signal:
raise StopAsyncIteration()
else:
return value
class AsyncStoppingCriteria(StoppingCriteria):
"""
Stopping criteria that checks for stop signal from a threading event.
Args:
stop_signal (threading.Event): Event that will receive stop signals
"""
def __init__(self, stop_signal: threading.Event):
self.stop_signal = stop_signal
def __call__(self, input_ids, scores, **kwargs) -> bool:
if self.stop_signal.is_set():
logger.info(f"Stop signal received. Can be caused by client disconnection.")
return True
return False
@dataclass
class HiggsAudioResponse:
audio: Optional[np.ndarray] = None
generated_audio_tokens: Optional[np.ndarray] = None
sampling_rate: Optional[int] = None
generated_text: str = ""
generated_text_tokens: np.ndarray = np.array([])
usage: Optional[dict] = None
class HiggsAudioServeEngine:
def __init__(
self,
model_name_or_path: str,
audio_tokenizer_name_or_path: str,
tokenizer_name_or_path: Optional[str] = None,
device: str = "cuda",
torch_dtype: Union[torch.dtype, str] = "auto",
kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
):
"""
Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
Args:
model_name_or_path (str):
The name or path of the model to load.
audio_tokenizer_name_or_path (str):
The name or path of the audio tokenizer to load.
tokenizer_name_or_path (str):
The name or path of the tokenizer to load.
device (str):
The device to use for the model.
kv_cache_lengths (List[int]):
The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
torch_dtype (Union[torch.dtype, str]):
The dtype to use for the model.
"""
self.device = device
self.model_name_or_path = model_name_or_path
self.torch_dtype = torch_dtype
# Initialize model and tokenizer
self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
if tokenizer_name_or_path is None:
tokenizer_name_or_path = model_name_or_path
logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
logger.info(f"Initializing Higgs Audio Tokenizer")
self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
self.audio_num_codebooks = self.model.config.audio_num_codebooks
self.audio_codebook_size = self.model.config.audio_codebook_size
self.audio_tokenizer_tps = self.audio_tokenizer.tps
self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
# Set the audio special tokens
self.model.set_audio_special_tokens(self.tokenizer)
# Prepare KV caches for different lengths
cache_config = deepcopy(self.model.config.text_config)
cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
if self.model.config.audio_dual_ffn_layers:
cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
# A list of KV caches for different lengths
self.kv_caches = {
length: StaticCache(
config=cache_config,
max_batch_size=1,
max_cache_len=length,
device=self.model.device,
dtype=self.model.dtype,
)
for length in sorted(kv_cache_lengths)
}
if self.model.config.encode_whisper_embed:
logger.info(f"Loading whisper processor")
whisper_processor = AutoProcessor.from_pretrained(
"openai/whisper-large-v3-turbo",
trust_remote=True,
device=self.device,
)
else:
whisper_processor = None
# Reuse collator to prepare inference samples
self.collator = HiggsAudioSampleCollator(
whisper_processor=whisper_processor,
encode_whisper_embed=self.model.config.encode_whisper_embed,
audio_in_token_id=self.model.config.audio_in_token_idx,
audio_out_token_id=self.model.config.audio_out_token_idx,
audio_stream_bos_id=self.model.config.audio_stream_bos_id,
audio_stream_eos_id=self.model.config.audio_stream_eos_id,
pad_token_id=self.model.config.pad_token_id,
return_audio_in_tokens=False,
use_delay_pattern=self.model.config.use_delay_pattern,
audio_num_codebooks=self.model.config.audio_num_codebooks,
round_to=1,
)
# Lock to prevent multiple generations from happening at the same time
self.generate_lock = threading.Lock()
# Capture CUDA graphs for each KV cache length
if device == "cuda":
logger.info(f"Capturing CUDA graphs for each KV cache length")
self.model.capture_model(self.kv_caches.values())
def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
input_tokens, _, audio_contents, _ = prepare_chatml_sample(
chat_ml_sample,
self.tokenizer,
)
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
if force_audio_gen:
postfix += "<|audio_out_bos|>"
postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
input_tokens.extend(postfix)
# Configure the audio inputs
audio_ids_l = []
for audio_content in audio_contents:
if audio_content.audio_url not in ["placeholder", ""]:
raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
elif audio_content.raw_audio is not None:
raw_audio, _ = librosa.load(
BytesIO(base64.b64decode(audio_content.raw_audio)),
sr=self.audio_tokenizer.sampling_rate,
)
else:
raw_audio = None
if raw_audio is not None:
audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
audio_ids_l.append(audio_ids.squeeze(0).cpu())
if len(audio_ids_l) > 0:
audio_ids_start = torch.tensor(
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
dtype=torch.long,
device=self.device,
)[0:-1]
audio_ids_concat = torch.cat(audio_ids_l, dim=1)
else:
audio_ids_start = None
audio_ids_concat = None
sample = ChatMLDatasetSample(
input_ids=torch.LongTensor(input_tokens),
label_ids=None,
audio_ids_concat=audio_ids_concat,
audio_ids_start=audio_ids_start,
audio_waveforms_concat=None,
audio_waveforms_start=None,
audio_sample_rate=None,
audio_speaker_indices=None,
)
data = self.collator([sample])
inputs = asdict(data)
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.model.device)
return inputs
def _prepare_kv_caches(self):
for kv_cache in self.kv_caches.values():
kv_cache.reset()
def generate(
self,
chat_ml_sample: ChatMLSample,
max_new_tokens: int,
temperature: float = 0.7,
top_k: Optional[int] = None,
top_p: float = 0.95,
stop_strings: Optional[List[str]] = None,
force_audio_gen: bool = False,
ras_win_len: Optional[int] = None,
ras_win_max_num_repeat: int = 2,
):
"""
Generate audio from a chatml sample.
Args:
chat_ml_sample: A chatml sample.
max_new_tokens: The maximum number of new tokens to generate.
temperature: The temperature to use for the generation.
top_p: The top p to use for the generation.
Returns:
A dictionary with the following keys:
audio: The generated audio.
sampling_rate: The sampling rate of the generated audio.
"""
# Default stop strings
if stop_strings is None:
stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
with torch.no_grad(), self.generate_lock:
inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
self._prepare_kv_caches()
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
stop_strings=stop_strings,
tokenizer=self.tokenizer,
do_sample=False if temperature == 0.0 else True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
past_key_values_buckets=self.kv_caches,
ras_win_len=ras_win_len,
ras_win_max_num_repeat=ras_win_max_num_repeat,
)
if len(outputs[1]) > 0:
wv_list = []
for output_audio in outputs[1]:
vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
wv_list.append(wv_numpy)
wv_numpy = np.concatenate(wv_list)
else:
wv_numpy = None
# We only support one request at a time now
generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
generated_text = self.tokenizer.decode(generated_text_tokens)
generated_audio_tokens = outputs[1][0].cpu().numpy()
return HiggsAudioResponse(
audio=wv_numpy,
generated_audio_tokens=generated_audio_tokens,
sampling_rate=self.audio_tokenizer.sampling_rate,
generated_text=generated_text,
generated_text_tokens=generated_text_tokens,
usage={
"prompt_tokens": prompt_token_ids.shape[0],
"completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
"total_tokens": (
prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
),
"cached_tokens": 0,
},
)
def text_normalize(self, text: str) -> str:
"""
Normalize the text.
"""
# Perform some basic normalization
text = normalize_chinese_punctuation(text)
# Handle parentheses
text = text.replace("(", " ")
text = text.replace(")", " ")
return text