higgs_audio_v2 / higgs_audio /model /cuda_graph_runner.py
zachzzc's picture
Upload tts playground and serving engine
07f1f64
raw
history blame
5.11 kB
import torch
import torch.nn as nn
from typing import Optional, List, Dict, Tuple, Union
import gc
from transformers.cache_utils import Cache
_NUM_WARMUP_ITERS = 2
class CUDAGraphRunner(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture(
self,
hidden_states: torch.Tensor,
causal_mask: torch.Tensor,
position_ids: torch.Tensor,
audio_discrete_codes_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Union[Cache, List[torch.FloatTensor]],
use_cache: bool,
audio_attention_mask: torch.Tensor,
fast_forward_attention_mask: torch.Tensor,
output_attentions: bool,
output_hidden_states: bool,
is_decoding_audio_token: Optional[bool] = None,
is_using_cuda_graph: Optional[bool] = False,
stream: torch.cuda.Stream = None,
memory_pool: Optional[Tuple[int, int]] = None,
):
assert self._graph is None
# Run warmup iterations
for _ in range(_NUM_WARMUP_ITERS):
self.model(
hidden_states=hidden_states,
causal_mask=causal_mask,
position_ids=position_ids,
audio_discrete_codes_mask=audio_discrete_codes_mask,
cache_position=cache_position,
past_key_values=past_key_values,
use_cache=use_cache,
audio_attention_mask=audio_attention_mask,
fast_forward_attention_mask=fast_forward_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
is_decoding_audio_token=is_decoding_audio_token,
is_using_cuda_graph=is_using_cuda_graph,
)
torch.cuda.synchronize()
# Capture the graph
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
out_hidden_states, all_hidden_states, all_self_attns = self.model(
hidden_states=hidden_states,
causal_mask=causal_mask,
position_ids=position_ids,
audio_discrete_codes_mask=audio_discrete_codes_mask,
cache_position=cache_position,
past_key_values=past_key_values,
use_cache=use_cache,
audio_attention_mask=audio_attention_mask,
fast_forward_attention_mask=fast_forward_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
is_decoding_audio_token=is_decoding_audio_token,
is_using_cuda_graph=is_using_cuda_graph,
)
# hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
# del outputs
gc.collect()
torch.cuda.synchronize()
# Save input and output buffers
self.input_buffers = {
"hidden_states": hidden_states,
"causal_mask": causal_mask,
"position_ids": position_ids,
"audio_discrete_codes_mask": audio_discrete_codes_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"audio_attention_mask": audio_attention_mask,
"fast_forward_attention_mask": fast_forward_attention_mask,
}
self.output_buffers = {
"hidden_states": out_hidden_states,
"all_hidden_states": all_hidden_states,
"all_self_attns": all_self_attns,
}
def forward(
self,
hidden_states: torch.Tensor,
causal_mask: torch.Tensor,
position_ids: torch.Tensor,
audio_discrete_codes_mask: torch.Tensor,
cache_position: torch.Tensor,
audio_attention_mask: torch.Tensor,
fast_forward_attention_mask: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# Copy input tensors to buffers
self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
# Run the captured graph
self.graph.replay()
return self.output_buffers["hidden_states"], None, None