File size: 1,434 Bytes
ce16420
616f571
 
 
 
 
 
ce16420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from dataclasses import dataclass, field
import torch

@dataclass
class Cache:
    key_states: torch.Tensor
    value_states: torch.Tensor
    _supports_index_copy: bool = field(init=False) # For CUDA graph support

    def __post_init__(self):
        self._supports_index_copy = self._check_index_copy_support()

    def _check_index_copy_support(self) -> bool:
        """Verifies support for `index_copy_` on device."""
        try:
            device = self.key_states.device
            dummy = torch.tensor([0, 0], device=device)
            dummy.index_copy_(0, torch.tensor([0], device=device), torch.tensor([1], device=device))
            return True
        except NotImplementedError:
            return False

    def update(self, curr_pos_id: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> None:
        """
        Updates the cache based on device operator support.
        Args:
            curr_pos_id (torch.Tensor): Current position indices for decoding.
            k (torch.Tensor): The keys to update
            v (torch.Tensor): The values to update
        """
        if self._supports_index_copy: # CUDA/CPU
            self.key_states.index_copy_(2, curr_pos_id, k)
            self.value_states.index_copy_(2, curr_pos_id, v)
        else: # MPS
            self.key_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(k)
            self.value_states[:, :, curr_pos_id:curr_pos_id +1, ...].copy_(v)