File size: 5,106 Bytes
07f1f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn as nn
from typing import Optional, List, Dict, Tuple, Union
import gc

from transformers.cache_utils import Cache


_NUM_WARMUP_ITERS = 2


class CUDAGraphRunner(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

    def capture(
        self,
        hidden_states: torch.Tensor,
        causal_mask: torch.Tensor,
        position_ids: torch.Tensor,
        audio_discrete_codes_mask: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Union[Cache, List[torch.FloatTensor]],
        use_cache: bool,
        audio_attention_mask: torch.Tensor,
        fast_forward_attention_mask: torch.Tensor,
        output_attentions: bool,
        output_hidden_states: bool,
        is_decoding_audio_token: Optional[bool] = None,
        is_using_cuda_graph: Optional[bool] = False,
        stream: torch.cuda.Stream = None,
        memory_pool: Optional[Tuple[int, int]] = None,
    ):
        assert self._graph is None
        # Run warmup iterations
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
                hidden_states=hidden_states,
                causal_mask=causal_mask,
                position_ids=position_ids,
                audio_discrete_codes_mask=audio_discrete_codes_mask,
                cache_position=cache_position,
                past_key_values=past_key_values,
                use_cache=use_cache,
                audio_attention_mask=audio_attention_mask,
                fast_forward_attention_mask=fast_forward_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                is_decoding_audio_token=is_decoding_audio_token,
                is_using_cuda_graph=is_using_cuda_graph,
            )

        torch.cuda.synchronize()

        # Capture the graph
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
            out_hidden_states, all_hidden_states, all_self_attns = self.model(
                hidden_states=hidden_states,
                causal_mask=causal_mask,
                position_ids=position_ids,
                audio_discrete_codes_mask=audio_discrete_codes_mask,
                cache_position=cache_position,
                past_key_values=past_key_values,
                use_cache=use_cache,
                audio_attention_mask=audio_attention_mask,
                fast_forward_attention_mask=fast_forward_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                is_decoding_audio_token=is_decoding_audio_token,
                is_using_cuda_graph=is_using_cuda_graph,
            )
            # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
            # del outputs
            gc.collect()
        torch.cuda.synchronize()

        # Save input and output buffers
        self.input_buffers = {
            "hidden_states": hidden_states,
            "causal_mask": causal_mask,
            "position_ids": position_ids,
            "audio_discrete_codes_mask": audio_discrete_codes_mask,
            "cache_position": cache_position,
            "past_key_values": past_key_values,
            "audio_attention_mask": audio_attention_mask,
            "fast_forward_attention_mask": fast_forward_attention_mask,
        }
        self.output_buffers = {
            "hidden_states": out_hidden_states,
            "all_hidden_states": all_hidden_states,
            "all_self_attns": all_self_attns,
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        causal_mask: torch.Tensor,
        position_ids: torch.Tensor,
        audio_discrete_codes_mask: torch.Tensor,
        cache_position: torch.Tensor,
        audio_attention_mask: torch.Tensor,
        fast_forward_attention_mask: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        # Copy input tensors to buffers
        self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
        self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
        self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
        self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
        self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
        self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
        self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)

        # Run the captured graph
        self.graph.replay()

        return self.output_buffers["hidden_states"], None, None