"""SupportDesk environment implementation.""" from __future__ import annotations import os import threading import uuid from pathlib import Path from typing import ClassVar from graders import grade_case from models import ( ActionHistoryEntry, CustomerFollowUp, SupportCaseProgress, SupportDeskAction, SupportDeskObservation, SupportDeskState, ) from openenv_compat import Environment, EnvironmentMetadata from tasks import ( ALL_ISSUE_TYPES, ALL_PRIORITIES, ALL_QUEUES, ALL_STATUSES, SupportTaskSpec, get_task, list_task_ids, ) class SupportDeskEnvironment( Environment[SupportDeskAction, SupportDeskObservation, SupportDeskState] ): """A realistic customer support triage environment with dense rewards.""" _state_lock: ClassVar[threading.RLock] = threading.RLock() _episode_store: ClassVar[dict[str, SupportDeskState]] = {} _episode_task_ids: ClassVar[dict[str, str]] = {} _latest_episode_id: ClassVar[str | None] = None _shared_reset_counter: ClassVar[int] = 0 def __init__(self, task_id: str | None = None): super().__init__() env_task_id = os.getenv("SUPPORTDESK_TASK_ID") self._explicit_task_id = task_id is not None or env_task_id is not None requested_task = task_id or env_task_id or list_task_ids()[0] self.task: SupportTaskSpec = get_task(requested_task) self._max_steps = self.task.max_steps self._step_count = 0 self._reward_total = 0.0 self._done = False self._last_feedback = "" self._history: list[ActionHistoryEntry] = [] self._case = SupportCaseProgress() self._episode_id: str | None = None self._current_sla_minutes_remaining = self.task.ticket.sla_minutes_remaining initial_grade = grade_case(self.task, self._case) self._score = initial_grade.total_score self._completed_milestones = list(initial_grade.completed_milestones) @classmethod def _build_initial_state(cls, task: SupportTaskSpec, episode_id: str) -> SupportDeskState: initial_case = SupportCaseProgress() initial_grade = grade_case(task, initial_case) return SupportDeskState( episode_id=episode_id, task_id=task.task_id, difficulty=task.difficulty, step_count=0, reward=0.0, done=False, current_score=initial_grade.total_score, max_steps=task.max_steps, case=initial_case, current_sla_minutes_remaining=task.ticket.sla_minutes_remaining, workflow_stage="intake", required_next_actions=["classify"], risk_flags=[], action_history=[], completed_milestones=list(initial_grade.completed_milestones), last_feedback="New case loaded. Review the ticket and policy snippets before acting.", ) @classmethod def _extract_episode_id(cls, episode_id: str | None = None, **kwargs) -> str | None: if episode_id: return episode_id for key in ("episode_id", "request_id"): value = kwargs.get(key) if isinstance(value, str) and value: return value return None def _load_episode(self, episode_id: str | None = None, **kwargs) -> None: resolved_episode_id = self._extract_episode_id(episode_id, **kwargs) or self.__class__._latest_episode_id if not resolved_episode_id: return episode_state = self.__class__._episode_store.get(resolved_episode_id) if episode_state is None: raise ValueError( f"Unknown episode_id '{resolved_episode_id}'. Call reset() first or provide a valid episode_id." ) task = get_task(self.__class__._episode_task_ids.get(resolved_episode_id, episode_state.task_id)) self.task = task self._max_steps = episode_state.max_steps self._step_count = episode_state.step_count self._reward_total = episode_state.reward self._done = episode_state.done self._last_feedback = episode_state.last_feedback self._history = [entry.model_copy(deep=True) for entry in episode_state.action_history] self._case = episode_state.case.model_copy(deep=True) self._episode_id = resolved_episode_id self._score = episode_state.current_score self._completed_milestones = list(episode_state.completed_milestones) self._current_sla_minutes_remaining = episode_state.current_sla_minutes_remaining def _persist_episode(self) -> None: if self._episode_id is None: return self.__class__._episode_store[self._episode_id] = SupportDeskState( episode_id=self._episode_id, task_id=self.task.task_id, difficulty=self.task.difficulty, step_count=self._step_count, reward=round(self._reward_total, 4), done=self._done, current_score=round(self._score, 4), max_steps=self._max_steps, case=self._case.model_copy(deep=True), current_sla_minutes_remaining=self._current_sla_minutes_remaining, workflow_stage=self._workflow_stage(), required_next_actions=self._required_next_actions(), risk_flags=self._risk_flags(), action_history=[entry.model_copy(deep=True) for entry in self._history], completed_milestones=list(self._completed_milestones), last_feedback=self._last_feedback, ) self.__class__._episode_task_ids[self._episode_id] = self.task.task_id self.__class__._latest_episode_id = self._episode_id @property def state(self) -> SupportDeskState: with self.__class__._state_lock: self._load_episode() return SupportDeskState( episode_id=self._episode_id, task_id=self.task.task_id, difficulty=self.task.difficulty, step_count=self._step_count, reward=round(self._reward_total, 4), done=self._done, current_score=round(self._score, 4), max_steps=self._max_steps, case=self._case.model_copy(deep=True), current_sla_minutes_remaining=self._current_sla_minutes_remaining, workflow_stage=self._workflow_stage(), required_next_actions=self._required_next_actions(), risk_flags=self._risk_flags(), action_history=[entry.model_copy(deep=True) for entry in self._history], completed_milestones=list(self._completed_milestones), last_feedback=self._last_feedback, ) def reset( self, seed: int | None = None, episode_id: str | None = None, **kwargs, ) -> SupportDeskObservation: with self.__class__._state_lock: if not self._explicit_task_id: task_ids = list_task_ids() next_task_id = task_ids[self.__class__._shared_reset_counter % len(task_ids)] self.__class__._shared_reset_counter += 1 self.task = get_task(next_task_id) self._max_steps = self.task.max_steps self._episode_id = episode_id or f"{self.task.task_id}-{uuid.uuid4().hex[:8]}" initial_state = self.__class__._build_initial_state(self.task, self._episode_id) self.__class__._episode_store[self._episode_id] = initial_state self.__class__._episode_task_ids[self._episode_id] = self.task.task_id self.__class__._latest_episode_id = self._episode_id self._load_episode(self._episode_id) return self._build_observation(reward=0.0, done=False) def step( self, action: SupportDeskAction, timeout_s: float | None = None, episode_id: str | None = None, **kwargs, ) -> SupportDeskObservation: with self.__class__._state_lock: self._load_episode(episode_id, **kwargs) if self._done: return self._build_observation( reward=-0.05, done=True, feedback="Episode already finished. Call reset() before taking more actions.", ) previous_grade = grade_case(self.task, self._case) previous_stage = self._workflow_stage() self._apply_action(action) self._step_count += 1 self._advance_external_events(action) self._degrade_sla() current_grade = grade_case(self.task, self._case) reward = current_grade.total_score - previous_grade.total_score reward += self._process_bonus(action, previous_stage, current_grade.total_score) reward += self._action_penalty( action, current_grade.total_score, previous_grade.total_score, ) reward = round(reward, 4) self._score = current_grade.total_score self._completed_milestones = list(current_grade.completed_milestones) if action.operation == "submit": self._done = True self._last_feedback = ( "Case submitted. Final deterministic grade is " f"{current_grade.total_score:.2f}." ) elif self._step_count >= self._max_steps: self._done = True self._last_feedback = ( f"Reached max steps ({self._max_steps}). Final deterministic grade is " f"{current_grade.total_score:.2f}." ) else: self._last_feedback = self._build_feedback(current_grade, reward) self._reward_total = round(self._reward_total + reward, 4) self._history.append( ActionHistoryEntry( step=self._step_count, operation=action.operation, summary=self._summarize_action(action), reward_delta=reward, ) ) self._persist_episode() return self._build_observation(reward=reward, done=self._done) @classmethod def state_for_episode(cls, episode_id: str) -> SupportDeskState: with cls._state_lock: state = cls._episode_store.get(episode_id) if state is None: raise ValueError(f"Unknown episode_id '{episode_id}'. Call reset() first.") return state.model_copy(deep=True) def close(self) -> None: """No-op close hook for compatibility with local scripts.""" def get_metadata(self) -> EnvironmentMetadata: """Return richer metadata for docs, validators, and HF Space UI.""" readme_path = Path(__file__).resolve().parents[1] / "README.md" readme_content = readme_path.read_text(encoding="utf-8") if readme_path.exists() else None return EnvironmentMetadata( name="supportdesk_env", description=( "A policy-heavy enterprise operations desk with deterministic grading, delayed " "customer follow-ups, SLA pressure, escalation tradeoffs, and sharper cross-functional triage." ), readme_content=readme_content, version="0.1.0", author="HyperBrick", ) def _apply_action(self, action: SupportDeskAction) -> None: if action.operation == "classify": if action.queue is not None: self._case.queue = action.queue if action.priority is not None: self._case.priority = action.priority if action.issue_type is not None: self._case.issue_type = action.issue_type return if action.operation == "request_info": if action.requested_fields: merged = {item for item in self._case.requested_fields} merged.update(action.requested_fields) self._case.requested_fields = sorted(merged) if self.task.follow_up_outcome != "none" and self._case.customer_follow_up.status == "none": self._case.customer_follow_up = CustomerFollowUp(status="pending") return if action.operation == "draft_reply": if action.reply is not None: self._case.reply = action.reply return if action.operation == "add_internal_note": if action.internal_note is not None: self._case.internal_note = action.internal_note return if action.operation == "submit": if action.status is not None: self._case.status = action.status if action.resolution_code is not None: self._case.resolution_code = action.resolution_code def _advance_external_events(self, action: SupportDeskAction) -> None: if self._case.customer_follow_up.status == "pending" and action.operation == "wait": self._case.customer_follow_up = CustomerFollowUp( status=self.task.follow_up_outcome, message=self.task.follow_up_message or None, provided_fields=list(self.task.follow_up_provided_fields), wrong_fields=list(self.task.follow_up_wrong_fields), ) def _degrade_sla(self) -> None: if self._current_sla_minutes_remaining is None: return self._current_sla_minutes_remaining = max( 0, self._current_sla_minutes_remaining - self.task.sla_step_cost, ) def _action_penalty( self, action: SupportDeskAction, current_score: float, previous_score: float, ) -> float: penalty = 0.0 if current_score <= previous_score: penalty -= 0.03 penalty -= self._mixed_action_penalty(action) penalty -= self._escalation_tradeoff_penalty() if action.operation == "draft_reply" and not action.reply: penalty -= 0.03 if action.operation == "request_info" and not action.requested_fields: penalty -= 0.03 if action.operation == "add_internal_note" and not action.internal_note: penalty -= 0.03 if action.operation == "classify" and not any( [action.queue, action.priority, action.issue_type, action.status, action.resolution_code] ): penalty -= 0.03 if action.operation == "wait" and self._case.customer_follow_up.status != "pending": penalty -= 0.02 if action.operation == "submit" and self._required_next_actions(): penalty -= 0.08 if ( self.task.under_escalation_deadline_step is not None and self._step_count >= self.task.under_escalation_deadline_step and (self._case.queue != self.task.gold_queue or self._case.priority != self.task.gold_priority) ): penalty -= 0.04 if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining <= 15: penalty -= 0.02 return round(penalty, 4) def _build_feedback(self, grade, reward: float) -> str: return ( f"Reward delta {reward:+.2f}. Current score {grade.total_score:.2f}. " f"SLA remaining: {self._current_sla_minutes_remaining if self._current_sla_minutes_remaining is not None else 'n/a'} minutes. " f"Stage: {self._workflow_stage()}. " f"Customer follow-up: {self._case.customer_follow_up.status}. " f"Next actions: {', '.join(self._required_next_actions()) or 'none'}. " f"Completed milestones: {', '.join(grade.completed_milestones) or 'none yet'}." ) def _summarize_action(self, action: SupportDeskAction) -> str: parts = [action.operation] if action.queue: parts.append(f"queue={action.queue}") if action.priority: parts.append(f"priority={action.priority}") if action.issue_type: parts.append(f"issue_type={action.issue_type}") if action.status: parts.append(f"status={action.status}") if action.resolution_code: parts.append(f"resolution={action.resolution_code}") if action.requested_fields: parts.append(f"requested={','.join(action.requested_fields)}") if action.reply: parts.append("reply=yes") if action.internal_note: parts.append("note=yes") return " | ".join(parts) def _build_observation( self, reward: float, done: bool, feedback: str | None = None, ) -> SupportDeskObservation: return SupportDeskObservation( task_id=self.task.task_id, difficulty=self.task.difficulty, objective=self.task.objective, ticket=self.task.ticket, knowledge_base=list(self.task.knowledge_base), available_queues=list(ALL_QUEUES), available_priorities=list(ALL_PRIORITIES), available_statuses=list(ALL_STATUSES), available_issue_types=list(ALL_ISSUE_TYPES), case=self._case.model_copy(deep=True), current_sla_minutes_remaining=self._current_sla_minutes_remaining, workflow_stage=self._workflow_stage(), required_next_actions=self._required_next_actions(), risk_flags=self._risk_flags(), action_history=[entry.model_copy(deep=True) for entry in self._history], feedback=feedback or self._last_feedback, remaining_steps=max(self._max_steps - self._step_count, 0), reward=reward, done=done, ) def _workflow_stage(self) -> str: if self._done: return "closed" if self._case.queue is None or self._case.priority is None or self._case.issue_type is None: return "intake" if self.task.required_requested_fields and sorted(self._case.requested_fields) != sorted(self.task.required_requested_fields): return "verification" if self._case.customer_follow_up.status == "pending": return "awaiting_customer" if self._case.customer_follow_up.status in {"partial", "incorrect"}: return "follow_up_review" if not self._case.reply: return "customer_communication" if not self._case.internal_note: return "internal_handoff" if self._case.status != self.task.gold_status or self._case.resolution_code != self.task.gold_resolution_code: return "final_resolution" return "ready_to_submit" def _required_next_actions(self) -> list[str]: if self._case.queue is None or self._case.priority is None or self._case.issue_type is None: return ["classify"] if self.task.required_requested_fields and sorted(self._case.requested_fields) != sorted(self.task.required_requested_fields): return ["request_info"] if self._case.customer_follow_up.status == "pending": return ["wait"] needed: list[str] = [] if not self._case.reply: needed.append("draft_reply") if not self._case.internal_note: needed.append("add_internal_note") if self._case.status != self.task.gold_status or self._case.resolution_code != self.task.gold_resolution_code: needed.append("submit") return needed def _risk_flags(self) -> list[str]: flags = list(self.task.risk_flags) if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining <= 30: flags.append("sla_breach_risk") if self.task.ticket.affected_users and self.task.ticket.affected_users >= 1000: flags.append("high_customer_impact") if self.task.ticket.secondary_concerns: flags.append("secondary_issue_present") if self._case.customer_follow_up.status == "partial": flags.append("customer_reply_incomplete") if self._case.customer_follow_up.status == "incorrect": flags.append("customer_reply_irrelevant") return sorted(set(flags)) def _process_bonus( self, action: SupportDeskAction, previous_stage: str, current_score: float, ) -> float: bonus = 0.0 stage_rank = { "intake": 0, "verification": 1, "awaiting_customer": 2, "follow_up_review": 3, "customer_communication": 4, "internal_handoff": 5, "final_resolution": 6, "ready_to_submit": 7, "closed": 8, } current_stage = self._workflow_stage() if stage_rank.get(current_stage, 0) > stage_rank.get(previous_stage, 0): bonus += 0.02 if action.operation == "classify" and self._step_count == 1: if self._case.queue == self.task.gold_queue and self._case.priority == self.task.gold_priority: bonus += 0.03 if action.operation == "request_info" and current_score > 0 and self.task.required_requested_fields: bonus += 0.02 if action.operation == "wait" and self._case.customer_follow_up.status in {"partial", "complete", "incorrect"}: bonus += 0.02 if action.operation == "submit" and not self._required_next_actions(): bonus += 0.03 if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining > 0: if self.task.gold_priority == "urgent" and self._step_count <= 2 and self._case.queue == self.task.gold_queue: bonus += 0.02 return round(bonus, 4) def _mixed_action_penalty(self, action: SupportDeskAction) -> float: allowed_fields = { "classify": {"queue", "priority", "issue_type"}, "request_info": {"requested_fields"}, "draft_reply": {"reply"}, "add_internal_note": {"internal_note"}, "submit": {"status", "resolution_code"}, "wait": set(), } populated_fields = { "queue": action.queue, "priority": action.priority, "issue_type": action.issue_type, "status": action.status, "resolution_code": action.resolution_code, "requested_fields": action.requested_fields, "reply": action.reply, "internal_note": action.internal_note, } extras = 0 for field_name, value in populated_fields.items(): if field_name in allowed_fields[action.operation]: continue if value is None: continue if isinstance(value, list) and not value: continue if isinstance(value, str) and not value: continue extras += 1 return min(0.06, extras * 0.02) def _escalation_tradeoff_penalty(self) -> float: penalty = 0.0 if self._case.queue in self.task.over_escalation_queues and self._case.queue != self.task.gold_queue: penalty += 0.06 return round(penalty, 4)