diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d35a7fac..672c3cb9 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -7,6 +7,7 @@ ForwardMode as ForwardMode, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, + ReversibleAdjoint as ReversibleAdjoint, ) from ._autocitation import citation as citation, citation_rules as citation_rules from ._brownian import ( diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index db701bd2..a46b0d59 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -362,8 +362,7 @@ def loop( if is_unsafe_sde(terms): kind = "lax" msg = ( - "Cannot reverse-mode autodifferentiate when using " - "`UnsafeBrownianPath`." + "Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`." ) elif max_steps is None: kind = "lax" @@ -908,3 +907,259 @@ def loop( **kwargs, ) return final_state + + +# Reversible Adjoint custom vjp computes gradients w.r.t. +# - y, corresponding to the initial state; +# - args, corresponding to explicit parameters; +# - terms, corresponding to implicit parameters as part of the vector field. + + +@eqx.filter_custom_vjp +def _loop_reversible(y__args__terms, *, self, throw, max_steps, init_state, **kwargs): + del throw + y, args, terms = y__args__terms + init_state = eqx.tree_at(lambda s: s.y, init_state, y) + del y + return self._loop( + args=args, + terms=terms, + max_steps=max_steps, + init_state=init_state, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), + **kwargs, + ) + + +@_loop_reversible.def_fwd +def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): + del perturbed + final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) + ts = final_state.reversible_ts + ts_final_index = final_state.reversible_save_index + y1 = final_state.y + save_state = final_state.save_state + solver_state = final_state.solver_state + return (final_state, aux_stats), (ts, ts_final_index, y1, save_state, solver_state) + + +@_loop_reversible.def_bwd +def _loop_reversible_bwd( + residuals, + grad_final_state__aux_stats, + perturbed, + y__args__terms, + *, + self, + saveat, + init_state, + solver, + event, + **kwargs, +): + assert event is None + + del perturbed, self, init_state, kwargs + ts, ts_final_index, y1, save_state, solver_state = residuals + del residuals + + grad_final_state, _ = grad_final_state__aux_stats + saveat_ts = save_state.ts + ys = save_state.ys + saveat_ts_index = save_state.saveat_ts_index - 1 + grad_ys = grad_final_state.save_state.ys + grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys) + + if saveat.subs.t1: + grad_y1 = (ω(grad_ys)[-1]).ω + else: + grad_y1 = jtu.tree_map(jnp.zeros_like, y1) + + if saveat.subs.t0: + saveat_ts_index = saveat_ts_index + 1 + + del grad_final_state, grad_final_state__aux_stats + + y, args, terms = y__args__terms + del y__args__terms + + diff_state = eqx.filter(solver_state, eqx.is_inexact_array) + diff_args = eqx.filter(args, eqx.is_inexact_array) + diff_terms = eqx.filter(terms, eqx.is_inexact_array) + grad_state = jtu.tree_map(jnp.zeros_like, diff_state) + grad_args = jtu.tree_map(jnp.zeros_like, diff_args) + grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms) + del diff_args, diff_terms + + def grad_step(state): + def forward_step(y0, solver_state, args, terms): + y1, _, dense_info, new_solver_state, _ = solver.step( + terms, t0, t1, y0, args, solver_state, False + ) + return y1, dense_info, new_solver_state + + ( + saveat_ts_index, + ts_index, + y1, + solver_state, + grad_y1, + grad_state, + grad_args, + grad_terms, + ) = state + + t1 = ts[ts_index] + t0 = ts[ts_index - 1] + + y0, dense_info, solver_state = solver.backward_step( + terms, t0, t1, y1, args, solver_state, False + ) + + # Pull gradients back through interpolation + + def interpolate(t, t0, t1, dense_info): + interpolator = solver.interpolation_cls(t0=t0, t1=t1, **dense_info) + return interpolator.evaluate(t) + + def _cond_fun(inner_state): + saveat_ts_index, _ = inner_state + return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0) + + def _body_fun(inner_state): + saveat_ts_index, grad_dense_info = inner_state + t = saveat_ts[saveat_ts_index] + grad_y = (ω(grad_ys)[saveat_ts_index]).ω + _, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info) + _, _, _, dgrad_dense_info = interp_vjp(grad_y) + grad_dense_info = eqx.apply_updates(grad_dense_info, dgrad_dense_info) + saveat_ts_index = saveat_ts_index - 1 + return saveat_ts_index, grad_dense_info + + grad_dense_info = jtu.tree_map(jnp.zeros_like, dense_info) + inner_state = (saveat_ts_index, grad_dense_info) + inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax") + saveat_ts_index, grad_dense_info = inner_state + + # Pull gradients back through forward step + + _, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms) + grad_y0, grad_state, dgrad_args, dgrad_terms = vjp_fn( + (grad_y1, grad_dense_info, grad_state) + ) + + grad_args = eqx.apply_updates(grad_args, dgrad_args) + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) + + ts_index = ts_index - 1 + + return ( + saveat_ts_index, + ts_index, + y0, + solver_state, + grad_y0, + grad_state, + grad_args, + grad_terms, + ) + + def cond_fun(state): + ts_index = state[1] + return ts_index > 0 + + state = ( + saveat_ts_index, + ts_final_index, + y1, + solver_state, + grad_y1, + grad_state, + grad_args, + grad_terms, + ) + + state = jax.lax.while_loop(cond_fun, grad_step, state) + _, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state + + # Pull solver_state gradients back onto y0, args, terms. + + _, init_vjp = eqx.filter_vjp(solver.init, terms, ts[0], ts[1], y0, args) + dgrad_terms, _, _, dgrad_y0, dgrad_args = init_vjp(grad_state) + grad_y0 = eqx.apply_updates(grad_y0, dgrad_y0) + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) + grad_args = eqx.apply_updates(grad_args, dgrad_args) + + return grad_y0, grad_args, grad_terms + + +class ReversibleAdjoint(AbstractAdjoint): + """Backpropagate through [`diffrax.diffeqsolve`][] when using a reversible solver + [`diffrax.AbstractReversibleSolver`][]. + + Gradient calculation is exact (up to floating point errors) and backpropagation + is linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps. + """ + + def loop( + self, + *, + args, + terms, + solver, + saveat, + max_steps, + init_state, + passed_solver_state, + passed_controller_state, + event, + **kwargs, + ): + if max_steps is None: + raise ValueError( + "`max_steps=None` is incompatible with `ReversibleAdjoint`." + ) + + if ( + jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) + != jtu.tree_structure(0) + or saveat.dense + or saveat.subs.steps + or (saveat.subs.fn is not save_y) + ): + raise ValueError( + """`ReversibleAdjoint` is only compatible with the following `SaveAt` + properties: `t0`, `t1`, `ts`, `fn=save_y` (default). + """ + ) + + if event is not None: + raise NotImplementedError( + "`diffrax.ReversibleAdjoint` is not compatible with events." + ) + + if is_unsafe_sde(terms): + raise ValueError( + "`adjoint=ReversibleAdjoint()` does not support `UnsafeBrownianPath`. " + "Consider using `VirtualBrownianTree` instead." + ) + + y = init_state.y + init_state = eqx.tree_at(lambda s: s.y, init_state, object()) + init_state = _nondiff_solver_controller_state( + self, init_state, passed_solver_state, passed_controller_state + ) + + final_state, aux_stats = _loop_reversible( + (y, args, terms), + self=self, + saveat=saveat, + max_steps=max_steps, + init_state=init_state, + solver=solver, + event=event, + **kwargs, + ) + final_state = _only_transpose_ys(final_state) + return final_state, aux_stats diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 88c014aa..ab712dec 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -112,6 +112,11 @@ class State(eqx.Module): event_dense_info: Optional[DenseInfo] event_values: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]] event_mask: Optional[PyTree[BoolScalarLike]] + # + # Information for reversible adjoint (save ts) + # + reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]] + reversible_save_index: Optional[IntScalarLike] def _is_none(x: Any) -> bool: @@ -230,7 +235,7 @@ def _outer_buffers(state): return ( [s.ts for s in save_states] + [s.ys for s in save_states] - + [state.dense_ts, state.dense_infos] + + [state.dense_ts, state.dense_infos, state.reversible_ts] ) @@ -304,6 +309,11 @@ def loop( dense_ts = dense_ts.at[0].set(t0) init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) + if init_state.reversible_ts is not None: + reversible_ts = init_state.reversible_ts + reversible_ts = reversible_ts.at[0].set(t0) + init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts) + def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: if subsaveat.t0: save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state) @@ -585,6 +595,15 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): result, ) + reversible_ts = state.reversible_ts + reversible_save_index = state.reversible_save_index + + if state.reversible_ts is not None: + reversible_ts = maybe_inplace( + reversible_save_index + 1, tprev, reversible_ts + ) + reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0) + new_state = State( y=y, tprev=tprev, @@ -606,6 +625,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType] + reversible_save_index=reversible_save_index, ) return ( @@ -1385,6 +1406,14 @@ def _outer_cond_fn(cond_fn_i): ) del had_event, event_structure, event_mask_leaves, event_values__mask + # Reversible info + if max_steps is None: + reversible_ts = None + reversible_save_index = None + else: + reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) + reversible_save_index = 0 + # Initialise state init_state = State( y=y0, @@ -1407,6 +1436,8 @@ def _outer_cond_fn(cond_fn_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_ts=reversible_ts, + reversible_save_index=reversible_save_index, ) # diff --git a/diffrax/_solver/base.py b/diffrax/_solver/base.py index 42f19e4c..93fc5658 100644 --- a/diffrax/_solver/base.py +++ b/diffrax/_solver/base.py @@ -348,3 +348,55 @@ def func( - `solver`: The solver to wrap. """ + + +class AbstractReversibleSolver(AbstractSolver[_SolverState]): + """Indicates that this is a reversible differential equation solver. This means + that the state at `t0` can be reconstructed (in closed form) from the state at `t1`. + + The reconstruction must be implemented by + [`diffrax.AbstractReversibleSolver.backward_step`][]. + + This solver can be combined with `adjoint=diffrax.ReversibleAdjoint` for exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory, for $n$ time steps. + """ + + @abc.abstractmethod + def backward_step( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, DenseInfo, _SolverState]: + """ + Make a single backward step with the reversible solver. + + Each step is made over the specified interval $[t_1, t_0]$. + + **Arguments:** + + - `terms`: The PyTree of terms representing the vector fields and controls. + - `t0`: The end of the interval that the backward step is made over. + - `t1`: The start of the interval that the backward step is made over. + - `y1`: The current value of the solution at `t1`. + - `args`: Any extra arguments passed to the vector field. + - `solver_state`: Any evolving state for the solver itself, at `t1`. + - `made_jump`: Whether there was a discontinuity in the vector field at `t1`. + Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there + are no jumps and for efficiency re-use information between steps; this + indicates that a jump has just occurred and this assumption is not true. + + **Returns:** + + A tuple of three objects: + + - The value of the solution at `t0`. + - Some dictionary of information that is passed to the solver's interpolation + routine to calculate dense output. Note that this is assumed to be the same + information returned on the forward step. + - The value of the solver state at `t1`. + """ diff --git a/diffrax/_solver/leapfrog_midpoint.py b/diffrax/_solver/leapfrog_midpoint.py index 00ba11da..d6e9ce75 100644 --- a/diffrax/_solver/leapfrog_midpoint.py +++ b/diffrax/_solver/leapfrog_midpoint.py @@ -2,6 +2,7 @@ from typing import ClassVar from typing_extensions import TypeAlias +import jax from equinox.internal import ω from jaxtyping import PyTree @@ -9,15 +10,15 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractSolver +from .base import AbstractReversibleSolver _ErrorEstimate: TypeAlias = None -_SolverState: TypeAlias = tuple[RealScalarLike, PyTree] +_SolverState: TypeAlias = tuple[RealScalarLike, PyTree, RealScalarLike] # TODO: support arbitrary linear multistep methods -class LeapfrogMidpoint(AbstractSolver): +class LeapfrogMidpoint(AbstractReversibleSolver): r"""Leapfrog/midpoint method. 2nd order linear multistep method. Uses 1st order local linear interpolation for @@ -29,6 +30,12 @@ class LeapfrogMidpoint(AbstractSolver): (which is usually taken to refer to the explicit Runge--Kutta method [`diffrax.Midpoint`][]). + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. + ??? cite "Reference" ```bibtex @@ -60,9 +67,12 @@ def init( y0: Y, args: Args, ) -> _SolverState: - del terms, t1, args + # We pre-compute the step size to avoid numerical instability during the + # backward_step. This is okay (albeit slightly ugly) as `LeapfrogMidpoint` can't + # be used with adaptive step sizes. + dt = t1 - t0 # Corresponds to making an explicit Euler step on the first step. - return t0, y0 + return t0, y0, dt def step( self, @@ -75,12 +85,36 @@ def step( made_jump: BoolScalarLike, ) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del made_jump - tm1, ym1 = solver_state + tm1, ym1, dt = solver_state control = terms.contr(tm1, t1) y1 = (ym1**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω dense_info = dict(y0=y0, y1=y1) - solver_state = (t0, y0) + solver_state = (t0, y0, dt) return y1, None, dense_info, solver_state, RESULTS.successful + def backward_step( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, DenseInfo, _SolverState]: + del made_jump + t0, y0, dt = solver_state + tm1 = t0 - dt + control = terms.contr(tm1, t1) + ym1 = (y1**ω - terms.vf_prod(t0, y0, args, control) ** ω).ω + dense_info = dict(y0=y0, y1=y1) + # On the last step we need to make sure our solver state is correct + # (i.e. the state used on the forward). Otherwise, in `ReversibleAdjoint`, + # we would take a local forward step from an incorrect `solver_state`. + solver_state = jax.lax.cond( + tm1 > 0, lambda _: (tm1, ym1, dt), lambda _: (t0, y0, dt), None + ) + return y0, dense_info, solver_state + def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/reversible_heun.py b/diffrax/_solver/reversible_heun.py index 0f0a9fe9..2779320f 100644 --- a/diffrax/_solver/reversible_heun.py +++ b/diffrax/_solver/reversible_heun.py @@ -10,13 +10,19 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver +from .base import ( + AbstractAdaptiveSolver, + AbstractReversibleSolver, + AbstractStratonovichSolver, +) _SolverState: TypeAlias = tuple[PyTree, PyTree] -class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): +class ReversibleHeun( + AbstractReversibleSolver, AbstractAdaptiveSolver, AbstractStratonovichSolver +): """Reversible Heun method. Algebraically reversible 2nd order method. Has an embedded 1st order method for @@ -24,6 +30,12 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): When used to solve SDEs, converges to the Stratonovich solution. + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. + ??? cite "Reference" ```bibtex @@ -83,5 +95,26 @@ def step( solver_state = (yhat1, vf1) return y1, y1_error, dense_info, solver_state, RESULTS.successful + def backward_step( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, DenseInfo, _SolverState]: + yhat1, vf1 = solver_state + + control = terms.contr(t0, t1) + yhat0 = (2 * y1**ω - yhat1**ω - terms.prod(vf1, control) ** ω).ω + vf0 = terms.vf(t0, yhat0, args) + y0 = (y1**ω - 0.5 * terms.prod((vf0**ω + vf1**ω).ω, control) ** ω).ω + + dense_info = dict(y0=y0, y1=y1) + solver_state = (yhat0, vf0) + return y0, dense_info, solver_state + def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/semi_implicit_euler.py b/diffrax/_solver/semi_implicit_euler.py index 00b9e1db..c3eaacda 100644 --- a/diffrax/_solver/semi_implicit_euler.py +++ b/diffrax/_solver/semi_implicit_euler.py @@ -9,7 +9,7 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractSolver +from .base import AbstractReversibleSolver _ErrorEstimate: TypeAlias = None @@ -19,11 +19,17 @@ Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] -class SemiImplicitEuler(AbstractSolver): +class SemiImplicitEuler(AbstractReversibleSolver): """Semi-implicit Euler's method. Symplectic method. Does not support adaptive step sizing. Uses 1st order local linear interpolation for dense/ts output. + + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. """ term_structure: ClassVar = (AbstractTerm, AbstractTerm) @@ -68,6 +74,30 @@ def step( dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful + def backward_step( + self, + terms: tuple[AbstractTerm, AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y1: tuple[Ya, Yb], + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[tuple[Ya, Yb], DenseInfo, _SolverState]: + del solver_state, made_jump + + term_1, term_2 = terms + y1_1, y1_2 = y1 + + control1 = term_1.contr(t0, t1) + control2 = term_2.contr(t0, t1) + y0_2 = (y1_2**ω - term_2.vf_prod(t0, y1_1, args, control2) ** ω).ω + y0_1 = (y1_1**ω - term_1.vf_prod(t0, y0_2, args, control1) ** ω).ω + + y0 = (y0_1, y0_2) + dense_info = dict(y0=y0, y1=y1) + return y0, dense_info, None + def func( self, terms: tuple[AbstractTerm, AbstractTerm], diff --git a/test/test_reversible.py b/test/test_reversible.py new file mode 100644 index 00000000..5dad9b2b --- /dev/null +++ b/test/test_reversible.py @@ -0,0 +1,162 @@ +from typing import cast + +import diffrax +import equinox as eqx +import jax.numpy as jnp +import jax.random as jr +import pytest +from jaxtyping import Array + +from .helpers import tree_allclose + + +class VectorField(eqx.Module): + mlp: eqx.nn.MLP + + def __init__(self, in_size, out_size, width_size, depth, key): + self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key) + + def __call__(self, t, y, args): + return args * self.mlp(y) + + +@eqx.filter_value_and_grad +def _loss(y0__args__term, solver, saveat, adjoint, stepsize_controller, dual_y0): + y0, args, term = y0__args__term + + sol = diffrax.diffeqsolve( + term, + solver, + t0=0, + t1=5, + dt0=0.01, + y0=y0, + args=args, + saveat=saveat, + max_steps=4096, + adjoint=adjoint, + stepsize_controller=stepsize_controller, + ) + if dual_y0: + y1 = sol.ys[0] # pyright: ignore + else: + y1 = sol.ys + return jnp.sum(cast(Array, y1)) + + +def _compare_grads(y0__args__term, solver, saveat, stepsize_controller, dual_y0): + loss, grads_base = _loss( + y0__args__term, + solver, + saveat, + adjoint=diffrax.RecursiveCheckpointAdjoint(), + stepsize_controller=stepsize_controller, + dual_y0=dual_y0, + ) + loss, grads_reversible = _loss( + y0__args__term, + solver, + saveat, + adjoint=diffrax.ReversibleAdjoint(), + stepsize_controller=stepsize_controller, + dual_y0=dual_y0, + ) + assert tree_allclose(grads_base, grads_reversible, atol=1e-5) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_semi_implicit_euler(saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + fkey, gkey = jr.split(key, 2) + f = VectorField(n, n, n, depth=4, key=fkey) + g = VectorField(n, n, n, depth=4, key=gkey) + terms = (diffrax.ODETerm(f), diffrax.ODETerm(g)) + y0 = (y0, y0) + args = jnp.linspace(0, 1, n) + solver = diffrax.SemiImplicitEuler() + stepsize_controller = diffrax.ConstantStepSize() + + _compare_grads((y0, args, terms), solver, saveat, stepsize_controller, dual_y0=True) + + +@pytest.mark.parametrize( + "stepsize_controller", + [diffrax.ConstantStepSize(), diffrax.PIDController(rtol=1e-8, atol=1e-8)], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_heun_ode(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + y0 = y0 + args = jnp.linspace(0, 1, n) + solver = diffrax.ReversibleHeun() + + _compare_grads( + (y0, args, terms), solver, saveat, stepsize_controller, dual_y0=False + ) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_heun_sde(saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + fkey, Wkey = jr.split(key, 2) + f = VectorField(n, n, n, depth=4, key=fkey) + g = lambda t, y, args: jnp.ones((n,)) + W = diffrax.VirtualBrownianTree(t0=0, t1=5, tol=1e-3, shape=(n,), key=Wkey) + terms = diffrax.MultiTerm(diffrax.ODETerm(f), diffrax.ControlTerm(g, W)) + y0 = y0 + args = jnp.linspace(0, 1, n) + solver = diffrax.ReversibleHeun() + stepsize_controller = diffrax.ConstantStepSize() + + _compare_grads( + (y0, args, terms), solver, saveat, stepsize_controller, dual_y0=False + ) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_leapfrog_midpoint(saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + y0 = y0 + args = jnp.linspace(0, 1, n) + solver = diffrax.LeapfrogMidpoint() + stepsize_controller = diffrax.ConstantStepSize() + + _compare_grads( + (y0, args, terms), solver, saveat, stepsize_controller, dual_y0=False + )