How do we teach LLMs to learn what they can forget ? A technical break down of the Self-Pruned KV Attention mechanism with fun ML techniques ⤵️
The SP-KV mechanism works by projecting a token's hidden state into a sigmoid to get a scalar "future utility" score for every attention KV pair. At inference, key-value pairs under a certain utility threshold get discarded, freeing up KV cache memory (see the paper thread for details).
Training is a lot more involved! ML experts might recognize that conditionally using a KV pair or not based on the utility value is a non-differentiable if/else where gradient cannot flow. We approach this problem 2 ways:
First, during training, we model the SP-KV mechanism as an attention mask bias. A KV utility value of 0 gets an attention mask bias of -infinity, preventing the key to be used in the attention computation (like the causal mask). A KV utility of 1 gets a bias of 0, enabling full use. Instead of binarizing the utility like at inference, we set the bias to log(utility), enabling us to softly interpolate between the two extremes. This enables full gradient flow: both the signal of "should a KV be more used" and "should a KV be less used" gets preserved. By optimizing NTP, the utility predictor implicitly learns to distinguish KVs that should be memorized from others, and WITHOUT OTHER INCENTIVES, the model learns to use only a fraction of KVs for long range attention. This is basically an extra degree of freedom in attention that improves the loss with respect to standard attention.
Importantly, the utility prediction can choose to open all gates, yielding a mechanism that is equivalent to Vanilla attention. During training, we often start from a normal checkpoint, bias the utility predictor to initialize with all gates open, then let the model self-sparsify and learn to leverage this extra DOF.
Since we still want to threshold at inference to remove low utility KVs, this "soft gating" would create a slight train/test discrepancy. Towards the very end of training, we freeze the utility predictor (detaching the gradient), and binarize the utility value with a given threshold like we would at inference. The mask bias is thus not continuous anymore (either 0 or -inf), which corresponds to the test time regime, which we let the LLM adapt to for a little while. Interestingly, the threshold value we choose at this point enables us to tradeoff performance for sparsity (0.5 is the default, but 0.7 would lead to more forgetting at a slight performance cost).
A fun alternative to soft gating with log(utility) is what we call "hard gating". To reduce test-time discrepancy, we want to train with binary open/close gate decisions (a KV is either used in the attention or not, no in between). To preserve gradient flow, we rely on a Straight-Through Estimator: the forward pass is binarized like at inference, while the backward pass acts as a soft gate with the log scaling.
To preserve the counterfactual gradient of closed gates (what would have happened were this gate were to be open), open or close decisions are made with a Bernoulli sampling of the utility (a KV with predicted utility 0.1 still has a 10% chance of being opened during the forward pass). The combination of Bernoulli sampling and the STE basically simulates soft gating while reducing test-time discrepancies. In the paper, we mostly use soft gating but I thought this was a fun bit of ML to discuss.
The tl;dr here is that SP-KV is just an extra degree of freedom the attention mechanism learns to use to shut down gates in long range attention. It can be learned implicitly and improves the loss, but forcing a binary decision induces a slight test-time discrepancy we need to mitigate!
https://arxiv.org/abs/2605.14037