""" AML Investigator - Baseline Inference Script Loops through all 3 tasks to satisfy the Phase 2 Validator. """ import asyncio import json import os import re import sys import textwrap from typing import Any, Dict, List, Optional, Tuple from openai import OpenAI from pydantic import ValidationError from server.AML_env_environment import AmlEnvironment from models import AmlAction, AmlObservation API_BASE_URL = os.getenv("API_BASE_URL") or "http://127.0.0.1:1234/v1" MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b") HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio" LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://127.0.0.1:7860") TASK_NAME = os.getenv("TASK_NAME", "aml_easy") TASKS = ["aml_easy", "aml_medium", "aml_hard"] BENCHMARK = "aml_investigator" MAX_STEPS = 25 HISTORY_MAX_STEPS = 3 SYSTEM_PROMPT = textwrap.dedent( """ You are a Tier 1 AML compliance investigator using a ReAct-style loop. Think privately, then return exactly one JSON object for the next action. Output format: { "thought": "Observation: ... Plan: ...", "action": { "action_type": "...", ... } } The "thought" field is your thinking pad and is required. It must include two labeled sections in order: - Observation: what evidence you see now. - Plan: the single next action and why. Keep it concise. Available actions: - {"action": {"action_type": "query_transactions", "account_id": "ACC-XXXX", "limit": 10, "offset": 0}} - {"action": {"action_type": "search_transactions", "account_id": "ACC-XXXX", "keyword": "invoice"}} - {"action": {"action_type": "get_kyc_record", "entity_id": "ENT-XXXX"}} - {"action": {"action_type": "submit_decision", "decision": "FRAUD", "evidence_links": ["ACC-1234"]}} - For false positives, use {"action": {"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []}} Rules: - Use only the alert, current observation, and recent history shown here. - get_kyc_record must use ENT ids, never ACC ids. - Return JSON only. No markdown fences. No explanation outside JSON. Example 1: {"thought":"Observation: The flagged account sent a large payment with a business-like memo. Plan: Check receiver KYC before deciding.","action":{"action_type":"get_kyc_record","entity_id":"ENT-9002"}} Example 2: {"thought":"Observation: There are multiple inbound deposits just under 10000 from different accounts. Plan: Inspect one sender's KYC to test structuring.","action":{"action_type":"get_kyc_record","entity_id":"ENT-9011"}} """ ).strip() def _extract_text_from_chat_completion(completion: object) -> str: choices = getattr(completion, "choices", None) or [] if not choices: raise ValueError("Model response has no choices") first_choice = choices[0] message = getattr(first_choice, "message", None) if message is None: raise ValueError("Model response choice has no message") content = getattr(message, "content", None) if isinstance(content, str) and content.strip(): return content.strip() if isinstance(content, list): chunks: List[str] = [] for item in content: if isinstance(item, dict): text_val = item.get("text") if isinstance(text_val, str): chunks.append(text_val) else: text_val = getattr(item, "text", None) if isinstance(text_val, str): chunks.append(text_val) merged = "".join(chunks).strip() if merged: return merged raise ValueError("Model response content is empty") def _extract_text_from_responses_api(response: object) -> str: output_text = getattr(response, "output_text", None) if isinstance(output_text, str) and output_text.strip(): return output_text.strip() output = getattr(response, "output", None) or [] chunks: List[str] = [] for item in output: content = getattr(item, "content", None) or [] for part in content: text_val = getattr(part, "text", None) if isinstance(text_val, str): chunks.append(text_val) elif isinstance(part, dict): maybe_text = part.get("text") if isinstance(maybe_text, str): chunks.append(maybe_text) merged = "".join(chunks).strip() if merged: return merged raise ValueError("Responses API output is empty") def _extract_text_from_completions_api(completion: object) -> str: choices = getattr(completion, "choices", None) or [] if not choices: raise ValueError("Completions API response has no choices") first_choice = choices[0] text_val = getattr(first_choice, "text", None) if isinstance(text_val, str) and text_val.strip(): return text_val.strip() raise ValueError("Completions API response text is empty") def _coerce_json_object(raw_text: str) -> str: text = raw_text.strip() if text.startswith("```"): text = text.replace("```json", "").replace("```", "").strip() if text.startswith("{") and text.endswith("}"): return text start = text.find("{") end = text.rfind("}") if start != -1 and end > start: return text[start : end + 1] return text def _strip_channel_wrappers(raw_text: str) -> str: """ Some OSS reasoning models emit channel tags like: <|channel|>analysis<|message|>...<|channel|>final<|message|>{...} Keep only the final/message payload before JSON parsing. """ text = raw_text.strip() if "<|channel|>" not in text: return text final_marker = "<|channel|>final<|message|>" if final_marker in text: return text.split(final_marker, 1)[1].strip() message_marker = "<|message|>" if message_marker in text: return text.split(message_marker, 1)[1].strip() return text def _extract_balanced_json_object(text: str) -> Optional[str]: start = text.find("{") if start == -1: return None depth = 0 in_string = False escape = False for idx in range(start, len(text)): ch = text[idx] if in_string: if escape: escape = False elif ch == "\\": escape = True elif ch == '"': in_string = False continue if ch == '"': in_string = True elif ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: return text[start : idx + 1] return None def _parse_action_payload(raw_text: str) -> AmlAction: cleaned_text = _strip_channel_wrappers(raw_text) candidate = _coerce_json_object(cleaned_text) parse_errors: List[str] = [] for attempt in ( candidate, _extract_balanced_json_object(cleaned_text) or "", _extract_balanced_json_object(raw_text) or "", ): if not attempt: continue try: payload = json.loads(attempt) if isinstance(payload, dict): return AmlAction.model_validate(payload) parse_errors.append("decoded JSON was not an object") continue except ValidationError as exc: parse_errors.append(f"schema: {exc.errors()[0]['msg']}") continue except Exception as exc: parse_errors.append(f"json: {exc}") details = parse_errors[-1] if parse_errors else "could not parse model output into JSON object" raise ValueError(details) def _debug_text_repr(value: Any) -> str: text = str(value) escaped = text.encode("unicode_escape", errors="backslashreplace").decode("ascii", errors="replace") return f"len={len(text)} repr={escaped!r}" def _build_model_observation(obs_dict: Dict[str, Any]) -> Dict[str, Any]: validated = AmlObservation.model_validate(obs_dict) return { "alert_details": validated.alert_details, "budget_remaining": validated.budget_remaining, "last_action": validated.last_action, "last_action_result": validated.last_action_result, "done": validated.done, "reward": validated.reward, } def _render_history(history: List[Dict[str, Any]]) -> str: if not history: return "No previous steps." entries = history[-HISTORY_MAX_STEPS:] lines = [json.dumps(item, ensure_ascii=True) for item in entries] return "\n".join(lines) if lines else "No previous steps." def _build_recovery_action_from_obs(obs_dict: dict, next_offsets: Dict[str, int]) -> dict: """Use a non-terminal fallback action when model output is malformed.""" alert = str(obs_dict.get("alert_details", "") or "") match = re.search(r"ACC-\d+", alert) if match: account_id = match.group(0) offset = next_offsets.get(account_id, 0) next_offsets[account_id] = offset + 10 return { "action": { "action_type": "query_transactions", "account_id": account_id, "limit": 10, "offset": offset, } } return { "action": { "action_type": "submit_decision", "decision": "CLEAR", "evidence_links": [], } } def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = error if error else "null" done_val = str(done).lower() print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True) def log_thought(step: int, thought: Optional[object]) -> None: """Print model thought to stderr so stdout contract stays validator-safe.""" if thought is None: return if isinstance(thought, dict): compact = json.dumps(thought, ensure_ascii=True) else: compact = str(thought) compact = compact.replace("\n", " ").strip() print(f"[THOUGHT] step={step} thought={compact}", file=sys.stderr, flush=True) def get_model_message( client: OpenAI, obs_dict: dict, history: List[Dict[str, Any]], next_offsets: Dict[str, int], ) -> Tuple[str, bool]: model_obs = _build_model_observation(obs_dict) history_block = _render_history(history) user_prompt = ( f"Observation:\n{json.dumps(model_obs, ensure_ascii=True, indent=2)}\n\n" f"History:\n{history_block}\n\n" "Return exactly one JSON object with keys: thought, action. " "thought must include 'Observation:' and 'Plan:'." ) try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.0, max_tokens=260, response_format={"type": "json_object"}, ) return _extract_text_from_chat_completion(completion), False except Exception as chat_exc: chat_error = f"chat:{chat_exc}" try: response = client.responses.create( model=MODEL_NAME, instructions=SYSTEM_PROMPT, input=user_prompt, max_output_tokens=1000, ) return _extract_text_from_responses_api(response), False except Exception as responses_exc: responses_error = f"responses:{responses_exc}" try: completion = client.completions.create( model=MODEL_NAME, prompt=f"{SYSTEM_PROMPT}\n\n{user_prompt}", temperature=0.0, max_tokens=260, ) return _extract_text_from_completions_api(completion), False except Exception as completions_exc: completions_error = f"completions:{completions_exc}" recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets) print( ( "[DEBUG] Model request failed; using recovery action " f"({completions_error}; {chat_error}; {responses_error})" ), file=sys.stderr, flush=True, ) recovery_payload = { "thought": "Observation: Model request failed. Plan: take a safe recovery action.", "action": recovery_json["action"], } return json.dumps(recovery_payload, ensure_ascii=True), True async def main() -> None: client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) env = AmlEnvironment() for task_name in TASKS: history: List[Dict[str, Any]] = [] rewards: List[float] = [] steps_taken = 0 score = 0.0 success = False had_parse_error = False next_offsets: Dict[str, int] = {} log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) try: obs = env.reset(task=task_name) for step in range(1, MAX_STEPS + 1): if obs.done: break obs_dict = AmlObservation.model_validate(obs.model_dump()).model_dump() action_str, used_recovery = get_model_message(client, obs_dict, history, next_offsets) if used_recovery: had_parse_error = True action_for_log = action_str action_payload_for_history: Dict[str, Any] = {} parsed_model_action = False model_thought_for_history: Optional[str] = None try: action_obj = _parse_action_payload(action_str) log_thought(step=step, thought=action_obj.thought) model_thought_for_history = action_obj.thought parsed_model_action = True action_payload_for_history = action_obj.action.model_dump(exclude={"metadata"}, exclude_none=True) action_for_log = json.dumps({"action": action_payload_for_history}, ensure_ascii=True) if action_payload_for_history.get("action_type") == "query_transactions": acc = action_payload_for_history.get("account_id") offset = int(action_payload_for_history.get("offset", 0)) limit = int(action_payload_for_history.get("limit", 10)) if isinstance(acc, str): next_offsets[acc] = max(next_offsets.get(acc, 0), offset + max(limit, 1)) error = None except Exception as e: had_parse_error = True error = f"JSON Parse/Schema Error: {str(e)}" debug_payload = _debug_text_repr(action_str) if action_str.strip() else "empty model output" print( f"[DEBUG] step={step} parse_failed_raw={debug_payload}", file=sys.stderr, flush=True, ) log_thought( step=step, thought="Observation: model output was invalid. Plan: use safe recovery action.", ) recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets) recovery_payload = { "thought": "Observation: JSON/schema parse failed. Plan: query next page safely.", "action": recovery_json["action"], } action_obj = AmlAction.model_validate(recovery_payload) action_payload_for_history = recovery_payload["action"] action_for_log = json.dumps({"action": action_payload_for_history}, ensure_ascii=True) obs = env.step(action_obj) reward = obs.reward or 0.0 done = obs.done rewards.append(reward) steps_taken = step log_step(step=step, action=action_for_log.replace("\n", ""), reward=reward, done=done, error=error) # Keep prompt context clean: only feed back model-authored, schema-valid turns. if parsed_model_action: history.append( { "step": step, "thought": model_thought_for_history, "action": action_payload_for_history, "result": obs.last_action_result, "budget_remaining": obs.budget_remaining, } ) if len(history) > HISTORY_MAX_STEPS: history = history[-HISTORY_MAX_STEPS:] if done: break if had_parse_error or obs.error_message: score = 0.05 elif "submit_decision" in (obs.last_action or ""): score = 0.75 else: score = 0.25 score = min(max(score, 0.01), 0.99) success = (not had_parse_error) and (obs.error_message is None) and score > 0.5 finally: log_end(success=success, steps=steps_taken, score=score, rewards=rewards) if __name__ == "__main__": asyncio.run(main())