Pyre PPO Agent β€” krooz/pyre-ppo-agent

PPO-trained actor-critic agent for the Pyre fire-evacuation environment (OpenEnv Hackathon, Apr 2026).

⚠️ This is a raw PyTorch checkpoint, not a transformers model. The Hugging Face hosted Inference API cannot run it directly. Use the inference code below to load and run it locally.

Training summary (artifact run: pyre_ppo_hard_v2)

Values below are from artifacts/pyre_ppo_hard_v2.csv, pyre_ppo_hard_v2_eval.csv, pyre_ppo_hard_v2.png (MA-20 curves match save_training_graph_png in train_torch_ppo.py), and artifacts/pyre_ppo_hard_v2_training.log (HTTP trainer via train_torch_ppo_http.py, env at http://localhost:8000).

Metric Value
Total episodes 600
Wall-clock training time ~227 s (~2.6 eps/s)
Final success rate (MA-20, training graph title) 55%
Final reward mean (MA-20) +3.21
Final success rate (rolling last 30 ep, CSV s30 / log) 47%
Overall evacuation rate (all 600 ep, CSV) 52.7%
Per-difficulty evacuation (easy / medium / hard) 67.7% / 59.5% / 10.5%
Curriculum easy β†’ medium β†’ hard with patience gate (0.70 over 20 ep); hard-phase mix hard:0.4, medium:0.4, easy:0.2
Eval cadence Every 25 episodes, 5 deterministic rollouts
Eval difficulty hard (pyre_ppo_hard_v2_eval.csv)

Training command (this run)

uv run python training/ppo/train_torch_ppo_http.py \
  --episodes 600 \
  --difficulty-schedule easy,medium,hard \
  --patience-threshold 0.70 \
  --patience-window 20 \
  --hard-mix-dist hard:0.4,medium:0.4,easy:0.2 \
  --update-every 8 \
  --update-epochs 6 \
  --eval-every 25 \
  --eval-difficulty hard \
  --eval-episodes 5 \
  --checkpoint-every 50 \
  --entropy-coef 0.05 \
  --step-delay 0 \
  --viz-after-ep 500 \
  --output artifacts/pyre_ppo_hard_v2.pt \
  --log-file artifacts/pyre_ppo_hard_v2_training.log

Network architecture (from training log)

Property Value
Total parameters 12,065,650
Input vector dim 23,140 (encoder base_dim 5785 Γ— 4 stacked frames)
Action dim 41 (4 move + 4 look + 1 wait + 16 door open + 16 door close)
Hidden MLP 512 β†’ 256 β†’ 128

Hyperparameters (this run)

Param Value
Learning rate 3Γ—10⁻⁴ (with LR decay toward 0.1Γ— end factor unless disabled)
PPO clip Ξ΅ 0.2
Entropy coeff 0.05
Value coeff 0.5
Gamma 0.99
GAE Ξ» 0.95
PPO update every 8 episodes
PPO epochs / minibatch 6 / 256
Max grad norm 0.5
Observation mode visible (partial observability)
Device cuda (train_torch_ppo.py default; set --device cpu if needed)

Periodic eval on hard (from pyre_ppo_hard_v2_eval.csv)

Episode Difficulty Success rate Reward mean Steps mean
25 hard 0% βˆ’10.124 58.0
50 hard 0% βˆ’11.184 58.4
75 hard 0% βˆ’11.468 35.6
100 hard 0% βˆ’9.827 74.0
125 hard 20% βˆ’7.792 25.0
150 hard 40% βˆ’4.238 28.0
175 hard 20% βˆ’6.674 35.2
200 hard 0% βˆ’12.304 74.6
225 hard 0% βˆ’11.080 100.0
250 hard 20% βˆ’5.648 38.4
275 hard 0% βˆ’10.368 76.2
300 hard 20% βˆ’4.421 72.8
325 hard 0% βˆ’11.180 48.2
350 hard 0% βˆ’9.845 74.0
375 hard 0% βˆ’11.320 26.4
400 hard 0% βˆ’12.256 34.0
425 hard 20% βˆ’7.024 36.4
450 hard 0% βˆ’10.726 56.4
475 hard 0% βˆ’9.072 88.6
500 hard 0% βˆ’12.050 66.6
525 hard 20% βˆ’5.528 41.6
550 hard 0% βˆ’11.274 52.4
575 hard 0% βˆ’10.578 58.4
600 hard 0% βˆ’12.068 36.6

Files in this repository

File Description
model.pt PyTorch checkpoint (network_state, optimizer_state, scheduler_state, args, episode)
training_graph.png Training curves (reward + success rate vs episode)
episode_metrics.csv Per-episode training metrics
eval_metrics.csv Periodic eval aggregates
training.log Full console transcript of the HTTP training run

Running inference locally

import sys
import torch
from huggingface_hub import hf_hub_download

# 1. Point Python at your local pyre_env checkout (or install the package)
sys.path.insert(0, "pyre_env")

from training.ppo.train_torch_ppo import (
    ActorCritic,
    ObservationEncoder,
    action_index_to_env_action,
    build_action_mask,
)

# 2. Download the checkpoint from this Hub repo
ckpt_path = hf_hub_download(repo_id="krooz/pyre-ppo-agent", filename="model.pt")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

# 3. Rebuild the policy from saved training args
saved_args = ckpt["args"]
encoder = ObservationEncoder(mode=saved_args.get("observation_mode", "visible"))
hidden_sizes = tuple(int(x) for x in saved_args.get("hidden_sizes", "512,256,128").split(","))
history_length = saved_args.get("history_length", 4)
input_dim = encoder.base_dim * history_length
network = ActorCritic(input_dim, 41, hidden_sizes)
network.load_state_dict(ckpt["network_state"])
network.eval()
print(f"Loaded checkpoint from episode {ckpt.get('episode', '?')}")

# 4. Roll out one episode (in-process env β€” swap for HTTP client if you prefer)
from openenv_pyre import PyreEnvironment
from collections import deque
import numpy as np

env = PyreEnvironment()
obs = env.reset(difficulty="medium")
frames = deque([np.zeros(encoder.base_dim, dtype=np.float32)] * history_length, maxlen=history_length)
frames.append(encoder.encode(obs))

total_reward = 0.0
with torch.no_grad():
    while True:
        state_vec = np.concatenate(list(frames), dtype=np.float32)
        obs_t = torch.tensor(state_vec, dtype=torch.float32).unsqueeze(0)
        mask_t = torch.tensor(build_action_mask(obs, exclude_look=True), dtype=torch.float32).unsqueeze(0)
        action_t, _, _ = network.act(obs_t, mask_t, deterministic=True)
        obs = env.step(action_index_to_env_action(int(action_t.item())))
        total_reward += float(obs.reward or 0.0)
        frames.append(encoder.encode(obs))
        if obs.done:
            break

print(f"Episode finished β€” evacuated={obs.agent_evacuated}  reward={total_reward:.3f}")

Environment & training resources

  • HF Space (live env): Krooz/pyre_env
  • PPO training in Colab (HTTP to Space): Pyre PPO training β€” Google Colab
  • Local HTTP trainer: training/ppo/train_torch_ppo_http.py
  • Local in-process trainer: training/ppo/train_torch_ppo.py
  • Notebook source: training/ppo/pyre_ppo_training.ipynb
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Space using Krooz/pyre-ppo-agent 1