higgs_audio_v2 / higgs_audio /model /cuda_graph_runner.py
zachzzc's picture
Upload tts playground and serving engine
07f1f64
import torch
import torch.nn as nn
from typing import Optional, List, Dict, Tuple, Union
import gc
from transformers.cache_utils import Cache
_NUM_WARMUP_ITERS = 2
class CUDAGraphRunner(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture(
self,
hidden_states: torch.Tensor,
causal_mask: torch.Tensor,
position_ids: torch.Tensor,
audio_discrete_codes_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Union[Cache, List[torch.FloatTensor]],
use_cache: bool,
audio_attention_mask: torch.Tensor,
fast_forward_attention_mask: torch.Tensor,
output_attentions: bool,
output_hidden_states: bool,
is_decoding_audio_token: Optional[bool] = None,
is_using_cuda_graph: Optional[bool] = False,
stream: torch.cuda.Stream = None,
memory_pool: Optional[Tuple[int, int]] = None,
):
assert self._graph is None
# Run warmup iterations
for _ in range(_NUM_WARMUP_ITERS):
self.model(
hidden_states=hidden_states,
causal_mask=causal_mask,
position_ids=position_ids,
audio_discrete_codes_mask=audio_discrete_codes_mask,
cache_position=cache_position,
past_key_values=past_key_values,
use_cache=use_cache,
audio_attention_mask=audio_attention_mask,
fast_forward_attention_mask=fast_forward_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
is_decoding_audio_token=is_decoding_audio_token,
is_using_cuda_graph=is_using_cuda_graph,
)
torch.cuda.synchronize()
# Capture the graph
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
out_hidden_states, all_hidden_states, all_self_attns = self.model(
hidden_states=hidden_states,
causal_mask=causal_mask,
position_ids=position_ids,
audio_discrete_codes_mask=audio_discrete_codes_mask,
cache_position=cache_position,
past_key_values=past_key_values,
use_cache=use_cache,
audio_attention_mask=audio_attention_mask,
fast_forward_attention_mask=fast_forward_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
is_decoding_audio_token=is_decoding_audio_token,
is_using_cuda_graph=is_using_cuda_graph,
)
# hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
# del outputs
gc.collect()
torch.cuda.synchronize()
# Save input and output buffers
self.input_buffers = {
"hidden_states": hidden_states,
"causal_mask": causal_mask,
"position_ids": position_ids,
"audio_discrete_codes_mask": audio_discrete_codes_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"audio_attention_mask": audio_attention_mask,
"fast_forward_attention_mask": fast_forward_attention_mask,
}
self.output_buffers = {
"hidden_states": out_hidden_states,
"all_hidden_states": all_hidden_states,
"all_self_attns": all_self_attns,
}
def forward(
self,
hidden_states: torch.Tensor,
causal_mask: torch.Tensor,
position_ids: torch.Tensor,
audio_discrete_codes_mask: torch.Tensor,
cache_position: torch.Tensor,
audio_attention_mask: torch.Tensor,
fast_forward_attention_mask: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# Copy input tensors to buffers
self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
# Run the captured graph
self.graph.replay()
return self.output_buffers["hidden_states"], None, None