Mentors4EDU's picture
Upload 41 files
3f2dde4 verified
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
@dataclass(slots=True)
class TrainerConfig:
model_name: str = "OpenPeerAI/OpenPeerLLM"
fallback_model_name: str = "sshleifer/tiny-gpt2"
gates: int = 512
steps: int = 40
out_path: str = "runs/openpeer_controller.pt"
train_jsonl: str | None = None
device: str = "auto"
demo_mode: bool = False
def _build_demo_records() -> list[dict[str, str]]:
return [
{"prompt": "Question: 14 + 27 = ?\nAnswer:", "completion": " 41"},
{"prompt": "Question: 36 + 18 = ?\nAnswer:", "completion": " 54"},
{"prompt": "Question: 47 + 36 = ?\nAnswer:", "completion": " 83"},
{"prompt": "Question: 19 + 8 = ?\nAnswer:", "completion": " 27"},
]
def fit_controller(config: TrainerConfig) -> str:
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from ntkmirror import ForwardFineTuner, load_jsonl_examples
except ImportError as exc: # pragma: no cover - dependency gate
raise RuntimeError(
"ntkmirror mode requires transformers, torch, and ntkmirror to be installed"
) from exc
train_path = Path(config.train_jsonl) if config.train_jsonl else Path("runs/demo_train.jsonl")
if not train_path.exists():
train_path.parent.mkdir(parents=True, exist_ok=True)
import json
with train_path.open("w", encoding="utf-8") as handle:
for record in _build_demo_records():
handle.write(json.dumps(record) + "\n")
model_name = config.model_name
tokenizer = AutoTokenizer.from_pretrained(model_name)
try:
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
except Exception:
if not config.demo_mode:
raise
model_name = config.fallback_model_name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tuner = ForwardFineTuner(model, tokenizer, gates=config.gates)
tuner.fit(load_jsonl_examples(str(train_path)), steps=config.steps)
out_path = Path(config.out_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
tuner.save(str(out_path))
return str(out_path)