A unifying view of linear attention (part 2)
Part 2 — Four Axes of Finite-Size Memory Models
1. Recap
Part 1 showed that linear attention and the other sub-quadratic models we discussed all fit
with different choices in . Softmax could be seen as a special form with an infinite state, or rather practically our “memory” is the entire growing KV cache, which gives us perfect recall using quadratic compute. If we pick a finite-dim feature map we get linear attention as a compressed form of softmax with fixed-size state in matrix form. We accept lossy recall to escape quadratic compute.
We also showed that the state with size holds at most near-orthogonal pairs cleanly and above that the cross-talk between the vectors drowns the signal.
The recurrence of DeltaNet is one gradient step on giving us
The main benefit is that seeing the same key twice doesn’t grow the associated value, and already stored orthogonal keys are kept. The capacity ceiling is still unchanged but the budget is used more cleanly.
Gating is introduced to exponentially decay the state via Mamba2’s and Gated DeltaNet’s . This doesn’t raise the capacity ceiling, but it lets us recycle the slots by decaying old writes so new ones can take their place.
The state was always a matrix and the update rule for DeltaNet followed from doing one step of gradient descent on . What if is instead a neural network with multiple layers? That’s the step that TTT and Titans take and build upon.
2. Test-time training (TTT)
Reframe the state update — sometimes referenced as inner loop — as one GD step on a parameterized inner model:
For with we get what we had before:
This is DeltaNet’s (Yang et al., 2024) Parallelizing Linear Transformers with the Delta Rule over Sequence Length Yang, Wang, Zhang, Shen, Kim · NeurIPS 2024 arXiv:2406.06484 update rule from the last part and also called TTT-Linear (Sun et al., 2024) Learning to (Learn at Test Time): RNNs with Expressive Hidden States Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin · 2024 arXiv:2407.04620 and it shows DeltaNet wasn’t a choice of a specific recurrence, it rather follows if we do one GD step on with being a single linear layer.
(Sun et al., 2024) Learning to (Learn at Test Time): RNNs with Expressive Hidden States Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin · 2024 arXiv:2407.04620 also introduces a second variant TTT-MLP with a richer : a two-layer MLP , , where we can also derive the gradients explicitly. Let , , , :
An explicit update rule for each weight matrix.
A note on the name.
In my opinion “test-time training” is misleading. What we have is an inner-loop state-update rule applied during forward inference, and not training in the conventional sense. There’s no held-out set, and no “done” condition. The outer parameters are trained once on a training set, the normal way, and are not changed during inference. What runs at inference is a per-token state update of what we can casually refer to as “memory”.
TTT-MLP doesn’t fit our recurrence form.
So far our general form rested on two assumptions
- the state is a single matrix
- the readout is linear in the state ( — a single matrix-vector product).
TTT-MLP breaks both. By having with a nonlinearity, there’s no that captures that. Our general form from Part 1 stops being expressive enough, but we will define a new general form across four axes.
3. Titans
Titans (Behrouz et al., 2024) Titans: Learning to Memorize at Test Time Behrouz, Zhong, Mirrokni · 2024 arXiv:2501.00663 extends TTT-MLP with two textbook optimizer ingredients: momentum and weight decay. Because of momentum our state is now the pair , and the update rule reads
Note: the Titans paper writes for the momentum decay scalar (not the parameters!) and for the momentum buffer. To keep as parameters across the post, we use for momentum decay and for the buffer (Adam-style first-moment convention). The Titans paper also writes the retention as ; we follow Mamba2 / Gated DeltaNet and use directly so means full retention across all gated models.
is an EMA of recent gradient directions (first moment), which acts as a low-pass on write directions (not on what’s been written).
is scalar retention on , structurally identical to Gated DeltaNet’s retention. Same idea as in Part 1, applied to all MLP parameters instead of a single matrix.
Both moves are standard SGD ingredients (momentum, weight decay) ported into the memory state update step. As this is just SGD on the memory state, anything we know about training neural nets is reusable here. We’ll see further such techniques later in this post.
4. MIRAS
We moved from the linear recurrence of a matrix to the general case where we optimise the parameters of some model on an inner objective that matches , with optional retention on . This gives us four design choices: (1) memory architecture , (2) inner objective , (3) inner optimizer, (4) retention on . We take them in turn.
Axis 1 — memory architecture .
Instead of having a single matrix like before we generalise our memory architecture to any model with parameters . We have seen already these special cases:
- Linear: — linear attention, DeltaNet, Gated DeltaNet.
- 2-layer MLP: — TTT-MLP, Titans.
This also includes all variants that use kernel mappings to increase the capacity of the state (Zhong et al., 2025) Understanding Transformer from the Perspective of Associative Memory Zhong, Xu, Ao, Shi · 2025 arXiv:2505.19488 .
Axis 2 — memory objective .
Next we define the loss which is optimised on. In the general case we want to reconstruct its associated value .
For the simple case , two objectives (dot-product, L2) generate two well-known models from the same one-step-GD recipe:
Axis 3 — inner optimizer.
This is how the next is computed from .
The default is one GD step. Linear attention, Mamba2, DeltaNet, Gated DeltaNet, and TTT-MLP all use this — what differs across them is the objective, not the optimizer. Titans is the first to iterate on the optimizer by introducing momentum. Later we will see Muon (ATLAS), the second optimizer move, also heavily used today when training LLMs.
Axis 4 — retention on .
We motivated this in Part 1 as a necessity to erase stale writes (key-value pairs that are not important anymore but take up space). MIRAS (Behrouz et al., 2025) It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization Behrouz, Razaviyayn, Zhong, Mirrokni · 2025 arXiv:2504.13173 further shows a derivation that this can be seen as regularization on our state in the form of adding to .
Whereas DeltaNet and TTT-MLP used none, Mamba2, Gated DeltaNet and Titans used a scalar . In §5 we will see KDA introducing an extension.
Where this leaves us.
The recurrence from Part 1 is the special case where the memory is linear (single matrix), the objective is dot-product or L2, the optimizer is one GD step, and retention is at most a scalar. The four axes generalize each of those choices. Every model in this post — including the two we still have to introduce — slots into a row of one shared table; we’ll fill it in at the end of §6 once KDA and ATLAS are placed.
5. Kimi Linear / KDA
So far retention was always a scalar that acted on the whole state but what if we make it a vector where each channel/dimension of gets its own retention scalar?
This is what Kimi Linear (Moonshot AI, 2025) Kimi Linear: An Expressive, Efficient Attention Architecture Moonshot AI · 2025 arXiv:2510.26692 introduces, giving us the following recurrence
Note that they write in their paper ; transposed to our convention .
The transition matrix on the right of is a DPLR matrix (diagonal-plus-low-rank, rank-1 here), supporting fast matrix-vector ops and a closed-form inverse via Sherman–Morrison.
With this per-channel gating, each controls the decay of the -th key-coordinate independently, giving the model fine-grained control over what to forget. Analogous to how the Adam optimizer has per-parameter learning rates.
Scalar gates forces the same forgetting rate across all key-coordinates, but different coordinates might encode information that ages at different rates (positional vs. semantic information). Diagonal lets the model trade off retention per coordinate.
MIRAS-lens. Same L2 reconstruction loss as DeltaNet, same one-step GD optimizer, only change is now the channel-wise retention as regularizer toward the channel-decayed prior:
The exact (closed-form) minimizer of is the KDA recurrence above, with reparameterization . (Derivation: set , invert via Sherman–Morrison.)
6. ATLAS
ATLAS (Behrouz et al., 2025) ATLAS: Learning to Optimally Memorize the Context at Test Time Behrouz, Razaviyayn, Zhong, Mirrokni · 2025 arXiv:2505.23735 introduces three independent changes vs. Titans that can each be placed on Axes 1, 2, and 3.
6.1 Omega rule — windowed inner objective (Axis 2).
Up to here every model’s memory loss has been a function of the current token only. ATLAS’s Omega rule sums the loss over a window of recent tokens, with per-token in-window gates :
For linear memory (, ), one GD step gives:
The edge case gives us DeltaNet/TTT-Linear, while recovers Mesa-layer-style global least-squares (which we will not cover further). Each token’s gradient enters the update sum at different timesteps. This is different from momentum, which stores the gradient once and lets it decay through a buffer — reusing a potentially outdated gradient direction. The Omega rule is more like mini-batch gradient descent on a sliding window, where the gradients are recomputed each step. Footnote: not quite true in practice due to chunkwise computation, introduced in Part 3.
6.2 Kernel feature maps (Axis 1).
ATLAS adds (or rather reintroduces from our view) a feature map on keys (and queries). The inner loss applies to instead of :
For polynomial of degree , the effective key dimension grows from to , and the capacity ceiling rises from to (paper’s Proposition 2). The exponential kernel — Taylor expansion of — is the limit and recovers softmax attention as the global-window special case (paper §4.2).
6.3 Muon (Axis 3).
Notation note: the ATLAS paper writes for parameters, for the momentum buffer, for the momentum decay scalar, and for the Newton–Schulz iteration count (clashes with our key vector). Stripped to our convention: parameters , momentum buffer , momentum decay , learning rate , retention , Newton–Schulz iteration count .
The full ATLAS recurrence:
where is the windowed loss over the past tokens.
Two-state recurrence with the same shape as Titans. is the EMA of recent (windowed) gradients with decay . is the memory parameters with scalar retention per step. The novel piece is the wrapper around the buffer.
Muon intuition. Momentum buffers for matrix-shaped parameters tend to be low-rank and ill-conditioned (a few dominant directions). Orthogonalization rescales every singular value to 1, producing a step that “moves equally in every direction” of the gradient subspace. As , converges to the nearest semi-orthogonal matrix to — a second-order approximation in the sense that it inverts the local curvature scale.
Summary. ATLAS introduces three changes: (1) Omega: per-token loss is essentially online SGD with batch size 1; window- gives the inner step access to recent context, like minibatch GD. (2) Kernels: Titans’s MLP memory still bottlenecks at the matrix-output dimension ; polynomial features lift the capacity ceiling to . (3) Muon: matrix-shaped momentum buffers concentrate energy in a few directions; Newton–Schulz redistributes step-size across the gradient subspace.
7. Overview
This gives us the following complete overview of all architectures including Part 1:
Figure 2
The MIRAS 4-axis lens, updated with KDA and ATLAS.
| model | memory architecture | objective | optimizer | retention |
|---|---|---|---|---|
| linear attention | linear | 1-step GD, scalar | identity | |
| Mamba2 | linear | 1-step GD, scalar | scalar | |
| DeltaNet | linear | 1-step GD, scalar | identity | |
| Gated DeltaNet | linear | 1-step GD, scalar | scalar | |
| TTT-MLP | 2-layer MLP | 1-step GD, scalar | identity | |
| Titans | 2-layer MLP | 1-step GD + momentum (buffer ) | scalar | |
| KDA | linear | 1-step GD, scalar | diagonal | |
| ATLAS | MLP (poly/exp kernel on ) | GD + momentum + Newton–Schulz (Muon) | scalar |
Each paper typically introduces a single change from the baseline of linear attention; some are combinations of earlier ones rather than genuinely new (Gated DeltaNet pairs DeltaNet’s L2 objective with Mamba2’s retention). ATLAS is the exception, moving three axes at once.
8. Outlook
What we’ve neglected so far was talking about the practicality of training these methods on modern accelerated hardware. Per-token updates are inherently sequential — nice for inference and what motivated our derivation from softmax attention — but it kills training throughput.
Part 3 covers chunkwise parallelization (DeltaNet) and further practical tricks that we will have to employ to train these recurrent algorithms fast.
We will implement all architectures mentioned so far (including Part 1) and benchmark them on MQAR to read off the capacity ceilings empirically.
Using the four axes from MIRAS lets us clearly see what changes across these architectures and which axis each paper moves. Part 3 will be more about what it actually takes to run them at modern-LLM scale and how they benchmark on MQAR.
References
- Yang, Wang, Zhang, Shen, Kim (2024). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. NeurIPS 2024. arXiv:2406.06484
- Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv:2407.04620
- Behrouz, Zhong, Mirrokni (2024). Titans: Learning to Memorize at Test Time. arXiv:2501.00663
- Behrouz, Razaviyayn, Zhong, Mirrokni (2025). It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization. arXiv:2504.13173
- Zhong, Xu, Ao, Shi (2025). Understanding Transformer from the Perspective of Associative Memory. arXiv:2505.19488
- Behrouz, Razaviyayn, Zhong, Mirrokni (2025). ATLAS: Learning to Optimally Memorize the Context at Test Time. arXiv:2505.23735
- Moonshot AI (2025). Kimi Linear: An Expressive, Efficient Attention Architecture. arXiv:2510.26692