From 9078bedee831d21ee3dd0414cb549a0c84e39dce Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Wed, 19 Feb 2025 10:11:51 +0000 Subject: [PATCH 1/7] AbstractReversibleSolver + ReversibleAdjoint --- diffrax/__init__.py | 1 + diffrax/_adjoint.py | 279 ++++++++++++++++++++++++- diffrax/_integrate.py | 33 ++- diffrax/_solver/base.py | 57 +++++ diffrax/_solver/leapfrog_midpoint.py | 48 ++++- diffrax/_solver/reversible_heun.py | 37 +++- diffrax/_solver/semi_implicit_euler.py | 34 ++- test/test_reversible.py | 162 ++++++++++++++ 8 files changed, 637 insertions(+), 14 deletions(-) create mode 100644 test/test_reversible.py diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d35a7fac..672c3cb9 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -7,6 +7,7 @@ ForwardMode as ForwardMode, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, + ReversibleAdjoint as ReversibleAdjoint, ) from ._autocitation import citation as citation, citation_rules as citation_rules from ._brownian import ( diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index db701bd2..0850d7b4 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -362,8 +362,7 @@ def loop( if is_unsafe_sde(terms): kind = "lax" msg = ( - "Cannot reverse-mode autodifferentiate when using " - "`UnsafeBrownianPath`." + "Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`." ) elif max_steps is None: kind = "lax" @@ -908,3 +907,279 @@ def loop( **kwargs, ) return final_state + + +# Reversible Adjoint custom vjp computes gradients w.r.t. +# - y, corresponding to the initial state; +# - args, corresponding to explicit parameters; +# - terms, corresponding to implicit parameters as part of the vector field. + + +@eqx.filter_custom_vjp +def _loop_reversible(y__args__terms, *, self, throw, max_steps, init_state, **kwargs): + del throw + y, args, terms = y__args__terms + init_state = eqx.tree_at(lambda s: s.y, init_state, y) + del y + return self._loop( + args=args, + terms=terms, + max_steps=max_steps, + init_state=init_state, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), + **kwargs, + ) + + +@_loop_reversible.def_fwd +def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): + del perturbed + final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) + ts = final_state.reversible_ts + ts_final_index = final_state.reversible_save_index + y1 = final_state.y + save_state = final_state.save_state + solver_state = final_state.solver_state + return (final_state, aux_stats), (ts, ts_final_index, y1, save_state, solver_state) + + +@_loop_reversible.def_bwd +def _loop_reversible_bwd( + residuals, + grad_final_state__aux_stats, + perturbed, + y__args__terms, + *, + self, + saveat, + init_state, + solver, + event, + **kwargs, +): + assert event is None + + del perturbed, self, init_state, kwargs + ts, ts_final_index, y1, save_state, solver_state = residuals + del residuals + + grad_final_state, _ = grad_final_state__aux_stats + saveat_ts = save_state.ts + ys = save_state.ys + saveat_ts_index = save_state.saveat_ts_index - 1 + grad_ys = grad_final_state.save_state.ys + grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys) + + if saveat.subs.t1: + grad_y1 = (ω(grad_ys)[-1]).ω + else: + grad_y1 = jtu.tree_map(jnp.zeros_like, y1) + + if saveat.subs.t0: + saveat_ts_index = saveat_ts_index + 1 + + del grad_final_state, grad_final_state__aux_stats + + y, args, terms = y__args__terms + del y__args__terms + + diff_state = eqx.filter(solver_state, eqx.is_inexact_array) + diff_args = eqx.filter(args, eqx.is_inexact_array) + diff_terms = eqx.filter(terms, eqx.is_inexact_array) + grad_state = jtu.tree_map(jnp.zeros_like, diff_state) + grad_args = jtu.tree_map(jnp.zeros_like, diff_args) + grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms) + del diff_args, diff_terms + + def grad_step(state): + def forward_step(y0, solver_state, args, terms): + y1, _, _, new_solver_state, _ = solver.step( + terms, t0, t1, y0, args, solver_state, False + ) + return y1, new_solver_state + + ( + saveat_ts_index, + ts_index, + y1, + solver_state, + grad_y1, + grad_state, + grad_args, + grad_terms, + ) = state + + t1 = ts[ts_index] + t0 = ts[ts_index - 1] + + y0, _, dense_info, solver_state, _ = solver.backward_step( + terms, t0, t1, y1, args, solver_state, False + ) + + # Pull gradients back through interpolation + + def interpolate(t, t0, t1, dense_info): + interpolator = solver.interpolation_cls(t0=t0, t1=t1, **dense_info) + return interpolator.evaluate(t) + + def _cond_fun(inner_state): + saveat_ts_index, _, _ = inner_state + return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0) + + def _body_fun(inner_state): + saveat_ts_index, grad_y0, grad_y1 = inner_state + t = saveat_ts[saveat_ts_index] + grad_y = (ω(grad_ys)[saveat_ts_index]).ω + _, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info) + interp_grads = interp_vjp(grad_y) + grad_y0 = eqx.apply_updates(grad_y0, interp_grads[3]["y0"]) + grad_y1 = eqx.apply_updates(grad_y1, interp_grads[3]["y1"]) + saveat_ts_index = saveat_ts_index - 1 + return saveat_ts_index, grad_y0, grad_y1 + + grad_y0 = jtu.tree_map(jnp.zeros_like, grad_y1) + inner_state = (saveat_ts_index, grad_y0, grad_y1) + inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax") + saveat_ts_index, grad_y0, grad_y1 = inner_state + + # Pull gradients back through forward step + + _, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms) + dgrad_y1 = vjp_fn((grad_y1, grad_state)) + + grad_y0 = eqx.apply_updates(grad_y0, dgrad_y1[0]) + grad_state = dgrad_y1[1] + grad_args = eqx.apply_updates(grad_args, dgrad_y1[2]) + grad_terms = eqx.apply_updates(grad_terms, dgrad_y1[3]) + + ts_index = ts_index - 1 + + return ( + saveat_ts_index, + ts_index, + y0, + solver_state, + grad_y0, + grad_state, + grad_args, + grad_terms, + ) + + def cond_fun(state): + ts_index = state[1] + return ts_index > 0 + + state = ( + saveat_ts_index, + ts_final_index, + y1, + solver_state, + grad_y1, + grad_state, + grad_args, + grad_terms, + ) + + state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") + _, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state + + # Pull solver_state gradients back onto y0, args, terms. + + _, init_vjp = eqx.filter_vjp(solver.init, terms, ts[0], ts[1], y0, args) + dgrad_terms, _, _, dgrad_y0, dgrad_args = init_vjp(grad_state) + grad_y0 = eqx.apply_updates(grad_y0, dgrad_y0) + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) + grad_args = eqx.apply_updates(grad_args, dgrad_args) + + return grad_y0, grad_args, grad_terms + + +class ReversibleAdjoint(AbstractAdjoint): + """Backpropagate through [`diffrax.diffeqsolve`][] when using a reversible solver + [`diffrax.AbstractReversibleSolver`][]. + + Gradient calculation is exact (up to floating point errors) and backpropagation + is linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps. + """ + + def loop( + self, + *, + args, + terms, + solver, + saveat, + max_steps, + init_state, + passed_solver_state, + passed_controller_state, + event, + **kwargs, + ): + if max_steps is None: + raise ValueError( + "`max_steps=None` is incompatible with `ReversibleAdjoint`." + ) + + if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure( + 0 + ): + raise NotImplementedError( + "Cannot use `adjoint=ReversibleAdjoint()` with `SaveAt(subs=...)`." + ) + + if saveat.dense or saveat.subs.steps: + raise NotImplementedError( + "Cannot use `adjoint=ReversibleAdjoint()` with " + "`saveat=SaveAt(steps=True)` or saveat=SaveAt(dense=True)`." + ) + + if saveat.subs.fn is not save_y: + raise NotImplementedError( + "Cannot use `adjoint=ReversibleAdjoint()` with `saveat=SaveAt(fn=...)`." + ) + + if event is not None: + raise NotImplementedError( + "`diffrax.ReversibleAdjoint` is not compatible with events." + ) + + if is_unsafe_sde(terms): + raise ValueError( + "`adjoint=ReversibleAdjoint()` does not support `UnsafeBrownianPath`. " + "Consider using `VirtualBrownianTree` instead." + ) + if is_sde(terms): + if isinstance(solver, AbstractItoSolver): + raise NotImplementedError( + f"`{solver.__class__.__name__}` converges to the Itô solution. " + "However `ReversibleAdjoint` currently only supports Stratonovich " + "SDEs." + ) + elif not isinstance(solver, AbstractStratonovichSolver): + warnings.warn( + f"{solver.__class__.__name__} is not marked as converging to " + "either the Itô or the Stratonovich solution. Note that " + "`ReversibleAdjoint` will only produce the correct solution for " + "Stratonovich SDEs." + ) + + y = init_state.y + init_state = eqx.tree_at(lambda s: s.y, init_state, object()) + init_state = _nondiff_solver_controller_state( + self, init_state, passed_solver_state, passed_controller_state + ) + + final_state, aux_stats = _loop_reversible( + (y, args, terms), + self=self, + saveat=saveat, + max_steps=max_steps, + init_state=init_state, + solver=solver, + event=event, + **kwargs, + ) + final_state = _only_transpose_ys(final_state) + return final_state, aux_stats diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 88c014aa..ab712dec 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -112,6 +112,11 @@ class State(eqx.Module): event_dense_info: Optional[DenseInfo] event_values: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]] event_mask: Optional[PyTree[BoolScalarLike]] + # + # Information for reversible adjoint (save ts) + # + reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]] + reversible_save_index: Optional[IntScalarLike] def _is_none(x: Any) -> bool: @@ -230,7 +235,7 @@ def _outer_buffers(state): return ( [s.ts for s in save_states] + [s.ys for s in save_states] - + [state.dense_ts, state.dense_infos] + + [state.dense_ts, state.dense_infos, state.reversible_ts] ) @@ -304,6 +309,11 @@ def loop( dense_ts = dense_ts.at[0].set(t0) init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) + if init_state.reversible_ts is not None: + reversible_ts = init_state.reversible_ts + reversible_ts = reversible_ts.at[0].set(t0) + init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts) + def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: if subsaveat.t0: save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state) @@ -585,6 +595,15 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): result, ) + reversible_ts = state.reversible_ts + reversible_save_index = state.reversible_save_index + + if state.reversible_ts is not None: + reversible_ts = maybe_inplace( + reversible_save_index + 1, tprev, reversible_ts + ) + reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0) + new_state = State( y=y, tprev=tprev, @@ -606,6 +625,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType] + reversible_save_index=reversible_save_index, ) return ( @@ -1385,6 +1406,14 @@ def _outer_cond_fn(cond_fn_i): ) del had_event, event_structure, event_mask_leaves, event_values__mask + # Reversible info + if max_steps is None: + reversible_ts = None + reversible_save_index = None + else: + reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) + reversible_save_index = 0 + # Initialise state init_state = State( y=y0, @@ -1407,6 +1436,8 @@ def _outer_cond_fn(cond_fn_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_ts=reversible_ts, + reversible_save_index=reversible_save_index, ) # diff --git a/diffrax/_solver/base.py b/diffrax/_solver/base.py index 42f19e4c..02a7c484 100644 --- a/diffrax/_solver/base.py +++ b/diffrax/_solver/base.py @@ -348,3 +348,60 @@ def func( - `solver`: The solver to wrap. """ + + +class AbstractReversibleSolver(AbstractSolver[_SolverState]): + """Indicates that this is a reversible differential equation solver. This means + that the state at `t0` can be reconstructed (in closed form) from the state at `t1`. + + The reconstruction must be implemented by + [`diffrax.AbstractReversibleSolver.backward_step`][]. + + This solver can be combined with `adjoint=diffrax.ReversibleAdjoint` for exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory, for $n$ time steps. + """ + + @abc.abstractmethod + def backward_step( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: + """ + Make a single backward step with the reversible solver. + + Each step is made over the specified interval $[t_1, t_0]$. + + **Arguments:** + + - `terms`: The PyTree of terms representing the vector fields and controls. + - `t0`: The end of the interval that the backward step is made over. + - `t1`: The start of the interval that the backward step is made over. + - `y1`: The current value of the solution at `t1`. + - `args`: Any extra arguments passed to the vector field. + - `solver_state`: Any evolving state for the solver itself, at `t1`. + - `made_jump`: Whether there was a discontinuity in the vector field at `t1`. + Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there + are no jumps and for efficiency re-use information between steps; this + indicates that a jump has just occurred and this assumption is not true. + + **Returns:** + + A tuple of several objects: + + - The value of the solution at `t0`. + - A local error estimate made during the step. (Used by adaptive step size + controllers to change the step size.) May be `None` if no estimate was + made. + - Some dictionary of information that is passed to the solver's interpolation + routine to calculate dense output. Note that this is assumed to be the same + information returned on the forward step. + - The value of the solver state at `t1`. + - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step + happened successfully, or if (unusually) it failed for some reason. + """ diff --git a/diffrax/_solver/leapfrog_midpoint.py b/diffrax/_solver/leapfrog_midpoint.py index 00ba11da..7645afe2 100644 --- a/diffrax/_solver/leapfrog_midpoint.py +++ b/diffrax/_solver/leapfrog_midpoint.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import jax from equinox.internal import ω from jaxtyping import PyTree @@ -9,15 +10,15 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractSolver +from .base import AbstractReversibleSolver _ErrorEstimate: TypeAlias = None -_SolverState: TypeAlias = tuple[RealScalarLike, PyTree] +_SolverState: TypeAlias = tuple[RealScalarLike, PyTree, RealScalarLike] # TODO: support arbitrary linear multistep methods -class LeapfrogMidpoint(AbstractSolver): +class LeapfrogMidpoint(AbstractReversibleSolver): r"""Leapfrog/midpoint method. 2nd order linear multistep method. Uses 1st order local linear interpolation for @@ -29,6 +30,12 @@ class LeapfrogMidpoint(AbstractSolver): (which is usually taken to refer to the explicit Runge--Kutta method [`diffrax.Midpoint`][]). + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. + ??? cite "Reference" ```bibtex @@ -60,9 +67,12 @@ def init( y0: Y, args: Args, ) -> _SolverState: - del terms, t1, args + # We pre-compute the step size to avoid numerical instability during the + # backward_step. This is okay (albeit slightly ugly) as `LeapfrogMidpoint` can't + # be used with adaptive step sizes. + dt = t1 - t0 # Corresponds to making an explicit Euler step on the first step. - return t0, y0 + return t0, y0, dt def step( self, @@ -75,12 +85,36 @@ def step( made_jump: BoolScalarLike, ) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del made_jump - tm1, ym1 = solver_state + tm1, ym1, dt = solver_state control = terms.contr(tm1, t1) y1 = (ym1**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω dense_info = dict(y0=y0, y1=y1) - solver_state = (t0, y0) + solver_state = (t0, y0, dt) return y1, None, dense_info, solver_state, RESULTS.successful + def backward_step( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del made_jump + t0, y0, dt = solver_state + tm1 = t0 - dt + control = terms.contr(tm1, t1) + ym1 = (y1**ω - terms.vf_prod(t0, y0, args, control) ** ω).ω + dense_info = dict(y0=y0, y1=y1) + # On the last step we need to make sure our solver state is correct + # (i.e. the state used on the forward). Otherwise, in `ReversibleAdjoint`, + # we would take a local forward step from an incorrect `solver_state`. + solver_state = jax.lax.cond( + tm1 > 0, lambda _: (tm1, ym1, dt), lambda _: (t0, y0, dt), None + ) + return y0, None, dense_info, solver_state, RESULTS.successful + def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/reversible_heun.py b/diffrax/_solver/reversible_heun.py index 0f0a9fe9..932fb4b3 100644 --- a/diffrax/_solver/reversible_heun.py +++ b/diffrax/_solver/reversible_heun.py @@ -10,13 +10,19 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver +from .base import ( + AbstractAdaptiveSolver, + AbstractReversibleSolver, + AbstractStratonovichSolver, +) _SolverState: TypeAlias = tuple[PyTree, PyTree] -class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): +class ReversibleHeun( + AbstractReversibleSolver, AbstractAdaptiveSolver, AbstractStratonovichSolver +): """Reversible Heun method. Algebraically reversible 2nd order method. Has an embedded 1st order method for @@ -24,6 +30,12 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): When used to solve SDEs, converges to the Stratonovich solution. + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. + ??? cite "Reference" ```bibtex @@ -83,5 +95,26 @@ def step( solver_state = (yhat1, vf1) return y1, y1_error, dense_info, solver_state, RESULTS.successful + def backward_step( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]: + yhat1, vf1 = solver_state + + control = terms.contr(t0, t1) + yhat0 = (2 * y1**ω - yhat1**ω - terms.prod(vf1, control) ** ω).ω + vf0 = terms.vf(t0, yhat0, args) + y0 = (y1**ω - 0.5 * terms.prod((vf0**ω + vf1**ω).ω, control) ** ω).ω + + dense_info = dict(y0=y0, y1=y1) + solver_state = (yhat0, vf0) + return y0, None, dense_info, solver_state, RESULTS.successful + def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/semi_implicit_euler.py b/diffrax/_solver/semi_implicit_euler.py index 00b9e1db..1084aad5 100644 --- a/diffrax/_solver/semi_implicit_euler.py +++ b/diffrax/_solver/semi_implicit_euler.py @@ -9,7 +9,7 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractSolver +from .base import AbstractReversibleSolver _ErrorEstimate: TypeAlias = None @@ -19,11 +19,17 @@ Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] -class SemiImplicitEuler(AbstractSolver): +class SemiImplicitEuler(AbstractReversibleSolver): """Semi-implicit Euler's method. Symplectic method. Does not support adaptive step sizing. Uses 1st order local linear interpolation for dense/ts output. + + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. """ term_structure: ClassVar = (AbstractTerm, AbstractTerm) @@ -68,6 +74,30 @@ def step( dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful + def backward_step( + self, + terms: tuple[AbstractTerm, AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y1: tuple[Ya, Yb], + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[tuple[Ya, Yb], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump + + term_1, term_2 = terms + y1_1, y1_2 = y1 + + control1 = term_1.contr(t0, t1) + control2 = term_2.contr(t0, t1) + y0_2 = (y1_2**ω - term_2.vf_prod(t0, y1_1, args, control2) ** ω).ω + y0_1 = (y1_1**ω - term_1.vf_prod(t0, y0_2, args, control1) ** ω).ω + + y0 = (y0_1, y0_2) + dense_info = dict(y0=y0, y1=y1) + return y0, None, dense_info, None, RESULTS.successful + def func( self, terms: tuple[AbstractTerm, AbstractTerm], diff --git a/test/test_reversible.py b/test/test_reversible.py new file mode 100644 index 00000000..5dad9b2b --- /dev/null +++ b/test/test_reversible.py @@ -0,0 +1,162 @@ +from typing import cast + +import diffrax +import equinox as eqx +import jax.numpy as jnp +import jax.random as jr +import pytest +from jaxtyping import Array + +from .helpers import tree_allclose + + +class VectorField(eqx.Module): + mlp: eqx.nn.MLP + + def __init__(self, in_size, out_size, width_size, depth, key): + self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key) + + def __call__(self, t, y, args): + return args * self.mlp(y) + + +@eqx.filter_value_and_grad +def _loss(y0__args__term, solver, saveat, adjoint, stepsize_controller, dual_y0): + y0, args, term = y0__args__term + + sol = diffrax.diffeqsolve( + term, + solver, + t0=0, + t1=5, + dt0=0.01, + y0=y0, + args=args, + saveat=saveat, + max_steps=4096, + adjoint=adjoint, + stepsize_controller=stepsize_controller, + ) + if dual_y0: + y1 = sol.ys[0] # pyright: ignore + else: + y1 = sol.ys + return jnp.sum(cast(Array, y1)) + + +def _compare_grads(y0__args__term, solver, saveat, stepsize_controller, dual_y0): + loss, grads_base = _loss( + y0__args__term, + solver, + saveat, + adjoint=diffrax.RecursiveCheckpointAdjoint(), + stepsize_controller=stepsize_controller, + dual_y0=dual_y0, + ) + loss, grads_reversible = _loss( + y0__args__term, + solver, + saveat, + adjoint=diffrax.ReversibleAdjoint(), + stepsize_controller=stepsize_controller, + dual_y0=dual_y0, + ) + assert tree_allclose(grads_base, grads_reversible, atol=1e-5) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_semi_implicit_euler(saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + fkey, gkey = jr.split(key, 2) + f = VectorField(n, n, n, depth=4, key=fkey) + g = VectorField(n, n, n, depth=4, key=gkey) + terms = (diffrax.ODETerm(f), diffrax.ODETerm(g)) + y0 = (y0, y0) + args = jnp.linspace(0, 1, n) + solver = diffrax.SemiImplicitEuler() + stepsize_controller = diffrax.ConstantStepSize() + + _compare_grads((y0, args, terms), solver, saveat, stepsize_controller, dual_y0=True) + + +@pytest.mark.parametrize( + "stepsize_controller", + [diffrax.ConstantStepSize(), diffrax.PIDController(rtol=1e-8, atol=1e-8)], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_heun_ode(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + y0 = y0 + args = jnp.linspace(0, 1, n) + solver = diffrax.ReversibleHeun() + + _compare_grads( + (y0, args, terms), solver, saveat, stepsize_controller, dual_y0=False + ) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_heun_sde(saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + fkey, Wkey = jr.split(key, 2) + f = VectorField(n, n, n, depth=4, key=fkey) + g = lambda t, y, args: jnp.ones((n,)) + W = diffrax.VirtualBrownianTree(t0=0, t1=5, tol=1e-3, shape=(n,), key=Wkey) + terms = diffrax.MultiTerm(diffrax.ODETerm(f), diffrax.ControlTerm(g, W)) + y0 = y0 + args = jnp.linspace(0, 1, n) + solver = diffrax.ReversibleHeun() + stepsize_controller = diffrax.ConstantStepSize() + + _compare_grads( + (y0, args, terms), solver, saveat, stepsize_controller, dual_y0=False + ) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_leapfrog_midpoint(saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + y0 = y0 + args = jnp.linspace(0, 1, n) + solver = diffrax.LeapfrogMidpoint() + stepsize_controller = diffrax.ConstantStepSize() + + _compare_grads( + (y0, args, terms), solver, saveat, stepsize_controller, dual_y0=False + ) From 78c9858e3018aa001d7576e91ae564dc1743d47a Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Wed, 19 Feb 2025 12:38:42 +0000 Subject: [PATCH 2/7] allow arbitrary interpolation --- diffrax/_adjoint.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 0850d7b4..cda06b48 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -994,10 +994,10 @@ def _loop_reversible_bwd( def grad_step(state): def forward_step(y0, solver_state, args, terms): - y1, _, _, new_solver_state, _ = solver.step( + y1, _, dense_info, new_solver_state, _ = solver.step( terms, t0, t1, y0, args, solver_state, False ) - return y1, new_solver_state + return y1, dense_info, new_solver_state ( saveat_ts_index, @@ -1024,31 +1024,30 @@ def interpolate(t, t0, t1, dense_info): return interpolator.evaluate(t) def _cond_fun(inner_state): - saveat_ts_index, _, _ = inner_state + saveat_ts_index, _ = inner_state return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0) def _body_fun(inner_state): - saveat_ts_index, grad_y0, grad_y1 = inner_state + saveat_ts_index, grad_dense_info = inner_state t = saveat_ts[saveat_ts_index] grad_y = (ω(grad_ys)[saveat_ts_index]).ω _, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info) interp_grads = interp_vjp(grad_y) - grad_y0 = eqx.apply_updates(grad_y0, interp_grads[3]["y0"]) - grad_y1 = eqx.apply_updates(grad_y1, interp_grads[3]["y1"]) + grad_dense_info = eqx.apply_updates(grad_dense_info, interp_grads[3]) saveat_ts_index = saveat_ts_index - 1 - return saveat_ts_index, grad_y0, grad_y1 + return saveat_ts_index, grad_dense_info - grad_y0 = jtu.tree_map(jnp.zeros_like, grad_y1) - inner_state = (saveat_ts_index, grad_y0, grad_y1) + grad_dense_info = jtu.tree_map(jnp.zeros_like, dense_info) + inner_state = (saveat_ts_index, grad_dense_info) inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax") - saveat_ts_index, grad_y0, grad_y1 = inner_state + saveat_ts_index, grad_dense_info = inner_state # Pull gradients back through forward step _, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms) - dgrad_y1 = vjp_fn((grad_y1, grad_state)) + dgrad_y1 = vjp_fn((grad_y1, grad_dense_info, grad_state)) - grad_y0 = eqx.apply_updates(grad_y0, dgrad_y1[0]) + grad_y0 = dgrad_y1[0] grad_state = dgrad_y1[1] grad_args = eqx.apply_updates(grad_args, dgrad_y1[2]) grad_terms = eqx.apply_updates(grad_terms, dgrad_y1[3]) From efa3765ab9d20ebe00e3e3cd82a5e0f48a704951 Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Tue, 4 Mar 2025 10:44:30 +0000 Subject: [PATCH 3/7] unpacking over indexing --- diffrax/_adjoint.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index cda06b48..92618b1a 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -1032,8 +1032,8 @@ def _body_fun(inner_state): t = saveat_ts[saveat_ts_index] grad_y = (ω(grad_ys)[saveat_ts_index]).ω _, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info) - interp_grads = interp_vjp(grad_y) - grad_dense_info = eqx.apply_updates(grad_dense_info, interp_grads[3]) + _, _, _, dgrad_dense_info = interp_vjp(grad_y) + grad_dense_info = eqx.apply_updates(grad_dense_info, dgrad_dense_info) saveat_ts_index = saveat_ts_index - 1 return saveat_ts_index, grad_dense_info @@ -1045,12 +1045,12 @@ def _body_fun(inner_state): # Pull gradients back through forward step _, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms) - dgrad_y1 = vjp_fn((grad_y1, grad_dense_info, grad_state)) + grad_y0, grad_state, dgrad_args, dgrad_terms = vjp_fn( + (grad_y1, grad_dense_info, grad_state) + ) - grad_y0 = dgrad_y1[0] - grad_state = dgrad_y1[1] - grad_args = eqx.apply_updates(grad_args, dgrad_y1[2]) - grad_terms = eqx.apply_updates(grad_terms, dgrad_y1[3]) + grad_args = eqx.apply_updates(grad_args, dgrad_args) + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) ts_index = ts_index - 1 From 67baf5d512311583b7e481c0fc6b9c384404846e Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Tue, 4 Mar 2025 10:45:40 +0000 Subject: [PATCH 4/7] jax while loop --- diffrax/_adjoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 92618b1a..6b9e8c85 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -1080,7 +1080,7 @@ def cond_fun(state): grad_terms, ) - state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") + state = jax.lax.while_loop(cond_fun, grad_step, state) _, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state # Pull solver_state gradients back onto y0, args, terms. From 01a7cc3efcc997707183d69fed03776bfb95f4b4 Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Tue, 4 Mar 2025 11:06:24 +0000 Subject: [PATCH 5/7] collapse saveat ValueErrors --- diffrax/_adjoint.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 6b9e8c85..8b59cab3 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -1121,22 +1121,17 @@ def loop( "`max_steps=None` is incompatible with `ReversibleAdjoint`." ) - if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure( - 0 + if ( + jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) + != jtu.tree_structure(0) + or saveat.dense + or saveat.subs.steps + or (saveat.subs.fn is not save_y) ): - raise NotImplementedError( - "Cannot use `adjoint=ReversibleAdjoint()` with `SaveAt(subs=...)`." - ) - - if saveat.dense or saveat.subs.steps: - raise NotImplementedError( - "Cannot use `adjoint=ReversibleAdjoint()` with " - "`saveat=SaveAt(steps=True)` or saveat=SaveAt(dense=True)`." - ) - - if saveat.subs.fn is not save_y: - raise NotImplementedError( - "Cannot use `adjoint=ReversibleAdjoint()` with `saveat=SaveAt(fn=...)`." + raise ValueError( + """`ReversibleAdjoint` is only compatible with the following `SaveAt` + properties: `t0`, `t1`, `ts`, `fn=save_y` (default). + """ ) if event is not None: From 61bbe3c6288c26e456402117dd37af83d1df9f8e Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Tue, 4 Mar 2025 11:16:59 +0000 Subject: [PATCH 6/7] remove statonovich solver condition --- diffrax/_adjoint.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 8b59cab3..bd7d6123 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -1144,20 +1144,6 @@ def loop( "`adjoint=ReversibleAdjoint()` does not support `UnsafeBrownianPath`. " "Consider using `VirtualBrownianTree` instead." ) - if is_sde(terms): - if isinstance(solver, AbstractItoSolver): - raise NotImplementedError( - f"`{solver.__class__.__name__}` converges to the Itô solution. " - "However `ReversibleAdjoint` currently only supports Stratonovich " - "SDEs." - ) - elif not isinstance(solver, AbstractStratonovichSolver): - warnings.warn( - f"{solver.__class__.__name__} is not marked as converging to " - "either the Itô or the Stratonovich solution. Note that " - "`ReversibleAdjoint` will only produce the correct solution for " - "Stratonovich SDEs." - ) y = init_state.y init_state = eqx.tree_at(lambda s: s.y, init_state, object()) From 529910e6bd8a469fe5d10773dda8d23feb9cd7d6 Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Tue, 4 Mar 2025 11:34:26 +0000 Subject: [PATCH 7/7] remove unused returns from AbstractReversibleSolver backward_step --- diffrax/_adjoint.py | 2 +- diffrax/_solver/base.py | 9 ++------- diffrax/_solver/leapfrog_midpoint.py | 4 ++-- diffrax/_solver/reversible_heun.py | 4 ++-- diffrax/_solver/semi_implicit_euler.py | 4 ++-- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index bd7d6123..a46b0d59 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -1013,7 +1013,7 @@ def forward_step(y0, solver_state, args, terms): t1 = ts[ts_index] t0 = ts[ts_index - 1] - y0, _, dense_info, solver_state, _ = solver.backward_step( + y0, dense_info, solver_state = solver.backward_step( terms, t0, t1, y1, args, solver_state, False ) diff --git a/diffrax/_solver/base.py b/diffrax/_solver/base.py index 02a7c484..93fc5658 100644 --- a/diffrax/_solver/base.py +++ b/diffrax/_solver/base.py @@ -371,7 +371,7 @@ def backward_step( args: Args, solver_state: _SolverState, made_jump: BoolScalarLike, - ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: + ) -> tuple[Y, DenseInfo, _SolverState]: """ Make a single backward step with the reversible solver. @@ -392,16 +392,11 @@ def backward_step( **Returns:** - A tuple of several objects: + A tuple of three objects: - The value of the solution at `t0`. - - A local error estimate made during the step. (Used by adaptive step size - controllers to change the step size.) May be `None` if no estimate was - made. - Some dictionary of information that is passed to the solver's interpolation routine to calculate dense output. Note that this is assumed to be the same information returned on the forward step. - The value of the solver state at `t1`. - - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step - happened successfully, or if (unusually) it failed for some reason. """ diff --git a/diffrax/_solver/leapfrog_midpoint.py b/diffrax/_solver/leapfrog_midpoint.py index 7645afe2..d6e9ce75 100644 --- a/diffrax/_solver/leapfrog_midpoint.py +++ b/diffrax/_solver/leapfrog_midpoint.py @@ -101,7 +101,7 @@ def backward_step( args: Args, solver_state: _SolverState, made_jump: BoolScalarLike, - ) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + ) -> tuple[Y, DenseInfo, _SolverState]: del made_jump t0, y0, dt = solver_state tm1 = t0 - dt @@ -114,7 +114,7 @@ def backward_step( solver_state = jax.lax.cond( tm1 > 0, lambda _: (tm1, ym1, dt), lambda _: (t0, y0, dt), None ) - return y0, None, dense_info, solver_state, RESULTS.successful + return y0, dense_info, solver_state def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/reversible_heun.py b/diffrax/_solver/reversible_heun.py index 932fb4b3..2779320f 100644 --- a/diffrax/_solver/reversible_heun.py +++ b/diffrax/_solver/reversible_heun.py @@ -104,7 +104,7 @@ def backward_step( args: Args, solver_state: _SolverState, made_jump: BoolScalarLike, - ) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]: + ) -> tuple[Y, DenseInfo, _SolverState]: yhat1, vf1 = solver_state control = terms.contr(t0, t1) @@ -114,7 +114,7 @@ def backward_step( dense_info = dict(y0=y0, y1=y1) solver_state = (yhat0, vf0) - return y0, None, dense_info, solver_state, RESULTS.successful + return y0, dense_info, solver_state def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/semi_implicit_euler.py b/diffrax/_solver/semi_implicit_euler.py index 1084aad5..c3eaacda 100644 --- a/diffrax/_solver/semi_implicit_euler.py +++ b/diffrax/_solver/semi_implicit_euler.py @@ -83,7 +83,7 @@ def backward_step( args: Args, solver_state: _SolverState, made_jump: BoolScalarLike, - ) -> tuple[tuple[Ya, Yb], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + ) -> tuple[tuple[Ya, Yb], DenseInfo, _SolverState]: del solver_state, made_jump term_1, term_2 = terms @@ -96,7 +96,7 @@ def backward_step( y0 = (y0_1, y0_2) dense_info = dict(y0=y0, y1=y1) - return y0, None, dense_info, None, RESULTS.successful + return y0, dense_info, None def func( self,