Ben commited on
Commit
7ed4dd8
·
1 Parent(s): 912195c

Enable ZeroGPU inference

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -9,15 +9,16 @@ import numpy as np
9
  import gymnasium as gym
10
  import imageio
11
  import gradio as gr
 
12
  from huggingface_hub import hf_hub_download
13
 
14
- # 1. Policy Network Architecture
15
  def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
16
  nn.init.orthogonal_(layer.weight, std)
17
  nn.init.constant_(layer.bias, bias_const)
18
  return layer
19
 
20
- # Reconstruct pure Sequential actor
21
  def get_actor_network(state_dim=8, action_dim=4):
22
  actor = nn.Sequential(
23
  layer_init(nn.Linear(state_dim, 64)),
@@ -28,7 +29,8 @@ def get_actor_network(state_dim=8, action_dim=4):
28
  )
29
  return actor
30
 
31
- # 2. Inference & Rendering
 
32
  def simulate_agent(stage_selection):
33
  weight_mapping = {
34
  "Stage 1: Baseline": "1_baseline.pth",
@@ -39,20 +41,16 @@ def simulate_agent(stage_selection):
39
  filename = weight_mapping.get(stage_selection)
40
  repo_id = "ben-dlwlrma/Representation-Over-Routing"
41
 
42
- # Download weights from HF Hub
43
  try:
44
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
45
  except Exception as e:
46
  raise gr.Error(f"Weight download failed. Error: {str(e)}")
47
 
48
- # Initialize env
49
  env = gym.make("LunarLander-v3", render_mode="rgb_array")
50
 
51
- # Initialize model on CPU
52
- device = torch.device("cpu")
53
  actor = get_actor_network(state_dim=8, action_dim=4).to(device)
54
 
55
- # Load weights
56
  try:
57
  actor.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True))
58
  actor.eval()
@@ -85,7 +83,6 @@ def simulate_agent(stage_selection):
85
 
86
  env.close()
87
 
88
- # Export to MP4
89
  video_filename = "eval_output.mp4"
90
  fps = 30
91
  try:
@@ -95,7 +92,7 @@ def simulate_agent(stage_selection):
95
 
96
  return video_filename
97
 
98
- # 3. Gradio Web UI
99
  with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo:
100
  gr.Markdown("## Representation over Routing")
101
  gr.Markdown("Multi-timescale RL evaluation environment. Select an ablation stage to visualize policy behavior.")
 
9
  import gymnasium as gym
10
  import imageio
11
  import gradio as gr
12
+ import spaces
13
  from huggingface_hub import hf_hub_download
14
 
15
+
16
  def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
17
  nn.init.orthogonal_(layer.weight, std)
18
  nn.init.constant_(layer.bias, bias_const)
19
  return layer
20
 
21
+
22
  def get_actor_network(state_dim=8, action_dim=4):
23
  actor = nn.Sequential(
24
  layer_init(nn.Linear(state_dim, 64)),
 
29
  )
30
  return actor
31
 
32
+
33
+ @spaces.GPU(duration=60)
34
  def simulate_agent(stage_selection):
35
  weight_mapping = {
36
  "Stage 1: Baseline": "1_baseline.pth",
 
41
  filename = weight_mapping.get(stage_selection)
42
  repo_id = "ben-dlwlrma/Representation-Over-Routing"
43
 
 
44
  try:
45
  weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
46
  except Exception as e:
47
  raise gr.Error(f"Weight download failed. Error: {str(e)}")
48
 
 
49
  env = gym.make("LunarLander-v3", render_mode="rgb_array")
50
 
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
52
  actor = get_actor_network(state_dim=8, action_dim=4).to(device)
53
 
 
54
  try:
55
  actor.load_state_dict(torch.load(weights_path, map_location=device, weights_only=True))
56
  actor.eval()
 
83
 
84
  env.close()
85
 
 
86
  video_filename = "eval_output.mp4"
87
  fps = 30
88
  try:
 
92
 
93
  return video_filename
94
 
95
+
96
  with gr.Blocks(title="Representation over Routing", theme=gr.themes.Base()) as demo:
97
  gr.Markdown("## Representation over Routing")
98
  gr.Markdown("Multi-timescale RL evaluation environment. Select an ablation stage to visualize policy behavior.")