""" PayOps Environment — core simulation. Supports all 10 action types: Terminal decisions: approve | reject | flag | escalate | hold Investigation sub-actions: inspect | request_docs | verify_kyc | contact_sender | file_sar Investigation sub-actions do NOT advance the task pointer; they reveal additional context and consume budget. A file_sar sub-action is required for full credit on regulatory tasks. Multi-step chain tasks (chain_total > 1) are presented as a single task entry; the grader handles them as one decision unit. """ from __future__ import annotations import copy import random import uuid from collections import deque from typing import Deque, List, Optional from payops_env.grader import INVESTIGATION_BONUS, grade from payops_env.models import PayOpsAction, PayOpsObservation, PayOpsState from payops_env.tasks import ACTION_COSTS, TASK_VARIANTS, TASKS, PayOpsTask TERMINAL_ACTIONS = {"approve", "reject", "flag", "escalate", "hold"} INVESTIGATION_ACTIONS = {"inspect", "request_docs", "verify_kyc", "contact_sender", "file_sar"} VALID_ACTIONS = TERMINAL_ACTIONS | INVESTIGATION_ACTIONS _RECENT_WINDOW = 4 # how many past (task_id, action) pairs to surface in obs class PayOpsEnvironment: """ OpenEnv-compatible environment for Payment Operations triage. An episode proceeds through every task in TASKS. For each task the agent may issue any number of investigation sub-actions before a terminal decision closes the task. Budget ------ The agent starts with ``budget_limit`` points. Each investigation sub-action deducts its cost (see tasks.ACTION_COSTS). Budget overspend is penalised in grade_episode; within the environment it is tracked but does not terminate the episode. """ BUDGET_LIMIT = 5.0 def __init__(self): self._tasks: List[PayOpsTask] = [] self._current_task: Optional[PayOpsTask] = None self._state: PayOpsState = PayOpsState() self._used_inv: dict = {} # task_id → set of investigation actions used self._sar_filed: set = set() # task_ids where file_sar was issued self._recent_decisions: Deque = deque(maxlen=_RECENT_WINDOW) # ------------------------------------------------------------------ # OpenEnv API # ------------------------------------------------------------------ async def reset_async( self, seed: Optional[int] = None, episode_id: Optional[str] = None, ) -> PayOpsObservation: # --- per-episode jitter: prevents agent overfitting to fixed values --- # Use caller-supplied seed when provided (enables reproducibility). episode_seed = seed if seed is not None else int(uuid.uuid4().int % 2**31) rng = random.Random(episode_seed) jittered: List[PayOpsTask] = [] for t in TASKS: jt = copy.copy(t) jt.amount = round(t.amount * rng.uniform(0.85, 1.20), 2) jt.risk_score = round(min(1.0, max(0.0, t.risk_score + rng.gauss(0, 0.03))), 4) if t.velocity_1h is not None: jt.velocity_1h = max(0, t.velocity_1h + rng.randint(-3, 3)) if t.velocity_24h is not None: jt.velocity_24h = max(0, t.velocity_24h + rng.randint(-3, 3)) jittered.append(jt) self._tasks = jittered # ── Variant pool selection ────────────────────────────────────────────── # Each task in TASK_VARIANTS has 2 alternative scenarios (in addition to # the base). The episode seed selects which scenario plays out, making # the correct_action unknowable from the task_id alone — the agent MUST # investigate to discover the decisive evidence. # # seed=0 is the canonical episode (all base variants); used by the test # suite so that hardcoded expected rewards remain stable. if episode_seed != 0: for i, jt in enumerate(self._tasks): variants = TASK_VARIANTS.get(jt.task_id, []) if not variants: continue # variant_idx=0 → base (no changes); 1..N → variants[0..N-1] variant_idx = (episode_seed * 1009 + i * 17) % (len(variants) + 1) if variant_idx == 0: continue overrides = variants[variant_idx - 1] for key, val in overrides.items(): setattr(jt, key, val) self._current_task = self._tasks[0] self._used_inv = {} self._sar_filed = set() self._recent_decisions = deque(maxlen=_RECENT_WINDOW) self._episode_seed = episode_seed self._state = PayOpsState( episode_id=episode_id if episode_id is not None else str(uuid.uuid4()), episode_seed=episode_seed, step_count=0, current_task_id=self._current_task.task_id, transactions_processed=0, total_tasks=len(self._tasks), cumulative_reward=0.0, actions_taken=[], last_action=None, done=False, budget_spent=0.0, budget_limit=self.BUDGET_LIMIT, investigation_actions_used=[], correct_decisions=0, wrong_high_cost=0, recent_decisions=[], ) return self._make_observation(reward=0.0, done=False, info={"event": "reset"}) async def step_async(self, action: PayOpsAction) -> PayOpsObservation: if self._current_task is None: raise RuntimeError("Environment must be reset before stepping.") action_type = action.action_type.lower() if action_type not in VALID_ACTIONS: raise ValueError( f"Invalid action '{action_type}'. " f"Valid actions: {sorted(VALID_ACTIONS)}" ) task = self._current_task task_id = task.task_id cost = ACTION_COSTS.get(action_type, 0.0) # ── MULTI-STEP CHAIN GATE ───────────────────────────────────────────── # Critical tasks with chain_total > 1 require (chain_total − 1) # investigation sub-actions before a terminal decision is accepted. # Blocked attempts return a helpful message without advancing the task. if action_type in TERMINAL_ACTIONS: chain_min = max(0, getattr(task, "chain_total", 1) - 1) inv_done = len(self._used_inv.get(task_id, set())) if chain_min > 0 and inv_done < chain_min: needed = chain_min - inv_done return self._make_observation( reward=-0.05, done=False, info={ "event": "chain_gate_blocked", "chain_status": "investigation_required", "chain_steps_needed": needed, "message": ( f"This {task.difficulty} transaction requires {needed} " f"more investigation step(s) before a terminal decision. " f"Please investigate first." ), }, ) # Deduct cost self._state.budget_spent = round(self._state.budget_spent + cost, 4) self._state.step_count += 1 self._state.actions_taken.append(action_type) self._state.last_action = action_type # ── INVESTIGATION SUB-ACTIONS ──────────────────────────────────── if action_type in INVESTIGATION_ACTIONS: used = self._used_inv.setdefault(task_id, set()) already = action_type in used reward = 0.0 if already else INVESTIGATION_BONUS used.add(action_type) self._state.investigation_actions_used.append(action_type) self._state.cumulative_reward = round( self._state.cumulative_reward + reward, 4 ) # Determine reveal text reveal_text: Optional[str] = None reveal_field: Optional[str] = None if action_type == "inspect": reveal_text = task.inspect_reveal or "No additional information available." reveal_field = "inspection_notes" elif action_type == "request_docs": reveal_text = task.docs_reveal or "No documents on record for this transaction." reveal_field = "docs_notes" elif action_type == "verify_kyc": reveal_text = task.kyc_reveal or "KYC records could not be retrieved." reveal_field = "kyc_notes" elif action_type == "contact_sender": reveal_text = task.contact_reveal or "Sender did not respond to contact attempt." reveal_field = "contact_notes" elif action_type == "file_sar": self._sar_filed.add(task_id) reveal_text = ( "SAR filed with FinCEN. Reference number will be generated within 24 h." if task.regulatory_action else "SAR filed. Note: this transaction may not meet SAR-filing threshold." ) reveal_field = "docs_notes" used_so_far = list(self._used_inv.get(task_id, set())) info = { "event": action_type, "already_used": already, "investigation_used": used_so_far, # full list for this task reveal_field: reveal_text, "budget_remaining": round( self.BUDGET_LIMIT - self._state.budget_spent, 4 ), } return self._make_observation( reward=reward, done=False, info=info, reveal_field=reveal_field, reveal_text=reveal_text ) # ── TERMINAL DECISION ──────────────────────────────────────────── used_inv = list(self._used_inv.get(task_id, set())) sar_used = task_id in self._sar_filed inspected_already = "inspect" in self._used_inv.get(task_id, set()) investigation_done = bool(self._used_inv.get(task_id, set())) reward = grade( action_type, task, inspected_already=inspected_already, investigation_done=investigation_done, ) self._state.cumulative_reward = round( self._state.cumulative_reward + reward, 4 ) is_correct = action_type == task.correct_action if is_correct: self._state.correct_decisions += 1 elif action_type == "approve" and task.correct_action in ("reject", "escalate"): self._state.wrong_high_cost += 1 self._recent_decisions.append( {"task_id": task_id, "action": action_type, "correct": is_correct} ) self._state.recent_decisions = list(self._recent_decisions) self._state.transactions_processed += 1 # Advance task pointer task_idx = self._tasks.index(task) remaining = self._tasks[task_idx + 1:] done = len(remaining) == 0 if not done: self._current_task = remaining[0] self._state.current_task_id = self._current_task.task_id else: self._state.done = True return self._make_observation( reward=reward, done=done, info={ "event": "step", "action_taken": action_type, "correct_action": task.correct_action, "task_id": task_id, "difficulty": task.difficulty, "is_correct": is_correct, "investigation_used": used_inv, "budget_remaining": round( self.BUDGET_LIMIT - self._state.budget_spent, 4 ), }, ) def state(self) -> PayOpsState: return self._state def close(self): pass # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _make_observation( self, reward: float, done: bool, info: dict, reveal_field: Optional[str] = None, reveal_text: Optional[str] = None, ) -> PayOpsObservation: task = self._current_task steps_remaining = self._state.total_tasks - self._state.transactions_processed # Progressive disclosure: sensitive forensic fields are only surfaced after # the agent has called 'inspect' on this task. This makes the investigate- # before-deciding mechanic load-bearing rather than cosmetic. task_id = task.task_id inspected = "inspect" in self._used_inv.get(task_id, set()) contacted = "contact_sender" in self._used_inv.get(task_id, set()) return PayOpsObservation( # ── transaction core ── transaction_id=task.transaction_id, amount=task.amount, currency=task.currency, sender=task.sender, receiver=task.receiver, transaction_type=task.transaction_type, status=( "inspected" if info.get("event") in INVESTIGATION_ACTIONS else ("done" if done else "pending") ), # ── risk signals ── risk_score=task.risk_score, ml_confidence=getattr(task, "ml_confidence", 0.90), flags=list(task.flags), velocity_1h=task.velocity_1h, velocity_24h=getattr(task, "velocity_24h", None), avg_transaction_amount=getattr(task, "avg_transaction_amount", None), account_age_days=getattr(task, "account_age_days", None), country_risk=task.country_risk, kyc_status=task.kyc_status, kyc_expiry_days=getattr(task, "kyc_expiry_days", None), previous_violations=task.previous_violations if inspected else None, previous_sars=getattr(task, "previous_sars", None) if inspected else None, counterparty_risk=getattr(task, "counterparty_risk", None) if inspected else None, # ── chain context ── chain_total=getattr(task, "chain_total", 1), chain_step=self._state.step_count, chain_context=task.description, # ── investigation reveals ── inspection_notes=( reveal_text if reveal_field == "inspection_notes" else info.get("inspection_notes") ), docs_notes=( reveal_text if reveal_field == "docs_notes" else None ), kyc_notes=( reveal_text if reveal_field == "kyc_notes" else None ), contact_notes=( reveal_text if reveal_field == "contact_notes" else (info.get("contact_notes") if contacted else None) ), # ── episode meta ── task_id=task.task_id, task_difficulty=task.difficulty, step_in_episode=self._state.step_count, steps_remaining=steps_remaining, action_cost=ACTION_COSTS.get(info.get("event", ""), 0.0), budget_remaining=round(self.BUDGET_LIMIT - self._state.budget_spent, 4), investigation_hints=[], # not surfaced upfront; agent must explore recent_decisions=list(self._recent_decisions), reward=reward, cumulative_reward=self._state.cumulative_reward, done=done, network_graph=getattr(task, "network_graph", None) if inspected else None, info=info, )