Update server/env.py
Browse files- server/env.py +36 -0
server/env.py
CHANGED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from gymnasium import spaces
|
| 4 |
+
|
| 5 |
+
class EmailTriageEnv(gym.Env):
|
| 6 |
+
def __init__(self, task="all"):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.full_dataset = [
|
| 9 |
+
{"difficulty": "easy", "description": "Spam promo", "correct_actions": (0, 0, 0)},
|
| 10 |
+
{"difficulty": "easy", "description": "Routine support", "correct_actions": (0, 1, 1)},
|
| 11 |
+
{"difficulty": "medium", "description": "Billing dispute", "correct_actions": (1, 2, 2)},
|
| 12 |
+
{"difficulty": "medium", "description": "Refund request", "correct_actions": (1, 2, 2)},
|
| 13 |
+
{"difficulty": "hard", "description": "IT password reset phish", "correct_actions": (2, 1, 2)},
|
| 14 |
+
{"difficulty": "hard", "description": "Ransomware threat", "correct_actions": (2, 2, 2)}
|
| 15 |
+
]
|
| 16 |
+
self._queue = [e for e in self.full_dataset if e.get("difficulty") == task] if task != "all" else self.full_dataset
|
| 17 |
+
self._step_idx = 0
|
| 18 |
+
|
| 19 |
+
def reset(self, seed=None, options=None):
|
| 20 |
+
self._step_idx = 0
|
| 21 |
+
return np.zeros(10, dtype=np.float32), {}
|
| 22 |
+
|
| 23 |
+
def step(self, action):
|
| 24 |
+
if self._step_idx >= len(self._queue):
|
| 25 |
+
return np.zeros(10), 0.0, True, False, {}
|
| 26 |
+
email = self._queue[self._step_idx]
|
| 27 |
+
correct = email["correct_actions"]
|
| 28 |
+
reward = 1.0 if tuple(action) == tuple(correct) else 0.0
|
| 29 |
+
if correct[0] == 2 and action[0] != 2: reward = -2.0
|
| 30 |
+
self._step_idx += 1
|
| 31 |
+
done = self._step_idx >= len(self._queue)
|
| 32 |
+
return np.zeros(10), float(reward), done, False, {"raw_reward": reward}
|
| 33 |
+
|
| 34 |
+
URGENCY_LABELS = ["General", "Billing", "Security Breach"]
|
| 35 |
+
ROUTING_LABELS = ["AI Auto-Reply", "Tech Support", "Legal"]
|
| 36 |
+
RESOLUTION_LABELS = ["Archive", "Draft Reply", "Escalate to Human"]
|