Cold-Compress 1.0: A Hackable Toolkit for KV-Cache Compression

Blog
Author

Griffin Adams, Faisal Ladhak

Published

August 1, 2024

Answer.AI is proud to publicly release Cold Compress 1.0, a toolkit for creating and benchmarking state-of-the-art KV cache compression methods. Built on top of GPT-Fast, it is hackable, lightweight, open-source, torch.compilable, and performant.

The KV cache, which stores precomputed activations for the entire context, is one of the most important, and most under-appreciated, parts of the language modeling pipeline. Google Gemini recently announced that they now support creating a context cache, and vLLM uses a page-based cache. KV cache compression can reduce the size of these stored activations. It is a critical component to reducing inference costs and has become an active area of research.

Despite its importance, there is currently no quick-and-easy way to compare methods or rapidly develop new ones.

Until now!

Cold Compress brings together recent advances in KV cache compression by providing a toolkit to make it easier to both integrate these advances into your own systems, or to go beyond them by trying out new approaches. All with minimal code and minimal overhead.

And it will continue to grow and improve.

Because cache compression is lossy, it’s important to know how the different techniques impact downstream performance. Therefore, Cold Compress includes support for a growing list of long context benchmarks. If you are ready to get started, feel free to skip to Part 2: The Cold Compress Toolkit, which describes this toolkit and how to use it right now.

Part 1: What is KV cache compression?

KV cache compression is a family of optimizations to support long-context generation in the transformer architecture.

When GPT-2 was first released, its context window was limited to 1k tokens. Now, models can read and generate much longer sequences. This allows models to coherently sustain long conversations (like at Cohere.ai), generate book-length reports, or summarize hundreds of pages (like Gemini Advanced).

To eliminate duplicate computation during inference, we cache previous tokens’ intermediate values in the KV Cache. Since the KV Cache’s memory usage grows linearly with the sequence length, compression can reduce memory bottlenecks and dramatically speed up inference.

Before getting into specific KV cache compression methods, let’s first discuss how the KV cache is created and updated during generation.

What is the KV Cache?

The standard decoder-only transformer is stateless. When generating the next token, it “re-reads” all previous tokens. The reading mechanism is the scaled dot product attention within the self-attention layer:

Eq1. Scaled dot product attention.

The query vector at the current timestep attends to the KV vectors of all prior tokens.

The KV vectors of past tokens—once generated—do not change. To optimize performance, KV vectors are cached and reused for each successive timestep. This KV Cache is our decoder-only transformer model’s “state”, but it is not a constant size. Since we append new keys + values to the cache at each timestep, its size grows linearly in the sequence length.

The Two Phases of Generation

Generation occurs in two distinct phases: the prefill and decoding phase.

The prefill is referred to as prompt encoding. The prompt is encoded in parallel with causal attention to produce KV vectors for each token, which are then used to pre-fill the KV cache.

The decoding stage follows the prefill. For each generated token, new KV vectors are computed and appended to the cache at each decoder layer. Attention is performed between the current query and KV cache, which contains the prompt and previously generated tokens.

We can write pseudocode for the two-step generation process1:

Show the code
n_layers = len(model.layers)
K_cache, V_cache = [[] for _ in range(n_layers)], [[] for _ in range(n_layers)]

def prefill(model, prompt):
    x = model.embedding(prompt)
    for i, layer in enumerate(model.layers):
        q, k, v = layer.qkv(x)
        # Multi-Head Attention
        x += MHA(q, k, v, mask=causal)
        x += layer.ffn(x)
        K_cache[i], V_cache[i] <- k, v
    return model.output(x[-1])

During prefill, the prompt is processed in parallel with causal attention. The KV vectors for the prompt are inserted into the KV cache at each layer. The prefill function returns the final hidden state of the last token, which will be used to generate the first token.

Show the code
def decode(model, next_token, max_new_tokens):
    tokens = [next_token]
    for _ in range(max_new_tokens - 1):
        x = model.embedding(next_token)
        for i, layer in enumerate(model.layers):
            q, k, v = layer.qkv(x)
            K_cache[i], V_cache[i] <- k, v
            x = MHA(q, K_cache, V_cache, mask=None)
            x += layer.ffn(x)
        next_token = sample(model.output(x))
        tokens <- next_token
    return tokens

The decoder takes in the first token generated from the prefill stage and decodes one at a time. Each new token’s KV vectors are inserted into the cache. Unlike the prefill, self-attention is performed between the current query token and the KV cache.

We combine the two methods for generation:

Show the code
prompt = "What is the capital of New York?"
logits = prefill(model, prompt)
next_token = sample(logits)
generated_tokens = decode(model, next_token, max_new_tokens=16)

The prefill stage is parallelizable and is primarily compute bound. This means that the input prompt can be processed all at once, in order to initialize the KV cache. On the other hand, the decoding stage is largely memory bound due to linear growth of the KV cache. This means that shrinking the size of the KV cache can grant us inference speedups at the decoding stage.

Due to its lack of parallelism, the decoding stage takes much longer than the prefill.

Why Compress the KV Cache?

As the context length increases, the size of the KV cache grows linearly. For long contexts and large batch sizes, the cache can far exceed the size of the model weights themselves. This explosive growth in memory footprint creates a cascade of challenges: it caps the maximum context length, potentially throttles inference speed, and can render state-of-the-art OSS models impractical on common hardware. KV Cache Compression mitigates these issues by reducing the size of the cache, allowing for more efficient long-context generation.

One of the most commonly used methods for KV cache compression is token dropping.

What is Token Dropping?

Let’s illustrate the process of token dropping with a simplified diagram:

Fig 1. Illustration of token dropping in KV Cache compression.

In this example, the KV Cache is fixed to 7 tokens. Once the sequence length surpasses 7, tokens are evicted to make room for new ones, and new incoming tokens are inserted into the evicted token’s now-empty memory slot. This token dropping mechanism allows us to maintain a constant memory footprint for the KV cache regardless of the sequence length.

To maintain a fixed size cache with token dropping, simple changes are needed:

Prefill

If the length of the prompt exceeds our fixed cache size, we first generate all KV vectors and then filter out |prompt| - cache size before inserting into the cache. We refer to this as Prompt Compression.

Decoding

Once the KV Cache has been filled, a single token is evicted at each timestep, freeing up space in the cache for the newly generated token. This is KV Cache Compression.

Prior research has more heavily focused on KV Cache Compression–a.k.a. compression at the decoding stage. However, both stages are crucial, especially in the common use case of long input text! We hope that our clear separation of these two steps will help clarify the landscape.

The Impact on on Attention

We’ve discussed how to implement compression, yet we haven’t explained how it affects the model. The best way to do this is to visualize the impact of compression on attention masking:

Fig 2. Illustration of how prompt compression and decoding eviction work to keep a fixed size cache.

The left hand side shows the attention map with green indicating un-masked attention, white indicating masked or future tokens, and red indicating dropped tokens.

This toy diagram shows a fixed cache size of 4 tokens. You can verify this by making sure the number of green (un-masked) cells in each row of the decoding phase is equal to 4.

Because the prompt has 6 tokens, we have to compress the prompt by removing 2 tokens: What’s, the, before pre-filling the KV cache. Subsequently, during decoding, each newly generated token overwrites an existing slot in the KV cache one-at-a-time.

Token dropping is permanent in that once a token is evicted from the KV cache, it never reappears. This column-wise sparsity is represented by each vertical red bar being unbroken.

Part 2: The Cold Compress Toolkit

Cold Compress is designed to simplify and centralize SOTA research in KV Cache compression. It supports existing published methods with proper attribution. In this process, we have unified some key abstractions under generic umbrella categories, to make navigating these methods easier (e.g., Heavy Hitter to cover {H20, Scissorhands, PyramidKV, etc.}).

Cold Compress 1.0 is flexible, modular, and compilable, supporting many SOTA methods:

Cold Compress was designed to be lightweight and easily hackable, which is why we built on top of GPT-Fast rather than complex engines, e.g., vLLM. Yet, we maintain the impressive performance of GPT-Fast by ensuring that all KV cache operations are static and torch compilable. The end result is a toolkit that strikes a balance between simplicity and performance, making it both accessible to all and performant enough for experimentation.

Fig 3. Without compression, memory grows linearly with sequence length leading to throttled inference speeds.

On 1 A100 GPU, we see 2-3x speed gains from using torch compile on book completion. In addition, if we add KV cache compression (e.g., --cache_strategy heavy_hitter), the cache size is fixed at 4,096 tokens, which allows us to maintain 70+ toks/sec inference at 64k tokens.

Cold Compress also comes with its own evaluation harness which supports a growing list of long context evals to assess performance regressions caused by cache compression, as well as support for “debuggingcompression—tracking the impact of compression on self-attention.

Scope

One way to delineate between the many types of KV Cache Compression is to categorize methods as either “training-free” or “training-required”.

Fig 4. Cold Compress 1.0 focuses on training free approaches for KV Cache compression.

Cold Compress supports…

For Cold Compress 1.0, we focus on the “training-free” methods. These techniques can be applied to any model at inference time without the need for additional training:

The primary focus of Cold Compress 1.0 is Token Dropping, which involves evicting unimportant tokens from the cache.

Note

MQA and GQA can be added to a multi-headed attention (MHA) model via fine-tuning, but since most LLMs (Llama-3.1, Gemma 2) are trained with KV head-sharing, e.g., MQA or GQA, we support these compression methods by default and view them as “training-free” cache size reductions in such cases.

In addition to Token Dropping, Cold Compress 1.0 also supports quantization of the KV cache. By reducing the precision of the stored values to 8-bits or 4-bits, quantization significantly decreases the memory footprint of the KV cache. This can lead to substantial memory savings while maintaining comparable performance.

Why Focus on Token Dropping?

Token Dropping, with its simplicity and broad applicability, is a useful gateway to more complex methods. Token dropping induces attention sparsity at inference time and may offer insight into which sparsity patterns are most aligned with pre-trained models.

Cold Compress does not (currently) support…

Cold Compress plans to support method that require model adaptation via fine-tuning:

Getting Involved

We’re excited about the potential of Cold Compress and we believe in the power of community-driven development. We’d love for you to get involved and help improve Cold Compress for future releases. Here’s how you can get involved:

  1. Join the Conversation: Participate in discussions of KV Cache Compression on our public discord channel. Share your ideas, ask questions, and collaborate with fellow researchers and developers.
  2. Report Issues: Found a bug or have a suggestion for improvement? Raise an issue on our GitHub repository. Your feedback is crucial for making Cold Compress better.
  3. Contribute Code: Have a great idea for a new feature or optimization? We welcome pull requests! Whether it’s a small fix or a major enhancement, your contributions can make a significant impact.
  4. Submit New Compression Methods: [Coming soon!] We’re working on a public leaderboard built with FastHTML. Soon, you’ll be able to submit your own compression methods and see how they stack up against existing techniques.

Getting Started

Tip

For detailed instructions to help you get started coding, please refer to the repo README.

The next section mostly explains the methods implemented & connects them to bits of code.

Modular Architecture

  • Base Class: All shared cache logic is handled by a KVCache base class.
  • Custom Strategies: Each custom cache strategy only needs to specify an eviction policy.
  • Separated Logic: Logic for prompt compression (prefill) and cache evictions (decoding) is kept separate. The KVCache class handles the latter and PromptCompressor the former. Separation allows us to create static compilable functions for each. It also allows for users to mix-and-match prefill and decoding strategies, or focus on optimizing one part. For each KV Cache strategy, we implement an analogous one for prompt compression.
  • Command Line Flexibility: You can invoke cache configs from the command line, making it easy to tune hyper-parameters and mix-and-match strategies for each layer.

Strategies Supported

Despite subtle nuances, existing compression methods differ with respect to how they identify important tokens. For ease of understanding, we group them into non-exclusive categories:

  1. Position-Based: Prioritizes initial and recent tokens.
  2. Attention-Based:Tokens which have received high attention should not be evicted.
  3. Attention-Proxy: KV Cache methods which require full attention scores cannot take advantage of efficient attention implementations, e.g., FlashAttention-2, which never form the full attention matrix. Recent work has looked into identifying correlates of high attention tokens, such as the L2-Norm of Key vectors.
  4. Hybrid: Each decoder layer and each attention head have unique attention patterns. Ideal compression ratios may vary by layer and ideal eviction strategies across attention heads.

Full Cache (--cache_strategy full)

The Full Cache strategy serves as our baseline, without any compression. For this strategy, each new token is inserted into the first unfilled cache position.

class KVCacheFull(KVCacheHeadConstant):
    ...
    def _eviction_idx(self, input_pos):
        # Select the first unfilled slot
        return self.pos[0, 0].argmin().view(1)

self.pos stores the original positions of each token in the cache and is initialized with -1s.

Recent Global (--cache_strategy recent_global)

Research which introduces Recent Global KV Compression: Attention Sink, LM-Infinite, LongFormer.

It’s widely recognized that the first (“global”2) and most recent tokens (“local”) are worth saving.

Preserving recent tokens makes sense intuitively, but the importance of global tokens may seem less obvious. In practice, however, models often assign large attention weights to the first few tokens, using them as a “sink” for excess attention.

The attention pattern produced with “global” tokens and sliding window is A-shaped:

Inheriting the base class methods, RecentGlobal (aka StreamingLLM) requires only 1 new LoC:

class KVCacheRecentGlobal(KVCacheHeadConstant):
    ...
    def _eviction_idx(self, input_pos):
        return (
            torch.argmin(self.pos[:, :, self.global_tokens :], dim=-1)
            + self.global_tokens
        ).view(1)

Random (--cache_strategy random)

Random extends RecentGlobal by keeping a random set of tokens in the middle, in addition to recent and global tokens.

Random provides a lower bound on attention-based eviction methods. All else equal, random selection should perform worse than importance-based selection.

To implement, we overwrite a _token_importances method to assign random scores:

class KVCacheRandom(KVCacheHeadConstant):
    ...
    def _token_importances(self, input_pos):
        # Assign random importance
        scores = torch.rand(
            self.max_cache_length, device=input_pos.device
        )

        # Protect Recent Tokens
        recent_mask = self.pos[0, 0] >= input_pos - self.recent_window
        scores[recent_mask] = float("inf")
        return scores
Tip

For this strategy, we implemented a _token_importances not _eviction_idx function. Either are valid. _token_importances is more suitable for methods which define token-level importances.

Heavy Hitter (--cache_strategy heavy_hitter)

Research which introduce / use Heavy Hitters: Scissorhands, H20, PyramidKV, SnapKV

A popular class of methods use historical attention scores to identify important, i.e., “Heavy Hitter” tokens, that are more heavily or frequently attended to.

These papers rely on the empirical finding that tokens which have been important will remain important, e.g., “the persistence of importance”:

Fig 6. Figure from Scissorhands

Apart from subtle nuances, each Heavy Hitter method follows a similar blueprint:

1. Record attention scores at each decoding timestep

2. Evict tokens with the lowest accumulated score

To perform 1., we write an optional update_state function for our KVCacheHeavyHitter class. For 2., we rely on these cached attention histories:

Show the code
class KVCacheHeavyHitter(KVCacheHeadSpecific):
    ...
    def _eviction_idx(self, input_pos):
        # Computes average historical attention
        numerator = self.attn_history_num.sum(dim=-1).float()
        denominator = self.attn_history_denom.clamp(1, self.history_window_size)
        avg_attn = numerator / denominator

        # Save the global & most recent tokens from being evicted
        avg_attn.masked_fill_(
            torch.logical_or(
                self.pos < self.global_tokens,
                self.pos >= input_pos - self.recent_window,
            ),
            1.0,
        )
        avg_attn.masked_fill_(self.pos == -1, 0.0)

        fill_idx = avg_attn.argmin(dim=-1).squeeze()
        ...

Evicting tokens with the lowest historical attentions makes sense for the decoding phase. Yet, if we need to compress the prompt during prefill, we don’t have any history over which to accumulate attention scores. To get around this issue, for Prompt Compression with heavy hitters, we rely on SnapKV (implemented as PromptCompressorHeavyHitter) which divides the prompt into a prefix and an observation:

Fig 7. SnapKV forms an observation window over which to compute Heavy Hitters for prompt compression.

SnapKV computes cumulative scores of the prefix (“This is a very long”) over the observation window (“prompt .”) when considering which tokens to drop. We keep the observation window and the prefix tokens with the highest cumulative attention.

Hybrid (–cache_strategy hybrid)

Research which introduces Hybrid Strategies: FastGen, MInference.

The above methods implement the same strategy for each attention head.

Yet, it’s known that attention heads perform different functions and focus on different things.

The FastGen paper analyzes attention and identified 5 types of heads with hybrid focuses:

  1. Special tokens
  2. + punctuation
  3. + recent tokens
  4. + heavy hitters
  5. Full

We can potentially get more bang for our buck by having eviction strategies be specialized for each attention head. To support a heavily customizable FastGen, we have created a KVCacheHybrid class! We support any user defined combination of attention head strategies.

Strategies can be chained together (logical or) by concatenating the string names: e.g., “window_heavy_hitter” will protect heavy hitter tokens as well as recent tokens. Each type of attention head can be specified with its own hyper-parameters, as shown in this config file.

Attention Head Profiling

How can we optimally match head types with each attention head?

FastGen offers a neat solution! They first introduce a term called “attention loss”:

Fig 8. Depiction of ``Attention Loss’’ from FastGen.

Attention loss is the sum of the attention probabilities for the evicted or excluded tokens.

Each attention head is assigned to the strategy which satisfies a user-defined minimum recovery ratio (1 - attention loss) at the lowest memory cost. As in FastGen, we perform this logic separately for each prompt during the prefill stage (although it can be done offline).

L2-Norm (–cache_strategy l2)

Research which introduces L2 Compression: L2-Norm Compression

FlashAttention-2–and most optimized attention implementations–don’t realize the full attention map. As such, attention-based eviction methods may not practical. A natural question is–can we choose good eviction strategies in advance without observing the full attention map? A recent paper linked the L2-norm of the key vector and attention:

Show the code
class KVCacheL2(KVCacheHeadSpecific):
    ...
    def _token_importances(self, input_pos):
        # 1. Lowest l2 norms have high importance (- self.key_norm)
        # 2. Lowest score needs to be > -1 : we evict unfilled tokens first (+ max value such that min score is 0)
        # 3. Save Recent Window (+ inf)
        return (
            (self.key_norm.max() - self.key_norm)
            .masked_fill(
                self.pos >= input_pos - self.recent_window,
                float("inf")
            ).squeeze(0)
        )

KVCacheL2 prioritizes recent tokens, followed by tokens whose key vectors have low L2-norms.

Benchmarking

Tasks

Cold Compress includes an evaluation harness to allow for reproducible benchmarking across a growing list of long-context tasks, including summarization, domain-specific tasks, coding, question-answering, synthetic retrieval, etc.

Llama-3 8B

We use Llama-3 8B Instruct for some lite benchmarking, which can be replicated here.

The full set of results can be found in our repository.

Overall, we find performance tends to decline as compression ratio increases, though the rate of decline varies significantly between different methods and tasks.

Tip

As noted in most prior work, there’s usually a task-specific sweet spot with the compression ratio (~0.4-0.6) in which there’s minimal performance degradation, but significant compression is achieved.

Let’s dive deeper into some of the more intriguing insights from our benchmarks:

Heavy Hitter: The consistent performer

Fig 9. Heavy Hitter tends to consistently outperform other token dropping methods using Llama-3-8B.

Heavy Hitter tends to perform well across different tasks, significantly outperforming other methods. The gains are particularly pronounced at high compression ratios. Our benchmarks on Musique and RulerQA illustrate this trend clearly. While most methods experience a sharper performance decline beyond 50% compression, Heavy Hitter maintains its effectiveness, showing minimal degradation even at extreme compression ratios (>90%).

Note

This pattern suggests that tokens receiving high attention scores in the past are likely to be relevant in future computations, and prioritizing the retention of these tokens is helpful.

Simplicity vs. Complexity

While more complex strategies often promise better performance, our benchmarks reveal some surprising insights about the trade-offs between simple and sophisticated approaches.

Fig 10. Recent Global performs well given it’s simplicity, often getting close to the performance of Heavy Hitter.

The Recent Global approach performs reasonably well across the different tasks, given its simplicity. The Hybrid strategy, conversely, which combines multiple approaches for each attention head—including Heavy Hitter, Recent Global, and Full (No Compression)—does not consistently outperform the Heavy Hitter-only approach.

Note

The complexity of Hybrid approaches, which require hyper-parameter tuning for each variant, may hinder their out-of-the-box performance. We encourage researchers to experiment with novel combinations using the KVCacheHybrid class, and even explore fine-tuning with hybrid sparsity.

Remarkably, our results show that Random—randomly dropping 20-30% of tokens across most tasks—leads to only minimal performance degradation. In some cases, such as the SQuALITY figure shown above, we can even drop as much as 50% of the tokens with minimal degradation. This showcases the remarkable ability of LLMs to infer missing information from context on certain tasks, effectively filling in the gaps based on surrounding tokens.

Not all Tasks are Equally Compressible

Fig 11. Tasks requiring broader contextual understanding like QMSum are easier to compress than tasks like RulerNIAH, which require retrieval specific information.

We also observe task-dependent variations in compression tolerance. Tasks requiring broader contextual understanding, such as summarization (e.g., QMSum and SQuALITY), maintain high performance even at 60-70% compression ratios. In contrast, tasks relying on specific information retrieval, like the “needle in a haystack” task RulerNIAH, experience sharp performance drops if compression does not exploit attention patterns, e.g., Heavy Hitter.

Note

There is a complex interplay between the nature of the task, compression strategies, and model performance, which calls for task-aware optimization of context compression in LLMs.

Qwen 2 7B

To get a sense of robustness, we perform the same experiments on Qwen-2 7B Instruct.

Fig 12. Heavy Hitter’s performance degrades more on Qwen2 than Llama-3 due to architectural differences.

Interestingly, the Heavy Hitter is significantly worse than full KVCache, especially for RULER synthetic retrieval. We hypothesize this disparity can be attributed to Qwen-2 having \(\frac{1}{2}\) the number of KV heads as Llama-3 (4 vs. 8). Given that Heavy Hitter relies heavily on specialization of attention heads, this difference has a notable impact on its performance.

Note

Heavy Hitter compression may be less beneficial for architectures employing increased KV head sharing across query heads, e.g., GQA. These findings highlight the importance of considering model architecture when selecting cache compression methods.

Quantization

In addition to the token dropping strategies we’ve discussed, our benchmarks also included quantization as a compression method. In particular, we experimented with 8-bit and 4-bit quantization of KV Cache, which results in 50% and 75% compression respectively.

Our findings reveal that KV Cache quantization performs consistently well across various tasks, with minimal performance degradation even at 4-bit quantization (75% compression).

Fig 13. Quantization consistently performs well across the different tasks, with minimal performance degration over the full, unquantized KV Cache.

This consistency is exemplified by its performance on domain-specific reasoning (Dolomites). While all token dropping methods experience performance drops beyond 40% compression, 4-bit quantization manages to match the performance of the unquantized, full KV Cache.

On the other hand, quantization alone has limitations in terms of compression potential. Although support for 2-bit quantization, naively quantizing the KVCache to such low precision leads to dramatic drops in performance due to gibberish in the generated outputs. In contrast, token dropping methods like Heavy Hitter can remain effective on certain tasks at much higher compression levels, as demonstrated in the RulerQA results above.

Note

These quantization results suggest that reducing precision, rather than discarding tokens, can be a remarkably effective approach to managing the memory footprint of large language models during inference. The robust performance of quantization across our diverse suite of tasks underscores its potential as a valuable alternative or complement to token dropping strategies.

Looking ahead, we plan to support mixed precision caches in the next release of Cold Compress. This enhancement will allow for higher compression ratios while mitigating the drawbacks of extreme low-bit quantization. Additionally, hybrid compression strategies—quantization + selective token dropping—may allow for even higher levels of compression.

Final Thoughts

Beyond the raw numbers, it can be useful to share more high-level insights:

Token Dropping isn’t One-Size-Fits-All

The strategy of permanently discarding certain tokens intuitively seems risky, particularly for extended inference scenarios such as open-ended chat bots. In these contexts, topics can shift abruptly, potentially rendering previously discarded information crucial.

This input sensitivity likely explains the varying degrees of accuracy degradation observed. Moreover, token dropping exhibits significant performance deterioration over long sequences. To quantify this effect, we evaluated perplexity at various generation sequence lengths using a PG-19 book corpus completion task (--tasks pg19):

Fig 14. ``Attention Loss’’ increases as a function of the number of generation steps.

This chart illustrates the difference in Perplexity compared to no compression for 3 different compression ratios: low (25%), medium (50%), and high (75%). Initially, there is a minimal increase in PPL over no compression for up to 1,000 decoding steps. However, beyond this point, the gap starts to widen steadily. High compression ratio (75%) exhibits the steepest slope. We include the previously discussed Attention Loss, which tracks the sum of the attention probabilities for discarded tokens. As we evict more tokens, both the attention loss and the performance gap (PPL Δ) rise correspondingly.

Note

These findings underscore the importance of exploring recoverable sparsity–mechanisms that compress the KV cache without permanent token discarding.

Recoverable evictions can be achieved by moving unimportant tokens into a low-rank cache, offloading them to CPU, or quantizing them. Leveraging hardware memory hierarchies is promising, though it will require a precise retrieval mechanism to minimize cache misses.

Given the risks with assigning non-uniform importance to prior states, you can also learn compressive memories (Infini-attention, GIST, LoMA, Multi-head Latent Attention (MLA)), or in the case of State Space models (SSMs), do away entirely with the KV Cache in favor of linear recurrence. However, these modifications are still undergoing rigorous testing and validation.

Untuned or Unfit?

Token dropping KV cache compression methods are zero-shot methods, meaning that they force a model to perform sparse attention for the very first time during inference. This may force models out-of-distribution as compared to their training setup.

For many of the methods we have explored, a key question arises:

Are they fundamentally flawed (unfit) or simply unadapted (untuned)?

Some methods might benefit from additional hyper-parameter tweaks or fine-tuning.

One such area is position embeddings. GPT-Fast, and Cold Compress by extension, caches KV values with their original positions by performing RoPE rotations before inserting into the cache. It is unclear, however, whether or not RoPE embeddings should be based on relative position in the cache instead. Adapting positional embeddings for KV cache compression could help narrow the gap with uncompressed models, warranting further investigation.

Re-Thinking How to Compute Heavy Hitters

Heavy Hitter methods rely on historical attention scores to prioritize evictions.

Yet, many efficient attention implementations, such as Flash Attention, intentionally avoid materializing the full attention map. While GPT-Fast experiences no slowdown due to its reliance on torch.compile, Heavy Hitters may not be feasible for inference engines like vLLM.

It may be beneficial to decouple QKV self-attention from token salience detection, potentially learning smaller, separate parameters for the latter.

Additionally, Heavy Hitter methods perform worse on models with more aggressive GQA (Qwen-2 versus Llama-3) because these methods rely on attention head specialization. As more LLMs adopt reduced KV caches through techniques like GQA, local-global attention, or state tying, further compression at inference time becomes increasingly challenging.

Working at the intersection of pre-training architectures, fine-tuning, and post-training modifications will be a focus of the next release of Cold Compress!

Acknowledgments

We would like to extend a special thank you to our collaborators at EleutherAI and Meta.

At EleutherAI, we want to call out the wildly talented Hailey Schoelkopf for her ongoing, savvy guidance on this project.

At Meta, we want to thank the brilliant creators of GPT-Fast within the PyTorch Team for creating such a performant, yet easy to hack, inference engine in native PyTorch. It is a tremendous piece of open-source software and we are thrilled to be able to showcase it!

In particular, we want to call out Yanbo Liang and Horace He, who have been invaluable advisors in helping us navigate the quirks of torch.compile. We are grateful for their continued interest in Cold Compress, and are excited to see them showcase our work at the PyTorch conference in September!

Citation

@misc{cold-compress-2024,
  title={Cold Compress: A Toolkit for Benchmarking KV Cache Compression Approaches},
  author={Adams, Griffin and Ladhak, Faisal and Schoelkopf, Hailey, and Biswas, Raja},
  month=8,
  year=2024,
  version={v1.0},
  url={https://www.answer.ai/posts/2024-08-01-cold-compress.html}
}

gd2md-html: xyzzy Wed Jul 31 2024

Footnotes

  1. Certain details, e.g., layer norm, have been intentionally removed for simplicity.↩︎

  2. Here, “global” is an homage to the LongFormer, which referred to lead tokens as global because they were un-masked during self-attention and thus had “global” reach.↩︎