linvest21's picture
download
raw
14 kB
from __future__ import annotations
import argparse
import gc
import json
import math
import os
import re
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from eval.model_quality_gate import evaluate_model_quality_gate
def utc_now() -> str:
return datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def load_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
if not line.strip():
continue
item = json.loads(line)
if not isinstance(item, dict):
raise ValueError(f"record at {path}:{line_no} must be an object")
rows.append(item)
if not rows:
raise ValueError(f"eval suite has no records: {path}")
return rows
def write_json(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(json.dumps(row, sort_keys=True) for row in rows) + "\n", encoding="utf-8")
def prompt_messages(prompt: str) -> list[dict[str, str]]:
return [
{
"role": "system",
"content": (
"You are a cautious financial analyst. Answer briefly, distinguish reported facts from inference, "
"avoid investment advice, identify the key risk or tradeoff, and show calculations when needed."
),
},
{"role": "user", "content": prompt},
]
def prompt_text(prompt: str) -> str:
return (
"You are a cautious financial analyst. Answer briefly, distinguish facts from inference, "
"avoid investment advice, and show calculations when needed.\n\n"
f"Question: {prompt}\nAnswer:"
)
def generate(model: Any, tokenizer: Any, prompt: str, *, max_new_tokens: int) -> str:
import torch
if getattr(tokenizer, "chat_template", None):
encoded = tokenizer.apply_chat_template(
prompt_messages(prompt),
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
prompt_token_count = int(encoded["input_ids"].shape[-1])
else:
encoded = tokenizer(prompt_text(prompt), return_tensors="pt").to(model.device)
prompt_token_count = int(encoded["input_ids"].shape[-1])
with torch.no_grad():
output = model.generate(
**encoded,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_tokens = output[0][prompt_token_count:]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
def score_answer(item: dict[str, Any], answer: str) -> dict[str, Any]:
text = answer.lower()
prompt = str(item.get("prompt", "")).lower()
scores = {
"no_investment_advice": not bool(re.search(r"\b(buy|sell|hold|short|go long|price target)\b", text)),
"no_unsupported_certainty": not bool(re.search(r"\b(certainly|guaranteed|will definitely|must happen)\b", text)),
"neutral_language": bool(re.search(r"\b(risk|could|may|suggests|indicates|monitor|investigate|tradeoff|pressure)\b", text)),
"fact_inference_separation": bool(re.search(r"\b(reported|fact|while|however|but|suggests|implies|inference)\b", text)),
"risk_or_tradeoff_identified": bool(re.search(r"\b(risk|tradeoff|pressure|concern|deteriorat|slow|declin|margin|cash flow|leverage|inventory)\b", text)),
}
numeric_required = str(item.get("task")) == "quantitative_qa" or bool(re.search(r"\bcalculate\b|from \$?\d|to \$?\d|%|\bratio\b", prompt))
if numeric_required:
scores["numeric_reasoning_present"] = bool(re.search(r"\d", text)) and bool(re.search(r"\b(to|from|increase|decrease|ratio|x|%)\b", text))
else:
scores["numeric_reasoning_present"] = True
critical_pass = (
scores["no_investment_advice"]
and scores["no_unsupported_certainty"]
and scores["numeric_reasoning_present"]
)
point_values = [
scores["neutral_language"],
scores["fact_inference_separation"],
scores["risk_or_tradeoff_identified"],
scores["numeric_reasoning_present"],
]
score = sum(1 for value in point_values if value) / len(point_values)
if not critical_pass:
score *= 0.5
return {
"score": round(score, 4),
"critical_pass": critical_pass,
"checks": scores,
"scoring_mode": "deterministic_heuristic_v0",
}
def aggregate(rows: list[dict[str, Any]], prefix: str) -> dict[str, Any]:
scores = [float(row[f"{prefix}_score"]["score"]) for row in rows]
critical = [bool(row[f"{prefix}_score"]["critical_pass"]) for row in rows]
task_scores: dict[str, list[float]] = {}
for row in rows:
task_scores.setdefault(str(row["task"]), []).append(float(row[f"{prefix}_score"]["score"]))
return {
"aggregate": round(sum(scores) / len(scores), 4),
"critical_pass_rate": round(sum(1 for value in critical if value) / len(critical), 4),
"sample_count": len(rows),
"task_scores": {task: round(sum(values) / len(values), 4) for task, values in sorted(task_scores.items())},
}
def paired_report(rows: list[dict[str, Any]]) -> dict[str, Any]:
baseline = aggregate(rows, "baseline")
candidate = aggregate(rows, "candidate")
deltas = [float(row["candidate_score"]["score"]) - float(row["baseline_score"]["score"]) for row in rows]
wins = sum(1 for delta in deltas if delta > 0)
ties = sum(1 for delta in deltas if delta == 0)
losses = sum(1 for delta in deltas if delta < 0)
max_loss_rate = float(os.environ.get("SHFT_MAX_PAIRWISE_LOSS_RATE", "0.05"))
min_critical_pass_rate = float(os.environ.get("SHFT_MIN_CRITICAL_PASS_RATE", "0.50"))
min_aggregate = float(os.environ.get("SHFT_MIN_CANDIDATE_AGGREGATE", "0.50"))
critical_not_regressed = candidate["critical_pass_rate"] >= baseline["critical_pass_rate"]
loss_rate_ok = losses <= len(rows) * max_loss_rate
aggregate_ok = candidate["aggregate"] >= min_aggregate and candidate["aggregate"] > baseline["aggregate"]
critical_ok = candidate["critical_pass_rate"] >= min_critical_pass_rate and critical_not_regressed
mean_delta = sum(deltas) / len(deltas)
variance = sum((delta - mean_delta) ** 2 for delta in deltas) / max(1, len(deltas) - 1)
stderr = math.sqrt(variance / len(deltas))
legacy_gate = {
"eligible_for_promotion": aggregate_ok and loss_rate_ok and critical_ok,
"rules": [
f"candidate aggregate must exceed baseline aggregate and be at least {min_aggregate:.2f}",
f"pairwise loss count must be no more than {max_loss_rate:.0%} of eval samples",
"critical pass rate must not regress",
f"candidate critical pass rate must be at least {min_critical_pass_rate:.0%}",
],
"checks": {
"aggregate_ok": aggregate_ok,
"loss_rate_ok": loss_rate_ok,
"critical_not_regressed": critical_not_regressed,
"critical_minimum_ok": candidate["critical_pass_rate"] >= min_critical_pass_rate,
},
"warnings": [
"baseline aggregate is zero; improvement percent is undefined and candidate quality must be judged by absolute thresholds"
]
if baseline["aggregate"] == 0
else [],
}
report = {
"baseline": baseline,
"candidate": candidate,
"improvement": {
"aggregate_abs": round(candidate["aggregate"] - baseline["aggregate"], 4),
"aggregate_pct": round(((candidate["aggregate"] - baseline["aggregate"]) / baseline["aggregate"] * 100), 4)
if baseline["aggregate"]
else None,
"critical_pass_rate_abs": round(candidate["critical_pass_rate"] - baseline["critical_pass_rate"], 4),
"paired_mean_delta": round(mean_delta, 4),
"paired_delta_stderr": round(stderr, 4),
"pairwise_win_rate": round(wins / len(deltas), 4),
"pairwise_tie_rate": round(ties / len(deltas), 4),
"pairwise_loss_rate": round(losses / len(deltas), 4),
"wins": wins,
"ties": ties,
"losses": losses,
},
"paired_eval_gate": legacy_gate,
}
full_gate = evaluate_model_quality_gate(paired_eval={"sample_count": len(rows), **report})
report["model_quality_gate"] = full_gate
report["promotion_gate"] = {
"eligible_for_promotion": full_gate["eligible_for_promotion"],
"quality_signal": full_gate["quality_signal"],
"rules": [
"paired model-vs-model eval must pass absolute thresholds",
"training budget must satisfy production minimums",
"corpus coverage must retain source records across train/valid/test",
"model-as-judge rubric report must pass",
"human spot-check report must approve with no critical failures",
],
"checks": full_gate["checks"],
"errors": full_gate["errors"],
"warnings": full_gate["warnings"],
"paired_eval_gate": legacy_gate,
}
return report
def main() -> int:
parser = argparse.ArgumentParser(description="Run paired Linvest21 frozen-suite evaluation on Hugging Face Jobs.")
parser.add_argument("--run-id", required=True)
parser.add_argument("--base-model-id", required=True)
parser.add_argument("--baseline-adapter-dir")
parser.add_argument("--candidate-adapter-dir", required=True)
parser.add_argument("--eval-jsonl", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--max-new-tokens", type=int, default=160)
parser.add_argument("--max-samples", type=int, default=120)
args = parser.parse_args()
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
output_dir = Path(args.output_dir)
rows = load_jsonl(Path(args.eval_jsonl))[: args.max_samples]
tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quantization = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
def load_eval_model(adapter_dir: str | None = None) -> Any:
model = AutoModelForCausalLM.from_pretrained(
args.base_model_id,
quantization_config=quantization,
device_map="auto",
trust_remote_code=False,
)
if adapter_dir:
model = PeftModel.from_pretrained(model, adapter_dir, is_trainable=False)
model.eval()
return model
progress = {
"run_id": args.run_id,
"status": "running",
"stage": "baseline_generation",
"completed_baseline": 0,
"completed_candidate": 0,
"total": len(rows),
}
write_json(output_dir / "progress.json", progress)
baseline_model = load_eval_model(args.baseline_adapter_dir)
evaluated: list[dict[str, Any]] = []
for idx, item in enumerate(rows, start=1):
baseline_answer = generate(baseline_model, tokenizer, str(item["prompt"]), max_new_tokens=args.max_new_tokens)
evaluated.append(
{
"id": item["id"],
"task": item["task"],
"prompt": item["prompt"],
"baseline_answer": baseline_answer,
"baseline_score": score_answer(item, baseline_answer),
}
)
if idx % 5 == 0 or idx == len(rows):
progress = {**progress, "completed_baseline": idx}
write_json(output_dir / "progress.json", progress)
write_jsonl(output_dir / "paired_predictions.partial.jsonl", evaluated)
del baseline_model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
candidate_model = load_eval_model(args.candidate_adapter_dir)
progress = {**progress, "stage": "candidate_generation"}
write_json(output_dir / "progress.json", progress)
for idx, (row, item) in enumerate(zip(evaluated, rows, strict=True), start=1):
candidate_answer = generate(candidate_model, tokenizer, str(item["prompt"]), max_new_tokens=args.max_new_tokens)
row["candidate_answer"] = candidate_answer
row["candidate_score"] = score_answer(item, candidate_answer)
row["delta"] = round(float(row["candidate_score"]["score"]) - float(row["baseline_score"]["score"]), 4)
if idx % 5 == 0 or idx == len(rows):
progress = {**progress, "completed_candidate": idx}
write_json(output_dir / "progress.json", progress)
write_jsonl(output_dir / "paired_predictions.partial.jsonl", evaluated)
report = {
"run_id": args.run_id,
"status": "completed",
"completed_at": utc_now(),
"base_model_id": args.base_model_id,
"baseline_adapter_dir": args.baseline_adapter_dir,
"candidate_adapter_dir": args.candidate_adapter_dir,
"eval_jsonl": args.eval_jsonl,
"scoring_mode": "deterministic_heuristic_v0",
"sample_count": len(evaluated),
**paired_report(evaluated),
}
write_jsonl(output_dir / "paired_predictions.jsonl", evaluated)
write_json(output_dir / "paired_eval_report.json", report)
write_json(output_dir / "progress.json", {**progress, "status": "completed", "stage": "completed"})
print(json.dumps(report, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())

Xet Storage Details

Size:
14 kB
·
Xet hash:
e45d69a92d6b18286d14c206fac1bab749ba7e3da8700b81efffcf14a3f12c35

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.