MagicPIG leverages sampling for efficient attention estimation, using CPUs as augmented memory (V100 + MagicPIG > A100), and highlights the capability of processing complex tasks, such as long-context reasoning, compared to TopK attention.
MagicPIG improves decoding throughput in various regimes by 1.76-4.99x compared with GPU-only attention. Here, we use L20 GPUs with Intel 8563C to simulate different sizes of VRAM.
MagicPIG achieves higher downstream accuracy in retrieval and reasoning tasks than Quest, the state-of-the-art baseline with lower computation. K10L150 and K11L300 correspond to 2% computation cost while K10L170 corresponds to 2.5%.
Our project is in active development! See QA and Future work for more information!
Serving long-context LLMs is challenging due to the unique bottleneck in an auto-regressive generation—the key-value (KV) cache, which stores intermediate attention keys and values to avoid re-computation.
Specifically, the KV cache grows linearly with batch size and sequence length, occupying substantial GPU memory and limiting the maximum batch size, thus leading to underutilization of GPU computational power. For instance, an NVIDIA A100-40GB GPU can only handle a single request for Llama-3.1-8B with a 128k context length, with nearly half of the decoding time spent accessing the KV cache and poor GPU utilization. Inference time strategies, such as Best-of-N and long chain-of-thoughts, exacerbate the situation by increasing generated tokens.
While it is widely recognized that attention mechanisms are inherently sparse, dynamic sparse attention and TopK-based approximations have been extensively explored. However, these methods often suffer from significant quality degradation. Existing KV cache compression techniques, such as Quest, H2O, and Loki, primarily focus on identifying subsets of the KV cache that yield the highest attention scores. Despite their efficiency, TopK-based attention remains a biased approximation and lacks theoretical guarantees.
The figure on the right demonstrates that even exact TopK attention leads to substantial estimation errors and downstream task performance degradation. This issue becomes even more pronounced in complex tasks that require high context utilization, such as aggregation tasks, common word extraction (CWE), frequent word extraction (FWE), and reasoning tasks. In such scenarios, the performance degradation caused by TopK-based approximations is particularly severe. We discuss this in The failure of TopK attention.
To tackle this problem, we leverage sampling instead of searching for the TopK key-value cache.
Instead of relying solely on keys and values with the highest attention scores, incorporating information about the underlying distribution can significantly improve estimation accuracy. We approach this as a bias correction problem in sampling. Unbiased and efficient sampling techniques have been extensively studied in fields such as biology, sociology, and machine learning, and they come with strong theoretical guarantees. The figure on the right highlights that sampling values proportionally to their corresponding attention scores—referred to as oracle sampling—results in a much lower estimation error, up to 4x smaller compared to the naive TopK selection approach. How sampling can help attention approximation? and Importance sampling, Locality Sensitive hashing, and MagicPIG discuss oracle sampling and the approximation by MagicPIG using hashing algorithms.
In addition to the accuracy degradation challenge, limited GPU memory capacity restricts the applicability of existing dynamic KV cache compression methods (e.g., Quest, Loki) in many scenarios. At the same time, approaches like DeepSpeed-Zero-Inference and FastDecode demonstrate the potential to offload KV cache and attention computation to CPUs, utilizing their memory bandwidth, which is about 10-20% of GPU VRAM.
This raises a natural question:
Can we reduce memory access in attention computation by 10x without compromising accuracy?
By leveraging sampling algorithms, such as the LSH-based sampling used in MagicPIG for attention estimation, we substantially reduce memory access. This effectively simulates an improvement in CPU memory bandwidth, enabling efficient attention computation without accuracy degradation. System Codesign illustrates the system design of MagicPIG, including how attention is computed and the computation with regard to hashing in detail.
While deploying sampling-based estimation in attention shows great promise, several challenges remain. First, it is unclear how reducing attention estimation errors directly impacts downstream task performance. Second, efficient sampling requires modeling the attention score distribution, but inferring the distribution parameters can be computationally expensive. Third, fully leveraging the computational resources of modern hardware, such as GPUs and CPUs, while integrating a theoretically efficient sampling algorithm is highly non-trivial. Balancing accuracy, computational cost, and hardware efficiency is essential for practical deployment.
To answer these questions and introduce MagicPIG solutions, our blog is structured as
(1) The failure of TopK attention.
(2) How sampling can help attention approximation?
(3) Importance sampling, Locality Sensitive hashing, and MagicPIG
(4) System Codesign
[1] DeepSpeed-Zero-Inference https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md
[2] FastDecode: High-Throughput GPU-Efficient LLM Serving using Heterogeneous Pipelines https://arxiv.org/abs/2403.11421
[3] Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference https://arxiv.org/abs/2406.10774
[4] H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models https://arxiv.org/abs/2306.14048