payops_env / scripts_util.py
padmapriyagosakan's picture
Iteration_4: expand task bank 20→30, add variants for all tasks, fix baseline policy for chain-gated CRIT tasks
ffea7f4
"""
Utilities for running the baseline agent programmatically (used by /baseline endpoint).
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from payops_env.environment import PayOpsEnvironment
from payops_env.grader import grade_episode
from payops_env.models import PayOpsAction
from payops_env.tasks import TASKS
# ---------------------------------------------------------------------------
# Adaptive rule-based policy
# ---------------------------------------------------------------------------
_DANGER_FLAGS = {
"sanctioned_country", "app_scam_indicator", "mule_account_pattern",
"structuring_pattern", "ctr_threshold_avoidance", "fraud_ring_indicator",
"geo_impossible_login", "account_takeover_indicator",
}
_WATCHLIST_FLAGS = {
"new_account_7d", "large_first_transfer", "solicitor_mule_pattern",
"dormant_receiver", "sudden_activity", "insider_threat", "internal_to_personal",
"invoice_mismatch", "trade_finance",
}
def _should_investigate(obs) -> Optional[str]:
"""
Decide whether to issue an investigation sub-action first.
Returns the sub-action name, or None if we should decide directly.
Priority:
1. KYC expired / pending → verify_kyc (once)
2. Watchlist flags AND high amount AND docs not yet requested → request_docs
3. Low ml_confidence AND medium risk → inspect (once)
4. contact_sender: APP scam pattern or insider threat
5. file_sar: if structuring/fraud-ring flags and not yet filed
"""
# The env sets both "already_used" (bool for this action) and
# "investigation_used" (list of all inv actions used for this task)
if isinstance(obs.info, dict):
inv_used = obs.info.get("investigation_used", [])
already = set(inv_used) if isinstance(inv_used, (list, set)) else set()
else:
already = set()
# file_sar if structuring / fraud ring and SAR not yet filed
sar_flags = {"structuring_pattern", "ctr_threshold_avoidance",
"fraud_ring_indicator", "coordinated_transfers"}
if sar_flags & set(obs.flags) and "file_sar" not in already:
return "file_sar"
# contact_sender for APP scam or insider
contact_flags = {"app_scam_indicator", "internal_to_personal", "account_takeover_indicator"}
if contact_flags & set(obs.flags) and "contact_sender" not in already:
return "contact_sender"
# verify_kyc if expired or pending
if obs.kyc_status in ("expired", "pending") and "verify_kyc" not in already:
return "verify_kyc"
# request_docs for first-time payees with high value
doc_flags = {"first_time_payee", "large_first_transfer", "invoice_mismatch", "trade_finance"}
if (doc_flags & set(obs.flags)
and obs.amount >= 50_000
and "request_docs" not in already):
return "request_docs"
# inspect when ml_confidence is low or watchlist flags are present
ml_conf = getattr(obs, "ml_confidence", 0.9) or 0.9
watchlist_hit = bool(_WATCHLIST_FLAGS & set(obs.flags))
if (ml_conf < 0.60 or watchlist_hit) and "inspect" not in already:
return "inspect"
# Fallback for chain-gated tasks (chain_total > 1): if not enough investigation
# steps have been done yet, issue generic actions in priority order so the
# baseline agent never gets stuck in an infinite chain-gate loop.
chain_min = max(0, getattr(obs, "chain_total", 1) - 1)
if chain_min > 0 and len(already) < chain_min:
for inv_action in ("inspect", "verify_kyc", "request_docs", "contact_sender", "file_sar"):
if inv_action not in already:
return inv_action
return None
def _terminal_decision(obs) -> str:
"""
Make a terminal decision. Uses enriched observation fields where available
(inspect_reveal, docs_notes, kyc_notes, contact_notes already in obs.info).
"""
# Definitive danger signals
if any(f in _DANGER_FLAGS for f in obs.flags):
# Exception: if we confirmed via contact that it is genuine, approve
contact = getattr(obs, "contact_notes", None) or ""
if "CEO confirms they did NOT" in contact or "Classic APP scam" in contact:
return "reject"
return "reject"
vel = obs.velocity_1h
if vel is not None and vel >= 10:
return "reject"
if obs.kyc_status in ("failed", "none"):
return "escalate"
# KYC expired → hold (wait for renewal)
if obs.kyc_status == "expired":
return "hold"
# Very high risk
if obs.risk_score >= 0.85:
return "reject"
# Watchlist flags with high value
if _WATCHLIST_FLAGS & set(obs.flags) and obs.amount >= 20_000:
return "escalate"
# FX / correspondent banking settlement
if "fx_settlement" in obs.flags:
return "approve"
# After inspection revealed legitimacy (corp-level, FX)
inspection = getattr(obs, "inspection_notes", None) or ""
if "correspondent banking" in inspection.lower() or "on file" in inspection.lower():
if obs.risk_score < 0.75:
return "approve"
if obs.risk_score >= 0.65:
return "escalate"
elif obs.risk_score >= 0.40 or obs.flags:
return "flag"
else:
return "approve"
def _rule_based_policy(obs) -> str:
"""
Adaptive policy. First checks whether investigation is needed; if so
returns a sub-action. Otherwise returns a terminal decision.
"""
sub = _should_investigate(obs)
if sub is not None:
return sub
return _terminal_decision(obs)
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
async def run_baseline() -> Tuple[List[Dict[str, Any]], float, float, int]:
"""
Run the adaptive rule-based baseline over the full task set.
Returns:
(per_task_rewards, total_reward, normalised_score, steps)
"""
env = PayOpsEnvironment()
obs = await env.reset_async()
actions_taken: List[str] = []
confs: List[Optional[float]] = []
step = 0
while not obs.done:
action_type = _rule_based_policy(obs)
action = PayOpsAction(
action_type=action_type,
transaction_id=obs.transaction_id,
)
obs = await env.step_async(action)
actions_taken.append(action_type)
confs.append(None)
step += 1
jittered_tasks = list(env._tasks)
env.close()
result = grade_episode(actions_taken, jittered_tasks, confs)
return result.per_task_rewards, result.total_reward, result.normalised_score, step