Spaces:
Sleeping
Sleeping
| """ | |
| Anti-Exploit Protections for Data-Centric RL Environment. | |
| Centralised module for all anti-hacking checks: | |
| 1. Input truncation (>200 chars β truncate, -0.02 penalty) | |
| 2. Validate spam prevention (cooldown + diminishing returns) | |
| 3. Recommendation ID staleness check | |
| 4. Ground truth immutability assertion | |
| 5. Catastrophic data loss detection | |
| 6. Duplicate apply prevention | |
| 7. Max applies per session (3) | |
| 8. Episode wall-clock timeout (5 min β forced submit, -0.10) | |
| 9. Step timeout (5 sec β timeout obs, -0.05) | |
| """ | |
| import logging | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Set | |
| logger = logging.getLogger(__name__) | |
| MAX_ACTION_CHARS = 200 | |
| MAX_APPLIES_PER_SESSION = 3 | |
| FREE_VALIDATES = 3 | |
| VALIDATE_COOLDOWN = 2 # must take this many non-validate actions before next validate | |
| EPISODE_TIMEOUT_SECS = 5 * 60 # 5 minutes | |
| STEP_TIMEOUT_SECS = 5 # 5 seconds per step | |
| # ββ Exploit tracker (per episode state) ββββββββββββββββββββββββββββββββββββββ | |
| class AntiExploitState: | |
| # Validate tracking | |
| validate_call_count: int = 0 | |
| steps_since_last_validate: int = 0 # cooldown counter | |
| # Apply tracking | |
| applied_ids_this_session: Set[int] = field(default_factory=set) | |
| applies_this_session: int = 0 | |
| # Timing | |
| episode_start_time: float = field(default_factory=time.time) | |
| # Ground truth row count (set at reset) | |
| ground_truth_row_count: int = 0 | |
| # ββ 1. Input truncation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_and_truncate_input(action: str) -> tuple[str, float, bool]: | |
| """ | |
| Returns (truncated_action, penalty, was_truncated). | |
| Penalty is -0.02 if truncated, else 0.0. | |
| """ | |
| if len(action) > MAX_ACTION_CHARS: | |
| logger.warning( | |
| "Input truncated: original length %d > %d", len(action), MAX_ACTION_CHARS | |
| ) | |
| return action[:MAX_ACTION_CHARS], -0.02, True | |
| return action, 0.0, False | |
| # ββ 2. Validate cooldown ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_validate_cooldown(state: AntiExploitState) -> tuple[bool, str]: | |
| """ | |
| Returns (allowed, error_message). | |
| Validate is blocked if steps_since_last_validate < VALIDATE_COOLDOWN. | |
| """ | |
| if state.steps_since_last_validate < VALIDATE_COOLDOWN and state.validate_call_count > 0: | |
| return False, ( | |
| f"Validate on cooldown. Take {VALIDATE_COOLDOWN - state.steps_since_last_validate} " | |
| f"more action(s) before validating again." | |
| ) | |
| return True, "" | |
| def get_validate_reward(state: AntiExploitState) -> float: | |
| """Returns +0.02 for first FREE_VALIDATES calls, -0.01 thereafter.""" | |
| if state.validate_call_count < FREE_VALIDATES: | |
| return 0.02 | |
| return -0.01 | |
| def record_validate(state: AntiExploitState): | |
| state.validate_call_count += 1 | |
| state.steps_since_last_validate = 0 | |
| def record_non_validate_step(state: AntiExploitState): | |
| state.steps_since_last_validate += 1 | |
| # ββ 3. Recommendation staleness βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_recommendation_staleness( | |
| rec_id: int, | |
| current_session_id: str, | |
| recommendation_session_id: str, | |
| ) -> tuple[bool, str]: | |
| """Returns (is_fresh, error_message).""" | |
| if current_session_id != recommendation_session_id: | |
| return False, ( | |
| f"Stale recommendation ID {rec_id}. " | |
| "Please re-query for fresh recommendations." | |
| ) | |
| return True, "" | |
| # ββ 4. Ground truth immutability ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def assert_ground_truth_intact( | |
| ground_truth_len: int, | |
| original_gt_len: int, | |
| ) -> tuple[bool, str]: | |
| """Asserts ground truth has not been mutated.""" | |
| if ground_truth_len != original_gt_len: | |
| msg = ( | |
| f"INTEGRITY VIOLATION: ground_truth row count changed " | |
| f"({original_gt_len} β {ground_truth_len}). This should never happen." | |
| ) | |
| logger.critical(msg) | |
| return False, msg | |
| return True, "" | |
| # ββ 5. Catastrophic data loss βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_catastrophic_data_loss( | |
| current_rows: int, | |
| original_rows: int, | |
| ) -> tuple[bool, str]: | |
| """Returns (is_catastrophic, message).""" | |
| ratio = current_rows / max(original_rows, 1) | |
| if ratio < 0.50: | |
| msg = ( | |
| f"CATASTROPHIC DATA LOSS: only {current_rows}/{original_rows} rows remain " | |
| f"({ratio*100:.1f}%). Episode terminated." | |
| ) | |
| logger.error(msg) | |
| return True, msg | |
| return False, "" | |
| # ββ 6 & 7. Duplicate apply and session limit ββββββββββββββββββββββββββββββββββ | |
| def check_apply_allowed( | |
| rec_id: int, | |
| state: AntiExploitState, | |
| ) -> tuple[bool, str]: | |
| """ | |
| Returns (allowed, error_message). | |
| Blocks: duplicate ID in session, or session apply limit reached. | |
| """ | |
| if state.applies_this_session >= MAX_APPLIES_PER_SESSION: | |
| return False, ( | |
| f"Max {MAX_APPLIES_PER_SESSION} applies per query session reached. " | |
| "Please re-query for more options." | |
| ) | |
| if rec_id in state.applied_ids_this_session: | |
| return False, ( | |
| f"Recommendation {rec_id} has already been applied this session. " | |
| "Duplicate apply not allowed." | |
| ) | |
| return True, "" | |
| def record_apply(rec_id: int, state: AntiExploitState): | |
| state.applied_ids_this_session.add(rec_id) | |
| state.applies_this_session += 1 | |
| def reset_session_apply_state(state: AntiExploitState): | |
| """Call this whenever a new query_X command resets the session.""" | |
| state.applied_ids_this_session = set() | |
| state.applies_this_session = 0 | |
| # ββ 8. Episode timeout ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_episode_timeout(state: AntiExploitState) -> tuple[bool, str]: | |
| elapsed = time.time() - state.episode_start_time | |
| if elapsed > EPISODE_TIMEOUT_SECS: | |
| msg = ( | |
| f"Episode wall-clock timeout ({elapsed:.0f}s > {EPISODE_TIMEOUT_SECS}s). " | |
| "Forcing submit. Penalty: -0.10." | |
| ) | |
| logger.warning(msg) | |
| return True, msg | |
| return False, "" | |
| # ββ 9. Step timeout context manager ββββββββββββββββββββββββββββββββββββββββββ | |
| class StepTimeoutError(Exception): | |
| pass | |
| def validate_calls_remaining(state: AntiExploitState) -> int: | |
| return max(0, FREE_VALIDATES - state.validate_call_count) | |