Spaces:
Sleeping
Sleeping
| """SupportDesk environment implementation.""" | |
| from __future__ import annotations | |
| import os | |
| import threading | |
| import uuid | |
| from pathlib import Path | |
| from typing import ClassVar | |
| from graders import grade_case | |
| from models import ( | |
| ActionHistoryEntry, | |
| CustomerFollowUp, | |
| SupportCaseProgress, | |
| SupportDeskAction, | |
| SupportDeskObservation, | |
| SupportDeskState, | |
| ) | |
| from openenv_compat import Environment, EnvironmentMetadata | |
| from tasks import ( | |
| ALL_ISSUE_TYPES, | |
| ALL_PRIORITIES, | |
| ALL_QUEUES, | |
| ALL_STATUSES, | |
| SupportTaskSpec, | |
| get_task, | |
| list_task_ids, | |
| ) | |
| class SupportDeskEnvironment( | |
| Environment[SupportDeskAction, SupportDeskObservation, SupportDeskState] | |
| ): | |
| """A realistic customer support triage environment with dense rewards.""" | |
| _state_lock: ClassVar[threading.RLock] = threading.RLock() | |
| _episode_store: ClassVar[dict[str, SupportDeskState]] = {} | |
| _episode_task_ids: ClassVar[dict[str, str]] = {} | |
| _latest_episode_id: ClassVar[str | None] = None | |
| _shared_reset_counter: ClassVar[int] = 0 | |
| def __init__(self, task_id: str | None = None): | |
| super().__init__() | |
| env_task_id = os.getenv("SUPPORTDESK_TASK_ID") | |
| self._explicit_task_id = task_id is not None or env_task_id is not None | |
| requested_task = task_id or env_task_id or list_task_ids()[0] | |
| self.task: SupportTaskSpec = get_task(requested_task) | |
| self._max_steps = self.task.max_steps | |
| self._step_count = 0 | |
| self._reward_total = 0.0 | |
| self._done = False | |
| self._last_feedback = "" | |
| self._history: list[ActionHistoryEntry] = [] | |
| self._case = SupportCaseProgress() | |
| self._episode_id: str | None = None | |
| self._current_sla_minutes_remaining = self.task.ticket.sla_minutes_remaining | |
| initial_grade = grade_case(self.task, self._case) | |
| self._score = initial_grade.total_score | |
| self._completed_milestones = list(initial_grade.completed_milestones) | |
| def _build_initial_state(cls, task: SupportTaskSpec, episode_id: str) -> SupportDeskState: | |
| initial_case = SupportCaseProgress() | |
| initial_grade = grade_case(task, initial_case) | |
| return SupportDeskState( | |
| episode_id=episode_id, | |
| task_id=task.task_id, | |
| difficulty=task.difficulty, | |
| step_count=0, | |
| reward=0.0, | |
| done=False, | |
| current_score=initial_grade.total_score, | |
| max_steps=task.max_steps, | |
| case=initial_case, | |
| current_sla_minutes_remaining=task.ticket.sla_minutes_remaining, | |
| workflow_stage="intake", | |
| required_next_actions=["classify"], | |
| risk_flags=[], | |
| action_history=[], | |
| completed_milestones=list(initial_grade.completed_milestones), | |
| last_feedback="New case loaded. Review the ticket and policy snippets before acting.", | |
| ) | |
| def _extract_episode_id(cls, episode_id: str | None = None, **kwargs) -> str | None: | |
| if episode_id: | |
| return episode_id | |
| for key in ("episode_id", "request_id"): | |
| value = kwargs.get(key) | |
| if isinstance(value, str) and value: | |
| return value | |
| return None | |
| def _load_episode(self, episode_id: str | None = None, **kwargs) -> None: | |
| resolved_episode_id = self._extract_episode_id(episode_id, **kwargs) or self.__class__._latest_episode_id | |
| if not resolved_episode_id: | |
| return | |
| episode_state = self.__class__._episode_store.get(resolved_episode_id) | |
| if episode_state is None: | |
| raise ValueError( | |
| f"Unknown episode_id '{resolved_episode_id}'. Call reset() first or provide a valid episode_id." | |
| ) | |
| task = get_task(self.__class__._episode_task_ids.get(resolved_episode_id, episode_state.task_id)) | |
| self.task = task | |
| self._max_steps = episode_state.max_steps | |
| self._step_count = episode_state.step_count | |
| self._reward_total = episode_state.reward | |
| self._done = episode_state.done | |
| self._last_feedback = episode_state.last_feedback | |
| self._history = [entry.model_copy(deep=True) for entry in episode_state.action_history] | |
| self._case = episode_state.case.model_copy(deep=True) | |
| self._episode_id = resolved_episode_id | |
| self._score = episode_state.current_score | |
| self._completed_milestones = list(episode_state.completed_milestones) | |
| self._current_sla_minutes_remaining = episode_state.current_sla_minutes_remaining | |
| def _persist_episode(self) -> None: | |
| if self._episode_id is None: | |
| return | |
| self.__class__._episode_store[self._episode_id] = SupportDeskState( | |
| episode_id=self._episode_id, | |
| task_id=self.task.task_id, | |
| difficulty=self.task.difficulty, | |
| step_count=self._step_count, | |
| reward=round(self._reward_total, 4), | |
| done=self._done, | |
| current_score=round(self._score, 4), | |
| max_steps=self._max_steps, | |
| case=self._case.model_copy(deep=True), | |
| current_sla_minutes_remaining=self._current_sla_minutes_remaining, | |
| workflow_stage=self._workflow_stage(), | |
| required_next_actions=self._required_next_actions(), | |
| risk_flags=self._risk_flags(), | |
| action_history=[entry.model_copy(deep=True) for entry in self._history], | |
| completed_milestones=list(self._completed_milestones), | |
| last_feedback=self._last_feedback, | |
| ) | |
| self.__class__._episode_task_ids[self._episode_id] = self.task.task_id | |
| self.__class__._latest_episode_id = self._episode_id | |
| def state(self) -> SupportDeskState: | |
| with self.__class__._state_lock: | |
| self._load_episode() | |
| return SupportDeskState( | |
| episode_id=self._episode_id, | |
| task_id=self.task.task_id, | |
| difficulty=self.task.difficulty, | |
| step_count=self._step_count, | |
| reward=round(self._reward_total, 4), | |
| done=self._done, | |
| current_score=round(self._score, 4), | |
| max_steps=self._max_steps, | |
| case=self._case.model_copy(deep=True), | |
| current_sla_minutes_remaining=self._current_sla_minutes_remaining, | |
| workflow_stage=self._workflow_stage(), | |
| required_next_actions=self._required_next_actions(), | |
| risk_flags=self._risk_flags(), | |
| action_history=[entry.model_copy(deep=True) for entry in self._history], | |
| completed_milestones=list(self._completed_milestones), | |
| last_feedback=self._last_feedback, | |
| ) | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs, | |
| ) -> SupportDeskObservation: | |
| with self.__class__._state_lock: | |
| if not self._explicit_task_id: | |
| task_ids = list_task_ids() | |
| next_task_id = task_ids[self.__class__._shared_reset_counter % len(task_ids)] | |
| self.__class__._shared_reset_counter += 1 | |
| self.task = get_task(next_task_id) | |
| self._max_steps = self.task.max_steps | |
| self._episode_id = episode_id or f"{self.task.task_id}-{uuid.uuid4().hex[:8]}" | |
| initial_state = self.__class__._build_initial_state(self.task, self._episode_id) | |
| self.__class__._episode_store[self._episode_id] = initial_state | |
| self.__class__._episode_task_ids[self._episode_id] = self.task.task_id | |
| self.__class__._latest_episode_id = self._episode_id | |
| self._load_episode(self._episode_id) | |
| return self._build_observation(reward=0.0, done=False) | |
| def step( | |
| self, | |
| action: SupportDeskAction, | |
| timeout_s: float | None = None, | |
| episode_id: str | None = None, | |
| **kwargs, | |
| ) -> SupportDeskObservation: | |
| with self.__class__._state_lock: | |
| self._load_episode(episode_id, **kwargs) | |
| if self._done: | |
| return self._build_observation( | |
| reward=-0.05, | |
| done=True, | |
| feedback="Episode already finished. Call reset() before taking more actions.", | |
| ) | |
| previous_grade = grade_case(self.task, self._case) | |
| previous_stage = self._workflow_stage() | |
| self._apply_action(action) | |
| self._step_count += 1 | |
| self._advance_external_events(action) | |
| self._degrade_sla() | |
| current_grade = grade_case(self.task, self._case) | |
| reward = current_grade.total_score - previous_grade.total_score | |
| reward += self._process_bonus(action, previous_stage, current_grade.total_score) | |
| reward += self._action_penalty( | |
| action, | |
| current_grade.total_score, | |
| previous_grade.total_score, | |
| ) | |
| reward = round(reward, 4) | |
| self._score = current_grade.total_score | |
| self._completed_milestones = list(current_grade.completed_milestones) | |
| if action.operation == "submit": | |
| self._done = True | |
| self._last_feedback = ( | |
| "Case submitted. Final deterministic grade is " | |
| f"{current_grade.total_score:.2f}." | |
| ) | |
| elif self._step_count >= self._max_steps: | |
| self._done = True | |
| self._last_feedback = ( | |
| f"Reached max steps ({self._max_steps}). Final deterministic grade is " | |
| f"{current_grade.total_score:.2f}." | |
| ) | |
| else: | |
| self._last_feedback = self._build_feedback(current_grade, reward) | |
| self._reward_total = round(self._reward_total + reward, 4) | |
| self._history.append( | |
| ActionHistoryEntry( | |
| step=self._step_count, | |
| operation=action.operation, | |
| summary=self._summarize_action(action), | |
| reward_delta=reward, | |
| ) | |
| ) | |
| self._persist_episode() | |
| return self._build_observation(reward=reward, done=self._done) | |
| def state_for_episode(cls, episode_id: str) -> SupportDeskState: | |
| with cls._state_lock: | |
| state = cls._episode_store.get(episode_id) | |
| if state is None: | |
| raise ValueError(f"Unknown episode_id '{episode_id}'. Call reset() first.") | |
| return state.model_copy(deep=True) | |
| def close(self) -> None: | |
| """No-op close hook for compatibility with local scripts.""" | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| """Return richer metadata for docs, validators, and HF Space UI.""" | |
| readme_path = Path(__file__).resolve().parents[1] / "README.md" | |
| readme_content = readme_path.read_text(encoding="utf-8") if readme_path.exists() else None | |
| return EnvironmentMetadata( | |
| name="supportdesk_env", | |
| description=( | |
| "A policy-heavy enterprise operations desk with deterministic grading, delayed " | |
| "customer follow-ups, SLA pressure, escalation tradeoffs, and sharper cross-functional triage." | |
| ), | |
| readme_content=readme_content, | |
| version="0.1.0", | |
| author="HyperBrick", | |
| ) | |
| def _apply_action(self, action: SupportDeskAction) -> None: | |
| if action.operation == "classify": | |
| if action.queue is not None: | |
| self._case.queue = action.queue | |
| if action.priority is not None: | |
| self._case.priority = action.priority | |
| if action.issue_type is not None: | |
| self._case.issue_type = action.issue_type | |
| return | |
| if action.operation == "request_info": | |
| if action.requested_fields: | |
| merged = {item for item in self._case.requested_fields} | |
| merged.update(action.requested_fields) | |
| self._case.requested_fields = sorted(merged) | |
| if self.task.follow_up_outcome != "none" and self._case.customer_follow_up.status == "none": | |
| self._case.customer_follow_up = CustomerFollowUp(status="pending") | |
| return | |
| if action.operation == "draft_reply": | |
| if action.reply is not None: | |
| self._case.reply = action.reply | |
| return | |
| if action.operation == "add_internal_note": | |
| if action.internal_note is not None: | |
| self._case.internal_note = action.internal_note | |
| return | |
| if action.operation == "submit": | |
| if action.status is not None: | |
| self._case.status = action.status | |
| if action.resolution_code is not None: | |
| self._case.resolution_code = action.resolution_code | |
| def _advance_external_events(self, action: SupportDeskAction) -> None: | |
| if self._case.customer_follow_up.status == "pending" and action.operation == "wait": | |
| self._case.customer_follow_up = CustomerFollowUp( | |
| status=self.task.follow_up_outcome, | |
| message=self.task.follow_up_message or None, | |
| provided_fields=list(self.task.follow_up_provided_fields), | |
| wrong_fields=list(self.task.follow_up_wrong_fields), | |
| ) | |
| def _degrade_sla(self) -> None: | |
| if self._current_sla_minutes_remaining is None: | |
| return | |
| self._current_sla_minutes_remaining = max( | |
| 0, | |
| self._current_sla_minutes_remaining - self.task.sla_step_cost, | |
| ) | |
| def _action_penalty( | |
| self, | |
| action: SupportDeskAction, | |
| current_score: float, | |
| previous_score: float, | |
| ) -> float: | |
| penalty = 0.0 | |
| if current_score <= previous_score: | |
| penalty -= 0.03 | |
| penalty -= self._mixed_action_penalty(action) | |
| penalty -= self._escalation_tradeoff_penalty() | |
| if action.operation == "draft_reply" and not action.reply: | |
| penalty -= 0.03 | |
| if action.operation == "request_info" and not action.requested_fields: | |
| penalty -= 0.03 | |
| if action.operation == "add_internal_note" and not action.internal_note: | |
| penalty -= 0.03 | |
| if action.operation == "classify" and not any( | |
| [action.queue, action.priority, action.issue_type, action.status, action.resolution_code] | |
| ): | |
| penalty -= 0.03 | |
| if action.operation == "wait" and self._case.customer_follow_up.status != "pending": | |
| penalty -= 0.02 | |
| if action.operation == "submit" and self._required_next_actions(): | |
| penalty -= 0.08 | |
| if ( | |
| self.task.under_escalation_deadline_step is not None | |
| and self._step_count >= self.task.under_escalation_deadline_step | |
| and (self._case.queue != self.task.gold_queue or self._case.priority != self.task.gold_priority) | |
| ): | |
| penalty -= 0.04 | |
| if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining <= 15: | |
| penalty -= 0.02 | |
| return round(penalty, 4) | |
| def _build_feedback(self, grade, reward: float) -> str: | |
| return ( | |
| f"Reward delta {reward:+.2f}. Current score {grade.total_score:.2f}. " | |
| f"SLA remaining: {self._current_sla_minutes_remaining if self._current_sla_minutes_remaining is not None else 'n/a'} minutes. " | |
| f"Stage: {self._workflow_stage()}. " | |
| f"Customer follow-up: {self._case.customer_follow_up.status}. " | |
| f"Next actions: {', '.join(self._required_next_actions()) or 'none'}. " | |
| f"Completed milestones: {', '.join(grade.completed_milestones) or 'none yet'}." | |
| ) | |
| def _summarize_action(self, action: SupportDeskAction) -> str: | |
| parts = [action.operation] | |
| if action.queue: | |
| parts.append(f"queue={action.queue}") | |
| if action.priority: | |
| parts.append(f"priority={action.priority}") | |
| if action.issue_type: | |
| parts.append(f"issue_type={action.issue_type}") | |
| if action.status: | |
| parts.append(f"status={action.status}") | |
| if action.resolution_code: | |
| parts.append(f"resolution={action.resolution_code}") | |
| if action.requested_fields: | |
| parts.append(f"requested={','.join(action.requested_fields)}") | |
| if action.reply: | |
| parts.append("reply=yes") | |
| if action.internal_note: | |
| parts.append("note=yes") | |
| return " | ".join(parts) | |
| def _build_observation( | |
| self, | |
| reward: float, | |
| done: bool, | |
| feedback: str | None = None, | |
| ) -> SupportDeskObservation: | |
| return SupportDeskObservation( | |
| task_id=self.task.task_id, | |
| difficulty=self.task.difficulty, | |
| objective=self.task.objective, | |
| ticket=self.task.ticket, | |
| knowledge_base=list(self.task.knowledge_base), | |
| available_queues=list(ALL_QUEUES), | |
| available_priorities=list(ALL_PRIORITIES), | |
| available_statuses=list(ALL_STATUSES), | |
| available_issue_types=list(ALL_ISSUE_TYPES), | |
| case=self._case.model_copy(deep=True), | |
| current_sla_minutes_remaining=self._current_sla_minutes_remaining, | |
| workflow_stage=self._workflow_stage(), | |
| required_next_actions=self._required_next_actions(), | |
| risk_flags=self._risk_flags(), | |
| action_history=[entry.model_copy(deep=True) for entry in self._history], | |
| feedback=feedback or self._last_feedback, | |
| remaining_steps=max(self._max_steps - self._step_count, 0), | |
| reward=reward, | |
| done=done, | |
| ) | |
| def _workflow_stage(self) -> str: | |
| if self._done: | |
| return "closed" | |
| if self._case.queue is None or self._case.priority is None or self._case.issue_type is None: | |
| return "intake" | |
| if self.task.required_requested_fields and sorted(self._case.requested_fields) != sorted(self.task.required_requested_fields): | |
| return "verification" | |
| if self._case.customer_follow_up.status == "pending": | |
| return "awaiting_customer" | |
| if self._case.customer_follow_up.status in {"partial", "incorrect"}: | |
| return "follow_up_review" | |
| if not self._case.reply: | |
| return "customer_communication" | |
| if not self._case.internal_note: | |
| return "internal_handoff" | |
| if self._case.status != self.task.gold_status or self._case.resolution_code != self.task.gold_resolution_code: | |
| return "final_resolution" | |
| return "ready_to_submit" | |
| def _required_next_actions(self) -> list[str]: | |
| if self._case.queue is None or self._case.priority is None or self._case.issue_type is None: | |
| return ["classify"] | |
| if self.task.required_requested_fields and sorted(self._case.requested_fields) != sorted(self.task.required_requested_fields): | |
| return ["request_info"] | |
| if self._case.customer_follow_up.status == "pending": | |
| return ["wait"] | |
| needed: list[str] = [] | |
| if not self._case.reply: | |
| needed.append("draft_reply") | |
| if not self._case.internal_note: | |
| needed.append("add_internal_note") | |
| if self._case.status != self.task.gold_status or self._case.resolution_code != self.task.gold_resolution_code: | |
| needed.append("submit") | |
| return needed | |
| def _risk_flags(self) -> list[str]: | |
| flags = list(self.task.risk_flags) | |
| if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining <= 30: | |
| flags.append("sla_breach_risk") | |
| if self.task.ticket.affected_users and self.task.ticket.affected_users >= 1000: | |
| flags.append("high_customer_impact") | |
| if self.task.ticket.secondary_concerns: | |
| flags.append("secondary_issue_present") | |
| if self._case.customer_follow_up.status == "partial": | |
| flags.append("customer_reply_incomplete") | |
| if self._case.customer_follow_up.status == "incorrect": | |
| flags.append("customer_reply_irrelevant") | |
| return sorted(set(flags)) | |
| def _process_bonus( | |
| self, | |
| action: SupportDeskAction, | |
| previous_stage: str, | |
| current_score: float, | |
| ) -> float: | |
| bonus = 0.0 | |
| stage_rank = { | |
| "intake": 0, | |
| "verification": 1, | |
| "awaiting_customer": 2, | |
| "follow_up_review": 3, | |
| "customer_communication": 4, | |
| "internal_handoff": 5, | |
| "final_resolution": 6, | |
| "ready_to_submit": 7, | |
| "closed": 8, | |
| } | |
| current_stage = self._workflow_stage() | |
| if stage_rank.get(current_stage, 0) > stage_rank.get(previous_stage, 0): | |
| bonus += 0.02 | |
| if action.operation == "classify" and self._step_count == 1: | |
| if self._case.queue == self.task.gold_queue and self._case.priority == self.task.gold_priority: | |
| bonus += 0.03 | |
| if action.operation == "request_info" and current_score > 0 and self.task.required_requested_fields: | |
| bonus += 0.02 | |
| if action.operation == "wait" and self._case.customer_follow_up.status in {"partial", "complete", "incorrect"}: | |
| bonus += 0.02 | |
| if action.operation == "submit" and not self._required_next_actions(): | |
| bonus += 0.03 | |
| if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining > 0: | |
| if self.task.gold_priority == "urgent" and self._step_count <= 2 and self._case.queue == self.task.gold_queue: | |
| bonus += 0.02 | |
| return round(bonus, 4) | |
| def _mixed_action_penalty(self, action: SupportDeskAction) -> float: | |
| allowed_fields = { | |
| "classify": {"queue", "priority", "issue_type"}, | |
| "request_info": {"requested_fields"}, | |
| "draft_reply": {"reply"}, | |
| "add_internal_note": {"internal_note"}, | |
| "submit": {"status", "resolution_code"}, | |
| "wait": set(), | |
| } | |
| populated_fields = { | |
| "queue": action.queue, | |
| "priority": action.priority, | |
| "issue_type": action.issue_type, | |
| "status": action.status, | |
| "resolution_code": action.resolution_code, | |
| "requested_fields": action.requested_fields, | |
| "reply": action.reply, | |
| "internal_note": action.internal_note, | |
| } | |
| extras = 0 | |
| for field_name, value in populated_fields.items(): | |
| if field_name in allowed_fields[action.operation]: | |
| continue | |
| if value is None: | |
| continue | |
| if isinstance(value, list) and not value: | |
| continue | |
| if isinstance(value, str) and not value: | |
| continue | |
| extras += 1 | |
| return min(0.06, extras * 0.02) | |
| def _escalation_tradeoff_penalty(self) -> float: | |
| penalty = 0.0 | |
| if self._case.queue in self.task.over_escalation_queues and self._case.queue != self.task.gold_queue: | |
| penalty += 0.06 | |
| return round(penalty, 4) | |