Skip to content

AbstractReversibleSolver + ReversibleAdjoint#603

Closed
sammccallum wants to merge 1 commit into
patrick-kidger:devfrom
sammccallum:AbstractReversibleSolver
Closed

AbstractReversibleSolver + ReversibleAdjoint#603
sammccallum wants to merge 1 commit into
patrick-kidger:devfrom
sammccallum:AbstractReversibleSolver

Conversation

@sammccallum

Copy link
Copy Markdown
Contributor

Re-opening #593.

Implements AbstractReversibleSolver base class and ReversibleAdjoint for reversible back propagation.

This updates SemiImplicitEuler, LeapfrogMidpoint and ReversibleHeun to subclass AbstractReversibleSolver.

Implementation

AbstractReversibleSolver subclasses AbstractSolver and adds a backward_step method:

@abc.abstractmethod
def backward_step(
    self,
    terms: PyTree[AbstractTerm],
    t0: RealScalarLike,
    t1: RealScalarLike,
    y1: Y,
    args: Args,
    solver_state: _SolverState,
    made_jump: BoolScalarLike,
) -> tuple[Y, DenseInfo, _SolverState]:

This method should reconstruct y0, solver_state at t0 from y1, solver_state at t1. See the aforementioned solvers for examples.

When backpropagating, ReversibleAdjoint uses this backward_step to reconstruct state. We then take a vjp through a local forward step and accumulate gradients.

ReversibleAdjoint now also pulls back gradients from any interpolated values, so we can use SaveAt(ts=...)!

We allow arbitrary solver_state (provided it can be reconstructed reversibly) and calculate gradients w.r.t. solver_state. Finally, we pull back these gradients onto y0, args, terms using the solver.init method.

@sammccallum

Copy link
Copy Markdown
Contributor Author

I've also added the Reversible RK solvers here which just subclasses AbstractReversibleAdjoint. Let me know what you think of this and I can add some documentation when it's good to go!

@patrick-kidger patrick-kidger left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, gosh, this one took far too long for me to get around. Thank you for your patience! If I can I'd like this to be the next big thing I focus on getting in to Diffrax.

Comment thread diffrax/__init__.py Outdated
Comment thread diffrax/_solver/base.py
Comment thread diffrax/_solver/reversible.py
Comment thread diffrax/_solver/reversible.py
Comment thread diffrax/_solver/reversible.py Outdated
Comment thread diffrax/_solver/base.py Outdated
Comment thread diffrax/_adjoint.py
Comment thread diffrax/_integrate.py
reversible_save_index + 1, tprev, reversible_ts
)
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor bug here: if it just so happens that we run with t0 == t1 then we'll end up with reversible_ts = [t0 inf inf inf ...], which will not produce desired results in the backward solve.

We have a special branch to handle the saving in the t0 == t1 case, we should add a line handling the state.reversible_ts is not None case there.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread diffrax/_adjoint.py Outdated
Comment thread diffrax/_integrate.py
Comment on lines +1409 to +1415
# Reversible info
if max_steps is None:
reversible_ts = None
reversible_save_index = None
else:
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thought of an alternative for this extra buffer, btw: ReversibleAdjoint.loop could intercept saveat and add an SubSaveAt(steps=True, save_fn=lambda t, y, args: None) to record the extra times. Then peel it off again when returning the final state.

I think that (a) might be doable without making any changes to _integrate.py and (b) would allow for also supporting SaveAt(steps=True). (As in that case we can just skip adding the extra SubSaveAt.) And (c) would avoid a few of the subtle issues I've commented on above about exactly which tprev/tnext-like value is actually being saved, because you can trust in the rest of the existing diffeqsolve to do that for you.

It's not a strong suggestion though.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was the original idea I tried but I couldn't get around a leaked tracer error! I'm willing to give it another go if you start feeling strongly about it though ;)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's nail everything else down and then consider this. Reflecting on this, I do suspect it will make the code much easier to maintain in the long run.

Comment thread diffrax/_solver/reversible.py Outdated
@sammccallum sammccallum force-pushed the AbstractReversibleSolver branch from 0cfd4ec to 3a26ac3 Compare May 14, 2025 10:14

@patrick-kidger patrick-kidger left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you were ready for a review on this yet, but I took a look over anyway 😁 We're making really good progress! In particular now that we're settled on just AbstractERK then I think all our complicated state-reconstruction concerns go away, so the chance of footgunning ourselves has gone way down 😁

Comment thread diffrax/_solver/base.py Outdated
Comment thread diffrax/_solver/base.py Outdated
Comment thread diffrax/_solver/leapfrog_midpoint.py
Comment thread diffrax/_solver/leapfrog_midpoint.py Outdated
Comment thread diffrax/_solver/leapfrog_midpoint.py Outdated
Comment thread diffrax/_solver/reversible.py
Comment thread diffrax/_solver/reversible.py Outdated
Comment thread diffrax/_adjoint.py
@patrick-kidger patrick-kidger changed the base branch from main to dev June 16, 2025 22:39
@patrick-kidger

patrick-kidger commented Jun 16, 2025

Copy link
Copy Markdown
Owner

Heads-up that I've just updated the base branch to dev. It looks like there are a number of old commits sitting around on this PR, likely from where this branch forked off of main. You should be able to resolve these by first (a) squashing all the commits that actually belong on this branch together, and then (b) rebasing that new single commit on top of dev.

(Unrelatedly, lmk when this branch is ready for review.)

add reversible

testing

testing

AbstractReversibleSolver + ReversibleAdjoint

allow arbitrary interpolation

unpacking over indexing

jax while loop

collapse saveat ValueErrors

remove statonovich solver condition

remove unused returns from AbstractReversibleSolver backward_step

add test and remove messy benchmark

add wrapped solver + tests

made_jump=True for both solver steps

improve docstrings

AbstractSolver and docstring note about SDEs

add AbstractReversibleSolver to public API

newline in docstrings

return RESULTS from reversible backward_step

restrict Reversible to AbstractERK and check result in adjoint

correct tprev and tnext of solver init

switch to linear interpolation and y0,y1 dense_info

name UReversible

various doc formatting changes

AbstractReversibleSolver check

add disable_fsal property to AbstractRungeKutta and use in UReversible

allow t0 != 0

Handle StepTo controller

t0==t1 branch
@sammccallum sammccallum force-pushed the AbstractReversibleSolver branch from 8d59058 to ae01942 Compare October 12, 2025 15:35
@sammccallum

Copy link
Copy Markdown
Contributor Author

I think I've now addressed all of your comments, so it should be ready for review 👍

Understanding how to rebase through multiple merges was an experience but I believe that is correct now...

Comment thread diffrax/_integrate.py
final_state,
(reversible_ts, reversible_save_index),
is_leaf=_is_none,
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left this t0==t1 branch separate from the above jax.lax.cond for readability but we can obviously combine if this is okay.

@sammccallum

Copy link
Copy Markdown
Contributor Author

(No pressure to review anytime soon Patrick, I just marked this as "review requested" so it's easy for you to see across your sprawling jax empire)

@luke-a-thompson

Copy link
Copy Markdown

Would love to see this implemented, especially in light of https://arxiv.org/abs/2507.21006

@sammccallum

Copy link
Copy Markdown
Contributor Author

@luke-a-thompson in the meantime if you're interested, Daniil and I implemented the EES methods here: https://github.com/sammccallum/diffrax/tree/EES

@patrick-kidger patrick-kidger deleted the branch patrick-kidger:dev February 1, 2026 12:53
@patrick-kidger

Copy link
Copy Markdown
Owner

This one hasn't been intentionally closed btw! Just autoclosed with the new release. I think this is the second time this has happened to you (#593!) but I still intend to get this in.

I've recently found more time to review some of these larger feature PRs. So Optimistix is next on my list and after that I'm aiming for a new Diffrax release, including this PR :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants