torch_remat most directly competes with SAC, where you specify a policy function which specifies, on a per operation basis, if you want to save or recompute it. torch_remat flips this on its head: instead, you specify if you want to save/recompute on a per op call basis.
I've been experimenting with a new activation checkpointing API, which we're calling torch_remat. We're still putting it through the paces, but I think it's already interesting enough to get some initial public feedback: https://github.com/meta-pytorch/remat