skip to content

A unifying view of linear attention (part 2)

Part 2 — Four Axes of Finite-Size Memory Models

1. Recap

Part 1 showed that linear attention and the other sub-quadratic models we discussed all fit

St=St1At+btktS_t = S_{t-1} A_t + b_t k_t^\top

with different choices in (At,bt)(A_t, b_t). Softmax could be seen as a special form with an infinite state, or rather practically our “memory” is the entire growing KV cache, which gives us perfect recall using quadratic compute. If we pick a finite-dim feature map ϕ\phi we get linear attention as a compressed form of softmax with fixed-size state St=jtvjϕ(kj)S_t = \sum_{j \le t} v_j \phi(k_j)^\top in matrix form. We accept lossy recall to escape quadratic compute.

We also showed that the state SS with size dv×dkd_v \times d_k holds at most ndkn \lesssim d_k near-orthogonal (k,v)(k, v) pairs cleanly and above that the cross-talk between the vectors drowns the signal.

The recurrence of DeltaNet is one gradient step on Lt(S)=12vtSkt2L_t(S) = \tfrac{1}{2}\|v_t - S k_t\|^2 giving us

St=St1(Iβtktkt)+βtvtkt.S_t = S_{t-1}(I - \beta_t k_t k_t^\top) + \beta_t v_t k_t^\top.

The main benefit is that seeing the same key twice doesn’t grow the associated value, and already stored orthogonal keys are kept. The capacity ceiling is still unchanged but the budget is used more cleanly.

Gating is introduced to exponentially decay the state SS via Mamba2’s At=αtIA_t = \alpha_t I and Gated DeltaNet’s At=αt(Iβtktkt)A_t = \alpha_t (I - \beta_t k_t k_t^\top). This doesn’t raise the capacity ceiling, but it lets us recycle the dkd_k slots by decaying old writes so new ones can take their place.

The state was always a matrix SS and the update rule for DeltaNet followed from doing one step of gradient descent on Lt(S)=12vtSkt2L_t(S) = \tfrac{1}{2}\|v_t - S k_t\|^2. What if SS is instead a neural network with multiple layers? That’s the step that TTT and Titans take and build upon.

2. Test-time training (TTT)

Reframe the state update — sometimes referenced as inner loop — as one GD step on a parameterized inner model:

Lt(θ)=12vtMθ(kt)2L_t(\theta) = \tfrac{1}{2}\|v_t - M_\theta(k_t)\|^2 θt=θt1βtθLt(θt1)\theta_t = \theta_{t-1} - \beta_t \nabla_\theta L_t(\theta_{t-1})

For Mθ(k)=SkM_\theta(k) = S k with θ=S\theta = S we get what we had before:

  • θLtθt1=(vtSt1kt)kt\nabla_\theta L_t \big|_{\theta_{t-1}} = -(v_t - S_{t-1} k_t)\, k_t^\top
  • St=St1+βt(vtSt1kt)ktS_t = S_{t-1} + \beta_t (v_t - S_{t-1} k_t)\, k_t^\top

This is DeltaNet’s (Yang et al., 2024) Parallelizing Linear Transformers with the Delta Rule over Sequence Length Yang, Wang, Zhang, Shen, Kim · NeurIPS 2024 arXiv:2406.06484 update rule from the last part and also called TTT-Linear (Sun et al., 2024) Learning to (Learn at Test Time): RNNs with Expressive Hidden States Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin · 2024 arXiv:2407.04620 and it shows DeltaNet wasn’t a choice of a specific recurrence, it rather follows if we do one GD step on 12vMθ(k)2\tfrac{1}{2}\|v - M_\theta(k)\|^2 with MθM_\theta being a single linear layer.

(Sun et al., 2024) Learning to (Learn at Test Time): RNNs with Expressive Hidden States Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin · 2024 arXiv:2407.04620 also introduces a second variant TTT-MLP with a richer MM: a two-layer MLP Mθ(k)=W2σ(W1k)M_\theta(k) = W_2 \sigma(W_1 k), θ=(W1,W2)\theta = (W_1, W_2), where we can also derive the gradients explicitly. Let u=W1ktu = W_1 k_t, h=σ(u)h = \sigma(u), f=W2hf = W_2 h, r=vtfr = v_t - f:

W2Lt=rh\nabla_{W_2} L_t = -r\, h^\top W1Lt=[W2r    σ(u)]kt\nabla_{W_1} L_t = -\big[W_2^\top r \;\odot\; \sigma'(u)\big]\, k_t^\top

An explicit update rule for each weight matrix.

A note on the name.

In my opinion “test-time training” is misleading. What we have is an inner-loop state-update rule applied during forward inference, and not training in the conventional sense. There’s no held-out set, and no “done” condition. The outer parameters (Wq,Wk,Wv,W1,W2)(W_q, W_k, W_v, W_1, W_2) are trained once on a training set, the normal way, and are not changed during inference. What runs at inference is a per-token state update of what we can casually refer to as “memory”.

TTT-MLP doesn’t fit our recurrence form.

So far our general form St=St1At+btktS_t = S_{t-1} A_t + b_t k_t^\top rested on two assumptions

  • the state is a single matrix
  • the readout is linear in the state (SkqS k_q — a single matrix-vector product).

TTT-MLP breaks both. By having Mθ(kq)=W2σ(W1kq)M_\theta(k_q) = W_2 \sigma(W_1 k_q) with a nonlinearity, there’s no AtA_t that captures that. Our general form from Part 1 stops being expressive enough, but we will define a new general form across four axes.

3. Titans

Titans (Behrouz et al., 2024) Titans: Learning to Memorize at Test Time Behrouz, Zhong, Mirrokni · 2024 arXiv:2501.00663 extends TTT-MLP with two textbook optimizer ingredients: momentum and weight decay. Because of momentum our state is now the pair (θt,mt)(\theta_t, m_t), and the update rule reads

mt=νtmt1+θLt(θt1)m_t = \nu_t\, m_{t-1} + \nabla_\theta L_t(\theta_{t-1}) θt=αtθt1βtmt.\theta_t = \alpha_t\, \theta_{t-1} - \beta_t\, m_t.

Note: the Titans paper writes θt\theta_t for the momentum decay scalar (not the parameters!) and St\mathcal{S}_t for the momentum buffer. To keep θ\theta as parameters across the post, we use νt\nu_t for momentum decay and mtm_t for the buffer (Adam-style first-moment convention). The Titans paper also writes the retention as (1αt)(1 - \alpha_t); we follow Mamba2 / Gated DeltaNet and use αt\alpha_t directly so αt=1\alpha_t = 1 means full retention across all gated models.

mtm_t is an EMA of recent gradient directions (first moment), which acts as a low-pass on write directions (not on what’s been written).

αt\alpha_t is scalar retention on θt1\theta_{t-1}, structurally identical to Gated DeltaNet’s retention. Same idea as in Part 1, applied to all MLP parameters instead of a single matrix.

Both moves are standard SGD ingredients (momentum, weight decay) ported into the memory state update step. As this is just SGD on the memory state, anything we know about training neural nets is reusable here. We’ll see further such techniques later in this post.

4. MIRAS

We moved from the linear recurrence St=St1At+btktS_t = S_{t-1} A_t + b_t k_t^\top of a matrix to the general case where we optimise the parameters θ\theta of some model MθM_\theta on an inner objective LtL_t that matches vMθ(k)v \approx M_\theta(k), with optional retention on θ\theta. This gives us four design choices: (1) memory architecture Mθ(k)M_\theta(k), (2) inner objective Lt(θ)L_t(\theta), (3) inner optimizer, (4) retention on θ\theta. We take them in turn.

Axis 1 — memory architecture Mθ(k)M_\theta(k).

Instead of having a single matrix SS like before we generalise our memory architecture to any model MM with parameters θ\theta. We have seen already these special cases:

  • Linear: Mθ(k)=θkM_\theta(k) = \theta k — linear attention, DeltaNet, Gated DeltaNet.
  • 2-layer MLP: Mθ(k)=W2σ(W1k)M_\theta(k) = W_2 \sigma(W_1 k) — TTT-MLP, Titans.

This also includes all variants that use kernel mappings Mθ(ϕ(k))M_\theta(\phi(k)) to increase the capacity of the state (Zhong et al., 2025) Understanding Transformer from the Perspective of Associative Memory Zhong, Xu, Ao, Shi · 2025 arXiv:2505.19488 .

Axis 2 — memory objective Lt(θ)L_t(\theta).

Next we define the loss which MθM_\theta is optimised on. In the general case we want Mθ(k)M_\theta(k) to reconstruct its associated value vv.

For the simple case Mθ(k)=θkM_\theta(k) = \theta k, two objectives (dot-product, L2) generate two well-known models from the same one-step-GD recipe:

Lt(θ)=vtMθ(kt)        θt=θt1+βtvtkt(linear attention)L_t(\theta) = -v_t^\top M_\theta(k_t) \;\;\Longrightarrow\;\; \theta_t = \theta_{t-1} + \beta_t\, v_t k_t^\top \quad\text{(linear attention)} Lt(θ)=12vtMθ(kt)2        θt=θt1+βt(vtθt1kt)kt(DeltaNet)L_t(\theta) = \tfrac{1}{2}\|v_t - M_\theta(k_t)\|^2 \;\;\Longrightarrow\;\; \theta_t = \theta_{t-1} + \beta_t (v_t - \theta_{t-1} k_t)\, k_t^\top \quad\text{(DeltaNet)}

Axis 3 — inner optimizer.

This is how the next θt\theta_t is computed from Lt\nabla L_t.

The default is one GD step. Linear attention, Mamba2, DeltaNet, Gated DeltaNet, and TTT-MLP all use this — what differs across them is the objective, not the optimizer. Titans is the first to iterate on the optimizer by introducing momentum. Later we will see Muon (ATLAS), the second optimizer move, also heavily used today when training LLMs.

Axis 4 — retention on θ\theta.

We motivated this in Part 1 as a necessity to erase stale writes (key-value pairs that are not important anymore but take up space). MIRAS (Behrouz et al., 2025) It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization Behrouz, Razaviyayn, Zhong, Mirrokni · 2025 arXiv:2504.13173 further shows a derivation that this can be seen as regularization on our state θ\theta in the form of adding θtαtθt12\|\theta_t - \alpha_t \theta_{t-1}\|^2 to LtL_t.

Whereas DeltaNet and TTT-MLP used none, Mamba2, Gated DeltaNet and Titans used a scalar αt\alpha_t. In §5 we will see KDA introducing an extension.

Where this leaves us.

The recurrence St=St1At+btktS_t = S_{t-1} A_t + b_t k_t^\top from Part 1 is the special case where the memory is linear (single matrix), the objective is dot-product or L2, the optimizer is one GD step, and retention is at most a scalar. The four axes generalize each of those choices. Every model in this post — including the two we still have to introduce — slots into a row of one shared table; we’ll fill it in at the end of §6 once KDA and ATLAS are placed.

5. Kimi Linear / KDA

So far retention was always a scalar αtR\alpha_t \in \mathbb{R} that acted on the whole state but what if we make it a vector αtRdk\alpha_t \in \mathbb{R}^{d_k} where each channel/dimension of kk gets its own retention scalar?

This is what Kimi Linear (Moonshot AI, 2025) Kimi Linear: An Expressive, Efficient Attention Architecture Moonshot AI · 2025 arXiv:2510.26692 introduces, giving us the following recurrence

St=St1Diag(αt)(Iβtktkt)+βtvtkt.S_t = S_{t-1}\, \mathrm{Diag}(\alpha_t)\, (I - \beta_t k_t k_t^\top) + \beta_t v_t k_t^\top.

Note that they write in their paper SRdk×dvS \in \mathbb{R}^{d_k \times d_v}; transposed to our convention SRdv×dkS \in \mathbb{R}^{d_v \times d_k}.

The transition matrix on the right of St1S_{t-1} is a DPLR matrix (diagonal-plus-low-rank, rank-1 here), supporting fast matrix-vector ops and a closed-form inverse via Sherman–Morrison.

Diag(αt)(Iβtktkt)=Diag(αt)βt(αtkt)kt.\mathrm{Diag}(\alpha_t)\,(I - \beta_t k_t k_t^\top) = \mathrm{Diag}(\alpha_t) - \beta_t (\alpha_t \odot k_t)\, k_t^\top.

With this per-channel gating, each αt,i\alpha_{t,i} controls the decay of the ii-th key-coordinate independently, giving the model fine-grained control over what to forget. Analogous to how the Adam optimizer has per-parameter learning rates.

Scalar gates αt\alpha_t forces the same forgetting rate across all key-coordinates, but different coordinates might encode information that ages at different rates (positional vs. semantic information). Diagonal αt\alpha_t lets the model trade off retention per coordinate.

MIRAS-lens. Same L2 reconstruction loss as DeltaNet, same one-step GD optimizer, only change is now the channel-wise retention as regularizer toward the channel-decayed prior:

L~t(S)=12vtSkt2+12ηRSSt1Diag(αt)F2.\tilde L_t(S) = \tfrac{1}{2}\|v_t - S k_t\|^2 + \tfrac{1}{2\eta_R}\big\|S - S_{t-1}\,\mathrm{Diag}(\alpha_t)\big\|_F^2.

The exact (closed-form) minimizer of L~t\tilde L_t is the KDA recurrence above, with reparameterization βt=ηR/(1+ηR)\beta_t = \eta_R / (1 + \eta_R). (Derivation: set L~t=0\nabla \tilde L_t = 0, invert via Sherman–Morrison.)

6. ATLAS

ATLAS (Behrouz et al., 2025) ATLAS: Learning to Optimally Memorize the Context at Test Time Behrouz, Razaviyayn, Zhong, Mirrokni · 2025 arXiv:2505.23735 introduces three independent changes vs. Titans that can each be placed on Axes 1, 2, and 3.

6.1 Omega rule — windowed inner objective (Axis 2).

Up to here every model’s memory loss has been a function of the current token only. ATLAS’s Omega rule sums the loss over a window of cc recent tokens, with per-token in-window gates γi(t)[0,1]\gamma_i^{(t)} \in [0,1]:

Lt(θ)=i=tc+1tγi(t)12viMθ(ki)2.L_t(\theta) = \sum_{i=t-c+1}^{t} \gamma_i^{(t)}\, \tfrac{1}{2}\|v_i - M_\theta(k_i)\|^2.

For linear memory (Mθ(k)=SkM_\theta(k) = S k, θ=S\theta = S), one GD step gives:

St=St1(Iβti=tc+1tγi(t)kiki)+βti=tc+1tγi(t)viki.S_t = S_{t-1}\Big(I - \beta_t \sum_{i=t-c+1}^{t} \gamma_i^{(t)}\, k_i k_i^\top\Big) + \beta_t \sum_{i=t-c+1}^{t} \gamma_i^{(t)}\, v_i k_i^\top.

The edge case c=1c=1 gives us DeltaNet/TTT-Linear, while cc \to \infty recovers Mesa-layer-style global least-squares (which we will not cover further). Each token’s gradient enters the update sum at cc different timesteps. This is different from momentum, which stores the gradient once and lets it decay through a buffer — reusing a potentially outdated gradient direction. The Omega rule is more like mini-batch gradient descent on a sliding window, where the gradients are recomputed each step. Footnote: not quite true in practice due to chunkwise computation, introduced in Part 3.

6.2 Kernel feature maps (Axis 1).

ATLAS adds (or rather reintroduces from our view) a feature map ϕ\phi on keys (and queries). The inner loss applies MθM_\theta to ϕ(ki)\phi(k_i) instead of kik_i:

(θ;ki,vi)=12viMθ(ϕ(ki))2.\ell(\theta; k_i, v_i) = \tfrac{1}{2}\|v_i - M_\theta(\phi(k_i))\|^2.

For polynomial ϕp\phi_p of degree p\le p, the effective key dimension grows from dkd_k to (dk+pp)\binom{d_k + p}{p}, and the capacity ceiling rises from O(dk)\mathcal{O}(d_k) to O(dkp)\mathcal{O}(d_k^p) (paper’s Proposition 2). The exponential kernel ϕ\phi^* — Taylor expansion of exp\exp — is the pp \to \infty limit and recovers softmax attention as the global-window special case (paper §4.2).

6.3 Muon (Axis 3).

Notation note: the ATLAS paper writes Mt\mathcal{M}_t for parameters, St\mathcal{S}_t for the momentum buffer, θt\theta_t for the momentum decay scalar, and kk for the Newton–Schulz iteration count (clashes with our key vector). Stripped to our convention: parameters θt\theta_t, momentum buffer mtm_t, momentum decay νt\nu_t, learning rate βt\beta_t, retention αt\alpha_t, Newton–Schulz iteration count JJ.

The full ATLAS recurrence:

mt=νtmt1+θLtOmega(θt1),m_t = \nu_t\, m_{t-1} + \nabla_\theta L_t^{\text{Omega}}(\theta_{t-1}), θt=αtθt1βtNewtonSchulzJ(mt),\theta_t = \alpha_t\, \theta_{t-1} - \beta_t \cdot \texttt{NewtonSchulz}_J(m_t),

where LtOmega(θ)=i=tc+1tγi(t)12viMθ(ϕ(ki))2L_t^{\text{Omega}}(\theta) = \sum_{i=t-c+1}^{t} \gamma_i^{(t)}\, \tfrac{1}{2}\|v_i - M_\theta(\phi(k_i))\|^2 is the windowed loss over the past cc tokens.

Two-state recurrence with the same shape as Titans. mtm_t is the EMA of recent (windowed) gradients with decay νt\nu_t. θt\theta_t is the memory parameters with scalar retention αt\alpha_t per step. The novel piece is the NewtonSchulzJ\texttt{NewtonSchulz}_J wrapper around the buffer.

Muon intuition. Momentum buffers for matrix-shaped parameters tend to be low-rank and ill-conditioned (a few dominant directions). Orthogonalization rescales every singular value to 1, producing a step that “moves equally in every direction” of the gradient subspace. As JJ \to \infty, NewtonSchulzJ(mt)\texttt{NewtonSchulz}_J(m_t) converges to the nearest semi-orthogonal matrix to mtm_t — a second-order approximation in the sense that it inverts the local curvature scale.

Summary. ATLAS introduces three changes: (1) Omega: per-token loss is essentially online SGD with batch size 1; window-cc gives the inner step access to recent context, like minibatch GD. (2) Kernels: Titans’s MLP memory still bottlenecks at the matrix-output dimension dkd_k; polynomial features lift the capacity ceiling to O(dkp)\mathcal{O}(d_k^p). (3) Muon: matrix-shaped momentum buffers concentrate energy in a few directions; Newton–Schulz redistributes step-size across the gradient subspace.

7. Overview

This gives us the following complete overview of all architectures including Part 1:

Figure 2

The MIRAS 4-axis lens, updated with KDA and ATLAS.

modelmemory architectureobjectiveoptimizerretention
linear attention

linear Mθ(k)=θkM_\theta(k) = \theta k

vMθ(k)-v^\top M_\theta(k)

1-step GD, scalar βt\beta_t

identity
Mamba2

linear Mθ(k)=θkM_\theta(k) = \theta k

vMθ(k)-v^\top M_\theta(k)

1-step GD, scalar βt\beta_t

scalar αt\alpha_t

DeltaNet

linear Mθ(k)=θkM_\theta(k) = \theta k

12vMθ(k)2\tfrac{1}{2}\|v - M_\theta(k)\|^2

1-step GD, scalar βt\beta_t

identity
Gated DeltaNet

linear Mθ(k)=θkM_\theta(k) = \theta k

12vMθ(k)2\tfrac{1}{2}\|v - M_\theta(k)\|^2

1-step GD, scalar βt\beta_t

scalar αt\alpha_t

TTT-MLP

2-layer MLP Mθ(k)=W2σ(W1k)M_\theta(k) = W_2\sigma(W_1 k)

12vMθ(k)2\tfrac{1}{2}\|v - M_\theta(k)\|^2

1-step GD, scalar βt\beta_t

identity
Titans

2-layer MLP Mθ(k)=W2σ(W1k)M_\theta(k) = W_2\sigma(W_1 k)

12vMθ(k)2\tfrac{1}{2}\|v - M_\theta(k)\|^2

1-step GD + momentum (buffer mtm_t)

scalar αt\alpha_t

KDA

linear Mθ(k)=θkM_\theta(k) = \theta k

12vMθ(k)2\tfrac{1}{2}\|v - M_\theta(k)\|^2

1-step GD, scalar βt\beta_t

diagonal αtRdk\alpha_t \in \mathbb{R}^{d_k}

ATLAS

MLP ϕ\circ\,\phi (poly/exp kernel on kk)

i=tc+1tγi(t)12viMθ(ϕ(ki))2\sum_{i=t-c+1}^{t} \gamma_i^{(t)} \tfrac{1}{2}\|v_i - M_\theta(\phi(k_i))\|^2

GD + momentum + Newton–Schulz (Muon)

scalar αt\alpha_t

Each paper typically introduces a single change from the baseline of linear attention; some are combinations of earlier ones rather than genuinely new (Gated DeltaNet pairs DeltaNet’s L2 objective with Mamba2’s retention). ATLAS is the exception, moving three axes at once.

8. Outlook

What we’ve neglected so far was talking about the practicality of training these methods on modern accelerated hardware. Per-token updates are inherently sequential — nice for inference and what motivated our derivation from softmax attention — but it kills training throughput.

Part 3 covers chunkwise parallelization (DeltaNet) and further practical tricks that we will have to employ to train these recurrent algorithms fast.

We will implement all architectures mentioned so far (including Part 1) and benchmark them on MQAR to read off the capacity ceilings empirically.

Using the four axes from MIRAS lets us clearly see what changes across these architectures and which axis each paper moves. Part 3 will be more about what it actually takes to run them at modern-LLM scale and how they benchmark on MQAR.

References

  1. Yang, Wang, Zhang, Shen, Kim (2024). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. NeurIPS 2024. arXiv:2406.06484
  2. Sun, Li, Geng, Hua, Wang, Zhao, Liu, Hardt, Chen, Pan, Lin, Wang, Han, Guestrin (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. arXiv:2407.04620
  3. Behrouz, Zhong, Mirrokni (2024). Titans: Learning to Memorize at Test Time. arXiv:2501.00663
  4. Behrouz, Razaviyayn, Zhong, Mirrokni (2025). It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization. arXiv:2504.13173
  5. Zhong, Xu, Ao, Shi (2025). Understanding Transformer from the Perspective of Associative Memory. arXiv:2505.19488
  6. Behrouz, Razaviyayn, Zhong, Mirrokni (2025). ATLAS: Learning to Optimally Memorize the Context at Test Time. arXiv:2505.23735
  7. Moonshot AI (2025). Kimi Linear: An Expressive, Efficient Attention Architecture. arXiv:2510.26692