From aaa5f00ff9aa12531264e2ff0d029bf727f634a4 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Wed, 15 Jan 2025 12:38:28 -0800 Subject: [PATCH] [Bug] HybridCache not subscriptable (#1047) Use transformers `Cache` interface rather than legacy tuples (still included for backwards compatibility). - When we overflow the size of StaticCache, HybridCache, reallocate a cache with double the size - Other fixed-size caches will just raise a warning and delete the cache until we adapt doubling logic to those cache types - Use `Cache.crop` when available for backtracking the cache - When `Cache.crop` is unavailable, try `Cache.reset` to avoid reallocation, finally falling back on deleting the cache --- guidance/models/transformers/_transformers.py | 67 +++++++++++++++++-- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/guidance/models/transformers/_transformers.py b/guidance/models/transformers/_transformers.py index 19b73a038..93edd531e 100644 --- a/guidance/models/transformers/_transformers.py +++ b/guidance/models/transformers/_transformers.py @@ -409,7 +409,7 @@ def __init__(self, self.model = model.__class__.__name__ self.device = self.model_obj.device # otherwise note the current device - self._past_key_values = None + self._past_key_values: Union[transformers_package.Cache, tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]], None] = None self._cached_logits = None self._cached_token_ids: list[int] = [] @@ -479,13 +479,66 @@ def get_logits(self, token_ids): # reset the cache length according to that number of positions past_key_values = self._past_key_values - past_length = past_key_values[0][0].size(-2) if past_key_values is not None else 0 - if past_length > num_cached: - # note we recompute the last token because we don't bother to handle the special case of just computing logits + max_cache_shape = None + if past_key_values is None: + past_length = 0 + elif isinstance(past_key_values, tuple): + past_length = past_key_values[0][0].size(-2) + elif isinstance(past_key_values, transformers_package.Cache): + # TODO: use model's `cache_position` as this may be deprecated in a future version + # https://github.com/huggingface/transformers/blob/70b07d97cf2c5f61fff55700b65528a1b6845cd2/src/transformers/cache_utils.py#L64 + past_length = past_key_values.get_seq_length() + # TODO: use `get_max_cache_shape` as `get_max_length` will be deprecated in a future version + # (`get_max_cache_shape` is not yet available so we can't use it yet) + # https://github.com/huggingface/transformers/blob/70b07d97cf2c5f61fff55700b65528a1b6845cd2/src/transformers/cache_utils.py#L67 + max_cache_shape = past_key_values.get_max_length() + else: + raise TypeError(f"Unknown type of past_key_values: {type(past_key_values)}") + + if max_cache_shape is not None and len(token_ids) > max_cache_shape: + # TODO: this seems to get set to the length of the first sequence we pass for models using + # StaticCache or HybridCache. We need to initialize our own cache with a large enough size + # if we want to continue generation with the same cache. + if isinstance(past_key_values, (transformers_package.StaticCache, transformers_package.HybridCache)): + # The __init__ API isn't consistent between different cache types, but there seems to be consistency + # between these two types, so we can use the same logic for both. + warnings.warn("Cache is too small. Re-initializing cache with larger size.") + cache_type = type(past_key_values) + config = self.model_obj.config + device = self.model_obj.device + hf_device_map = getattr(self.model_obj, "hf_device_map", {}) + # hf_device_map is not always a complete mapping of layers to devices... + layer_device_map = {k: hf_device_map.get(k, device) for k in range(config.num_hidden_layers)} + self._past_key_values = cache_type( + config=config, + batch_size=past_key_values.batch_size, + # Double the cache size to be safe + max_cache_len=len(token_ids)*2, + dtype=past_key_values.dtype, + layer_device_map=layer_device_map, + ) + else: + warnings.warn(f"Cache is too small. Resetting cache (no method implemented to resize cache for type {type(past_key_values)}).") + self._past_key_values = None + past_length = 0 + elif past_length > num_cached: past_length = max(0, num_cached - 1) - self._past_key_values = tuple( - tuple(p[..., :past_length, :] for p in v) for v in past_key_values - ) + if isinstance(past_key_values, tuple): + self._past_key_values = tuple( + tuple(p[..., :past_length, :] for p in v) for v in past_key_values + ) + else: + if hasattr(past_key_values, "crop"): + self._past_key_values.crop(past_length) + else: + warnings.warn(f"Cropping unsupported for cache type: {type(self._past_key_values)}. Resetting cache.") + if hasattr(self._past_key_values, "reset"): + # Use built-in reset method if available to avoid constructing/allocating a new cache + self._past_key_values.reset() + else: + self._past_key_values = None + past_length = 0 + cache_token_ids[past_length:] = [] # call the model