""" Domain Schema Definition for domainTokenizer. A declarative format that describes the fields in a domain's event data. Each field type maps to a specific tokenization strategy. References: - Nubank nuFormer (arXiv:2507.23267): amount sign(2) + amount bucket(21) + calendar(74) + BPE text - ActionPiece (arXiv:2502.13581): unordered feature sets tokenized via BPE-like merging - Banking TF (arXiv:2410.08243): date + amount + wording composite tokens """ from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Any class FieldType(Enum): """Supported field types for domain tokenization. Each type maps to a specific tokenization strategy: SIGN → 2 tokens (positive/negative, credit/debit) NUMERICAL_CONTINUOUS → quantile-based magnitude bins (default: 21, following Nubank) NUMERICAL_DISCRETE → small fixed vocabulary for countable values (quantities, counts) CATEGORICAL_FIXED → direct vocabulary mapping from a known set of categories TEMPORAL → calendar decomposition into month/dow/dom/hour tokens TEXT → BPE subword tokenization (for descriptions, names, free text) """ SIGN = "sign" NUMERICAL_CONTINUOUS = "numerical_continuous" NUMERICAL_DISCRETE = "numerical_discrete" CATEGORICAL_FIXED = "categorical_fixed" TEMPORAL = "temporal" TEXT = "text" # Mapping from calendar field name → number of tokens it produces CALENDAR_FIELD_SIZES: Dict[str, int] = { "month": 12, "dow": 7, # day of week "dom": 31, # day of month "hour": 24, "quarter": 4, "minute_bin": 4, # 15-min bins: 0-14, 15-29, 30-44, 45-59 } @dataclass class FieldSpec: """Specification for a single field in a domain event. Args: name: Field name (must match the key in event dictionaries). field_type: How this field should be tokenized. prefix: Token prefix override. Defaults to uppercase field name. n_bins: Number of quantile bins for NUMERICAL_CONTINUOUS (default: 21, Nubank). categories: List of category values for CATEGORICAL_FIXED. vocab_size: Explicit vocab size for NUMERICAL_DISCRETE (e.g., 11 for quantities 0-10). calendar_fields: Which calendar components to extract for TEMPORAL fields. max_value: Upper bound for NUMERICAL_DISCRETE (tokens: 0..max_value, max_value+). """ name: str field_type: FieldType prefix: Optional[str] = None n_bins: int = 21 categories: Optional[List[str]] = None vocab_size: Optional[int] = None calendar_fields: List[str] = field(default_factory=lambda: ["month", "dow", "dom", "hour"]) max_value: Optional[int] = None def __post_init__(self): if self.prefix is None: self.prefix = self.name.upper() # Validation if self.field_type == FieldType.CATEGORICAL_FIXED and self.categories is None: raise ValueError(f"Field '{self.name}': CATEGORICAL_FIXED requires 'categories' list") if self.field_type == FieldType.NUMERICAL_DISCRETE and self.max_value is None: raise ValueError(f"Field '{self.name}': NUMERICAL_DISCRETE requires 'max_value'") @property def token_count(self) -> int: """Number of special tokens this field produces in the vocabulary.""" if self.field_type == FieldType.SIGN: return 2 elif self.field_type == FieldType.NUMERICAL_CONTINUOUS: return self.n_bins elif self.field_type == FieldType.NUMERICAL_DISCRETE: return self.max_value + 2 # 0..max_value + overflow bin elif self.field_type == FieldType.CATEGORICAL_FIXED: return len(self.categories) + 1 # +1 for unknown elif self.field_type == FieldType.TEMPORAL: return sum(CALENDAR_FIELD_SIZES.get(cf, 0) for cf in self.calendar_fields) elif self.field_type == FieldType.TEXT: return 0 # text tokens come from BPE, not special vocab return 0 @property def tokens_per_event(self) -> int: """Number of tokens this field contributes per event (fixed part only).""" if self.field_type == FieldType.SIGN: return 1 elif self.field_type in (FieldType.NUMERICAL_CONTINUOUS, FieldType.NUMERICAL_DISCRETE, FieldType.CATEGORICAL_FIXED): return 1 elif self.field_type == FieldType.TEMPORAL: return len(self.calendar_fields) elif self.field_type == FieldType.TEXT: return 0 # variable length return 0 @dataclass class DomainSchema: """Complete schema for a domain's event data. A schema defines the ordered list of fields that make up each event (transaction, purchase, clinical encounter, etc.). The field order determines the token order within each event. Args: name: Human-readable domain name (e.g., "finance", "ecommerce"). fields: Ordered list of field specifications. event_separator: Special token to separate events in a sequence. description: Optional human-readable description. """ name: str fields: List[FieldSpec] event_separator: str = "[SEP_EVENT]" description: str = "" @property def special_token_count(self) -> int: """Total number of domain-specific special tokens needed.""" # Base special tokens: PAD, UNK, BOS, EOS, MASK, CLS, SEP, event separator base = 8 return base + sum(f.token_count for f in self.fields) @property def fixed_tokens_per_event(self) -> int: """Number of fixed (non-text) tokens per event, including separator.""" return 1 + sum(f.tokens_per_event for f in self.fields) # +1 for event separator @property def has_text_fields(self) -> bool: """Whether the schema includes any free-text fields.""" return any(f.field_type == FieldType.TEXT for f in self.fields) @property def text_field_names(self) -> List[str]: """Names of all text fields in the schema.""" return [f.name for f in self.fields if f.field_type == FieldType.TEXT] @property def fittable_field_names(self) -> List[str]: """Names of fields that require fitting on training data.""" return [f.name for f in self.fields if f.field_type == FieldType.NUMERICAL_CONTINUOUS] def get_field(self, name: str) -> Optional[FieldSpec]: """Look up a field by name.""" for f in self.fields: if f.name == name: return f return None def summary(self) -> str: """Human-readable summary of the schema.""" lines = [f"DomainSchema: {self.name}"] if self.description: lines.append(f" {self.description}") lines.append(f" Fields: {len(self.fields)}") lines.append(f" Special tokens: {self.special_token_count}") lines.append(f" Fixed tokens/event: {self.fixed_tokens_per_event}") lines.append(f" Has text fields: {self.has_text_fields}") lines.append(f" Requires fitting: {self.fittable_field_names}") lines.append("") for f in self.fields: lines.append(f" [{f.field_type.value}] {f.name} → prefix={f.prefix}, " f"tokens_in_vocab={f.token_count}, tokens_per_event={f.tokens_per_event}") return "\n".join(lines)