from typing import Any import jax def tree_chunk(tree: Any, n_chunk: int, axis: int = 0) -> Any: return jax.tree_map( lambda v: v.reshape(v.shape[:axis] + (n_chunk, -1) + v.shape[axis + 1:]), tree ) def tree_unchunk(tree: Any, axis: int = 0) -> Any: return jax.tree_map( lambda x: x.reshape(x.shape[:axis] + (-1,) + x.shape[axis + 2:]), tree )