elprofesoriqo's picture
Upload model weights with LFS
bdc79dd
---
language:
- en
license: mit
library_name: jax
tags:
- robotics
- reinforcement-learning
- continuous-control
- test-time-adaptation
- flow-matching
- orbax
---
# Model Card for quadruped_domain_randomization_model
## Model Description
This repository contains the pre-trained neural network weights for the **Reversible Flow Adaptation** architecture, applied to the **Quadruped** environment using the **domain_randomization** policy.
These checkpoints represent the fully-converged policy at the final epoch (Epoch 600) across 3 independent random seeds (`seed_1`, `seed_2`, `seed_42`).
### Architecture
The policy consists of a dual-objective architecture:
1. **Vision Transformer (ViT)**: Acts as the state encoder, augmented with an Auxiliary Physics Distillation Head to predict privileged parameters (like mass or friction).
2. **1D U-Net Vector Field**: A deterministic, parallelized Optimal Transport Flow Matching (Rectified Flows) model providing $O(1)$ generative inference latency.
### Usage
The checkpoints are saved in the highly-optimized **JAX / Orbax** format.
```python
import orbax.checkpoint as ocp
# Example: Loading seed 42
checkpoint_path = "seed_42"
mngr = ocp.CheckpointManager(os.path.abspath(checkpoint_path))
# Load the parameters
restored = mngr.restore(mngr.latest_step())
print(restored.keys())
```