Delta-Vector's picture
Upload eval_kl.py with huggingface_hub
93462da verified
#!/usr/bin/env python3
"""
Evaluate a student checkpoint against the frozen teacher using the same
single-process sharded setup and fixed eval cache as distill_sharded.py.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
import distill_sharded as ds
def main():
p = argparse.ArgumentParser()
p.add_argument("--config", required=True)
p.add_argument("--student", default=None, help="Optional student override path")
p.add_argument("--samples", type=int, default=None, help="Optional eval sample override")
args = p.parse_args()
cfg = ds.load_config(args.config)
if args.student:
cfg["model"]["student"] = args.student
if args.samples:
cfg["eval"]["samples"] = args.samples
student_device = torch.device(cfg["model"]["student_device"])
teacher_devices = list(cfg["model"]["teacher_devices"])
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
pad_id = tokenizer.pad_token_id
student = ds.load_student(
cfg["model"]["student"],
ds.parse_dtype(cfg["train"]["student_dtype"]),
grad_ckpt=False,
attn_impl=cfg["train"]["attn_implementation"],
)
student.to(student_device)
student.eval()
teacher = ds.load_teacher(
cfg["model"]["teacher"],
ds.parse_dtype(cfg["train"]["teacher_dtype"]),
attn_impl=cfg["train"]["attn_implementation"],
devices=teacher_devices,
max_mem_gb=cfg["model"]["teacher_max_memory_gb"],
)
teacher_input_device, _ = ds.get_teacher_devices(teacher)
specs = ds.build_dataset_specs(cfg["data"])
if Path(cfg["eval"]["cache_path"]).exists():
eval_batches = ds.build_or_load_eval_cache(cfg["eval"]["cache_path"])
else:
eval_loader = ds.MixedStreamingLoader(
specs=specs,
tokenizer=tokenizer,
min_chars=cfg["data"]["min_chars"],
max_seq_len=cfg["data"]["max_seq_len"],
kl_start_pos=cfg["data"]["kl_start_pos"],
seed=cfg["eval"]["seed"],
shuffle_buffer=cfg["data"]["shuffle_buffer"],
)
eval_batches = ds.build_or_load_eval_cache(
cfg["eval"]["cache_path"],
eval_loader,
cfg["eval"]["samples"],
)
kl = ds.evaluate(
student,
teacher,
eval_batches,
pad_id,
cfg["data"]["kl_start_pos"],
cfg["train"]["kl_chunk_size"],
student_device,
teacher_input_device,
)
print(f"{kl:.6f}")
if __name__ == "__main__":
main()