CyberSecurity_OWASP / training /trackio_utils.py
Humanlearning's picture
Update CyberSecurity_OWASP environment
e3d939d verified
"""Trackio helpers used by training and evaluation scripts."""
from __future__ import annotations
import os
import subprocess
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any, Iterator
TRAIN_METRICS = [
"train/reward_total_mean",
"train/reward_discovery_mean",
"train/reward_security_mean",
"train/reward_regression_mean",
"train/reward_public_routes_mean",
"train/reward_patch_quality_mean",
"train/reward_visible_tests_mean",
"train/reward_safety_mean",
"train/reward_anti_cheat_mean",
"train/success_rate",
"train/exploit_block_rate",
"train/regression_preservation_rate",
"train/public_route_preservation_rate",
"train/invalid_action_rate",
"train/timeout_rate",
"train/safety_violation_rate",
"train/reward_hacking_suspected_rate",
"train/episode_length_mean",
"train/episode_length_p95",
"train/rollouts_per_second",
"train/tokens_per_second",
"train/loss",
"train/learning_rate",
"train/kl",
"train/grad_norm",
]
EVAL_METRICS = [
"eval/baseline_success_rate",
"eval/trained_success_rate",
"eval/absolute_success_improvement",
"eval/baseline_mean_reward",
"eval/trained_mean_reward",
"eval/absolute_reward_improvement",
"eval/heldout_success_rate",
"eval/heldout_mean_reward",
"eval/exploit_block_rate",
"eval/regression_preservation_rate",
"eval/public_route_preservation_rate",
"eval/anti_cheat_pass_rate",
"eval/invalid_action_rate",
"eval/timeout_rate",
"eval/safety_violation_rate",
"eval/mean_episode_length",
]
def build_run_name(model: str, algo: str, difficulty: int, git_sha: str = "nogit") -> str:
stamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
model_slug = model.replace("/", "-")
return f"CyberSecurity_OWASP-{model_slug}-{algo}-level{difficulty}-{stamp}-{git_sha[:8]}"
def get_git_sha(default: str = "nogit") -> str:
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
check=True,
capture_output=True,
text=True,
)
except Exception:
return default
return result.stdout.strip() or default
def _load_trackio():
os.environ.setdefault("TRACKIO_DIR", str((Path.cwd() / "outputs" / "trackio").resolve()))
try:
import trackio
except ImportError as exc:
raise RuntimeError(
"Trackio is required for CyberSecurity_OWASP runs. Install dependencies "
"with `uv sync` and set TRACKIO_SPACE_ID when you want remote HF Spaces tracking."
) from exc
return trackio
def init_trackio_run(
*,
run_name: str,
run_type: str,
config: dict[str, Any] | None = None,
project: str | None = None,
space_id: str | None = None,
group: str | None = None,
):
trackio = _load_trackio()
project = project or os.getenv("TRACKIO_PROJECT", "CyberSecurity_OWASP")
space_id = space_id if space_id is not None else os.getenv("TRACKIO_SPACE_ID", "")
run_config = {
"environment": "CyberSecurity_OWASP",
"run_type": run_type,
**(config or {}),
}
kwargs: dict[str, Any] = {
"project": project,
"name": run_name,
"config": run_config,
}
if space_id:
kwargs["space_id"] = space_id
if group:
kwargs["group"] = group
return trackio.init(**kwargs)
def log_trackio_metrics(metrics: dict[str, Any], step: int | None = None) -> None:
trackio = _load_trackio()
numeric = {
key: value
for key, value in metrics.items()
if isinstance(value, (int, float, bool))
}
if step is None:
trackio.log(numeric)
else:
trackio.log(numeric, step=step)
def finish_trackio_run() -> None:
trackio = _load_trackio()
trackio.finish()
@contextmanager
def trackio_run(
*,
run_name: str,
run_type: str,
config: dict[str, Any] | None = None,
project: str | None = None,
space_id: str | None = None,
group: str | None = None,
) -> Iterator[Any]:
run = init_trackio_run(
run_name=run_name,
run_type=run_type,
config=config,
project=project,
space_id=space_id,
group=group,
)
try:
yield run
finally:
finish_trackio_run()
def log_eval_summary(run_name: str, summary: dict[str, Any], config: dict[str, Any] | None = None) -> None:
metrics = {
f"eval/{key}": float(value)
for key, value in summary.items()
if isinstance(value, (int, float, bool))
}
with trackio_run(run_name=run_name, run_type="eval", config=config, group="eval"):
log_trackio_metrics(metrics, step=0)