Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ForwardMode as ForwardMode,
ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
ReversibleAdjoint as ReversibleAdjoint,
)
from ._autocitation import citation as citation, citation_rules as citation_rules
from ._brownian import (
Expand Down
259 changes: 257 additions & 2 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def loop(
if is_unsafe_sde(terms):
kind = "lax"
msg = (
"Cannot reverse-mode autodifferentiate when using "
"`UnsafeBrownianPath`."
"Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`."
)
elif max_steps is None:
kind = "lax"
Expand Down Expand Up @@ -908,3 +907,259 @@ def loop(
**kwargs,
)
return final_state


# Reversible Adjoint custom vjp computes gradients w.r.t.
# - y, corresponding to the initial state;
# - args, corresponding to explicit parameters;
# - terms, corresponding to implicit parameters as part of the vector field.


@eqx.filter_custom_vjp
def _loop_reversible(y__args__terms, *, self, throw, max_steps, init_state, **kwargs):
del throw
y, args, terms = y__args__terms
init_state = eqx.tree_at(lambda s: s.y, init_state, y)
del y
return self._loop(
args=args,
terms=terms,
max_steps=max_steps,
init_state=init_state,
inner_while_loop=ft.partial(_inner_loop, kind="lax"),
outer_while_loop=ft.partial(_outer_loop, kind="lax"),
**kwargs,
)


@_loop_reversible.def_fwd
def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs):
del perturbed
final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs)
ts = final_state.reversible_ts
ts_final_index = final_state.reversible_save_index
y1 = final_state.y
save_state = final_state.save_state
solver_state = final_state.solver_state
return (final_state, aux_stats), (ts, ts_final_index, y1, save_state, solver_state)


@_loop_reversible.def_bwd
def _loop_reversible_bwd(
residuals,
grad_final_state__aux_stats,
perturbed,
y__args__terms,
*,
self,
saveat,
init_state,
solver,
event,
**kwargs,
):
assert event is None

del perturbed, self, init_state, kwargs
ts, ts_final_index, y1, save_state, solver_state = residuals
del residuals

grad_final_state, _ = grad_final_state__aux_stats
saveat_ts = save_state.ts
ys = save_state.ys
saveat_ts_index = save_state.saveat_ts_index - 1
grad_ys = grad_final_state.save_state.ys
grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)

if saveat.subs.t1:
grad_y1 = (ω(grad_ys)[-1]).ω
else:
grad_y1 = jtu.tree_map(jnp.zeros_like, y1)

if saveat.subs.t0:
saveat_ts_index = saveat_ts_index + 1

del grad_final_state, grad_final_state__aux_stats

y, args, terms = y__args__terms
del y__args__terms

diff_state = eqx.filter(solver_state, eqx.is_inexact_array)
diff_args = eqx.filter(args, eqx.is_inexact_array)
diff_terms = eqx.filter(terms, eqx.is_inexact_array)
grad_state = jtu.tree_map(jnp.zeros_like, diff_state)
grad_args = jtu.tree_map(jnp.zeros_like, diff_args)
grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
del diff_args, diff_terms

def grad_step(state):
def forward_step(y0, solver_state, args, terms):
y1, _, dense_info, new_solver_state, _ = solver.step(
terms, t0, t1, y0, args, solver_state, False
)
return y1, dense_info, new_solver_state

(
saveat_ts_index,
ts_index,
y1,
solver_state,
grad_y1,
grad_state,
grad_args,
grad_terms,
) = state

t1 = ts[ts_index]
t0 = ts[ts_index - 1]

y0, dense_info, solver_state = solver.backward_step(
terms, t0, t1, y1, args, solver_state, False
)

# Pull gradients back through interpolation

def interpolate(t, t0, t1, dense_info):
interpolator = solver.interpolation_cls(t0=t0, t1=t1, **dense_info)
return interpolator.evaluate(t)

def _cond_fun(inner_state):
saveat_ts_index, _ = inner_state
return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0)

def _body_fun(inner_state):
saveat_ts_index, grad_dense_info = inner_state
t = saveat_ts[saveat_ts_index]
grad_y = (ω(grad_ys)[saveat_ts_index]).ω
_, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info)
_, _, _, dgrad_dense_info = interp_vjp(grad_y)
grad_dense_info = eqx.apply_updates(grad_dense_info, dgrad_dense_info)
saveat_ts_index = saveat_ts_index - 1
return saveat_ts_index, grad_dense_info

grad_dense_info = jtu.tree_map(jnp.zeros_like, dense_info)
inner_state = (saveat_ts_index, grad_dense_info)
inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax")
saveat_ts_index, grad_dense_info = inner_state

# Pull gradients back through forward step

_, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms)
grad_y0, grad_state, dgrad_args, dgrad_terms = vjp_fn(
(grad_y1, grad_dense_info, grad_state)
)

grad_args = eqx.apply_updates(grad_args, dgrad_args)
grad_terms = eqx.apply_updates(grad_terms, dgrad_terms)

ts_index = ts_index - 1

return (
saveat_ts_index,
ts_index,
y0,
solver_state,
grad_y0,
grad_state,
grad_args,
grad_terms,
)

def cond_fun(state):
ts_index = state[1]
return ts_index > 0

state = (
saveat_ts_index,
ts_final_index,
y1,
solver_state,
grad_y1,
grad_state,
grad_args,
grad_terms,
)

state = jax.lax.while_loop(cond_fun, grad_step, state)
_, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state

# Pull solver_state gradients back onto y0, args, terms.

_, init_vjp = eqx.filter_vjp(solver.init, terms, ts[0], ts[1], y0, args)
dgrad_terms, _, _, dgrad_y0, dgrad_args = init_vjp(grad_state)
grad_y0 = eqx.apply_updates(grad_y0, dgrad_y0)
grad_terms = eqx.apply_updates(grad_terms, dgrad_terms)
grad_args = eqx.apply_updates(grad_args, dgrad_args)

return grad_y0, grad_args, grad_terms


class ReversibleAdjoint(AbstractAdjoint):
"""Backpropagate through [`diffrax.diffeqsolve`][] when using a reversible solver
[`diffrax.AbstractReversibleSolver`][].

Gradient calculation is exact (up to floating point errors) and backpropagation
is linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps.
"""

def loop(
self,
*,
args,
terms,
solver,
saveat,
max_steps,
init_state,
passed_solver_state,
passed_controller_state,
event,
**kwargs,
):
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with `ReversibleAdjoint`."
)

if (
jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat)
!= jtu.tree_structure(0)
or saveat.dense
or saveat.subs.steps
or (saveat.subs.fn is not save_y)
):
raise ValueError(
"""`ReversibleAdjoint` is only compatible with the following `SaveAt`
properties: `t0`, `t1`, `ts`, `fn=save_y` (default).
"""
)

if event is not None:
raise NotImplementedError(
"`diffrax.ReversibleAdjoint` is not compatible with events."
)

if is_unsafe_sde(terms):
raise ValueError(
"`adjoint=ReversibleAdjoint()` does not support `UnsafeBrownianPath`. "
"Consider using `VirtualBrownianTree` instead."
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
init_state = _nondiff_solver_controller_state(
self, init_state, passed_solver_state, passed_controller_state
)

final_state, aux_stats = _loop_reversible(
(y, args, terms),
self=self,
saveat=saveat,
max_steps=max_steps,
init_state=init_state,
solver=solver,
event=event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats
33 changes: 32 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class State(eqx.Module):
event_dense_info: Optional[DenseInfo]
event_values: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]]
event_mask: Optional[PyTree[BoolScalarLike]]
#
# Information for reversible adjoint (save ts)
#
reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]]
reversible_save_index: Optional[IntScalarLike]


def _is_none(x: Any) -> bool:
Expand Down Expand Up @@ -230,7 +235,7 @@ def _outer_buffers(state):
return (
[s.ts for s in save_states]
+ [s.ys for s in save_states]
+ [state.dense_ts, state.dense_infos]
+ [state.dense_ts, state.dense_infos, state.reversible_ts]
)


Expand Down Expand Up @@ -304,6 +309,11 @@ def loop(
dense_ts = dense_ts.at[0].set(t0)
init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts)

if init_state.reversible_ts is not None:
reversible_ts = init_state.reversible_ts
reversible_ts = reversible_ts.at[0].set(t0)
init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts)

def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.t0:
save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state)
Expand Down Expand Up @@ -585,6 +595,15 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
result,
)

reversible_ts = state.reversible_ts
reversible_save_index = state.reversible_save_index

if state.reversible_ts is not None:
reversible_ts = maybe_inplace(
reversible_save_index + 1, tprev, reversible_ts
)
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0)

new_state = State(
y=y,
tprev=tprev,
Expand All @@ -606,6 +625,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
event_dense_info=event_dense_info,
event_values=event_values,
event_mask=event_mask,
reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType]
reversible_save_index=reversible_save_index,
)

return (
Expand Down Expand Up @@ -1385,6 +1406,14 @@ def _outer_cond_fn(cond_fn_i):
)
del had_event, event_structure, event_mask_leaves, event_values__mask

# Reversible info
if max_steps is None:
reversible_ts = None
reversible_save_index = None
else:
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0

# Initialise state
init_state = State(
y=y0,
Expand All @@ -1407,6 +1436,8 @@ def _outer_cond_fn(cond_fn_i):
event_dense_info=event_dense_info,
event_values=event_values,
event_mask=event_mask,
reversible_ts=reversible_ts,
reversible_save_index=reversible_save_index,
)

#
Expand Down
Loading