Spaces:
Sleeping
Sleeping
| """ | |
| Core SupportEnv environment logic. | |
| Simulates a customer support ticket triage workflow: | |
| - Task 1 (easy): Ticket Classification — assign category + priority | |
| - Task 2 (medium): Information Extraction — pull entities + required actions | |
| - Task 3 (hard): Resolution Generation — write response + resolution steps | |
| Manages episode lifecycle: | |
| reset(task_id, ticket_index) → Observation | |
| step(episode_id, action) → StepResult | |
| get_state(episode_id) → State | |
| grade(episode_id) → (score, breakdown, feedback) | |
| """ | |
| from __future__ import annotations | |
| import uuid | |
| from typing import Any, Dict, Optional, Tuple | |
| from data import TASK_META, get_task_meta, get_tickets | |
| from graders import grade_task | |
| from models import ( | |
| Action, | |
| Observation, | |
| Reward, | |
| State, | |
| StepResult, | |
| TicketInfo, | |
| ) | |
| # In-memory store: episode_id → episode dict | |
| _EPISODES: Dict[str, Dict[str, Any]] = {} | |
| # --------------------------------------------------------------------------- | |
| # Reward constants (match openenv.yaml) | |
| # --------------------------------------------------------------------------- | |
| STEP_COST = -0.02 | |
| SUBMIT_BONUS = 0.05 | |
| MAX_STEP_PENALTY = -0.10 | |
| # --------------------------------------------------------------------------- | |
| # Core API | |
| # --------------------------------------------------------------------------- | |
| def reset(task_id: str, ticket_index: int = 0) -> Observation: | |
| """Create a new episode for the given task and ticket.""" | |
| if task_id not in TASK_META: | |
| raise ValueError(f"Unknown task_id {task_id!r}. Valid: {list(TASK_META)}") | |
| meta = TASK_META[task_id] | |
| tickets = get_tickets(task_id) | |
| if ticket_index < 0 or ticket_index >= len(tickets): | |
| raise ValueError( | |
| f"ticket_index {ticket_index} out of range [0, {len(tickets) - 1}]" | |
| ) | |
| ticket_data = tickets[ticket_index] | |
| safe_meta = get_task_meta(task_id) | |
| episode_id = str(uuid.uuid4()) | |
| _EPISODES[episode_id] = { | |
| "task_id": task_id, | |
| "ticket_index": ticket_index, | |
| "ticket_data": ticket_data, | |
| "step_number": 0, | |
| "max_steps": meta["max_steps"], | |
| "done": False, | |
| "total_reward": 0.0, | |
| "action_history": [], | |
| "final_score": None, | |
| } | |
| ticket_info = TicketInfo( | |
| ticket_id=ticket_data["ticket_id"], | |
| subject=ticket_data["subject"], | |
| body=ticket_data["body"], | |
| customer_tier=ticket_data["customer_tier"], | |
| account_age_days=ticket_data["account_age_days"], | |
| previous_tickets=ticket_data["previous_tickets"], | |
| attachments=ticket_data.get("attachments", []), | |
| ) | |
| return Observation( | |
| task_id=task_id, | |
| task_description=safe_meta["description"], | |
| episode_id=episode_id, | |
| ticket=ticket_info, | |
| thread_history=[], | |
| available_actions=safe_meta["available_actions"], | |
| step_number=0, | |
| max_steps=meta["max_steps"], | |
| hint=_get_hint(task_id, 0), | |
| ) | |
| def step(episode_id: str, action: Action) -> StepResult: | |
| """Advance the episode by one step.""" | |
| ep = _EPISODES.get(episode_id) | |
| if ep is None: | |
| raise KeyError(f"Episode {episode_id} not found") | |
| if ep["done"]: | |
| raise ValueError(f"Episode {episode_id} is already done.") | |
| task_id = ep["task_id"] | |
| ep["step_number"] += 1 | |
| ep["action_history"].append(action.model_dump()) | |
| # Determine if done | |
| done = False | |
| if action.action_type == "submit": | |
| done = True | |
| elif ep["step_number"] >= ep["max_steps"]: | |
| done = True | |
| # Calculate step reward | |
| step_reward, explanation = _calculate_step_reward(task_id, action, ep, done) | |
| # Apply grader bonus on terminal step | |
| if done: | |
| final_score, _breakdown, _feedback = grade_task(task_id, ep) | |
| ep["final_score"] = final_score | |
| # Grader score is the terminal bonus (0–1) | |
| step_reward += final_score | |
| explanation += f" | Grader score: {final_score:.3f}" | |
| # Penalty for running out of steps without submitting | |
| if action.action_type != "submit" and ep["step_number"] >= ep["max_steps"]: | |
| step_reward += MAX_STEP_PENALTY | |
| explanation += f" | Max-step penalty: {MAX_STEP_PENALTY}" | |
| else: | |
| final_score = None | |
| ep["total_reward"] = round(ep["total_reward"] + step_reward, 4) | |
| ep["done"] = done | |
| # Build observation | |
| ticket_data = ep["ticket_data"] | |
| safe_meta = get_task_meta(task_id) | |
| ticket_info = TicketInfo( | |
| ticket_id=ticket_data["ticket_id"], | |
| subject=ticket_data["subject"], | |
| body=ticket_data["body"], | |
| customer_tier=ticket_data["customer_tier"], | |
| account_age_days=ticket_data["account_age_days"], | |
| previous_tickets=ticket_data["previous_tickets"], | |
| attachments=ticket_data.get("attachments", []), | |
| ) | |
| thread_history = [ | |
| {"role": "agent", "content": _summarize_action(a)} | |
| for a in ep["action_history"] | |
| ] | |
| obs = Observation( | |
| task_id=task_id, | |
| task_description=safe_meta["description"], | |
| episode_id=episode_id, | |
| ticket=ticket_info, | |
| thread_history=thread_history, | |
| available_actions=safe_meta["available_actions"] if not done else [], | |
| step_number=ep["step_number"], | |
| max_steps=ep["max_steps"], | |
| hint=None if done else _get_hint(task_id, ep["step_number"]), | |
| ) | |
| reward = Reward( | |
| step_reward=round(step_reward, 4), | |
| total_reward=ep["total_reward"], | |
| explanation=explanation, | |
| ) | |
| info: Dict[str, Any] = {"step": ep["step_number"]} | |
| if done: | |
| info["final_score"] = final_score | |
| return StepResult(observation=obs, reward=reward, done=done, info=info) | |
| def get_state(episode_id: str) -> State: | |
| """Return the current state of an episode.""" | |
| ep = _EPISODES.get(episode_id) | |
| if ep is None: | |
| raise KeyError(f"Episode {episode_id} not found") | |
| return State( | |
| task_id=ep["task_id"], | |
| episode_id=episode_id, | |
| step_number=ep["step_number"], | |
| max_steps=ep["max_steps"], | |
| done=ep["done"], | |
| total_reward=ep["total_reward"], | |
| history=ep["action_history"], | |
| final_score=ep.get("final_score"), | |
| ) | |
| def grade(episode_id: str) -> Tuple[float, Dict[str, float], str]: | |
| """Grade a finished episode.""" | |
| ep = _EPISODES.get(episode_id) | |
| if ep is None: | |
| raise KeyError(f"Episode {episode_id} not found") | |
| if not ep.get("done"): | |
| raise ValueError(f"Episode {episode_id} is not done yet") | |
| task_id = ep["task_id"] | |
| score, breakdown, feedback = grade_task(task_id, ep) | |
| ep["final_score"] = score | |
| return score, breakdown, feedback | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _calculate_step_reward( | |
| task_id: str, action: Action, ep: Dict[str, Any], done: bool | |
| ) -> Tuple[float, str]: | |
| """Dense per-step reward.""" | |
| reward = STEP_COST # small cost per step | |
| if action.action_type == "submit": | |
| reward += SUBMIT_BONUS | |
| return reward, "Submitted for grading" | |
| # Partial-progress signals based on task | |
| if task_id == "task1": | |
| if action.action_type == "classify": | |
| if action.category: | |
| reward += 0.02 | |
| if action.priority: | |
| reward += 0.02 | |
| return reward, f"Classified: category={action.category}, priority={action.priority}" | |
| elif task_id == "task2": | |
| if action.action_type == "extract": | |
| n_entities = len(action.extracted_entities) if action.extracted_entities else 0 | |
| n_actions = len(action.required_actions) if action.required_actions else 0 | |
| reward += min(n_entities * 0.005, 0.04) | |
| reward += min(n_actions * 0.005, 0.02) | |
| return reward, f"Extracted {n_entities} entities, {n_actions} actions" | |
| elif task_id == "task3": | |
| if action.action_type == "respond": | |
| text_len = len(action.response_text or "") | |
| n_steps = len(action.resolution_steps) if action.resolution_steps else 0 | |
| if text_len > 0: | |
| reward += min(text_len * 0.0001, 0.03) | |
| if n_steps > 0: | |
| reward += min(n_steps * 0.005, 0.02) | |
| return reward, f"Response ({text_len} chars), {n_steps} resolution steps" | |
| return reward, "Step taken" | |
| def _summarize_action(action_dict: Dict[str, Any]) -> str: | |
| """One-line summary of an action for thread_history.""" | |
| atype = action_dict.get("action_type", "unknown") | |
| if atype == "classify": | |
| return f"classify(category={action_dict.get('category')}, priority={action_dict.get('priority')})" | |
| elif atype == "extract": | |
| ents = action_dict.get("extracted_entities") or {} | |
| acts = action_dict.get("required_actions") or [] | |
| return f"extract(entities={list(ents.keys())}, actions={acts})" | |
| elif atype == "respond": | |
| text = (action_dict.get("response_text") or "")[:60] | |
| steps = action_dict.get("resolution_steps") or [] | |
| return f"respond(text='{text}...', steps={len(steps)})" | |
| elif atype == "submit": | |
| return "submit()" | |
| return f"{atype}()" | |
| def _get_hint(task_id: str, step: int) -> Optional[str]: | |
| """Contextual hints to guide the agent.""" | |
| if step == 0: | |
| hints = { | |
| "task1": "Read the ticket carefully and classify by category and priority.", | |
| "task2": "Extract all entities (IDs, names, amounts) and identify required actions.", | |
| "task3": "Write a professional response and list resolution steps.", | |
| } | |
| return hints.get(task_id) | |
| return None | |