| |
| """ |
| Evaluate a student checkpoint against the frozen teacher using the same |
| single-process sharded setup and fixed eval cache as distill_sharded.py. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
|
|
| import distill_sharded as ds |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--config", required=True) |
| p.add_argument("--student", default=None, help="Optional student override path") |
| p.add_argument("--samples", type=int, default=None, help="Optional eval sample override") |
| args = p.parse_args() |
|
|
| cfg = ds.load_config(args.config) |
| if args.student: |
| cfg["model"]["student"] = args.student |
| if args.samples: |
| cfg["eval"]["samples"] = args.samples |
|
|
| student_device = torch.device(cfg["model"]["student_device"]) |
| teacher_devices = list(cfg["model"]["teacher_devices"]) |
|
|
| from transformers import AutoTokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["tokenizer"], trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| pad_id = tokenizer.pad_token_id |
|
|
| student = ds.load_student( |
| cfg["model"]["student"], |
| ds.parse_dtype(cfg["train"]["student_dtype"]), |
| grad_ckpt=False, |
| attn_impl=cfg["train"]["attn_implementation"], |
| ) |
| student.to(student_device) |
| student.eval() |
|
|
| teacher = ds.load_teacher( |
| cfg["model"]["teacher"], |
| ds.parse_dtype(cfg["train"]["teacher_dtype"]), |
| attn_impl=cfg["train"]["attn_implementation"], |
| devices=teacher_devices, |
| max_mem_gb=cfg["model"]["teacher_max_memory_gb"], |
| ) |
| teacher_input_device, _ = ds.get_teacher_devices(teacher) |
|
|
| specs = ds.build_dataset_specs(cfg["data"]) |
| if Path(cfg["eval"]["cache_path"]).exists(): |
| eval_batches = ds.build_or_load_eval_cache(cfg["eval"]["cache_path"]) |
| else: |
| eval_loader = ds.MixedStreamingLoader( |
| specs=specs, |
| tokenizer=tokenizer, |
| min_chars=cfg["data"]["min_chars"], |
| max_seq_len=cfg["data"]["max_seq_len"], |
| kl_start_pos=cfg["data"]["kl_start_pos"], |
| seed=cfg["eval"]["seed"], |
| shuffle_buffer=cfg["data"]["shuffle_buffer"], |
| ) |
| eval_batches = ds.build_or_load_eval_cache( |
| cfg["eval"]["cache_path"], |
| eval_loader, |
| cfg["eval"]["samples"], |
| ) |
| kl = ds.evaluate( |
| student, |
| teacher, |
| eval_batches, |
| pad_id, |
| cfg["data"]["kl_start_pos"], |
| cfg["train"]["kl_chunk_size"], |
| student_device, |
| teacher_input_device, |
| ) |
| print(f"{kl:.6f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|