Patch_Hawk / patchhawk /agent /environment.py
kanishcr7's picture
Updated the prompt
b70c5b9
"""
PatchHawkEnv – OpenEnv-compliant environment for
supply-chain vulnerability detection and patching.
Actions (PatchHawkAction.action_type):
0 = ANALYZE
1 = EXECUTE_SANDBOX
2 = BLOCK_PR
3 = SUBMIT_PATCH
4 = REQUEST_REVIEW
Reward table (set on Observation.reward):
Correct BLOCK on malicious β†’ +2.0
Correct SUBMIT_PATCH (validated) β†’ +3.0
BLOCK on benign β†’ βˆ’1.0
SUBMIT_PATCH on benign (applied) β†’ βˆ’1.5
Episode ends w/o block/patch on mal β†’ βˆ’5.0 (at max_steps)
EXECUTE_SANDBOX β†’ +0.1
"""
import json
import random
from typing import Optional, Any
from uuid import uuid4
from openenv.core import Environment
from patchhawk.env_models import (
PatchHawkAction,
PatchHawkObservation,
PatchHawkState,
)
from patchhawk.agent.sandbox import run_code, validate_patch
class PatchHawkEnv(Environment[PatchHawkAction, PatchHawkObservation, PatchHawkState]):
"""OpenEnv environment for PatchHawk."""
# Action constants
ACTION_ANALYZE = 0
ACTION_EXECUTE_SANDBOX = 1
ACTION_BLOCK_PR = 2
ACTION_SUBMIT_PATCH = 3
ACTION_REQUEST_REVIEW = 4
ACTION_NAMES = [
"ANALYZE",
"EXECUTE_SANDBOX",
"BLOCK_PR",
"SUBMIT_PATCH",
"REQUEST_REVIEW",
]
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(
self,
scenarios_path: str = "patchhawk/data/scenarios.json",
use_docker: bool = False,
max_steps: int = 5,
**kwargs: Any,
):
super().__init__(**kwargs)
self.use_docker = use_docker
self.max_steps = max_steps
self.scenarios_path = scenarios_path
self.scenarios = self._load_scenarios()
# Episode state
self.current_scenario: Optional[dict] = None
self.step_counter: int = 0
self.cumulative_reward: float = 0.0
self._last_action: Optional[int] = None
self._telemetry: Optional[dict] = None
self._patch_validated: bool = False
# Internal state object
self._state = PatchHawkState(
episode_id=str(uuid4()),
step_count=0,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _load_scenarios(self) -> list:
try:
with open(self.scenarios_path, "r") as f:
return json.load(f)
except Exception as e:
print(f"Warning: Could not load scenarios from {self.scenarios_path}: {e}")
return []
def _compute_static_flags(self, code: str) -> list[int]:
"""
Static heuristics used to seed a risk score.
Important: these are intentionally lightweight (string-based) so they work
in CPU-only judge environments without extra dependencies.
"""
lowered = code.lower()
# We keep flags small and interpretable; risk_score is mean(flags).
flags = [0] * 8
# 0) code execution primitives
if "eval(" in lowered or "exec(" in lowered:
flags[0] = 1
# 1) shell / process execution
if "subprocess" in lowered or "os.system" in lowered:
flags[1] = 1
# 2) network primitives
if "socket" in lowered or "requests" in lowered or "urllib" in lowered:
flags[2] = 1
# 3) environment manipulation / credential access
if "os.environ" in lowered:
flags[3] = 1
# 4) obfuscation indicators
if "base64" in lowered or "zlib" in lowered:
flags[4] = 1
# 5) typosquatting / suspicious imports (common misspellings from our scenarios)
# We treat these as high-signal in this hackathon setting.
if "import pythonn" in lowered or "import reqeusts" in lowered:
flags[5] = 1
# 6) unsafe deserialization
if "pickle.loads" in lowered:
flags[6] = 1
# 7) dynamic import / code download patterns (very coarse)
if "__import__(" in lowered or "importlib.import_module" in lowered:
flags[7] = 1
return flags
def _build_observation(
self,
*,
done: bool = False,
reward: float = 0.0,
reason: str = "",
extra_meta: Optional[dict] = None,
) -> PatchHawkObservation:
"""Build a PatchHawkObservation from current episode state."""
code = self.current_scenario["code_snippet"] if self.current_scenario else ""
flags = self._compute_static_flags(code)
risk = sum(flags) / max(len(flags), 1)
meta: dict[str, Any] = {
"scenario_id": self.current_scenario.get("id", "none")
if self.current_scenario
else "none",
"step": self.step_counter,
"cumulative_reward": self.cumulative_reward,
"reward_reason": reason,
}
if extra_meta:
meta.update(extra_meta)
telemetry_str: Optional[str] = None
if self._telemetry:
telemetry_str = json.dumps(self._telemetry, default=str)
return PatchHawkObservation(
code_snippet=code,
static_flags=flags,
risk_score=round(risk, 4),
sandbox_telemetry=telemetry_str,
done=done,
reward=reward,
metadata=meta,
)
# ------------------------------------------------------------------
# OpenEnv API
# ------------------------------------------------------------------
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> PatchHawkObservation:
"""Reset the environment and return the initial observation.
Keyword args:
task_id: Optional[str] β€” filter scenarios by task_id field
"""
self._reset_rubric()
task_id: Optional[str] = kwargs.get("task_id")
self.current_task: Optional[str] = task_id
if seed is not None:
random.seed(seed)
# Check for direct scenario override (used by GRPO training)
scenario_override = kwargs.get("scenario")
# Pick scenario
if scenario_override:
self.current_scenario = scenario_override
elif not self.scenarios:
self.current_scenario = {
"id": "fallback",
"type": "functional",
"label": "benign",
"code_snippet": "print('hello')",
"patch": None,
"unit_test_code": None,
"attack_type": None,
"task_id": None,
}
elif task_id:
# Primary: filter by task_id field in scenarios.json
matches = [s for s in self.scenarios if s.get("task_id") == task_id]
if not matches:
# Fallback: map task_id to attack_type for backwards compat
_TASK_FILTER = {
"easy_typosquat": "typosquatting",
"medium_obfuscated": "obfuscated_exec",
"hard_patch": None,
}
atk = _TASK_FILTER.get(task_id)
if atk:
matches = [s for s in self.scenarios if s.get("attack_type") == atk]
else:
matches = [
s
for s in self.scenarios
if s.get("label") == "malicious" and s.get("patch")
]
self.current_scenario = (
random.choice(matches) if matches else random.choice(self.scenarios)
)
else:
self.current_scenario = random.choice(self.scenarios)
# Reset episode counters
self.step_counter = 0
self.cumulative_reward = 0.0
self._last_action = None
self._telemetry = None
self._patch_validated = False
self._state = PatchHawkState(
episode_id=episode_id or str(uuid4()),
step_count=0,
scenario_id=self.current_scenario.get("id", "none"),
current_task=task_id,
last_action_type=None,
patch_validated=False,
sandbox_log=None,
)
return self._build_observation(reason="reset")
def step(
self,
action: PatchHawkAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> PatchHawkObservation:
"""Execute one step in the environment."""
self.step_counter += 1
act = action.action_type
self._last_action = act
reward = 0.0
done = False
reason = ""
extra_meta: dict[str, Any] = {"action_name": self.ACTION_NAMES[act]}
label = self.current_scenario["label"]
# ── ANALYZE ──────────────────────────────────────────────
if act == self.ACTION_ANALYZE:
reward = 0.0
reason = "analyze β€” no reward"
# ── EXECUTE_SANDBOX ──────────────────────────────────────
elif act == self.ACTION_EXECUTE_SANDBOX:
result = run_code(
self.current_scenario["code_snippet"],
timeout_sec=3,
use_docker=self.use_docker,
)
self._telemetry = result
extra_meta["telemetry"] = result
reward = 0.1
reason = "sandbox executed"
# ── BLOCK_PR ─────────────────────────────────────────────
elif act == self.ACTION_BLOCK_PR:
if label == "malicious":
reward = 2.0
reason = "correct block on malicious"
else:
reward = -1.0
reason = "incorrect block on benign"
done = True
# ── SUBMIT_PATCH ─────────────────────────────────────────
elif act == self.ACTION_SUBMIT_PATCH:
patch_code = action.patch_content or self.current_scenario.get("patch")
if not patch_code:
reward = -0.5
reason = "no patch available"
extra_meta["validation"] = "No patch available"
else:
success, msg, details = validate_patch(
self.current_scenario, patch_code, use_docker=self.use_docker
)
extra_meta["validation"] = msg
extra_meta["details"] = details
self._patch_validated = success
if success:
if label == "malicious":
reward = 3.0
reason = "valid patch on malicious"
else:
reward = -1.5
reason = "patch applied to benign code"
else:
reward = -1.0
reason = f"patch validation failed: {msg}"
done = True
# ── REQUEST_REVIEW ───────────────────────────────────────
elif act == self.ACTION_REQUEST_REVIEW:
reward = 0.0
reason = "requested human review"
done = True
# ── Max-step penalty ────────────────────────────────────
if self.step_counter >= self.max_steps and not done:
done = True
if label == "malicious":
reward -= 5.0
reason += " | max steps reached on malicious scenario"
# ── Dynamic Risk Bonus ───────────────────────────────────
predict_risk = getattr(action, "predicted_risk", None)
if predict_risk is not None:
actual_risk = 1.0 if label == "malicious" else 0.0
accuracy_bonus = (1.0 - abs(actual_risk - float(predict_risk))) * 2.0
reward += accuracy_bonus
reason += f" | AI risk accuracy bonus: +{accuracy_bonus:.2f}"
self.cumulative_reward += reward
# Update internal state
self._state.step_count = self.step_counter
self._state.last_action_type = act
self._state.patch_validated = self._patch_validated
if self._telemetry:
self._state.sandbox_log = json.dumps(self._telemetry, default=str)
obs = self._build_observation(
done=done,
reward=reward,
reason=reason,
extra_meta=extra_meta,
)
return self._apply_transform(obs)
@property
def state(self) -> PatchHawkState:
"""Return the current internal state."""
return self._state
def close(self) -> None:
"""Clean up any Docker containers launched during the episode."""
self.current_scenario = None
self._telemetry = None