| | from transformers.cache_utils import Cache |
| | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | from transformers.utils import logging |
| | from transformers.configuration_utils import PretrainedConfig |
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class HybridCache(Cache): |
| | """ |
| | Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention |
| | and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention |
| | and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. |
| | |
| | Parameters: |
| | config (`PretrainedConfig): |
| | The configuration file defining the shape-related attributes required to initialize the static cache. |
| | batch_size (`int`): |
| | The batch size with which the model will be used. Note that a new instance must be instantiated if a |
| | smaller batch size is used. |
| | max_cache_len (`int`): |
| | The maximum sequence length with which the model will be used. |
| | device (`torch.device` or `str`, *optional*): |
| | The device on which the cache should be initialized. If you're using more than 1 computation device, you |
| | should pass the `layer_device_map` argument instead. |
| | dtype (torch.dtype, *optional*, defaults to `torch.float32`): |
| | The default `dtype` to use when initializing the layer. |
| | layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): |
| | Mapping between the layers and its device. This is required when you are manually initializing the cache |
| | and the model is splitted between differents gpus. You can know which layers mapped to which device by |
| | checking the associated device_map: `model.hf_device_map`. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache |
| | |
| | >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") |
| | >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") |
| | |
| | >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") |
| | |
| | >>> # Prepare a cache class and pass it to model's forward |
| | >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate |
| | >>> max_generated_length = inputs.input_ids.shape[1] + 10 |
| | >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) |
| | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
| | >>> outputs.past_key_values # access cache filled with key/values from generation |
| | HybridCache() |
| | ``` |
| | """ |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | def __init__( |
| | self, |
| | config: PretrainedConfig, |
| | batch_size: int = None, |
| | max_cache_len: int = None, |
| | device: Union[torch.device, str] = None, |
| | dtype: torch.dtype = torch.float32, |
| | max_batch_size: Optional[int] = None, |
| | layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, |
| | ) -> None: |
| | super().__init__() |
| | if batch_size is not None: |
| | logger.warning_once( |
| | f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " |
| | "v4.49. Use the more precisely named 'max_batch_size' argument instead." |
| | ) |
| | if not hasattr(config, "sliding_window") or config.sliding_window is None: |
| | raise ValueError( |
| | "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " |
| | "sliding window attention, please check if there is a `sliding_window` field in the model " |
| | "config and it's not set to None." |
| | ) |
| | self.max_cache_len = max_cache_len |
| | self.max_batch_size = batch_size or max_batch_size |
| | |
| | self.head_dim = ( |
| | config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
| | ) |
| |
|
| | self.dtype = dtype |
| | self.num_key_value_heads = ( |
| | config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads |
| | ) |
| |
|
| | layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 |
| | self.is_sliding = torch.tensor( |
| | [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool |
| | ) |
| | self.key_cache: List[torch.Tensor] = [] |
| | self.value_cache: List[torch.Tensor] = [] |
| | self.chunk_cache = {} |
| | global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) |
| | sliding_cache_shape = ( |
| | self.max_batch_size, |
| | self.num_key_value_heads, |
| | min(config.sliding_window, max_cache_len), |
| | self.head_dim, |
| | ) |
| | device = torch.device(device) if device is not None else None |
| | for i in range(config.num_hidden_layers): |
| | if layer_device_map is not None: |
| | layer_device = layer_device_map[i] |
| | else: |
| | layer_device = device |
| | |
| | |
| | cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape |
| | new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) |
| | new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) |
| | torch._dynamo.mark_static_address(new_layer_key_cache) |
| | torch._dynamo.mark_static_address(new_layer_value_cache) |
| | self.key_cache.append(new_layer_key_cache) |
| | self.value_cache.append(new_layer_value_cache) |
| |
|
| | def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): |
| | if cache_position.shape[0] > max_cache_len: |
| | k_out = key_states[:, :, -max_cache_len:, :] |
| | v_out = value_states[:, :, -max_cache_len:, :] |
| | |
| | self.key_cache[layer_idx] += k_out |
| | self.value_cache[layer_idx] += v_out |
| | |
| | |
| | return key_states, value_states |
| |
|
| | slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) |
| | cache_position = cache_position.clamp(0, max_cache_len - 1) |
| | to_shift = cache_position >= max_cache_len - 1 |
| | indices = (slicing + to_shift[-1].int() - 1) % max_cache_len |
| | k_out = k_out[:, :, indices] |
| | v_out = v_out[:, :, indices] |
| |
|
| | k_out[:, :, cache_position] = key_states |
| | v_out[:, :, cache_position] = value_states |
| | |
| | self.key_cache[layer_idx].zero_() |
| | self.value_cache[layer_idx].zero_() |
| |
|
| | self.key_cache[layer_idx] += k_out |
| | self.value_cache[layer_idx] += v_out |
| | return k_out, v_out |
| |
|
| | def _static_update(self, layer_idx,cache): |
| | self.chunk_cache[layer_idx] = cache |
| | return |
| | |
| | def _get_chunk_cache(self,layer_idx): |
| | self.chunk_cache.setdefault(layer_idx,None) |
| | return self.chunk_cache[layer_idx] |
| |
|
| | def update( |
| | self, |
| | key_states: torch.Tensor, |
| | value_states: torch.Tensor, |
| | layer_idx: int, |
| | cache_kwargs: Optional[Dict[str, Any]] = None, |
| | ) -> Tuple[torch.Tensor]: |
| | cache_position = cache_kwargs.get("cache_position") |
| | sliding_window = cache_kwargs.get("sliding_window") |
| |
|
| | |
| | |
| | if self.key_cache[layer_idx].device != key_states.device: |
| | self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) |
| | if self.value_cache[layer_idx].device != value_states.device: |
| | self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) |
| |
|
| | k_out = self.key_cache[layer_idx] |
| | v_out = self.value_cache[layer_idx] |
| | key_states = key_states.to(k_out.dtype) |
| | value_states = value_states.to(v_out.dtype) |
| |
|
| | if sliding_window: |
| | update_fn = self._sliding_update |
| | else: |
| | update_fn = self._static_update |
| |
|
| | return update_fn( |
| | cache_position, |
| | layer_idx, |
| | key_states, |
| | value_states, |
| | k_out, |
| | v_out, |
| | k_out.shape[2], |
| | ) |
| |
|
| | def get_max_cache_shape(self) -> Optional[int]: |
| | return self.max_cache_len |
| |
|
| | def get_seq_length(self, layer_idx: Optional[int] = 0): |
| | |
| | |
| | |
| | if layer_idx != 0: |
| | raise ValueError( |
| | "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " |
| | "Using the `layer_idx` argument is not supported." |
| | ) |
| | return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
| |
|
| | def reset(self): |
| | """Resets the cache values while preserving the objects""" |
| | for layer_idx in range(len(self.key_cache)): |
| | |
| | self.key_cache[layer_idx].zero_() |
| | self.value_cache[layer_idx].zero_() |
| |
|
| | @property |
| | def batch_size(self): |
| | logger.warning_once( |
| | f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " |
| | "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." |
| | ) |
| | return self.max_batch_size |
| |
|