File size: 6,694 Bytes
d727210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f139ff
 
 
 
 
 
 
d727210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffea7f4
 
 
 
 
 
 
 
 
d727210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128f77d
d727210
 
128f77d
d727210
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
Utilities for running the baseline agent programmatically (used by /baseline endpoint).
"""
from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple

from payops_env.environment import PayOpsEnvironment
from payops_env.grader import grade_episode
from payops_env.models import PayOpsAction
from payops_env.tasks import TASKS


# ---------------------------------------------------------------------------
# Adaptive rule-based policy
# ---------------------------------------------------------------------------

_DANGER_FLAGS = {
    "sanctioned_country", "app_scam_indicator", "mule_account_pattern",
    "structuring_pattern", "ctr_threshold_avoidance", "fraud_ring_indicator",
    "geo_impossible_login", "account_takeover_indicator",
}

_WATCHLIST_FLAGS = {
    "new_account_7d", "large_first_transfer", "solicitor_mule_pattern",
    "dormant_receiver", "sudden_activity", "insider_threat", "internal_to_personal",
    "invoice_mismatch", "trade_finance",
}


def _should_investigate(obs) -> Optional[str]:
    """
    Decide whether to issue an investigation sub-action first.

    Returns the sub-action name, or None if we should decide directly.

    Priority:
      1. KYC expired / pending → verify_kyc (once)
      2. Watchlist flags AND high amount AND docs not yet requested → request_docs
      3. Low ml_confidence AND medium risk → inspect (once)
      4. contact_sender: APP scam pattern or insider threat
      5. file_sar: if structuring/fraud-ring flags and not yet filed
    """
    # The env sets both "already_used" (bool for this action) and
    # "investigation_used" (list of all inv actions used for this task)
    if isinstance(obs.info, dict):
        inv_used = obs.info.get("investigation_used", [])
        already = set(inv_used) if isinstance(inv_used, (list, set)) else set()
    else:
        already = set()

    # file_sar if structuring / fraud ring and SAR not yet filed
    sar_flags = {"structuring_pattern", "ctr_threshold_avoidance",
                 "fraud_ring_indicator", "coordinated_transfers"}
    if sar_flags & set(obs.flags) and "file_sar" not in already:
        return "file_sar"

    # contact_sender for APP scam or insider
    contact_flags = {"app_scam_indicator", "internal_to_personal", "account_takeover_indicator"}
    if contact_flags & set(obs.flags) and "contact_sender" not in already:
        return "contact_sender"

    # verify_kyc if expired or pending
    if obs.kyc_status in ("expired", "pending") and "verify_kyc" not in already:
        return "verify_kyc"

    # request_docs for first-time payees with high value
    doc_flags = {"first_time_payee", "large_first_transfer", "invoice_mismatch", "trade_finance"}
    if (doc_flags & set(obs.flags)
            and obs.amount >= 50_000
            and "request_docs" not in already):
        return "request_docs"

    # inspect when ml_confidence is low or watchlist flags are present
    ml_conf = getattr(obs, "ml_confidence", 0.9) or 0.9
    watchlist_hit = bool(_WATCHLIST_FLAGS & set(obs.flags))
    if (ml_conf < 0.60 or watchlist_hit) and "inspect" not in already:
        return "inspect"

    # Fallback for chain-gated tasks (chain_total > 1): if not enough investigation
    # steps have been done yet, issue generic actions in priority order so the
    # baseline agent never gets stuck in an infinite chain-gate loop.
    chain_min = max(0, getattr(obs, "chain_total", 1) - 1)
    if chain_min > 0 and len(already) < chain_min:
        for inv_action in ("inspect", "verify_kyc", "request_docs", "contact_sender", "file_sar"):
            if inv_action not in already:
                return inv_action

    return None


def _terminal_decision(obs) -> str:
    """
    Make a terminal decision.  Uses enriched observation fields where available
    (inspect_reveal, docs_notes, kyc_notes, contact_notes already in obs.info).
    """
    # Definitive danger signals
    if any(f in _DANGER_FLAGS for f in obs.flags):
        # Exception: if we confirmed via contact that it is genuine, approve
        contact = getattr(obs, "contact_notes", None) or ""
        if "CEO confirms they did NOT" in contact or "Classic APP scam" in contact:
            return "reject"
        return "reject"

    vel = obs.velocity_1h
    if vel is not None and vel >= 10:
        return "reject"

    if obs.kyc_status in ("failed", "none"):
        return "escalate"

    # KYC expired → hold (wait for renewal)
    if obs.kyc_status == "expired":
        return "hold"

    # Very high risk
    if obs.risk_score >= 0.85:
        return "reject"

    # Watchlist flags with high value
    if _WATCHLIST_FLAGS & set(obs.flags) and obs.amount >= 20_000:
        return "escalate"

    # FX / correspondent banking settlement
    if "fx_settlement" in obs.flags:
        return "approve"

    # After inspection revealed legitimacy (corp-level, FX)
    inspection = getattr(obs, "inspection_notes", None) or ""
    if "correspondent banking" in inspection.lower() or "on file" in inspection.lower():
        if obs.risk_score < 0.75:
            return "approve"

    if obs.risk_score >= 0.65:
        return "escalate"
    elif obs.risk_score >= 0.40 or obs.flags:
        return "flag"
    else:
        return "approve"


def _rule_based_policy(obs) -> str:
    """
    Adaptive policy.  First checks whether investigation is needed; if so
    returns a sub-action.  Otherwise returns a terminal decision.
    """
    sub = _should_investigate(obs)
    if sub is not None:
        return sub
    return _terminal_decision(obs)


# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------

async def run_baseline() -> Tuple[List[Dict[str, Any]], float, float, int]:
    """
    Run the adaptive rule-based baseline over the full task set.

    Returns:
        (per_task_rewards, total_reward, normalised_score, steps)
    """
    env = PayOpsEnvironment()
    obs = await env.reset_async()
    actions_taken: List[str] = []
    confs: List[Optional[float]] = []
    step = 0

    while not obs.done:
        action_type = _rule_based_policy(obs)
        action = PayOpsAction(
            action_type=action_type,
            transaction_id=obs.transaction_id,
        )
        obs = await env.step_async(action)
        actions_taken.append(action_type)
        confs.append(None)
        step += 1

    jittered_tasks = list(env._tasks)
    env.close()

    result = grade_episode(actions_taken, jittered_tasks, confs)
    return result.per_task_rewards, result.total_reward, result.normalised_score, step