Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,106 Bytes
07f1f64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
from typing import Optional, List, Dict, Tuple, Union
import gc
from transformers.cache_utils import Cache
_NUM_WARMUP_ITERS = 2
class CUDAGraphRunner(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture(
self,
hidden_states: torch.Tensor,
causal_mask: torch.Tensor,
position_ids: torch.Tensor,
audio_discrete_codes_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Union[Cache, List[torch.FloatTensor]],
use_cache: bool,
audio_attention_mask: torch.Tensor,
fast_forward_attention_mask: torch.Tensor,
output_attentions: bool,
output_hidden_states: bool,
is_decoding_audio_token: Optional[bool] = None,
is_using_cuda_graph: Optional[bool] = False,
stream: torch.cuda.Stream = None,
memory_pool: Optional[Tuple[int, int]] = None,
):
assert self._graph is None
# Run warmup iterations
for _ in range(_NUM_WARMUP_ITERS):
self.model(
hidden_states=hidden_states,
causal_mask=causal_mask,
position_ids=position_ids,
audio_discrete_codes_mask=audio_discrete_codes_mask,
cache_position=cache_position,
past_key_values=past_key_values,
use_cache=use_cache,
audio_attention_mask=audio_attention_mask,
fast_forward_attention_mask=fast_forward_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
is_decoding_audio_token=is_decoding_audio_token,
is_using_cuda_graph=is_using_cuda_graph,
)
torch.cuda.synchronize()
# Capture the graph
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
out_hidden_states, all_hidden_states, all_self_attns = self.model(
hidden_states=hidden_states,
causal_mask=causal_mask,
position_ids=position_ids,
audio_discrete_codes_mask=audio_discrete_codes_mask,
cache_position=cache_position,
past_key_values=past_key_values,
use_cache=use_cache,
audio_attention_mask=audio_attention_mask,
fast_forward_attention_mask=fast_forward_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
is_decoding_audio_token=is_decoding_audio_token,
is_using_cuda_graph=is_using_cuda_graph,
)
# hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
# del outputs
gc.collect()
torch.cuda.synchronize()
# Save input and output buffers
self.input_buffers = {
"hidden_states": hidden_states,
"causal_mask": causal_mask,
"position_ids": position_ids,
"audio_discrete_codes_mask": audio_discrete_codes_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"audio_attention_mask": audio_attention_mask,
"fast_forward_attention_mask": fast_forward_attention_mask,
}
self.output_buffers = {
"hidden_states": out_hidden_states,
"all_hidden_states": all_hidden_states,
"all_self_attns": all_self_attns,
}
def forward(
self,
hidden_states: torch.Tensor,
causal_mask: torch.Tensor,
position_ids: torch.Tensor,
audio_discrete_codes_mask: torch.Tensor,
cache_position: torch.Tensor,
audio_attention_mask: torch.Tensor,
fast_forward_attention_mask: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# Copy input tensors to buffers
self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
# Run the captured graph
self.graph.replay()
return self.output_buffers["hidden_states"], None, None
|