linvest21's picture
download
raw
6.67 kB
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from n21.config import write_json
from n21.settings import CONFIG_ROOT, SHFT_WORKSPACE_ROOT
from observability.audit_log import utc_now
from eval.model_quality_gate import load_model_quality_thresholds
def _load_json(path: Path) -> dict[str, Any] | None:
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8-sig"))
except (OSError, json.JSONDecodeError):
return None
def _number(value: Any) -> float:
try:
return float(value)
except (TypeError, ValueError):
return 0.0
def _metrics(run_path: Path) -> dict[str, Any]:
paired = _load_json(run_path / "eval" / "paired_eval_report.json") or {}
gate = _load_json(run_path / "eval" / "model_quality_gate.json") or paired.get("model_quality_gate") or {}
candidate = paired.get("candidate", {})
baseline = paired.get("baseline", {})
improvement = paired.get("improvement", {})
thresholds = load_model_quality_thresholds()
paired_thresholds = thresholds.get("paired_eval", {})
min_aggregate = _number(paired_thresholds.get("min_candidate_aggregate", 0.6))
min_critical = _number(paired_thresholds.get("min_candidate_critical_pass_rate", 0.7))
aggregate = _number(candidate.get("aggregate"))
critical = _number(candidate.get("critical_pass_rate"))
return {
"run_id": run_path.name,
"run_dir": str(run_path),
"paired_eval_present": bool(paired),
"model_quality_ok": bool(gate.get("ok")),
"eligible_for_promotion": bool(gate.get("eligible_for_promotion") or paired.get("promotion_gate", {}).get("eligible_for_promotion")),
"baseline_aggregate": _number(baseline.get("aggregate")),
"candidate_aggregate": aggregate,
"candidate_critical_pass_rate": critical,
"aggregate_abs": _number(improvement.get("aggregate_abs")),
"pairwise_win_rate": _number(improvement.get("pairwise_win_rate")),
"pairwise_loss_rate": _number(improvement.get("pairwise_loss_rate")),
"sample_count": int(_number(paired.get("sample_count") or candidate.get("sample_count"))),
"distance_to_thresholds": {
"candidate_aggregate_gap": round(max(0.0, min_aggregate - aggregate), 6),
"critical_pass_gap": round(max(0.0, min_critical - critical), 6),
},
"gate_errors": list(gate.get("errors", [])),
}
def _rank(metrics: dict[str, Any]) -> tuple[float, float, float, float, float, float]:
return (
1.0 if metrics.get("model_quality_ok") else 0.0,
_number(metrics.get("candidate_aggregate")),
_number(metrics.get("candidate_critical_pass_rate")),
_number(metrics.get("aggregate_abs")),
_number(metrics.get("pairwise_win_rate")),
-_number(metrics.get("pairwise_loss_rate")),
)
def _comparison(current: dict[str, Any], previous: dict[str, Any] | None) -> dict[str, Any]:
if not previous:
return {
"previous_best_run_id": None,
"aggregate_delta_vs_previous_best": None,
"critical_pass_delta_vs_previous_best": None,
"pairwise_loss_rate_delta_vs_previous_best": None,
"improved_vs_previous_best": bool(current.get("paired_eval_present")),
}
aggregate_delta = _number(current.get("candidate_aggregate")) - _number(previous.get("candidate_aggregate"))
critical_delta = _number(current.get("candidate_critical_pass_rate")) - _number(previous.get("candidate_critical_pass_rate"))
loss_delta = _number(current.get("pairwise_loss_rate")) - _number(previous.get("pairwise_loss_rate"))
return {
"previous_best_run_id": previous.get("run_id"),
"aggregate_delta_vs_previous_best": round(aggregate_delta, 6),
"critical_pass_delta_vs_previous_best": round(critical_delta, 6),
"pairwise_loss_rate_delta_vs_previous_best": round(loss_delta, 6),
"improved_vs_previous_best": _rank(current) > _rank(previous),
}
def _source_batch_acceptance(current: dict[str, Any], previous: dict[str, Any] | None, updated: bool, reason: str) -> dict[str, Any]:
comparison = _comparison(current, previous)
if not current.get("paired_eval_present"):
decision = "pending_no_paired_eval"
accepted = False
elif updated:
decision = "accepted_new_best_measured_checkpoint"
accepted = True
else:
decision = "rejected_did_not_improve_previous_best"
accepted = False
return {
"schema_version": "shft_source_batch_acceptance_v1",
"accepted_for_future_training": accepted,
"decision": decision,
"reason": reason,
"comparison": comparison,
"rule": "A downloaded source batch is accepted only when its trained adapter becomes the best measured checkpoint for this release.",
}
def update_best_run(*, run_id: str, release_id: str, output_path: Path | None = None) -> dict[str, Any]:
"""Record the best measured SHFT run for a release.
This is evidence tracking only. It does not certify or package a failed run.
"""
run_path = SHFT_WORKSPACE_ROOT / "runs" / run_id
current = _metrics(run_path)
best_dir = SHFT_WORKSPACE_ROOT / "best_runs"
best_dir.mkdir(parents=True, exist_ok=True)
path = output_path or (best_dir / f"{release_id}.json")
existing = _load_json(path) or {}
existing_best = existing.get("best_run") if isinstance(existing.get("best_run"), dict) else None
updated = False
reason = "kept_existing_best"
previous_best = existing_best
if not current["paired_eval_present"]:
reason = "current_run_has_no_paired_eval"
best = existing_best
elif existing_best is None or _rank(current) > _rank(existing_best):
best = current
updated = True
reason = "current_run_is_best_measured"
else:
best = existing_best
source_batch_acceptance = _source_batch_acceptance(current, previous_best, updated, reason)
report = {
"schema_version": "shft_best_run_tracker_v1",
"release_id": release_id,
"current_run": current,
"best_run": best,
"updated": updated,
"reason": reason,
"previous_best_comparison": source_batch_acceptance["comparison"],
"source_batch_acceptance": source_batch_acceptance,
"path": str(path),
"thresholds_path": str(CONFIG_ROOT / "thresholds" / "model_quality.yaml"),
"created_at": utc_now(),
"ok": bool(best),
}
write_json(path, report)
return report

Xet Storage Details

Size:
6.67 kB
·
Xet hash:
37b8cdc14a2e1b58db69108daa993d8aff1001fc152d1449cbcf2f9e76ddfa6c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.