| | """ |
| | Training Monitor - Check progress and evaluate completed models. |
| | """ |
| |
|
| | import asyncio |
| | import json |
| | import os |
| | from datetime import datetime |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| |
|
| | import tinker |
| | from tinker import types |
| | from tinker_cookbook import renderers |
| | from tinker_cookbook.tokenizer_utils import get_tokenizer |
| | import numpy as np |
| |
|
| | BASE_MODEL = "meta-llama/Llama-3.1-8B" |
| |
|
| | VALID_CATEGORIES = { |
| | "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
| | "company.business_priorities", "company.tools_config", "company.performance_context", |
| | "user.communication_style", "user.strategic_approach", "user.role_context", |
| | "user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
| | "none" |
| | } |
| |
|
| |
|
| | def list_training_runs(): |
| | """List all training runs and their checkpoints.""" |
| | service_client = tinker.ServiceClient() |
| | rest_client = service_client.create_rest_client() |
| | |
| | runs = rest_client.list_training_runs().result() |
| | |
| | print("=" * 70) |
| | print("TRAINING RUNS") |
| | print("=" * 70) |
| | |
| | for run in runs.training_runs[:10]: |
| | ckpts = rest_client.list_checkpoints(run.training_run_id).result() |
| | |
| | |
| | sft_ckpts = [c for c in ckpts.checkpoints if 'sft' in c.checkpoint_id] |
| | rl_ckpts = [c for c in ckpts.checkpoints if 'rl_' in c.checkpoint_id] |
| | |
| | print(f"\nRun: {run.training_run_id}") |
| | print(f" Last request: {run.last_request_time}") |
| | print(f" SFT checkpoints: {len(sft_ckpts)}") |
| | print(f" RL checkpoints: {len(rl_ckpts)}") |
| | |
| | if rl_ckpts: |
| | |
| | latest = sorted(rl_ckpts, key=lambda x: x.time)[-1] |
| | print(f" Latest RL: {latest.checkpoint_id}") |
| | |
| | |
| | if 'final' in latest.checkpoint_id: |
| | print(f" STATUS: RL COMPLETE") |
| | print(f" Final checkpoint: tinker://{run.training_run_id}/{latest.checkpoint_id}") |
| |
|
| |
|
| | async def quick_eval(checkpoint_path: str, n_samples: int = 20): |
| | """Quick evaluation of a checkpoint.""" |
| | service_client = tinker.ServiceClient() |
| | tokenizer = get_tokenizer(BASE_MODEL) |
| | renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
| | |
| | |
| | with open("training/processed_data/test_data.json", "r") as f: |
| | test_data = json.load(f) |
| | |
| | print(f"\nEvaluating: {checkpoint_path}") |
| | print(f"Samples: {n_samples}") |
| | |
| | sampling_client = service_client.create_sampling_client(model_path=checkpoint_path) |
| | stop_sequences = renderer.get_stop_sequences() |
| | |
| | correct = 0 |
| | total = 0 |
| | |
| | for example in test_data[:n_samples]: |
| | gold = example.get("categories", []) |
| | messages = example.get("messages", []) |
| | prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
| | |
| | if not prompt_messages: |
| | continue |
| | |
| | prompt = renderer.build_generation_prompt(prompt_messages) |
| | params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences) |
| | |
| | result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
| | response, success = renderer.parse_response(result.sequences[0].tokens) |
| | predicted_text = response["content"] if success else "" |
| | |
| | predicted_set = set([c.strip().lower() for c in predicted_text.split(",") |
| | if c.strip().lower() in VALID_CATEGORIES]) |
| | gold_set = set([c.lower() for c in gold]) |
| | |
| | if predicted_set & gold_set: |
| | correct += 1 |
| | total += 1 |
| | |
| | accuracy = correct / total if total > 0 else 0 |
| | print(f"Any Match Accuracy: {accuracy:.1%} ({correct}/{total})") |
| | |
| | return accuracy |
| |
|
| |
|
| | def find_best_checkpoint(): |
| | """Find the best completed RL checkpoint.""" |
| | service_client = tinker.ServiceClient() |
| | rest_client = service_client.create_rest_client() |
| | |
| | runs = rest_client.list_training_runs().result() |
| | |
| | best_rl_checkpoint = None |
| | best_sft_checkpoint = None |
| | |
| | for run in runs.training_runs: |
| | ckpts = rest_client.list_checkpoints(run.training_run_id).result() |
| | |
| | for ckpt in ckpts.checkpoints: |
| | if 'rl_final' in ckpt.checkpoint_id: |
| | path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}" |
| | if best_rl_checkpoint is None or ckpt.time > best_rl_checkpoint[1]: |
| | best_rl_checkpoint = (path, ckpt.time) |
| | |
| | if 'sft_final_sampler' in ckpt.checkpoint_id: |
| | path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}" |
| | if best_sft_checkpoint is None or ckpt.time > best_sft_checkpoint[1]: |
| | best_sft_checkpoint = (path, ckpt.time) |
| | |
| | return best_sft_checkpoint, best_rl_checkpoint |
| |
|
| |
|
| | async def main(): |
| | import sys |
| | |
| | if len(sys.argv) > 1 and sys.argv[1] == "eval": |
| | |
| | sft_ckpt, rl_ckpt = find_best_checkpoint() |
| | |
| | print("=" * 70) |
| | print("CHECKPOINT EVALUATION") |
| | print("=" * 70) |
| | |
| | if sft_ckpt: |
| | print(f"\nBest SFT: {sft_ckpt[0]}") |
| | await quick_eval(sft_ckpt[0], n_samples=50) |
| | |
| | if rl_ckpt: |
| | print(f"\nBest RL: {rl_ckpt[0]}") |
| | await quick_eval(rl_ckpt[0], n_samples=50) |
| | else: |
| | |
| | list_training_runs() |
| | |
| | sft_ckpt, rl_ckpt = find_best_checkpoint() |
| | |
| | print("\n" + "=" * 70) |
| | print("BEST CHECKPOINTS") |
| | print("=" * 70) |
| | |
| | if sft_ckpt: |
| | print(f"\nSFT: {sft_ckpt[0]}") |
| | print(f" Time: {sft_ckpt[1]}") |
| | |
| | if rl_ckpt: |
| | print(f"\nRL: {rl_ckpt[0]}") |
| | print(f" Time: {rl_ckpt[1]}") |
| | |
| | print("\nTo evaluate, run: python training/monitor.py eval") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | asyncio.run(main()) |
| |
|
| |
|