MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from functools import partial
import jax
import jax.numpy as jnp
from .functional import chunk_encode, cache_grad, unchunk_args
def cache_train_step(loss_fn, state, ss, tt, axis='device'):
def encode_with_params(params, **kwargs):
return state.apply_fn(params=params, **kwargs)
encode_fn = chunk_encode(partial(encode_with_params, state.params))
grad_fn = cache_grad(encode_with_params)
s_reps = encode_fn(**ss)
t_reps = encode_fn(**tt)
@unchunk_args(axis=0, argnums=(0, 1))
def grad_cache_fn(xx, yy):
return jnp.mean(loss_fn(xx, yy, axis=axis))
loss, (s_grads, t_grads) = jax.value_and_grad(grad_cache_fn, argnums=(0, 1))(s_reps, t_reps)
grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params)
grads = grad_fn(state.params, grads, s_grads, **ss)
grads = grad_fn(state.params, grads, t_grads, **tt)
loss, grads = jax.lax.pmean([loss, grads], axis)
new_state = state.apply_gradients(grads=grads)
return loss, new_state