File size: 3,369 Bytes
c0db7bb
 
 
 
 
 
 
 
 
 
 
 
 
74aae3b
81e1efb
c0db7bb
81e1efb
 
 
c0db7bb
74aae3b
 
81e1efb
 
 
 
 
 
 
 
 
 
74aae3b
 
81e1efb
74aae3b
 
 
81e1efb
 
74aae3b
 
81e1efb
74aae3b
 
81e1efb
 
74aae3b
 
81e1efb
74aae3b
81e1efb
 
74aae3b
 
81e1efb
 
74aae3b
 
 
dfd1faa
 
81e1efb
 
74aae3b
 
 
 
 
dfd1faa
81e1efb
 
a4c032a
74aae3b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Data models for the Aml Env Environment.

The AML_env environment is a simple test environment that echoes back messages.
"""

from openenv.core.env_server.types import Action, Observation
from pydantic import ConfigDict, Field, field_validator
from typing import List, Literal, Optional, Any, Union

# ==========================================
# OBSERVATION SPACE
# ==========================================
class AmlObservation(Observation):
    model_config = ConfigDict(extra="forbid", strict=True)

    alert_details: str = Field(description="The constant mission objective and initial alert.")
    budget_remaining: int = Field(description="API calls remaining.")
    last_action: Optional[str] = Field(default=None, description="Last tool used.")
    last_action_result: Optional[Any] = Field(default=None, description="Payload returned by the API.")
    error_message: Optional[str] = Field(default=None, description="Error string if action failed.")

# ==========================================
# ACTION SPACE
# ==========================================
class QueryTransactions(Action):
    model_config = ConfigDict(extra="forbid", strict=True)

    action_type: Literal["query_transactions"]
    account_id: str = Field(pattern=r"^ACC-\d{4}$", description="The exact ACC-XXXX ID to query.")
    limit: int = Field(default=10, ge=1, le=100, description="Max transactions to return.")
    offset: int = Field(default=0, ge=0, description="Offset for pagination.")

class SearchTransactions(Action):
    model_config = ConfigDict(extra="forbid", strict=True)

    action_type: Literal["search_transactions"]
    account_id: str = Field(pattern=r"^ACC-\d{4}$", description="The exact ACC-XXXX ID to query.")
    keyword: str = Field(min_length=1, description="Keyword to search in memo_text.")

class GetKYCRecord(Action):
    model_config = ConfigDict(extra="forbid", strict=True)

    action_type: Literal["get_kyc_record"]
    entity_id: str = Field(pattern=r"^ENT-\d{4}$", description="The exact ENT-XXXX ID to look up.")

class SubmitDecision(Action):
    model_config = ConfigDict(extra="forbid", strict=True)

    action_type: Literal["submit_decision"]
    decision: Literal["FRAUD", "CLEAR"] = Field(description="Your final verdict.")
    evidence_links: List[str] = Field(
        default_factory=list,
        description="List of ACC-XXXX or ENT-XXXX IDs proving fraud.",
    )

# The master Action model using Union
class AmlAction(Action):
    model_config = ConfigDict(extra="forbid", strict=True)

    thought: str = Field(
        min_length=1,
        description="Short thinking pad with Observation: and Plan: sections.",
    )
    action: Union[QueryTransactions, SearchTransactions, GetKYCRecord, SubmitDecision] = Field(
        discriminator='action_type'
    )

    @field_validator("thought")
    @classmethod
    def thought_must_include_sections(cls, value: str) -> str:
        text = value.strip()
        lower_text = text.lower()
        if "observation:" not in lower_text or "plan:" not in lower_text:
            raise ValueError("thought must include 'Observation:' and 'Plan:' sections")
        return text