| from typing import Any, Dict, Optional, Tuple |
| from transformers.cache_utils import DynamicCache |
| import torch |
|
|
|
|
| class DynamicCacheWithQuery(DynamicCache): |
| """ |
| Cache class used for QRRetriever; |
| LJN: put the query states in the cache_kwargs to keep the same signature as DynamicCache |
| LJN: please take the query states from the cache_kwargs |
| """ |
|
|
| def __init__(self, query_indices=[]) -> None: |
| super().__init__() |
| self._query_indices = query_indices |
| self.query_cache = [] |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| |
| Parameters: |
| key_states (`torch.Tensor`): |
| The new key states to cache. |
| value_states (`torch.Tensor`): |
| The new value states to cache. |
| layer_idx (`int`): |
| The index of the layer to cache the states for. |
| cache_kwargs (`Dict[str, Any]`, `optional`): |
| Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
| |
| Return: |
| A tuple containing the updated key and value states. |
| """ |
| |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
|
|
| |
| if key_states is not None: |
| if len(self.key_cache) <= layer_idx: |
| |
| for _ in range(len(self.key_cache), layer_idx): |
| self.key_cache.append(torch.tensor([])) |
| self.value_cache.append(torch.tensor([])) |
| self.key_cache.append(key_states) |
| self.value_cache.append(value_states) |
| elif ( |
| not self.key_cache[layer_idx].numel() |
| ): |
| self.key_cache[layer_idx] = key_states |
| self.value_cache[layer_idx] = value_states |
| else: |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
| if cache_kwargs is not None: |
| query_states = cache_kwargs.get("query_states", None) |
| else: |
| query_states = None |
| if query_states is not None: |
| if len(self.query_cache) <= layer_idx: |
| self.query_cache.append(query_states) |
| else: |
| self.query_cache[layer_idx] = torch.cat([self.query_cache[layer_idx], query_states], dim=-2) |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
| @classmethod |
| def from_legacy_cache_with_query_indices(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, query_indices = []) -> "DynamicCache": |
| """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for |
| backward compatibility.""" |
| cache = cls(query_indices=query_indices) |
| if past_key_values is not None: |
| for layer_idx in range(len(past_key_values)): |
| key_states, value_states = past_key_values[layer_idx] |
| cache.update(key_states, value_states, layer_idx) |
| return cache |
|
|