MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
from typing import Any
import jax
def tree_chunk(tree: Any, n_chunk: int, axis: int = 0) -> Any:
return jax.tree_map(
lambda v: v.reshape(v.shape[:axis] + (n_chunk, -1) + v.shape[axis + 1:]),
tree
)
def tree_unchunk(tree: Any, axis: int = 0) -> Any:
return jax.tree_map(
lambda x: x.reshape(x.shape[:axis] + (-1,) + x.shape[axis + 2:]),
tree
)