While Top-K effectively captures the ranking of tokens, it falls short in accurately estimating the full distribution of attention scores. When the distribution is not optimally sparse, significant estimation errors can arise. Incorporating knowledge of the distribution can improve accuracy rather than relying solely on the keys and values with the highest scores. We address this issue as a bias correction problem in sampling. Unbiased and efficient sampling techniques, extensively studied in fields like biology, sociology, and machine learning, serve as a foundation for our approach.
Instead of using the KV cache with the largest attention scores (TopK), the attention output can be estimated using the following process.
Left and Middle: Oracle sampling estimation can significantly reduce numerical error compared to TopK attention. The evaluated context size is 16k. The x-axis is the Oracle sampling budget and the TopK attention computation budget. Notice that the estimation error of TopK attention will cross oracle sampling after a certain large budget (12k in figures). This is because Oracle sampling will repetitively sample the same subset of tokens with a high probability, while TopK will not. Right: Downstream comparison for Oracle sampling estimation and TopK attention. The x-axis for both methods is the computation budget ratio, i.e., the fraction of selected/sampled tokens.
Oracle sampling performs significantly better than TopK attention in downstream tasks and estimation error. Oracle sampling maintains full accuracy in reasoning tasks with merely 0.005% sampled keys and values, far beyond the effectiveness of TopK.
However, just like TopK, oracle sampling requires accessing the precise attention scores, which is impractical without actually computing attention scores. In the next section, we will introduce, how MagicPIG solves the problem, via self-normalized importance sampling and locality-sensitive hashing.