DataBoySu commited on
Commit
74aae3b
·
1 Parent(s): 9670629

cot improv

Browse files
Files changed (2) hide show
  1. inference.py +176 -230
  2. models.py +34 -27
inference.py CHANGED
@@ -11,12 +11,13 @@ import textwrap
11
  from typing import Any, Dict, List, Optional, Tuple
12
 
13
  from openai import OpenAI
 
14
 
15
  from server.AML_env_environment import AmlEnvironment
16
- from models import AmlAction
17
 
18
 
19
- API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
20
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
21
  HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
22
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
@@ -26,46 +27,45 @@ TASKS = ["aml_easy", "aml_medium", "aml_hard"]
26
  BENCHMARK = "aml_investigator"
27
  MAX_STEPS = 25
28
 
29
- OBS_RESULT_MAX_ITEMS = 8
30
  HISTORY_MAX_STEPS = 3
31
- HISTORY_MAX_CHARS = 1600
32
- TEXT_CLIP_CHARS = 320
33
 
34
  SYSTEM_PROMPT = textwrap.dedent(
35
  """
36
- You are a Tier 1 AML Compliance Investigator.
37
- You must investigate the provided alert by querying the bank's internal APIs.
38
 
39
- You have a strict API budget. Be efficient.
40
- Respond with EXACTLY ONE valid JSON object representing your action. Do not include markdown formatting or explanations.
41
-
42
- Available Action JSON Schemas:
43
- 1. {"action": {"action_type": "query_transactions", "account_id": "ACC-XXXX", "limit": 10, "offset": 0}}
44
- 2. {"action": {"action_type": "search_transactions", "account_id": "ACC-XXXX", "keyword": "invoice"}}
45
- 3. {"action": {"action_type": "get_kyc_record", "entity_id": "ENT-XXXX"}}
46
- 4. {"action": {"action_type": "submit_decision", "decision": "FRAUD", "evidence_links": ["ACC-1234"]}} (Use "CLEAR" for False Positives with empty evidence_links).
47
-
48
- Required top-level JSON format:
49
  {
50
- "thought": {
51
- "observation": "...",
52
- "plan": "...",
53
- "action": "..."
54
- },
55
- "action": {...}
56
  }
57
 
58
- Thought rules:
59
- - Use caveman style: short, simple, low-token wording.
60
- - Keep thought informative but brief.
61
- - observation = what clue found now.
62
- - plan = next investigation goal.
63
- - action = exact tool call you will make now.
64
-
65
- Data rules:
66
- - get_kyc_record must use ENT-XXXX only, never ACC-XXXX.
67
- - submit_decision only when evidence is enough; else keep investigating.
68
- - Use only the alert, the current observation, and the recent history shown here.
 
 
 
 
 
 
 
 
 
 
 
 
69
  """
70
  ).strip()
71
 
@@ -156,80 +156,100 @@ def _coerce_json_object(raw_text: str) -> str:
156
  return text
157
 
158
 
159
- def _clip_text(value: Any, max_chars: int = TEXT_CLIP_CHARS) -> str:
160
- text = str(value).replace("\n", " ").strip()
161
- if len(text) <= max_chars:
 
 
 
 
 
162
  return text
163
- return text[: max_chars - 3] + "..."
164
-
165
-
166
- def _compact_record(record: Dict[str, Any]) -> Dict[str, Any]:
167
- keep_keys = [
168
- "txn_id",
169
- "timestamp",
170
- "sender_account",
171
- "receiver_account",
172
- "amount",
173
- "memo_text",
174
- "account_id",
175
- "owner_entity_id",
176
- "status",
177
- "entity_id",
178
- "name",
179
- "type",
180
- "registration_address",
181
- "directors",
182
- ]
183
- compact: Dict[str, Any] = {}
184
- for key in keep_keys:
185
- if key not in record:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  continue
187
- value = record.get(key)
188
- if key == "directors" and isinstance(value, list):
189
- compact[key] = value[:4]
190
- if len(value) > 4:
191
- compact["directors_truncated"] = len(value) - 4
192
  continue
193
- if isinstance(value, str):
194
- compact[key] = _clip_text(value, max_chars=180)
195
- else:
196
- compact[key] = value
197
- return compact
198
 
 
 
199
 
200
- def _compact_action_result(last_action: Optional[str], value: Any) -> Any:
201
- if value is None:
202
- return None
203
- if isinstance(value, list):
204
- items = []
205
- for item in value[:OBS_RESULT_MAX_ITEMS]:
206
- if isinstance(item, dict):
207
- items.append(_compact_record(item))
208
- else:
209
- items.append(_clip_text(item))
210
- return {
211
- "kind": "list",
212
- "count": len(value),
213
- "items": items,
214
- "truncated": len(value) > OBS_RESULT_MAX_ITEMS,
215
- "source_action": last_action,
216
- }
217
- if isinstance(value, dict):
218
- return _compact_record(value)
219
- if isinstance(value, str):
220
- return _clip_text(value, max_chars=420)
221
- return value
222
 
223
 
224
  def _build_model_observation(obs_dict: Dict[str, Any]) -> Dict[str, Any]:
 
225
  return {
226
- "alert_details": obs_dict.get("alert_details"),
227
- "budget_remaining": obs_dict.get("budget_remaining"),
228
- "last_action": obs_dict.get("last_action"),
229
- "last_action_result": _compact_action_result(obs_dict.get("last_action"), obs_dict.get("last_action_result")),
230
- "error_message": _clip_text(obs_dict.get("error_message")) if obs_dict.get("error_message") else None,
231
- "done": obs_dict.get("done"),
232
- "reward": obs_dict.get("reward"),
233
  }
234
 
235
 
@@ -237,9 +257,7 @@ def _render_history(history: List[Dict[str, Any]]) -> str:
237
  if not history:
238
  return "No previous steps."
239
  entries = history[-HISTORY_MAX_STEPS:]
240
- lines = [json.dumps(item, ensure_ascii=True, separators=(",", ":")) for item in entries]
241
- while lines and len("\n".join(lines)) > HISTORY_MAX_CHARS:
242
- lines.pop(0)
243
  return "\n".join(lines) if lines else "No previous steps."
244
 
245
 
@@ -268,49 +286,6 @@ def _build_recovery_action_from_obs(obs_dict: dict, next_offsets: Dict[str, int]
268
  }
269
 
270
 
271
- def _normalize_thought(payload: Dict[str, Any]) -> None:
272
- action = payload.get("action") if isinstance(payload.get("action"), dict) else {}
273
- action_type = action.get("action_type", "unknown")
274
- if "thought" not in payload or not isinstance(payload.get("thought"), dict):
275
- payload["thought"] = {
276
- "observation": "see current clue now.",
277
- "plan": "find next real link.",
278
- "action": f"do {action_type} now.",
279
- }
280
- return
281
-
282
- thought = payload["thought"]
283
- for key, fallback in (
284
- ("observation", "see clue now."),
285
- ("plan", "next check key link."),
286
- ("action", f"do {action_type} now."),
287
- ):
288
- value = thought.get(key)
289
- if not isinstance(value, str) or not value.strip():
290
- thought[key] = fallback
291
- else:
292
- thought[key] = _clip_text(value, max_chars=140)
293
-
294
-
295
- def _try_validate_action_json(raw_text: str) -> Optional[str]:
296
- """Return canonical JSON string if valid, else None."""
297
- candidate = _coerce_json_object(raw_text)
298
- try:
299
- payload = json.loads(candidate)
300
- if not isinstance(payload, dict):
301
- raise ValueError("top-level JSON is not an object")
302
- action = payload.get("action")
303
- if not isinstance(action, dict):
304
- raise ValueError("missing 'action' object")
305
- action_type = action.get("action_type")
306
- if not isinstance(action_type, str):
307
- raise ValueError("missing 'action_type' string")
308
- _normalize_thought(payload)
309
- return json.dumps(payload, ensure_ascii=True)
310
- except Exception:
311
- return None
312
-
313
-
314
  def log_start(task: str, env: str, model: str) -> None:
315
  print(f"[START] task={task} env={env} model={model}", flush=True)
316
 
@@ -349,25 +324,9 @@ def get_model_message(
349
  user_prompt = (
350
  f"Observation:\n{json.dumps(model_obs, ensure_ascii=True, indent=2)}\n\n"
351
  f"History:\n{history_block}\n\n"
352
- "Return exactly one JSON object with keys: thought, action."
 
353
  )
354
- parse_errors: List[str] = []
355
-
356
- try:
357
- response = client.responses.create(
358
- model=MODEL_NAME,
359
- instructions=SYSTEM_PROMPT,
360
- input=user_prompt,
361
- max_output_tokens=700,
362
- )
363
- raw_text = _extract_text_from_responses_api(response)
364
- canonical = _try_validate_action_json(raw_text)
365
- if canonical is not None:
366
- return canonical, False
367
- parse_errors.append("responses:invalid_json")
368
- except Exception as responses_exc:
369
- parse_errors.append(f"responses:{responses_exc}")
370
-
371
  try:
372
  completion = client.chat.completions.create(
373
  model=MODEL_NAME,
@@ -376,46 +335,46 @@ def get_model_message(
376
  {"role": "user", "content": user_prompt},
377
  ],
378
  temperature=0.0,
379
- max_tokens=700,
 
380
  )
381
- raw_text = _extract_text_from_chat_completion(completion)
382
- canonical = _try_validate_action_json(raw_text)
383
- if canonical is not None:
384
- return canonical, False
385
- parse_errors.append("chat:invalid_json")
386
  except Exception as chat_exc:
387
- parse_errors.append(f"chat:{chat_exc}")
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  try:
390
  completion = client.completions.create(
391
  model=MODEL_NAME,
392
  prompt=f"{SYSTEM_PROMPT}\n\n{user_prompt}",
393
  temperature=0.0,
394
- max_tokens=280,
395
  )
396
- raw_text = _extract_text_from_completions_api(completion)
397
- canonical = _try_validate_action_json(raw_text)
398
- if canonical is not None:
399
- return canonical, False
400
- parse_errors.append("completions:invalid_json")
401
  except Exception as completions_exc:
402
- parse_errors.append(f"completions:{completions_exc}")
403
 
404
  recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets)
405
  print(
406
  (
407
- "[DEBUG] Non-JSON/invalid model action; using recovery action "
408
- f"({'; '.join(parse_errors)})"
409
  ),
410
  file=sys.stderr,
411
  flush=True,
412
  )
413
  recovery_payload = {
414
- "thought": {
415
- "observation": "model output bad json.",
416
- "plan": "use safe step. keep investigate.",
417
- "action": "query alert account next page.",
418
- },
419
  "action": recovery_json["action"],
420
  }
421
  return json.dumps(recovery_payload, ensure_ascii=True), True
@@ -433,7 +392,6 @@ async def main() -> None:
433
  success = False
434
  had_parse_error = False
435
  next_offsets: Dict[str, int] = {}
436
- query_seen_counts: Dict[Tuple[str, int], int] = {}
437
 
438
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
439
 
@@ -444,60 +402,46 @@ async def main() -> None:
444
  if obs.done:
445
  break
446
 
447
- obs_dict = obs.model_dump()
448
  action_str, used_recovery = get_model_message(client, obs_dict, history, next_offsets)
449
  if used_recovery:
450
  had_parse_error = True
451
 
452
  action_for_log = action_str
453
  action_payload_for_history: Dict[str, Any] = {}
 
 
454
  try:
455
- clean_str = _coerce_json_object(action_str)
456
- action_json = json.loads(clean_str)
457
- thought_for_log = action_json.get("thought")
458
- if thought_for_log is None:
459
- action_type = action_json.get("action", {}).get("action_type", "unknown")
460
- thought_for_log = f"do {action_type} now"
461
- log_thought(step=step, thought=thought_for_log)
462
- action_obj = AmlAction.model_validate(action_json)
463
-
464
- action_payload_for_history = action_json.get("action", {}) if isinstance(action_json, dict) else {}
465
  action_for_log = json.dumps({"action": action_payload_for_history}, ensure_ascii=True)
466
  if action_payload_for_history.get("action_type") == "query_transactions":
467
  acc = action_payload_for_history.get("account_id")
468
  offset = int(action_payload_for_history.get("offset", 0))
469
  limit = int(action_payload_for_history.get("limit", 10))
470
  if isinstance(acc, str):
471
- query_key = (acc, offset)
472
- query_seen_counts[query_key] = query_seen_counts.get(query_key, 0) + 1
473
- # Hard guardrail: avoid wasting budget on repeated same page.
474
- if task_name == "aml_hard" and query_seen_counts[query_key] > 2:
475
- new_offset = max(next_offsets.get(acc, offset + max(limit, 1)), offset + max(limit, 1))
476
- action_json["action"]["offset"] = new_offset
477
- action_json["thought"]["plan"] = _clip_text(
478
- f"repeat page seen. move to next offset {new_offset}.",
479
- max_chars=120,
480
- )
481
- action_json["thought"]["action"] = _clip_text(
482
- f"query_transactions {acc} offset {new_offset}",
483
- max_chars=120,
484
- )
485
- action_for_log = json.dumps(action_json, ensure_ascii=True)
486
- action_obj = AmlAction.model_validate(action_json)
487
- offset = new_offset
488
  next_offsets[acc] = max(next_offsets.get(acc, 0), offset + max(limit, 1))
489
  error = None
490
  except Exception as e:
491
  had_parse_error = True
492
  error = f"JSON Parse/Schema Error: {str(e)}"
493
- log_thought(step=step, thought="parse fail; use recovery action")
 
 
 
 
 
 
 
 
 
494
  recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets)
495
  recovery_payload = {
496
- "thought": {
497
- "observation": "parse fail now.",
498
- "plan": "safe step, keep digging.",
499
- "action": "query alert next page.",
500
- },
501
  "action": recovery_json["action"],
502
  }
503
  action_obj = AmlAction.model_validate(recovery_payload)
@@ -513,17 +457,19 @@ async def main() -> None:
513
  steps_taken = step
514
 
515
  log_step(step=step, action=action_for_log.replace("\n", ""), reward=reward, done=done, error=error)
516
- history.append(
517
- {
518
- "step": step,
519
- "action": action_payload_for_history,
520
- "result": _compact_action_result(obs.last_action, obs.last_action_result),
521
- "error": _clip_text(obs.error_message) if obs.error_message else None,
522
- "budget_remaining": obs.budget_remaining,
523
- }
524
- )
525
- if len(history) > 24:
526
- history = history[-24:]
 
 
527
 
528
  if done:
529
  break
 
11
  from typing import Any, Dict, List, Optional, Tuple
12
 
13
  from openai import OpenAI
14
+ from pydantic import ValidationError
15
 
16
  from server.AML_env_environment import AmlEnvironment
17
+ from models import AmlAction, AmlObservation
18
 
19
 
20
+ API_BASE_URL = os.getenv("API_BASE_URL") or "http://127.0.0.1:1234/v1"
21
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
22
  HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
23
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
 
27
  BENCHMARK = "aml_investigator"
28
  MAX_STEPS = 25
29
 
 
30
  HISTORY_MAX_STEPS = 3
 
 
31
 
32
  SYSTEM_PROMPT = textwrap.dedent(
33
  """
34
+ You are a Tier 1 AML compliance investigator using a ReAct-style loop.
35
+ Think privately, then return exactly one JSON object for the next action.
36
 
37
+ Output format:
 
 
 
 
 
 
 
 
 
38
  {
39
+ "thought": "Observation: ... Plan: ...",
40
+ "action": {
41
+ "action_type": "...",
42
+ ...
43
+ }
 
44
  }
45
 
46
+ The "thought" field is your thinking pad and is required.
47
+ It must include two labeled sections in order:
48
+ - Observation: what evidence you see now.
49
+ - Plan: the single next action and why.
50
+ Keep it concise.
51
+
52
+ Available actions:
53
+ - {"action": {"action_type": "query_transactions", "account_id": "ACC-XXXX", "limit": 10, "offset": 0}}
54
+ - {"action": {"action_type": "search_transactions", "account_id": "ACC-XXXX", "keyword": "invoice"}}
55
+ - {"action": {"action_type": "get_kyc_record", "entity_id": "ENT-XXXX"}}
56
+ - {"action": {"action_type": "submit_decision", "decision": "FRAUD", "evidence_links": ["ACC-1234"]}}
57
+ - For false positives, use {"action": {"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []}}
58
+
59
+ Rules:
60
+ - Use only the alert, current observation, and recent history shown here.
61
+ - get_kyc_record must use ENT ids, never ACC ids.
62
+ - Return JSON only. No markdown fences. No explanation outside JSON.
63
+
64
+ Example 1:
65
+ {"thought":"Observation: The flagged account sent a large payment with a business-like memo. Plan: Check receiver KYC before deciding.","action":{"action_type":"get_kyc_record","entity_id":"ENT-9002"}}
66
+
67
+ Example 2:
68
+ {"thought":"Observation: There are multiple inbound deposits just under 10000 from different accounts. Plan: Inspect one sender's KYC to test structuring.","action":{"action_type":"get_kyc_record","entity_id":"ENT-9011"}}
69
  """
70
  ).strip()
71
 
 
156
  return text
157
 
158
 
159
+ def _strip_channel_wrappers(raw_text: str) -> str:
160
+ """
161
+ Some OSS reasoning models emit channel tags like:
162
+ <|channel|>analysis<|message|>...<|channel|>final<|message|>{...}
163
+ Keep only the final/message payload before JSON parsing.
164
+ """
165
+ text = raw_text.strip()
166
+ if "<|channel|>" not in text:
167
  return text
168
+
169
+ final_marker = "<|channel|>final<|message|>"
170
+ if final_marker in text:
171
+ return text.split(final_marker, 1)[1].strip()
172
+
173
+ message_marker = "<|message|>"
174
+ if message_marker in text:
175
+ return text.split(message_marker, 1)[1].strip()
176
+
177
+ return text
178
+
179
+
180
+ def _extract_balanced_json_object(text: str) -> Optional[str]:
181
+ start = text.find("{")
182
+ if start == -1:
183
+ return None
184
+
185
+ depth = 0
186
+ in_string = False
187
+ escape = False
188
+ for idx in range(start, len(text)):
189
+ ch = text[idx]
190
+ if in_string:
191
+ if escape:
192
+ escape = False
193
+ elif ch == "\\":
194
+ escape = True
195
+ elif ch == '"':
196
+ in_string = False
197
+ continue
198
+
199
+ if ch == '"':
200
+ in_string = True
201
+ elif ch == "{":
202
+ depth += 1
203
+ elif ch == "}":
204
+ depth -= 1
205
+ if depth == 0:
206
+ return text[start : idx + 1]
207
+ return None
208
+
209
+
210
+ def _parse_action_payload(raw_text: str) -> AmlAction:
211
+ cleaned_text = _strip_channel_wrappers(raw_text)
212
+ candidate = _coerce_json_object(cleaned_text)
213
+ parse_errors: List[str] = []
214
+
215
+ for attempt in (
216
+ candidate,
217
+ _extract_balanced_json_object(cleaned_text) or "",
218
+ _extract_balanced_json_object(raw_text) or "",
219
+ ):
220
+ if not attempt:
221
+ continue
222
+ try:
223
+ payload = json.loads(attempt)
224
+ if isinstance(payload, dict):
225
+ return AmlAction.model_validate(payload)
226
+ parse_errors.append("decoded JSON was not an object")
227
  continue
228
+ except ValidationError as exc:
229
+ parse_errors.append(f"schema: {exc.errors()[0]['msg']}")
 
 
 
230
  continue
231
+ except Exception as exc:
232
+ parse_errors.append(f"json: {exc}")
 
 
 
233
 
234
+ details = parse_errors[-1] if parse_errors else "could not parse model output into JSON object"
235
+ raise ValueError(details)
236
 
237
+
238
+ def _debug_text_repr(value: Any) -> str:
239
+ text = str(value)
240
+ escaped = text.encode("unicode_escape", errors="backslashreplace").decode("ascii", errors="replace")
241
+ return f"len={len(text)} repr={escaped!r}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
 
244
  def _build_model_observation(obs_dict: Dict[str, Any]) -> Dict[str, Any]:
245
+ validated = AmlObservation.model_validate(obs_dict)
246
  return {
247
+ "alert_details": validated.alert_details,
248
+ "budget_remaining": validated.budget_remaining,
249
+ "last_action": validated.last_action,
250
+ "last_action_result": validated.last_action_result,
251
+ "done": validated.done,
252
+ "reward": validated.reward,
 
253
  }
254
 
255
 
 
257
  if not history:
258
  return "No previous steps."
259
  entries = history[-HISTORY_MAX_STEPS:]
260
+ lines = [json.dumps(item, ensure_ascii=True) for item in entries]
 
 
261
  return "\n".join(lines) if lines else "No previous steps."
262
 
263
 
 
286
  }
287
 
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  def log_start(task: str, env: str, model: str) -> None:
290
  print(f"[START] task={task} env={env} model={model}", flush=True)
291
 
 
324
  user_prompt = (
325
  f"Observation:\n{json.dumps(model_obs, ensure_ascii=True, indent=2)}\n\n"
326
  f"History:\n{history_block}\n\n"
327
+ "Return exactly one JSON object with keys: thought, action. "
328
+ "thought must include 'Observation:' and 'Plan:'."
329
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  try:
331
  completion = client.chat.completions.create(
332
  model=MODEL_NAME,
 
335
  {"role": "user", "content": user_prompt},
336
  ],
337
  temperature=0.0,
338
+ max_tokens=260,
339
+ response_format={"type": "json_object"},
340
  )
341
+ return _extract_text_from_chat_completion(completion), False
 
 
 
 
342
  except Exception as chat_exc:
343
+ chat_error = f"chat:{chat_exc}"
344
+
345
+ try:
346
+ response = client.responses.create(
347
+ model=MODEL_NAME,
348
+ instructions=SYSTEM_PROMPT,
349
+ input=user_prompt,
350
+ max_output_tokens=1000,
351
+ )
352
+ return _extract_text_from_responses_api(response), False
353
+ except Exception as responses_exc:
354
+ responses_error = f"responses:{responses_exc}"
355
 
356
  try:
357
  completion = client.completions.create(
358
  model=MODEL_NAME,
359
  prompt=f"{SYSTEM_PROMPT}\n\n{user_prompt}",
360
  temperature=0.0,
361
+ max_tokens=260,
362
  )
363
+ return _extract_text_from_completions_api(completion), False
 
 
 
 
364
  except Exception as completions_exc:
365
+ completions_error = f"completions:{completions_exc}"
366
 
367
  recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets)
368
  print(
369
  (
370
+ "[DEBUG] Model request failed; using recovery action "
371
+ f"({completions_error}; {chat_error}; {responses_error})"
372
  ),
373
  file=sys.stderr,
374
  flush=True,
375
  )
376
  recovery_payload = {
377
+ "thought": "Observation: Model request failed. Plan: take a safe recovery action.",
 
 
 
 
378
  "action": recovery_json["action"],
379
  }
380
  return json.dumps(recovery_payload, ensure_ascii=True), True
 
392
  success = False
393
  had_parse_error = False
394
  next_offsets: Dict[str, int] = {}
 
395
 
396
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
397
 
 
402
  if obs.done:
403
  break
404
 
405
+ obs_dict = AmlObservation.model_validate(obs.model_dump()).model_dump()
406
  action_str, used_recovery = get_model_message(client, obs_dict, history, next_offsets)
407
  if used_recovery:
408
  had_parse_error = True
409
 
410
  action_for_log = action_str
411
  action_payload_for_history: Dict[str, Any] = {}
412
+ parsed_model_action = False
413
+ model_thought_for_history: Optional[str] = None
414
  try:
415
+ action_obj = _parse_action_payload(action_str)
416
+ log_thought(step=step, thought=action_obj.thought)
417
+ model_thought_for_history = action_obj.thought
418
+ parsed_model_action = True
419
+
420
+ action_payload_for_history = action_obj.action.model_dump(exclude={"metadata"}, exclude_none=True)
 
 
 
 
421
  action_for_log = json.dumps({"action": action_payload_for_history}, ensure_ascii=True)
422
  if action_payload_for_history.get("action_type") == "query_transactions":
423
  acc = action_payload_for_history.get("account_id")
424
  offset = int(action_payload_for_history.get("offset", 0))
425
  limit = int(action_payload_for_history.get("limit", 10))
426
  if isinstance(acc, str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  next_offsets[acc] = max(next_offsets.get(acc, 0), offset + max(limit, 1))
428
  error = None
429
  except Exception as e:
430
  had_parse_error = True
431
  error = f"JSON Parse/Schema Error: {str(e)}"
432
+ debug_payload = _debug_text_repr(action_str) if action_str.strip() else "empty model output"
433
+ print(
434
+ f"[DEBUG] step={step} parse_failed_raw={debug_payload}",
435
+ file=sys.stderr,
436
+ flush=True,
437
+ )
438
+ log_thought(
439
+ step=step,
440
+ thought="Observation: model output was invalid. Plan: use safe recovery action.",
441
+ )
442
  recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets)
443
  recovery_payload = {
444
+ "thought": "Observation: JSON/schema parse failed. Plan: query next page safely.",
 
 
 
 
445
  "action": recovery_json["action"],
446
  }
447
  action_obj = AmlAction.model_validate(recovery_payload)
 
457
  steps_taken = step
458
 
459
  log_step(step=step, action=action_for_log.replace("\n", ""), reward=reward, done=done, error=error)
460
+ # Keep prompt context clean: only feed back model-authored, schema-valid turns.
461
+ if parsed_model_action:
462
+ history.append(
463
+ {
464
+ "step": step,
465
+ "thought": model_thought_for_history,
466
+ "action": action_payload_for_history,
467
+ "result": obs.last_action_result,
468
+ "budget_remaining": obs.budget_remaining,
469
+ }
470
+ )
471
+ if len(history) > HISTORY_MAX_STEPS:
472
+ history = history[-HISTORY_MAX_STEPS:]
473
 
474
  if done:
475
  break
models.py CHANGED
@@ -11,13 +11,15 @@ The AML_env environment is a simple test environment that echoes back messages.
11
  """
12
 
13
  from openenv.core.env_server.types import Action, Observation
14
- from pydantic import BaseModel, Field
15
  from typing import List, Literal, Optional, Any, Union
16
 
17
  # ==========================================
18
  # OBSERVATION SPACE
19
  # ==========================================
20
  class AmlObservation(Observation):
 
 
21
  alert_details: str = Field(description="The constant mission objective and initial alert.")
22
  budget_remaining: int = Field(description="API calls remaining.")
23
  last_action: Optional[str] = Field(default=None, description="Last tool used.")
@@ -28,48 +30,53 @@ class AmlObservation(Observation):
28
  # ACTION SPACE
29
  # ==========================================
30
  class QueryTransactions(Action):
 
 
31
  action_type: Literal["query_transactions"]
32
- account_id: str = Field(description="The exact ACC-XXXX ID to query.")
33
- limit: int = Field(default=10, description="Max transactions to return.")
34
- offset: int = Field(default=0, description="Offset for pagination.")
35
 
36
  class SearchTransactions(Action):
 
 
37
  action_type: Literal["search_transactions"]
38
- account_id: str = Field(description="The exact ACC-XXXX ID to query.")
39
- keyword: str = Field(description="Keyword to search in memo_text.")
40
 
41
  class GetKYCRecord(Action):
 
 
42
  action_type: Literal["get_kyc_record"]
43
- entity_id: str = Field(description="The exact ENT-XXXX ID to look up.")
44
 
45
  class SubmitDecision(Action):
 
 
46
  action_type: Literal["submit_decision"]
47
  decision: Literal["FRAUD", "CLEAR"] = Field(description="Your final verdict.")
48
- evidence_links: List[str] = Field(description="List of ACC-XXXX or ENT-XXXX IDs proving fraud.")
49
-
50
-
51
- # ==========================================
52
- # OPTIONAL THOUGHT SCRATCHPAD
53
- # ==========================================
54
- class ThoughtProcess(BaseModel):
55
- observation: str = Field(
56
- description="Analyze what just happened and summarize useful clues from the last tool output."
57
- )
58
- plan: str = Field(
59
- description="State the next investigation step and why it follows from the current evidence."
60
- )
61
- action: str = Field(
62
- description="Explain which tool call you are about to make and with which key parameters."
63
  )
64
 
65
  # The master Action model using Union
66
  class AmlAction(Action):
67
- # Keep this optional so existing inference JSON remains compatible.
68
- thought: Optional[ThoughtProcess] = Field(
69
- default=None,
70
- description="Optional ReAct-style scratchpad for model reasoning.",
 
71
  )
72
  action: Union[QueryTransactions, SearchTransactions, GetKYCRecord, SubmitDecision] = Field(
73
  discriminator='action_type'
74
  )
75
-
 
 
 
 
 
 
 
 
 
11
  """
12
 
13
  from openenv.core.env_server.types import Action, Observation
14
+ from pydantic import ConfigDict, Field, field_validator
15
  from typing import List, Literal, Optional, Any, Union
16
 
17
  # ==========================================
18
  # OBSERVATION SPACE
19
  # ==========================================
20
  class AmlObservation(Observation):
21
+ model_config = ConfigDict(extra="forbid", strict=True)
22
+
23
  alert_details: str = Field(description="The constant mission objective and initial alert.")
24
  budget_remaining: int = Field(description="API calls remaining.")
25
  last_action: Optional[str] = Field(default=None, description="Last tool used.")
 
30
  # ACTION SPACE
31
  # ==========================================
32
  class QueryTransactions(Action):
33
+ model_config = ConfigDict(extra="forbid", strict=True)
34
+
35
  action_type: Literal["query_transactions"]
36
+ account_id: str = Field(pattern=r"^ACC-\d{4}$", description="The exact ACC-XXXX ID to query.")
37
+ limit: int = Field(default=10, ge=1, le=100, description="Max transactions to return.")
38
+ offset: int = Field(default=0, ge=0, description="Offset for pagination.")
39
 
40
  class SearchTransactions(Action):
41
+ model_config = ConfigDict(extra="forbid", strict=True)
42
+
43
  action_type: Literal["search_transactions"]
44
+ account_id: str = Field(pattern=r"^ACC-\d{4}$", description="The exact ACC-XXXX ID to query.")
45
+ keyword: str = Field(min_length=1, description="Keyword to search in memo_text.")
46
 
47
  class GetKYCRecord(Action):
48
+ model_config = ConfigDict(extra="forbid", strict=True)
49
+
50
  action_type: Literal["get_kyc_record"]
51
+ entity_id: str = Field(pattern=r"^ENT-\d{4}$", description="The exact ENT-XXXX ID to look up.")
52
 
53
  class SubmitDecision(Action):
54
+ model_config = ConfigDict(extra="forbid", strict=True)
55
+
56
  action_type: Literal["submit_decision"]
57
  decision: Literal["FRAUD", "CLEAR"] = Field(description="Your final verdict.")
58
+ evidence_links: List[str] = Field(
59
+ default_factory=list,
60
+ description="List of ACC-XXXX or ENT-XXXX IDs proving fraud.",
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
  # The master Action model using Union
64
  class AmlAction(Action):
65
+ model_config = ConfigDict(extra="forbid", strict=True)
66
+
67
+ thought: str = Field(
68
+ min_length=1,
69
+ description="Short thinking pad with Observation: and Plan: sections.",
70
  )
71
  action: Union[QueryTransactions, SearchTransactions, GetKYCRecord, SubmitDecision] = Field(
72
  discriminator='action_type'
73
  )
74
+
75
+ @field_validator("thought")
76
+ @classmethod
77
+ def thought_must_include_sections(cls, value: str) -> str:
78
+ text = value.strip()
79
+ lower_text = text.lower()
80
+ if "observation:" not in lower_text or "plan:" not in lower_text:
81
+ raise ValueError("thought must include 'Observation:' and 'Plan:' sections")
82
+ return text