""" Anti-Exploit Protections for Data-Centric RL Environment. Centralised module for all anti-hacking checks: 1. Input truncation (>200 chars → truncate, -0.02 penalty) 2. Validate spam prevention (cooldown + diminishing returns) 3. Recommendation ID staleness check 4. Ground truth immutability assertion 5. Catastrophic data loss detection 6. Duplicate apply prevention 7. Max applies per session (3) 8. Episode wall-clock timeout (5 min → forced submit, -0.10) 9. Step timeout (5 sec → timeout obs, -0.05) """ import logging import time from dataclasses import dataclass, field from typing import Optional, Set logger = logging.getLogger(__name__) MAX_ACTION_CHARS = 200 MAX_APPLIES_PER_SESSION = 3 FREE_VALIDATES = 3 VALIDATE_COOLDOWN = 2 # must take this many non-validate actions before next validate EPISODE_TIMEOUT_SECS = 5 * 60 # 5 minutes STEP_TIMEOUT_SECS = 5 # 5 seconds per step # ── Exploit tracker (per episode state) ────────────────────────────────────── @dataclass class AntiExploitState: # Validate tracking validate_call_count: int = 0 steps_since_last_validate: int = 0 # cooldown counter # Apply tracking applied_ids_this_session: Set[int] = field(default_factory=set) applies_this_session: int = 0 # Timing episode_start_time: float = field(default_factory=time.time) # Ground truth row count (set at reset) ground_truth_row_count: int = 0 # ── 1. Input truncation ─────────────────────────────────────────────────────── def check_and_truncate_input(action: str) -> tuple[str, float, bool]: """ Returns (truncated_action, penalty, was_truncated). Penalty is -0.02 if truncated, else 0.0. """ if len(action) > MAX_ACTION_CHARS: logger.warning( "Input truncated: original length %d > %d", len(action), MAX_ACTION_CHARS ) return action[:MAX_ACTION_CHARS], -0.02, True return action, 0.0, False # ── 2. Validate cooldown ────────────────────────────────────────────────────── def check_validate_cooldown(state: AntiExploitState) -> tuple[bool, str]: """ Returns (allowed, error_message). Validate is blocked if steps_since_last_validate < VALIDATE_COOLDOWN. """ if state.steps_since_last_validate < VALIDATE_COOLDOWN and state.validate_call_count > 0: return False, ( f"Validate on cooldown. Take {VALIDATE_COOLDOWN - state.steps_since_last_validate} " f"more action(s) before validating again." ) return True, "" def get_validate_reward(state: AntiExploitState) -> float: """Returns +0.02 for first FREE_VALIDATES calls, -0.01 thereafter.""" if state.validate_call_count < FREE_VALIDATES: return 0.02 return -0.01 def record_validate(state: AntiExploitState): state.validate_call_count += 1 state.steps_since_last_validate = 0 def record_non_validate_step(state: AntiExploitState): state.steps_since_last_validate += 1 # ── 3. Recommendation staleness ─────────────────────────────────────────────── def check_recommendation_staleness( rec_id: int, current_session_id: str, recommendation_session_id: str, ) -> tuple[bool, str]: """Returns (is_fresh, error_message).""" if current_session_id != recommendation_session_id: return False, ( f"Stale recommendation ID {rec_id}. " "Please re-query for fresh recommendations." ) return True, "" # ── 4. Ground truth immutability ────────────────────────────────────────────── def assert_ground_truth_intact( ground_truth_len: int, original_gt_len: int, ) -> tuple[bool, str]: """Asserts ground truth has not been mutated.""" if ground_truth_len != original_gt_len: msg = ( f"INTEGRITY VIOLATION: ground_truth row count changed " f"({original_gt_len} → {ground_truth_len}). This should never happen." ) logger.critical(msg) return False, msg return True, "" # ── 5. Catastrophic data loss ───────────────────────────────────────────────── def check_catastrophic_data_loss( current_rows: int, original_rows: int, ) -> tuple[bool, str]: """Returns (is_catastrophic, message).""" ratio = current_rows / max(original_rows, 1) if ratio < 0.50: msg = ( f"CATASTROPHIC DATA LOSS: only {current_rows}/{original_rows} rows remain " f"({ratio*100:.1f}%). Episode terminated." ) logger.error(msg) return True, msg return False, "" # ── 6 & 7. Duplicate apply and session limit ────────────────────────────────── def check_apply_allowed( rec_id: int, state: AntiExploitState, ) -> tuple[bool, str]: """ Returns (allowed, error_message). Blocks: duplicate ID in session, or session apply limit reached. """ if state.applies_this_session >= MAX_APPLIES_PER_SESSION: return False, ( f"Max {MAX_APPLIES_PER_SESSION} applies per query session reached. " "Please re-query for more options." ) if rec_id in state.applied_ids_this_session: return False, ( f"Recommendation {rec_id} has already been applied this session. " "Duplicate apply not allowed." ) return True, "" def record_apply(rec_id: int, state: AntiExploitState): state.applied_ids_this_session.add(rec_id) state.applies_this_session += 1 def reset_session_apply_state(state: AntiExploitState): """Call this whenever a new query_X command resets the session.""" state.applied_ids_this_session = set() state.applies_this_session = 0 # ── 8. Episode timeout ──────────────────────────────────────────────────────── def check_episode_timeout(state: AntiExploitState) -> tuple[bool, str]: elapsed = time.time() - state.episode_start_time if elapsed > EPISODE_TIMEOUT_SECS: msg = ( f"Episode wall-clock timeout ({elapsed:.0f}s > {EPISODE_TIMEOUT_SECS}s). " "Forcing submit. Penalty: -0.10." ) logger.warning(msg) return True, msg return False, "" # ── 9. Step timeout context manager ────────────────────────────────────────── class StepTimeoutError(Exception): pass def validate_calls_remaining(state: AntiExploitState) -> int: return max(0, FREE_VALIDATES - state.validate_call_count)