import os if os.getenv("SPACE_ID") is not None: os.environ["SDL_VIDEODRIVER"] = "dummy" os.environ["SDL_AUDIODRIVER"] = "dummy" import torch import torch.nn as nn import numpy as np import gymnasium as gym import imageio import gradio as gr import spaces from huggingface_hub import hf_hub_download def layer_init(layer, std=np.sqrt(2), bias_const=0.0): nn.init.orthogonal_(layer.weight, std) nn.init.constant_(layer.bias, bias_const) return layer def get_actor_network(state_dim=8, action_dim=4): actor = nn.Sequential( layer_init(nn.Linear(state_dim, 64)), nn.Tanh(), layer_init(nn.Linear(64, 64)), nn.Tanh(), layer_init(nn.Linear(64, action_dim), std=0.01), ) return actor @spaces.GPU(duration=60) def simulate_agent(stage_selection): weight_mapping = { "Stage 1: Baseline": "1_baseline.pth", "Stage 2: Surrogate Hacking": "2_surrogate_hacking_attention.pth", "Stage 3: Temporal Paradox ": "3_temporal_paradox_variance.pth", "Stage 4: Target Decoupling": "4_target_decoupling_final.pth" } filename = weight_mapping.get(stage_selection) repo_id = "ben-dlwlrma/Representation-Over-Routing" try: weights_path = hf_hub_download(repo_id=repo_id, filename=filename) except Exception as e: raise gr.Error(f"Weight download failed. Error: {str(e)}") env = gym.make("LunarLander-v3", render_mode="rgb_array") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") actor = get_actor_network(state_dim=8, action_dim=4).to(device) try: actor.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True)) actor.eval() except Exception as e: env.close() raise gr.Error(f"Architecture mismatch. Error: {str(e)}") state, _ = env.reset(seed=32) done = False frames = [] step_count = 0 while not done and step_count < 600: try: frame = env.render() if frame is not None: frames.append(frame) except Exception as e: env.close() raise gr.Error(f"Render failed: {str(e)}") state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) with torch.no_grad(): action_logits = actor(state_tensor) action = torch.argmax(action_logits, dim=1).item() state, _, terminated, truncated, _ = env.step(action) step_count += 1 done = terminated or truncated env.close() video_filename = "eval_output.mp4" fps = 30 try: imageio.mimsave(video_filename, frames, fps=fps, codec='libx264', pixelformat='yuv420p') except Exception as e: raise gr.Error(f"Video encoding failed: {str(e)}") return video_filename with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo: gr.Markdown("## Representation over Routing") gr.Markdown("Multi-timescale RL evaluation environment. Select an ablation stage to visualize policy behavior.") with gr.Row(): with gr.Column(scale=1): model_dropdown = gr.Dropdown( choices=[ "Stage 1: Baseline", "Stage 2: Surrogate Hacking", "Stage 3: Temporal Paradox ", "Stage 4: Target Decoupling" ], value="Stage 4: Target Decoupling", label="Model Stage" ) run_button = gr.Button("Run Inference", variant="primary") with gr.Column(scale=2): video_output = gr.Video(label="Environment Render", autoplay=True) run_button.click( fn=simulate_agent, inputs=[model_dropdown], outputs=[video_output] ) if __name__ == "__main__": demo.launch()