Spaces:
Running
on
Zero
Running
on
Zero
File size: 18,567 Bytes
07f1f64 b4da283 07f1f64 b4da283 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 |
import asyncio
import base64
import torch
import numpy as np
from io import BytesIO
from dataclasses import dataclass
from typing import List, Optional, Union
from copy import deepcopy
from transformers import AutoTokenizer, AutoProcessor
from transformers.cache_utils import StaticCache
from transformers.generation.streamers import BaseStreamer
from transformers.generation.stopping_criteria import StoppingCriteria
from dataclasses import asdict
from loguru import logger
import threading
import librosa
from ..dataset.chatml_dataset import (
ChatMLSample,
ChatMLDatasetSample,
prepare_chatml_sample,
)
from ..model import HiggsAudioModel
from ..model.utils import revert_delay_pattern
from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
def normalize_chinese_punctuation(text):
"""
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
"""
# Mapping of Chinese punctuation to English punctuation
chinese_to_english_punct = {
",": ",", # comma
"。": ".", # period
":": ":", # colon
";": ";", # semicolon
"?": "?", # question mark
"!": "!", # exclamation mark
"(": "(", # left parenthesis
")": ")", # right parenthesis
"【": "[", # left square bracket
"】": "]", # right square bracket
"《": "<", # left angle quote
"》": ">", # right angle quote
"“": '"', # left double quotation
"”": '"', # right double quotation
"‘": "'", # left single quotation
"’": "'", # right single quotation
"、": ",", # enumeration comma
"—": "-", # em dash
"…": "...", # ellipsis
"·": ".", # middle dot
"「": '"', # left corner bracket
"」": '"', # right corner bracket
"『": '"', # left double corner bracket
"』": '"', # right double corner bracket
}
# Replace each Chinese punctuation with its English counterpart
for zh_punct, en_punct in chinese_to_english_punct.items():
text = text.replace(zh_punct, en_punct)
return text
@dataclass
class HiggsAudioStreamerDelta:
"""Represents a chunk of generated content, either text or audio tokens."""
text: Optional[str] = None
text_tokens: Optional[torch.Tensor] = None
audio_tokens: Optional[torch.Tensor] = None
finish_reason: Optional[str] = None
class AsyncHiggsAudioStreamer(BaseStreamer):
"""
Async streamer that handles both text and audio token generation from Higgs-Audio model.
Stores chunks in a queue to be consumed by downstream applications.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenizer used to decode text tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt tokens in generation.
timeout (`float`, *optional*):
The timeout for the queue. If `None`, the queue will block indefinitely.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```python
>>> from transformers import AutoTokenizer
>>> from threading import Thread
>>> import asyncio
>>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
>>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
>>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
>>> async def main():
... streamer = AsyncHiggsAudioStreamer(tokenizer)
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
... thread.start()
...
... async for delta in streamer:
... if delta.text is not None:
... print("Text:", delta.text)
... if delta.audio_tokens is not None:
... print("Audio tokens shape:", delta.audio_tokens.shape)
>>> asyncio.run(main())
```
"""
def __init__(
self,
tokenizer: "AutoTokenizer",
skip_prompt: bool = False,
timeout: Optional[float] = None,
audio_num_codebooks: int = 1,
**decode_kwargs,
):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.timeout = timeout
self.decode_kwargs = decode_kwargs
self.audio_num_codebooks = audio_num_codebooks
# Queue to store generated chunks
self.queue = asyncio.Queue()
self.stop_signal = None
# Get running event loop
self.loop = asyncio.get_running_loop()
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
# State tracking
self.next_tokens_are_prompt = True
def put(self, value: torch.Tensor):
"""
Receives tokens and processes them as either text or audio tokens.
For text tokens, decodes and caches them until complete words are formed.
For audio tokens, directly queues them.
"""
if value.shape[0] > 1 and not self.next_tokens_are_prompt:
# This is likely audio tokens (shape: [audio_num_codebooks])
assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
delta = HiggsAudioStreamerDelta(audio_tokens=value)
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
return
# Skip prompt tokens if configured
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Process as text tokens
if len(value.shape) > 1:
value = value[0]
text = self.tokenizer.decode(value, **self.decode_kwargs)
delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
def end(self):
"""Flushes any remaining text tokens and signals the end of generation."""
self.next_tokens_are_prompt = True
self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
def __aiter__(self):
return self
async def __anext__(self):
try:
if self.has_asyncio_timeout:
async with asyncio.timeout(self.timeout):
value = await self.queue.get()
else:
value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
except asyncio.TimeoutError:
raise TimeoutError()
else:
if value == self.stop_signal:
raise StopAsyncIteration()
else:
return value
class AsyncStoppingCriteria(StoppingCriteria):
"""
Stopping criteria that checks for stop signal from a threading event.
Args:
stop_signal (threading.Event): Event that will receive stop signals
"""
def __init__(self, stop_signal: threading.Event):
self.stop_signal = stop_signal
def __call__(self, input_ids, scores, **kwargs) -> bool:
if self.stop_signal.is_set():
logger.info(f"Stop signal received. Can be caused by client disconnection.")
return True
return False
@dataclass
class HiggsAudioResponse:
audio: Optional[np.ndarray] = None
generated_audio_tokens: Optional[np.ndarray] = None
sampling_rate: Optional[int] = None
generated_text: str = ""
generated_text_tokens: np.ndarray = np.array([])
usage: Optional[dict] = None
class HiggsAudioServeEngine:
def __init__(
self,
model_name_or_path: str,
audio_tokenizer_name_or_path: str,
tokenizer_name_or_path: Optional[str] = None,
device: str = "cuda",
torch_dtype: Union[torch.dtype, str] = "auto",
kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
):
"""
Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
Args:
model_name_or_path (str):
The name or path of the model to load.
audio_tokenizer_name_or_path (str):
The name or path of the audio tokenizer to load.
tokenizer_name_or_path (str):
The name or path of the tokenizer to load.
device (str):
The device to use for the model.
kv_cache_lengths (List[int]):
The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
torch_dtype (Union[torch.dtype, str]):
The dtype to use for the model.
"""
self.device = device
self.model_name_or_path = model_name_or_path
self.torch_dtype = torch_dtype
# Initialize model and tokenizer
self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
if tokenizer_name_or_path is None:
tokenizer_name_or_path = model_name_or_path
logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
logger.info(f"Initializing Higgs Audio Tokenizer")
self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
self.audio_num_codebooks = self.model.config.audio_num_codebooks
self.audio_codebook_size = self.model.config.audio_codebook_size
self.audio_tokenizer_tps = self.audio_tokenizer.tps
self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
# Set the audio special tokens
self.model.set_audio_special_tokens(self.tokenizer)
# Prepare KV caches for different lengths
cache_config = deepcopy(self.model.config.text_config)
cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
if self.model.config.audio_dual_ffn_layers:
cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
# A list of KV caches for different lengths
self.kv_caches = {
length: StaticCache(
config=cache_config,
max_batch_size=1,
max_cache_len=length,
device=self.model.device,
dtype=self.model.dtype,
)
for length in sorted(kv_cache_lengths)
}
if self.model.config.encode_whisper_embed:
logger.info(f"Loading whisper processor")
whisper_processor = AutoProcessor.from_pretrained(
"openai/whisper-large-v3-turbo",
trust_remote=True,
device=self.device,
)
else:
whisper_processor = None
# Reuse collator to prepare inference samples
self.collator = HiggsAudioSampleCollator(
whisper_processor=whisper_processor,
encode_whisper_embed=self.model.config.encode_whisper_embed,
audio_in_token_id=self.model.config.audio_in_token_idx,
audio_out_token_id=self.model.config.audio_out_token_idx,
audio_stream_bos_id=self.model.config.audio_stream_bos_id,
audio_stream_eos_id=self.model.config.audio_stream_eos_id,
pad_token_id=self.model.config.pad_token_id,
return_audio_in_tokens=False,
use_delay_pattern=self.model.config.use_delay_pattern,
audio_num_codebooks=self.model.config.audio_num_codebooks,
round_to=1,
)
# Lock to prevent multiple generations from happening at the same time
self.generate_lock = threading.Lock()
# Capture CUDA graphs for each KV cache length
if device == "cuda":
logger.info(f"Capturing CUDA graphs for each KV cache length")
self.model.capture_model(self.kv_caches.values())
def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
input_tokens, _, audio_contents, _ = prepare_chatml_sample(
chat_ml_sample,
self.tokenizer,
)
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
if force_audio_gen:
postfix += "<|audio_out_bos|>"
postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
input_tokens.extend(postfix)
# Configure the audio inputs
audio_ids_l = []
for audio_content in audio_contents:
if audio_content.audio_url not in ["placeholder", ""]:
raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
elif audio_content.raw_audio is not None:
raw_audio, _ = librosa.load(
BytesIO(base64.b64decode(audio_content.raw_audio)),
sr=self.audio_tokenizer.sampling_rate,
)
else:
raw_audio = None
if raw_audio is not None:
audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
audio_ids_l.append(audio_ids.squeeze(0).cpu())
if len(audio_ids_l) > 0:
audio_ids_start = torch.tensor(
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
dtype=torch.long,
device=self.device,
)[0:-1]
audio_ids_concat = torch.cat(audio_ids_l, dim=1)
else:
audio_ids_start = None
audio_ids_concat = None
sample = ChatMLDatasetSample(
input_ids=torch.LongTensor(input_tokens),
label_ids=None,
audio_ids_concat=audio_ids_concat,
audio_ids_start=audio_ids_start,
audio_waveforms_concat=None,
audio_waveforms_start=None,
audio_sample_rate=None,
audio_speaker_indices=None,
)
data = self.collator([sample])
inputs = asdict(data)
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.model.device)
return inputs
def _prepare_kv_caches(self):
for kv_cache in self.kv_caches.values():
kv_cache.reset()
def generate(
self,
chat_ml_sample: ChatMLSample,
max_new_tokens: int,
temperature: float = 0.7,
top_k: Optional[int] = None,
top_p: float = 0.95,
stop_strings: Optional[List[str]] = None,
force_audio_gen: bool = False,
ras_win_len: Optional[int] = None,
ras_win_max_num_repeat: int = 2,
):
"""
Generate audio from a chatml sample.
Args:
chat_ml_sample: A chatml sample.
max_new_tokens: The maximum number of new tokens to generate.
temperature: The temperature to use for the generation.
top_p: The top p to use for the generation.
Returns:
A dictionary with the following keys:
audio: The generated audio.
sampling_rate: The sampling rate of the generated audio.
"""
# Default stop strings
if stop_strings is None:
stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
with torch.no_grad(), self.generate_lock:
inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
self._prepare_kv_caches()
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
stop_strings=stop_strings,
tokenizer=self.tokenizer,
do_sample=False if temperature == 0.0 else True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
past_key_values_buckets=self.kv_caches,
ras_win_len=ras_win_len,
ras_win_max_num_repeat=ras_win_max_num_repeat,
)
if len(outputs[1]) > 0:
wv_list = []
for output_audio in outputs[1]:
vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
wv_list.append(wv_numpy)
wv_numpy = np.concatenate(wv_list)
else:
wv_numpy = None
# We only support one request at a time now
generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
generated_text = self.tokenizer.decode(generated_text_tokens)
generated_audio_tokens = outputs[1][0].cpu().numpy()
return HiggsAudioResponse(
audio=wv_numpy,
generated_audio_tokens=generated_audio_tokens,
sampling_rate=self.audio_tokenizer.sampling_rate,
generated_text=generated_text,
generated_text_tokens=generated_text_tokens,
usage={
"prompt_tokens": prompt_token_ids.shape[0],
"completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
"total_tokens": (
prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
),
"cached_tokens": 0,
},
)
def text_normalize(self, text: str) -> str:
"""
Normalize the text.
"""
# Perform some basic normalization
text = normalize_chinese_punctuation(text)
# Handle parentheses
text = text.replace("(", " ")
text = text.replace(")", " ")
return text
|