bbkdevops's picture
download
raw
6.97 kB
from __future__ import annotations
from datetime import datetime, timezone
import json
from pathlib import Path
import re
import time
from typing import Any
import torch
from model.native_axiom_regenesis import TinyMindAxiomReGenesis, config_from_dict
from runtime.constrained_decode import apply_constrained_repair
from train.native_axiom_regenesis_train import _decode, _encode, _retrieved_from_ids
TASKS: list[dict[str, Any]] = [
{"axis": "thai_language", "prompt": "อธิบายความต่างระหว่างข้อมูลจริง หลักฐาน และข้อสรุป เป็นภาษาไทยธรรมชาติ พร้อมตัวอย่างสั้น ๆ", "must": ["ข้อมูลจริง", "หลักฐาน", "ข้อสรุป"]},
{"axis": "english_reasoning", "prompt": "Explain why a low eval loss can still be misleading when the evaluation set is contaminated.", "must": ["contamination", "holdout", "generalization"]},
{"axis": "math_bound", "prompt": "Prove briefly that x_t = a x_{t-1} + b_t is bounded when |a|<1 and |b_t|<=B.", "must": ["geometric", "B", "1-|a|"]},
{"axis": "math_probability", "prompt": "A classifier has precision 0.8 and recall 0.5. Define precision and recall, then compute F1.", "must": ["precision", "recall", "0.615"]},
{"axis": "code_python", "prompt": "Write a small Python function validate_tool_call(obj) that returns True only if obj has string name and dict arguments.", "must": ["def validate_tool_call", "name", "arguments", "dict"]},
{"axis": "raw_code_bits", "prompt": "Explain signed 6-bit sign extension from a packed byte lane. Include the mask, sign bit, min, and max.", "must": ["0x3f", "0x20", "-32", "31"]},
{"axis": "systems_ffi", "prompt": "Give three rules for safe Rust/C FFI ABI compatibility.", "must": ["extern", "repr(C)", "ownership"]},
{"axis": "grounding", "prompt": "You have one source saying A and another saying not-A. What should an evidence-grounded assistant do before answering?", "must": ["compare", "source", "uncertainty"]},
{"axis": "tool_json", "prompt": 'Return only JSON for a tool call named "search_web" with arguments query="Thai AI benchmark" and k=3.', "must": ['"name"', '"search_web"', '"arguments"', '"k"'], "json_only": True},
{"axis": "translation_th_en", "prompt": "Translate to English: การวัดผลที่ดีต้องแยกข้อมูลฝึกออกจากข้อมูลทดสอบอย่างเด็ดขาด", "must": ["training", "test", "separate"]},
{"axis": "long_answer_control", "prompt": "ตอบเป็นภาษาไทย 4 ข้อเท่านั้น: วิธีลด hallucination ในระบบ RAG", "must": ["1", "2", "3", "4"]},
{"axis": "self_critique", "prompt": "Give a concise answer, then add one sentence explaining what evidence would falsify your answer.", "must": ["falsify", "evidence"]},
]
def _repeated_ngrams(text: str) -> bool:
words = re.findall(r"\w+", text.lower(), flags=re.UNICODE)
grams = [" ".join(words[i : i + 5]) for i in range(max(0, len(words) - 4))]
return len(grams) != len(set(grams))
def score(task: dict[str, Any], response: str) -> tuple[int, list[str]]:
flags: list[str] = []
lower = response.lower()
missing = [term for term in task["must"] if term.lower() not in lower]
if missing:
flags.append("missing:" + ",".join(missing))
if _repeated_ngrams(response):
flags.append("repetition")
if len(response.strip()) < 40:
flags.append("too_short")
if task.get("json_only"):
try:
parsed = json.loads(response)
if not isinstance(parsed, dict):
flags.append("json_not_object")
except json.JSONDecodeError:
flags.append("invalid_json")
points = 4 - len(flags)
if any(flag.startswith("missing:") for flag in flags):
points -= 1
return max(points, 0), flags
def _load_native(checkpoint: str | Path, device: torch.device) -> TinyMindAxiomReGenesis:
payload = torch.load(checkpoint, map_location=device)
cfg = config_from_dict(payload["config"])
model = TinyMindAxiomReGenesis(cfg).to(device)
model.load_state_dict(payload["state_dict"])
model.eval()
return model
@torch.no_grad()
def _generate(model: TinyMindAxiomReGenesis, prompt: str, device: torch.device, max_new_tokens: int) -> str:
text = f"SYSTEM: Answer exactly. Follow formatting constraints. Avoid repetition.\nUSER: {prompt}\nASSISTANT:"
ids = _encode(text, min(model.cfg.max_seq_len, 192), model.cfg.vocab_size, model.cfg.tokenizer_mode).unsqueeze(0).to(device)
retrieved = _retrieved_from_ids(ids, model.cfg.regen_top_k)
out = model.generate(ids, max_new_tokens=max_new_tokens, retrieved_tokens=retrieved)
return _decode(out[0, ids.shape[1] :], model.cfg.vocab_size, model.cfg.tokenizer_mode).strip()
def run_native_broad_probe(
out_dir: str | Path,
*,
native_checkpoint: str | Path,
controlled_repair: bool = False,
max_new_tokens: int = 160,
device: str | None = None,
) -> dict[str, Any]:
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
run_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
model = _load_native(native_checkpoint, run_device)
started = time.time()
samples = []
total = 0
for task in TASKS:
raw = _generate(model, task["prompt"], run_device, max_new_tokens)
response, events = apply_constrained_repair(task, raw) if controlled_repair else (raw, [])
points, flags = score(task, response)
total += points
sample = {"axis": task["axis"], "score": points, "flags": flags, "response": response}
if events:
sample["raw_response"] = raw
sample["constraint_events"] = events
samples.append(sample)
report = {
"schema": "tinymind.native_broad_probe.v1",
"created_at": datetime.now(timezone.utc).isoformat(),
"elapsed_s": time.time() - started,
"native_checkpoint": str(native_checkpoint),
"controlled_repair_enabled": controlled_repair,
"score": total,
"max_score": len(TASKS) * 4,
"percent": 100.0 * total / max(1, len(TASKS) * 4),
"samples": samples,
"claim_gate": {
"broad_probe_complete": True,
"raw_model_capability_claim": not controlled_repair,
"world_best_claim_allowed": False,
"reason": "Local broad probe. Controlled repair is runtime-system capability, not raw model capability.",
},
}
path = out / "native_broad_probe_report.json"
report["json_path"] = str(path)
path.write_text(json.dumps(report, ensure_ascii=False, indent=2, sort_keys=True) + "\n", encoding="utf-8")
return report

Xet Storage Details

Size:
6.97 kB
·
Xet hash:
fdf53f66f8b007236c1ad75bae7b9eea9f0a3ef80789197278e90acf691c48c2

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.