Spaces:
Paused
Paused
Iteration_4: expand task bank 20→30, add variants for all tasks, fix baseline policy for chain-gated CRIT tasks
ffea7f4 | """ | |
| Utilities for running the baseline agent programmatically (used by /baseline endpoint). | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from payops_env.environment import PayOpsEnvironment | |
| from payops_env.grader import grade_episode | |
| from payops_env.models import PayOpsAction | |
| from payops_env.tasks import TASKS | |
| # --------------------------------------------------------------------------- | |
| # Adaptive rule-based policy | |
| # --------------------------------------------------------------------------- | |
| _DANGER_FLAGS = { | |
| "sanctioned_country", "app_scam_indicator", "mule_account_pattern", | |
| "structuring_pattern", "ctr_threshold_avoidance", "fraud_ring_indicator", | |
| "geo_impossible_login", "account_takeover_indicator", | |
| } | |
| _WATCHLIST_FLAGS = { | |
| "new_account_7d", "large_first_transfer", "solicitor_mule_pattern", | |
| "dormant_receiver", "sudden_activity", "insider_threat", "internal_to_personal", | |
| "invoice_mismatch", "trade_finance", | |
| } | |
| def _should_investigate(obs) -> Optional[str]: | |
| """ | |
| Decide whether to issue an investigation sub-action first. | |
| Returns the sub-action name, or None if we should decide directly. | |
| Priority: | |
| 1. KYC expired / pending → verify_kyc (once) | |
| 2. Watchlist flags AND high amount AND docs not yet requested → request_docs | |
| 3. Low ml_confidence AND medium risk → inspect (once) | |
| 4. contact_sender: APP scam pattern or insider threat | |
| 5. file_sar: if structuring/fraud-ring flags and not yet filed | |
| """ | |
| # The env sets both "already_used" (bool for this action) and | |
| # "investigation_used" (list of all inv actions used for this task) | |
| if isinstance(obs.info, dict): | |
| inv_used = obs.info.get("investigation_used", []) | |
| already = set(inv_used) if isinstance(inv_used, (list, set)) else set() | |
| else: | |
| already = set() | |
| # file_sar if structuring / fraud ring and SAR not yet filed | |
| sar_flags = {"structuring_pattern", "ctr_threshold_avoidance", | |
| "fraud_ring_indicator", "coordinated_transfers"} | |
| if sar_flags & set(obs.flags) and "file_sar" not in already: | |
| return "file_sar" | |
| # contact_sender for APP scam or insider | |
| contact_flags = {"app_scam_indicator", "internal_to_personal", "account_takeover_indicator"} | |
| if contact_flags & set(obs.flags) and "contact_sender" not in already: | |
| return "contact_sender" | |
| # verify_kyc if expired or pending | |
| if obs.kyc_status in ("expired", "pending") and "verify_kyc" not in already: | |
| return "verify_kyc" | |
| # request_docs for first-time payees with high value | |
| doc_flags = {"first_time_payee", "large_first_transfer", "invoice_mismatch", "trade_finance"} | |
| if (doc_flags & set(obs.flags) | |
| and obs.amount >= 50_000 | |
| and "request_docs" not in already): | |
| return "request_docs" | |
| # inspect when ml_confidence is low or watchlist flags are present | |
| ml_conf = getattr(obs, "ml_confidence", 0.9) or 0.9 | |
| watchlist_hit = bool(_WATCHLIST_FLAGS & set(obs.flags)) | |
| if (ml_conf < 0.60 or watchlist_hit) and "inspect" not in already: | |
| return "inspect" | |
| # Fallback for chain-gated tasks (chain_total > 1): if not enough investigation | |
| # steps have been done yet, issue generic actions in priority order so the | |
| # baseline agent never gets stuck in an infinite chain-gate loop. | |
| chain_min = max(0, getattr(obs, "chain_total", 1) - 1) | |
| if chain_min > 0 and len(already) < chain_min: | |
| for inv_action in ("inspect", "verify_kyc", "request_docs", "contact_sender", "file_sar"): | |
| if inv_action not in already: | |
| return inv_action | |
| return None | |
| def _terminal_decision(obs) -> str: | |
| """ | |
| Make a terminal decision. Uses enriched observation fields where available | |
| (inspect_reveal, docs_notes, kyc_notes, contact_notes already in obs.info). | |
| """ | |
| # Definitive danger signals | |
| if any(f in _DANGER_FLAGS for f in obs.flags): | |
| # Exception: if we confirmed via contact that it is genuine, approve | |
| contact = getattr(obs, "contact_notes", None) or "" | |
| if "CEO confirms they did NOT" in contact or "Classic APP scam" in contact: | |
| return "reject" | |
| return "reject" | |
| vel = obs.velocity_1h | |
| if vel is not None and vel >= 10: | |
| return "reject" | |
| if obs.kyc_status in ("failed", "none"): | |
| return "escalate" | |
| # KYC expired → hold (wait for renewal) | |
| if obs.kyc_status == "expired": | |
| return "hold" | |
| # Very high risk | |
| if obs.risk_score >= 0.85: | |
| return "reject" | |
| # Watchlist flags with high value | |
| if _WATCHLIST_FLAGS & set(obs.flags) and obs.amount >= 20_000: | |
| return "escalate" | |
| # FX / correspondent banking settlement | |
| if "fx_settlement" in obs.flags: | |
| return "approve" | |
| # After inspection revealed legitimacy (corp-level, FX) | |
| inspection = getattr(obs, "inspection_notes", None) or "" | |
| if "correspondent banking" in inspection.lower() or "on file" in inspection.lower(): | |
| if obs.risk_score < 0.75: | |
| return "approve" | |
| if obs.risk_score >= 0.65: | |
| return "escalate" | |
| elif obs.risk_score >= 0.40 or obs.flags: | |
| return "flag" | |
| else: | |
| return "approve" | |
| def _rule_based_policy(obs) -> str: | |
| """ | |
| Adaptive policy. First checks whether investigation is needed; if so | |
| returns a sub-action. Otherwise returns a terminal decision. | |
| """ | |
| sub = _should_investigate(obs) | |
| if sub is not None: | |
| return sub | |
| return _terminal_decision(obs) | |
| # --------------------------------------------------------------------------- | |
| # Episode runner | |
| # --------------------------------------------------------------------------- | |
| async def run_baseline() -> Tuple[List[Dict[str, Any]], float, float, int]: | |
| """ | |
| Run the adaptive rule-based baseline over the full task set. | |
| Returns: | |
| (per_task_rewards, total_reward, normalised_score, steps) | |
| """ | |
| env = PayOpsEnvironment() | |
| obs = await env.reset_async() | |
| actions_taken: List[str] = [] | |
| confs: List[Optional[float]] = [] | |
| step = 0 | |
| while not obs.done: | |
| action_type = _rule_based_policy(obs) | |
| action = PayOpsAction( | |
| action_type=action_type, | |
| transaction_id=obs.transaction_id, | |
| ) | |
| obs = await env.step_async(action) | |
| actions_taken.append(action_type) | |
| confs.append(None) | |
| step += 1 | |
| jittered_tasks = list(env._tasks) | |
| env.close() | |
| result = grade_episode(actions_taken, jittered_tasks, confs) | |
| return result.per_task_rewards, result.total_reward, result.normalised_score, step | |