data-centric-env / server /anti_exploit.py
Aswini-Kumar's picture
Data-Centric AI RL Environment β€” OpenEnv Hackathon Submission
71dc210
"""
Anti-Exploit Protections for Data-Centric RL Environment.
Centralised module for all anti-hacking checks:
1. Input truncation (>200 chars β†’ truncate, -0.02 penalty)
2. Validate spam prevention (cooldown + diminishing returns)
3. Recommendation ID staleness check
4. Ground truth immutability assertion
5. Catastrophic data loss detection
6. Duplicate apply prevention
7. Max applies per session (3)
8. Episode wall-clock timeout (5 min β†’ forced submit, -0.10)
9. Step timeout (5 sec β†’ timeout obs, -0.05)
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Optional, Set
logger = logging.getLogger(__name__)
MAX_ACTION_CHARS = 200
MAX_APPLIES_PER_SESSION = 3
FREE_VALIDATES = 3
VALIDATE_COOLDOWN = 2 # must take this many non-validate actions before next validate
EPISODE_TIMEOUT_SECS = 5 * 60 # 5 minutes
STEP_TIMEOUT_SECS = 5 # 5 seconds per step
# ── Exploit tracker (per episode state) ──────────────────────────────────────
@dataclass
class AntiExploitState:
# Validate tracking
validate_call_count: int = 0
steps_since_last_validate: int = 0 # cooldown counter
# Apply tracking
applied_ids_this_session: Set[int] = field(default_factory=set)
applies_this_session: int = 0
# Timing
episode_start_time: float = field(default_factory=time.time)
# Ground truth row count (set at reset)
ground_truth_row_count: int = 0
# ── 1. Input truncation ───────────────────────────────────────────────────────
def check_and_truncate_input(action: str) -> tuple[str, float, bool]:
"""
Returns (truncated_action, penalty, was_truncated).
Penalty is -0.02 if truncated, else 0.0.
"""
if len(action) > MAX_ACTION_CHARS:
logger.warning(
"Input truncated: original length %d > %d", len(action), MAX_ACTION_CHARS
)
return action[:MAX_ACTION_CHARS], -0.02, True
return action, 0.0, False
# ── 2. Validate cooldown ──────────────────────────────────────────────────────
def check_validate_cooldown(state: AntiExploitState) -> tuple[bool, str]:
"""
Returns (allowed, error_message).
Validate is blocked if steps_since_last_validate < VALIDATE_COOLDOWN.
"""
if state.steps_since_last_validate < VALIDATE_COOLDOWN and state.validate_call_count > 0:
return False, (
f"Validate on cooldown. Take {VALIDATE_COOLDOWN - state.steps_since_last_validate} "
f"more action(s) before validating again."
)
return True, ""
def get_validate_reward(state: AntiExploitState) -> float:
"""Returns +0.02 for first FREE_VALIDATES calls, -0.01 thereafter."""
if state.validate_call_count < FREE_VALIDATES:
return 0.02
return -0.01
def record_validate(state: AntiExploitState):
state.validate_call_count += 1
state.steps_since_last_validate = 0
def record_non_validate_step(state: AntiExploitState):
state.steps_since_last_validate += 1
# ── 3. Recommendation staleness ───────────────────────────────────────────────
def check_recommendation_staleness(
rec_id: int,
current_session_id: str,
recommendation_session_id: str,
) -> tuple[bool, str]:
"""Returns (is_fresh, error_message)."""
if current_session_id != recommendation_session_id:
return False, (
f"Stale recommendation ID {rec_id}. "
"Please re-query for fresh recommendations."
)
return True, ""
# ── 4. Ground truth immutability ──────────────────────────────────────────────
def assert_ground_truth_intact(
ground_truth_len: int,
original_gt_len: int,
) -> tuple[bool, str]:
"""Asserts ground truth has not been mutated."""
if ground_truth_len != original_gt_len:
msg = (
f"INTEGRITY VIOLATION: ground_truth row count changed "
f"({original_gt_len} β†’ {ground_truth_len}). This should never happen."
)
logger.critical(msg)
return False, msg
return True, ""
# ── 5. Catastrophic data loss ─────────────────────────────────────────────────
def check_catastrophic_data_loss(
current_rows: int,
original_rows: int,
) -> tuple[bool, str]:
"""Returns (is_catastrophic, message)."""
ratio = current_rows / max(original_rows, 1)
if ratio < 0.50:
msg = (
f"CATASTROPHIC DATA LOSS: only {current_rows}/{original_rows} rows remain "
f"({ratio*100:.1f}%). Episode terminated."
)
logger.error(msg)
return True, msg
return False, ""
# ── 6 & 7. Duplicate apply and session limit ──────────────────────────────────
def check_apply_allowed(
rec_id: int,
state: AntiExploitState,
) -> tuple[bool, str]:
"""
Returns (allowed, error_message).
Blocks: duplicate ID in session, or session apply limit reached.
"""
if state.applies_this_session >= MAX_APPLIES_PER_SESSION:
return False, (
f"Max {MAX_APPLIES_PER_SESSION} applies per query session reached. "
"Please re-query for more options."
)
if rec_id in state.applied_ids_this_session:
return False, (
f"Recommendation {rec_id} has already been applied this session. "
"Duplicate apply not allowed."
)
return True, ""
def record_apply(rec_id: int, state: AntiExploitState):
state.applied_ids_this_session.add(rec_id)
state.applies_this_session += 1
def reset_session_apply_state(state: AntiExploitState):
"""Call this whenever a new query_X command resets the session."""
state.applied_ids_this_session = set()
state.applies_this_session = 0
# ── 8. Episode timeout ────────────────────────────────────────────────────────
def check_episode_timeout(state: AntiExploitState) -> tuple[bool, str]:
elapsed = time.time() - state.episode_start_time
if elapsed > EPISODE_TIMEOUT_SECS:
msg = (
f"Episode wall-clock timeout ({elapsed:.0f}s > {EPISODE_TIMEOUT_SECS}s). "
"Forcing submit. Penalty: -0.10."
)
logger.warning(msg)
return True, msg
return False, ""
# ── 9. Step timeout context manager ──────────────────────────────────────────
class StepTimeoutError(Exception):
pass
def validate_calls_remaining(state: AntiExploitState) -> int:
return max(0, FREE_VALIDATES - state.validate_call_count)