| language: | |
| - en | |
| license: mit | |
| library_name: jax | |
| tags: | |
| - robotics | |
| - reinforcement-learning | |
| - continuous-control | |
| - test-time-adaptation | |
| - flow-matching | |
| - orbax | |
| # Model Card for quadruped_domain_randomization_model | |
| ## Model Description | |
| This repository contains the pre-trained neural network weights for the **Reversible Flow Adaptation** architecture, applied to the **Quadruped** environment using the **domain_randomization** policy. | |
| These checkpoints represent the fully-converged policy at the final epoch (Epoch 600) across 3 independent random seeds (`seed_1`, `seed_2`, `seed_42`). | |
| ### Architecture | |
| The policy consists of a dual-objective architecture: | |
| 1. **Vision Transformer (ViT)**: Acts as the state encoder, augmented with an Auxiliary Physics Distillation Head to predict privileged parameters (like mass or friction). | |
| 2. **1D U-Net Vector Field**: A deterministic, parallelized Optimal Transport Flow Matching (Rectified Flows) model providing $O(1)$ generative inference latency. | |
| ### Usage | |
| The checkpoints are saved in the highly-optimized **JAX / Orbax** format. | |
| ```python | |
| import orbax.checkpoint as ocp | |
| # Example: Loading seed 42 | |
| checkpoint_path = "seed_42" | |
| mngr = ocp.CheckpointManager(os.path.abspath(checkpoint_path)) | |
| # Load the parameters | |
| restored = mngr.restore(mngr.latest_step()) | |
| print(restored.keys()) | |
| ``` | |