| from typing import Iterable, Any | |
| from functools import partial | |
| import jax | |
| import jax.numpy as jnp | |
| from .tree_utils import tree_unchunk | |
| Array = Any | |
| def grad_with_cache(f, **grad_kwargs): | |
| def cache_f(params, cache, *args, **kwargs): | |
| return jnp.sum(f(params, *args, **kwargs) * cache) | |
| return jax.grad(cache_f, **grad_kwargs) | |
| def encode_scan_fn(f, carry, x): | |
| return carry, f(**x) | |
| def cache_grad_scan_fn(f, params, acc, x): | |
| cached_grad, kwargs = x | |
| def fwd_fn(w): | |
| return f(params=w, **kwargs) | |
| chunk_grad = grad_with_cache(fwd_fn)(params, cached_grad) | |
| acc = jax.tree_multimap(lambda u, v: u + v, acc, chunk_grad) | |
| return acc, None | |
| def chunk_encode(encode_fn): | |
| def f(**xx): | |
| _, hh = jax.lax.scan(partial(encode_scan_fn, encode_fn), 0, xx) | |
| return hh | |
| return f | |
| def cache_grad(encode_fn): | |
| def f(params, grad_accumulator, cached_grad, **xx): | |
| grads, _ = jax.lax.scan( | |
| partial(cache_grad_scan_fn, encode_fn, params), grad_accumulator, [cached_grad, xx] | |
| ) | |
| return grads | |
| return f | |
| def unchunk_args(axis: int = 0, argnums: Iterable[int] = ()): | |
| def decorator_unchunk(f): | |
| def g(*args, **kwargs): | |
| new_args = list(args) | |
| for i in argnums: | |
| new_args[i] = tree_unchunk(args[i], axis) | |
| return f(*new_args, **kwargs) | |
| return g | |
| return decorator_unchunk | |