| """ |
| Evaluate the SFT model and run RL continuation. |
| |
| This script: |
| 1. Evaluates the SFT checkpoint from our full_pipeline run |
| 2. Continues RL training from the SFT state checkpoint |
| 3. Evaluates the final RL model |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| import time |
| import numpy as np |
| from collections import Counter |
| from datetime import datetime |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| import tinker |
| from tinker import types |
| from tinker_cookbook import renderers |
| from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
| |
| BASE_MODEL = "meta-llama/Llama-3.1-8B" |
| LORA_RANK = 32 |
|
|
| |
| SFT_STATE_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/weights/sft_final" |
| SFT_SAMPLER_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/sft_final_sampler" |
|
|
| |
| RL_ITERATIONS = 10 |
| RL_BATCH_SIZE = 16 |
| RL_GROUP_SIZE = 4 |
| RL_LR = 1e-5 |
| RL_TEMPERATURE = 0.7 |
|
|
| |
| TRAIN_DATA_PATH = "training/processed_data/train_data.json" |
| TEST_DATA_PATH = "training/processed_data/test_data.json" |
|
|
| |
| VALID_CATEGORIES = { |
| "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
| "company.business_priorities", "company.tools_config", "company.performance_context", |
| "user.communication_style", "user.strategic_approach", "user.role_context", |
| "user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
| "none" |
| } |
|
|
| CATEGORY_PERSISTENCE = { |
| "company.brand_core": "long", "company.strategic_signatures": "long", |
| "company.knowledge_artifacts": "long", "company.business_priorities": "short", |
| "company.tools_config": "medium", "company.performance_context": "rolling", |
| "user.communication_style": "long", "user.strategic_approach": "long", |
| "user.role_context": "medium", "user.workflow_patterns": "medium", |
| "user.session_history": "short", "user.interaction_preferences": "evolving", |
| "none": "short" |
| } |
|
|
|
|
| def compute_reward(predicted_text: str, gold_categories: list) -> tuple: |
| """Compute reward with detailed breakdown.""" |
| info = {"format_valid": True, "r_f1": 0, "r_temp": 0, "r_parity": 0, "r_eff": 0} |
| |
| if not predicted_text or not predicted_text.strip(): |
| info["format_valid"] = False |
| return -1.0, info |
| |
| predicted = set([c.strip().lower() for c in predicted_text.split(",") |
| if c.strip().lower() in VALID_CATEGORIES]) |
| |
| if not predicted: |
| info["format_valid"] = False |
| return -1.0, info |
| |
| if "none" in predicted and len(predicted) > 1: |
| predicted.discard("none") |
| |
| gold = set([c.lower() for c in gold_categories]) |
| |
| |
| if predicted and gold: |
| tp = len(predicted & gold) |
| precision = tp / len(predicted) |
| recall = tp / len(gold) |
| info["r_f1"] = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
| elif not predicted and not gold: |
| info["r_f1"] = 1.0 |
| |
| |
| def majority_persistence(cats): |
| if not cats: |
| return "medium" |
| persis = [CATEGORY_PERSISTENCE.get(c, "medium") for c in cats] |
| return Counter(persis).most_common(1)[0][0] |
| |
| if majority_persistence(predicted) == majority_persistence(gold): |
| info["r_temp"] = 1.0 |
| |
| |
| def get_scope(cats): |
| scopes = set() |
| for c in cats: |
| if c.startswith("company."): |
| scopes.add("company") |
| elif c.startswith("user."): |
| scopes.add("user") |
| if len(scopes) == 2: |
| return "mixed" |
| return scopes.pop() if scopes else "none" |
| |
| if get_scope(predicted) == get_scope(gold): |
| info["r_parity"] = 1.0 |
| |
| |
| n = len(predicted) |
| info["r_eff"] = 1.0 if n <= 3 else (0.7 if n == 4 else 0.4) |
| |
| r_total = 0.6 * info["r_f1"] + 0.2 * info["r_temp"] + 0.1 * info["r_parity"] + 0.1 * info["r_eff"] |
| return r_total, info |
|
|
|
|
| async def evaluate_model(service_client, checkpoint, tokenizer, renderer, test_data, name, n_samples=100): |
| """Evaluate a model checkpoint.""" |
| print(f"\n{'='*60}") |
| print(f"EVALUATING: {name}") |
| print(f"{'='*60}") |
| |
| sampling_client = service_client.create_sampling_client(model_path=checkpoint) |
| stop_sequences = renderer.get_stop_sequences() |
| |
| results = [] |
| |
| for i, example in enumerate(test_data[:n_samples]): |
| gold = example.get("categories", []) |
| messages = example.get("messages", []) |
| prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
| |
| if not prompt_messages: |
| continue |
| |
| prompt = renderer.build_generation_prompt(prompt_messages) |
| params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences) |
| |
| result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
| response, success = renderer.parse_response(result.sequences[0].tokens) |
| predicted_text = response["content"] if success else "" |
| |
| predicted_set = set([c.strip().lower() for c in predicted_text.split(",") |
| if c.strip().lower() in VALID_CATEGORIES]) |
| gold_set = set([c.lower() for c in gold]) |
| |
| reward, info = compute_reward(predicted_text, gold) |
| |
| results.append({ |
| "any_match": len(predicted_set & gold_set) > 0, |
| "exact_match": predicted_set == gold_set, |
| "precision": len(predicted_set & gold_set) / len(predicted_set) if predicted_set else 0, |
| "recall": len(predicted_set & gold_set) / len(gold_set) if gold_set else 0, |
| "reward": reward, |
| "format_valid": info["format_valid"] |
| }) |
| |
| if (i + 1) % 25 == 0: |
| any_match = np.mean([r["any_match"] for r in results]) |
| print(f" Progress: {i+1}/{n_samples}, Any Match: {any_match:.1%}") |
| |
| metrics = { |
| "any_match": np.mean([r["any_match"] for r in results]), |
| "exact_match": np.mean([r["exact_match"] for r in results]), |
| "precision": np.mean([r["precision"] for r in results]), |
| "recall": np.mean([r["recall"] for r in results]), |
| "mean_reward": np.mean([r["reward"] for r in results]), |
| "format_valid": np.mean([r["format_valid"] for r in results]), |
| } |
| metrics["f1"] = 2 * metrics["precision"] * metrics["recall"] / (metrics["precision"] + metrics["recall"]) if (metrics["precision"] + metrics["recall"]) > 0 else 0 |
| |
| print(f"\nResults for {name}:") |
| print(f" Any Match: {metrics['any_match']:.1%}") |
| print(f" Exact Match: {metrics['exact_match']:.1%}") |
| print(f" F1 Score: {metrics['f1']:.1%}") |
| print(f" Mean Reward: {metrics['mean_reward']:.3f}") |
| |
| return metrics |
|
|
|
|
| async def run_rl_phase(service_client, training_client, tokenizer, renderer, train_data): |
| """Run RL training phase.""" |
| print(f"\n{'='*60}") |
| print("PHASE 2: REINFORCEMENT LEARNING") |
| print(f"{'='*60}") |
| |
| print(f"Loading SFT state from: {SFT_STATE_CHECKPOINT}") |
| await training_client.load_state_async(SFT_STATE_CHECKPOINT) |
| print("SFT weights loaded successfully!") |
| |
| stop_sequences = renderer.get_stop_sequences() |
| metrics_log = [] |
| |
| for iteration in range(RL_ITERATIONS): |
| iter_start = time.time() |
| print(f"\n--- RL Iteration {iteration + 1}/{RL_ITERATIONS} ---") |
| |
| |
| save_future = await training_client.save_weights_for_sampler_async( |
| name=f"rl_iter_{iteration:03d}" |
| ) |
| save_result = await save_future.result_async() |
| sampling_client = service_client.create_sampling_client(model_path=save_result.path) |
| |
| |
| batch_indices = np.random.choice(len(train_data), size=RL_BATCH_SIZE, replace=False) |
| |
| all_rollouts = [] |
| all_rewards = [] |
| |
| for idx in batch_indices: |
| example = train_data[idx] |
| gold_categories = example.get("categories", []) |
| messages = example.get("messages", []) |
| prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
| |
| if not prompt_messages: |
| continue |
| |
| prompt = renderer.build_generation_prompt(prompt_messages) |
| params = types.SamplingParams( |
| max_tokens=50, temperature=RL_TEMPERATURE, stop=stop_sequences |
| ) |
| |
| result = sampling_client.sample( |
| prompt=prompt, sampling_params=params, num_samples=RL_GROUP_SIZE |
| ).result() |
| |
| for seq in result.sequences: |
| response, success = renderer.parse_response(seq.tokens) |
| predicted = response["content"] if success else "" |
| reward, _ = compute_reward(predicted, gold_categories) |
| |
| all_rollouts.append({ |
| "prompt": prompt, |
| "tokens": seq.tokens, |
| "logprobs": seq.logprobs or [], |
| "predicted": predicted, |
| "gold": gold_categories |
| }) |
| all_rewards.append(reward) |
| |
| |
| rewards_arr = np.array(all_rewards) |
| mean_reward = rewards_arr.mean() |
| std_reward = rewards_arr.std() + 1e-8 |
| advantages = (rewards_arr - mean_reward) / std_reward |
| |
| |
| training_data = [] |
| for i, rollout in enumerate(all_rollouts): |
| if not rollout["logprobs"]: |
| continue |
| |
| prompt_tokens = rollout["prompt"].to_ints() |
| gen_tokens = rollout["tokens"] |
| logprobs = rollout["logprobs"] |
| adv = advantages[i] |
| |
| n_prompt = len(prompt_tokens) - 1 |
| n_gen = len(gen_tokens) |
| |
| if len(logprobs) != n_gen: |
| continue |
| |
| full_input = prompt_tokens + gen_tokens[:-1] if n_gen > 1 else prompt_tokens |
| full_target = prompt_tokens[1:] + gen_tokens |
| full_logprobs = [0.0] * n_prompt + logprobs |
| full_advantages = [0.0] * n_prompt + [adv] * n_gen |
| |
| if len(full_target) != len(full_input) or len(full_logprobs) != len(full_input): |
| continue |
| |
| training_data.append(types.Datum( |
| model_input=types.ModelInput.from_ints(full_input), |
| loss_fn_inputs=dict( |
| target_tokens=full_target, |
| logprobs=full_logprobs, |
| advantages=full_advantages |
| ) |
| )) |
| |
| |
| if training_data: |
| fwd_future = await training_client.forward_backward_async( |
| training_data, loss_fn="importance_sampling" |
| ) |
| optim_future = await training_client.optim_step_async( |
| types.AdamParams(learning_rate=RL_LR, beta1=0.9, beta2=0.95, eps=1e-8) |
| ) |
| await fwd_future.result_async() |
| await optim_future.result_async() |
| |
| iter_time = time.time() - iter_start |
| accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0 |
| |
| metrics = { |
| "iteration": iteration, |
| "mean_reward": float(mean_reward), |
| "accuracy": accuracy, |
| "num_rollouts": len(all_rollouts), |
| "time": iter_time |
| } |
| metrics_log.append(metrics) |
| |
| print(f" Reward: {mean_reward:.3f}, Accuracy: {accuracy:.1%}, Time: {iter_time:.1f}s") |
| |
| |
| print("\nSaving final RL checkpoint...") |
| final_future = await training_client.save_weights_for_sampler_async(name="rl_final") |
| final_result = await final_future.result_async() |
| rl_checkpoint = final_result.path |
| |
| print(f"RL checkpoint: {rl_checkpoint}") |
| |
| return rl_checkpoint, metrics_log |
|
|
|
|
| async def main(): |
| print("=" * 70) |
| print("MEMORY ROUTING AGENT - EVALUATION & RL CONTINUATION") |
| print("=" * 70) |
| print(f"Timestamp: {datetime.now()}") |
| print(f"Base Model: {BASE_MODEL}") |
| print(f"SFT State Checkpoint: {SFT_STATE_CHECKPOINT}") |
| |
| |
| service_client = tinker.ServiceClient() |
| tokenizer = get_tokenizer(BASE_MODEL) |
| renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
| |
| |
| with open(TRAIN_DATA_PATH, "r") as f: |
| train_data = json.load(f) |
| with open(TEST_DATA_PATH, "r") as f: |
| test_data = json.load(f) |
| |
| print(f"Train: {len(train_data)}, Test: {len(test_data)}") |
| |
| |
| sft_metrics = await evaluate_model( |
| service_client, SFT_SAMPLER_CHECKPOINT, tokenizer, renderer, test_data, "SFT Model", n_samples=100 |
| ) |
| |
| |
| training_client = await service_client.create_lora_training_client_async( |
| base_model=BASE_MODEL, |
| rank=LORA_RANK, |
| ) |
| |
| |
| rl_checkpoint, rl_metrics = await run_rl_phase( |
| service_client, training_client, tokenizer, renderer, train_data |
| ) |
| |
| |
| rl_eval_metrics = await evaluate_model( |
| service_client, rl_checkpoint, tokenizer, renderer, test_data, "RL Model", n_samples=100 |
| ) |
| |
| |
| print("\n" + "=" * 70) |
| print("TRAINING COMPLETE - SUMMARY") |
| print("=" * 70) |
| print(f"\nSFT Model:") |
| print(f" Checkpoint: {SFT_SAMPLER_CHECKPOINT}") |
| print(f" Any Match: {sft_metrics['any_match']:.1%}") |
| print(f" F1 Score: {sft_metrics['f1']:.1%}") |
| |
| print(f"\nRL Model:") |
| print(f" Checkpoint: {rl_checkpoint}") |
| print(f" Any Match: {rl_eval_metrics['any_match']:.1%}") |
| print(f" F1 Score: {rl_eval_metrics['f1']:.1%}") |
| |
| improvement = rl_eval_metrics['any_match'] - sft_metrics['any_match'] |
| print(f"\nImprovement: {improvement:+.1%}") |
| |
| |
| results = { |
| "sft_checkpoint": SFT_SAMPLER_CHECKPOINT, |
| "rl_checkpoint": rl_checkpoint, |
| "sft_metrics": sft_metrics, |
| "rl_metrics": rl_eval_metrics, |
| "rl_training_log": rl_metrics |
| } |
| |
| results_path = f"training/experiments/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| os.makedirs(os.path.dirname(results_path), exist_ok=True) |
| with open(results_path, "w") as f: |
| json.dump(results, f, indent=2, default=str) |
| |
| print(f"\nResults saved to: {results_path}") |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|
|
|