/AI8h ago

MIT's Phillip Isola proposes Supervised Memory Training to train RNNs without backpropagation through time

The method decouples memory updates to enable time-parallel training.

234015928960.5K
Original postHan Guo#741
Akarsh Kumar@akarshkumar0101

We never really knew how to train nonlinear RNNs well… BPTT struggled with vanishing grads (no long-range memory) and sequential rollout (hard to parallelizable).

What if instead an oracle told us the optimal memory state m_t at each step? Then the RNN could do one-step supervised learning on (m_t, x_{t+1}) → m_{t+1} labels.

We call this Supervised Memory Training (SMT): a replacement for BPTT that trains RNNs without unrolling them. SMT is time-parallelizable and solves vanishing gradients.

Website: https://akarshkumar.com/smt/ arXiv: https://arxiv.org/abs/2606.06479

2:24 PM · Jun 7, 2026 · 40.9K Views
Sentiment

Many users call Supervised Memory Training a fundamental improvement over BPTT for RNNs because it enables parallel optimization through qualitatively different credit assignment.

Pos
100.0%
Neg
0.0%
5 comments with sentiment.
Cluster Engagement
Posts from X
Most Activity
Most Activity
VIEWS9.5KBOOKMARKS42LIKES62RETWEETS4REPLIES2
Vincent Sitzmann@vincesitzmann

A really cool idea! The question of how we can train sequence models such that they remember things that are T timesteps in the past without backpropping through T timesteps remains one of the core problems in ML, and this looks like an inspiring approach!

Akarsh Kumar@akarshkumar0101

We never really knew how to train nonlinear RNNs well… BPTT struggled with vanishing grads (no long-range memory) and sequential rollout (hard to parallelizable).

What if instead an oracle told us the optimal memory state m_t at each step? Then the RNN could do one-step supervised learning on (m_t, x_{t+1}) → m_{t+1} labels.

We call this Supervised Memory Training (SMT): a replacement for BPTT that trains RNNs without unrolling them. SMT is time-parallelizable and solves vanishing gradients.

Website: https://akarshkumar.com/smt/ arXiv: https://arxiv.org/abs/2606.06479

7hViews 9.5KLikes 62Bookmarks 42
Phillip Isola@phillip_isola

We introduce a method for training RNNs that is time-parallel and does not suffer from vanishing/exploding gradients.

Key idea is to decouple learning 1) what should be remembered (can be done without recurrence) and 2) how to update memory (can be one-step supervised by #1).

Akarsh Kumar@akarshkumar0101

We never really knew how to train nonlinear RNNs well… BPTT struggled with vanishing grads (no long-range memory) and sequential rollout (hard to parallelizable).

What if instead an oracle told us the optimal memory state m_t at each step? Then the RNN could do one-step supervised learning on (m_t, x_{t+1}) → m_{t+1} labels.

We call this Supervised Memory Training (SMT): a replacement for BPTT that trains RNNs without unrolling them. SMT is time-parallelizable and solves vanishing gradients.

Website: https://akarshkumar.com/smt/ arXiv: https://arxiv.org/abs/2606.06479

7hViews 2KLikes 31Bookmarks 17
Minqi Jiang@MinqiJiang

It's so over. The kids are now giving their neural networks DMT.

Jokes aside, this is a clever + promising way for time-parallel training of sequence models—one of the original motivations for the transformer. Despite being an alternative solution, it is also complementary.

Akarsh Kumar@akarshkumar0101

We never really knew how to train nonlinear RNNs well… BPTT struggled with vanishing grads (no long-range memory) and sequential rollout (hard to parallelizable).

What if instead an oracle told us the optimal memory state m_t at each step? Then the RNN could do one-step supervised learning on (m_t, x_{t+1}) → m_{t+1} labels.

We call this Supervised Memory Training (SMT): a replacement for BPTT that trains RNNs without unrolling them. SMT is time-parallelizable and solves vanishing gradients.

Website: https://akarshkumar.com/smt/ arXiv: https://arxiv.org/abs/2606.06479

3hViews 5.6KLikes 22Bookmarks 16
Danfei Xu@danfei_xu

Listened to @phillip_isola talking about this work. Really clever idea!

Akarsh Kumar@akarshkumar0101

We never really knew how to train nonlinear RNNs well… BPTT struggled with vanishing grads (no long-range memory) and sequential rollout (hard to parallelizable).

What if instead an oracle told us the optimal memory state m_t at each step? Then the RNN could do one-step supervised learning on (m_t, x_{t+1}) → m_{t+1} labels.

We call this Supervised Memory Training (SMT): a replacement for BPTT that trains RNNs without unrolling them. SMT is time-parallelizable and solves vanishing gradients.

Website: https://akarshkumar.com/smt/ arXiv: https://arxiv.org/abs/2606.06479

6hViews 4.5KLikes 22Bookmarks 11
Jayden Teoh@jayden_teoh_

@akarshkumar0101 Awesome stuff. We also showed that you can pre-train a RNN without recurrence by using the transformer backbone to forecast latent states and training the RNN on one-step latent predictions in Next-Latent Prediction Transformers (https://arxiv.org/abs/2511.05963)

6hViews 508Likes 11Bookmarks 8
Akarsh Kumar@akarshkumar0101

SMT is akin to off-policy behavior cloning, and is mainly for pretraining.

To stabilize RNN rollouts, we introduce an on-policy imitation algo: DAgger Memory Training (DMT), a relatively lightweight fine-tuning phase.

8hViews 662Likes 12Bookmarks 4
Akarsh Kumar@akarshkumar0101

Long Range Memory

Encoder+decoder are Transformers and can lookup any token in the past and future and associate them immediately via attention (O(1) gradient path).

This solves vanishing gradients (left).

With this, SMT can learn long-range memory and even train next-pixel prediction RNNs (right).

8hViews 1.2KLikes 12Bookmarks 3
Akarsh Kumar@akarshkumar0101

Time-parallelism

SMT is fully time-parallel, making it efficient on GPUs.

SMT outperforms BPTT in sequential computation required to achieve a certain loss.

8hViews 792Likes 10Bookmarks 1
Akarsh Kumar@akarshkumar0101

In scaling laws, the y-axis is often loss. But what if it was instead compression?

In SMT, increasing training compute allows you to get to the same loss, but with a smaller memory state size.

This is a new way to spend your compute.

8hViews 712Likes 7Bookmarks 1
Akarsh Kumar@akarshkumar0101

Thanks to @phillip_isola for inspiring me to pursue this direction in depth and providing invaluable guidance!

8hViews 596Likes 5Bookmarks 2
Akarsh Kumar@akarshkumar0101

SMT estimates the oracle via a time-parallel encoder trained to embed past context into a representation that a decoder can use to predict the future.

This creates memory states that remember important info and purposefully forget unimportant details, similar to biological memory.

8hViews 1.5KLikes 9
Akarsh Kumar@akarshkumar0101

SMT+DMT are a fundamental improvement over BPTT because they perform credit assignment across a sequence in a qualitatively different way (without recurrence).

Check out the paper for many more experiments and insights.

8hViews 632Likes 8
Akarsh Kumar@akarshkumar0101

@vincesitzmann Thanks Vincent!

6hViews 259Bookmarks 1
Phillip Isola@phillip_isola

What should be remembered: a compressed representation of the past that predicts the future (predictive state).

How to update memory: predict the next predictive state.

Phillip Isola@phillip_isola

We introduce a method for training RNNs that is time-parallel and does not suffer from vanishing/exploding gradients.

Key idea is to decouple learning 1) what should be remembered (can be done without recurrence) and 2) how to update memory (can be one-step supervised by #1).

7hViews 665Likes 7Bookmarks 0
Francois Chaubard@FrancoisChauba1

@akarshkumar0101 I like anything getting us off of BPTT.. but.. what if the oracle doesnt exist. what if we are trying to solve a class of problems humans dont know how to solve. then there is no trace to train on. thats what we have to solve.

2hViews 174
Akarsh Kumar@akarshkumar0101

@MinqiJiang Thanks Minqi!

3hViews 306Likes 2
secemp@secemp9

@akarshkumar0101 cc @neurallambda

59mViews 97
Jiaqi Feng@FengLeader

@vincesitzmann For AR we use embeddings; for diffusion we use encoders/decoders. Yet for hybrid AR-diffusion models like recent world models, we know too little about what makes a good encoder.

6hViews 82
Vincent@InsiderPresider

@danfei_xu @phillip_isola this work is actually valid but does smt really hold up against transformers in the long run anyway

6hViews 54
Leandro Morel@MorelLeand78015

@FrancoisChauba1 @akarshkumar0101 There is a mechanism for reconstruction although how to implement it that's a different matter. It is the experiments section.

https://github.com/Lexlangel/Interaction-dynamics-core/tree/main

2hViews 15
Load more posts