elprofesoriqo's picture
Upload model weights with LFS
bdc79dd
metadata
language:
  - en
license: mit
library_name: jax
tags:
  - robotics
  - reinforcement-learning
  - continuous-control
  - test-time-adaptation
  - flow-matching
  - orbax

Model Card for quadruped_domain_randomization_model

Model Description

This repository contains the pre-trained neural network weights for the Reversible Flow Adaptation architecture, applied to the Quadruped environment using the domain_randomization policy.

These checkpoints represent the fully-converged policy at the final epoch (Epoch 600) across 3 independent random seeds (seed_1, seed_2, seed_42).

Architecture

The policy consists of a dual-objective architecture:

  1. Vision Transformer (ViT): Acts as the state encoder, augmented with an Auxiliary Physics Distillation Head to predict privileged parameters (like mass or friction).
  2. 1D U-Net Vector Field: A deterministic, parallelized Optimal Transport Flow Matching (Rectified Flows) model providing $O(1)$ generative inference latency.

Usage

The checkpoints are saved in the highly-optimized JAX / Orbax format.

import orbax.checkpoint as ocp

# Example: Loading seed 42
checkpoint_path = "seed_42"
mngr = ocp.CheckpointManager(os.path.abspath(checkpoint_path))

# Load the parameters
restored = mngr.restore(mngr.latest_step())
print(restored.keys())