#!/usr/bin/env python3 """Train Bee LoRA adapters on real instruction data. Loads pretrained model + instruction datasets, trains LoRA adapters, saves checkpoint, optionally evaluates before/after. Usage (MacBook, slow): python scripts/train_lora.py --data ./datasets/train_mixed.jsonl --steps 100 --device mps Usage (GPU cloud): python scripts/train_lora.py --data ./datasets/train_mixed.jsonl --steps 1000 --batch_size 4 --device cuda """ import argparse import json import logging import os import sys import time from pathlib import Path import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.lora_adapter import DomainLoRAManager, LoRAConfig from bee.model_profiles import DEFAULT_MODEL_PROFILE, resolve_model_id logger = logging.getLogger("bee.train") class InstructionDataset(Dataset): """Simple instruction-following dataset from JSONL.""" def __init__(self, data_path: str, tokenizer, max_length: int = 512): self.samples = [] self.tokenizer = tokenizer self.max_length = max_length with open(data_path) as f: for line in f: ex = json.loads(line) instruction = ex.get("instruction", "") input_text = ex.get("input", "") output = ex.get("output", "") # Use chat template if available if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: user_msg = instruction if input_text: user_msg += f"\n\n{input_text}" chat = [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": output}, ] text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) else: text = f"### Instruction:\n{instruction}\n### Input:\n{input_text}\n### Response:\n{output}" self.samples.append(text) logger.info("Loaded %d instruction samples from %s", len(self.samples), data_path) def __len__(self): return len(self.samples) def __getitem__(self, idx): text = self.samples[idx] encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) input_ids = encoding["input_ids"].squeeze(0) attention_mask = encoding["attention_mask"].squeeze(0) # Labels = input_ids for causal LM (shifted internally) labels = input_ids.clone() labels[attention_mask == 0] = -100 return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} def train( data_path: str, model_path: str = DEFAULT_MODEL_PROFILE, device: str = "mps", lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, steps: int = 100, batch_size: int = 1, learning_rate: float = 5e-4, warmup_steps: int = 10, max_length: int = 512, save_path: str = "./lora_checkpoints/general", eval_before: bool = True, ): model_path = resolve_model_id(model_path) # Load model logger.info("Loading model: %s", model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Use float32 for training (float16 causes NaN on MPS with LoRA) model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, ).to(device) # Setup LoRA lora_cfg = LoRAConfig(r=lora_r, alpha=lora_alpha, dropout=lora_dropout) manager = DomainLoRAManager(model, lora_cfg) manager.add_adapter("general") manager.activate_domain("general") logger.info("LoRA adapters: %d trainable params", manager.count_adapter_params("general")) # Load data if not os.path.exists(data_path): logger.error("Dataset not found: %s", data_path) logger.info("Run: python scripts/download_datasets.py") return dataset = InstructionDataset(data_path, tokenizer, max_length) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Optimizer: only LoRA params lora_params = [] for name, module in model.named_modules(): if hasattr(module, "lora_A") and hasattr(module, "lora_B"): lora_params.extend([module.lora_A, module.lora_B]) optimizer = torch.optim.AdamW(lora_params, lr=learning_rate) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=steps ) # Training loop logger.info("Starting training: %d steps, batch_size=%d, lr=%.1e", steps, batch_size, learning_rate) model.train() global_step = 0 epoch = 0 losses = [] while global_step < steps: epoch += 1 for batch in loader: if global_step >= steps: break input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() torch.nn.utils.clip_grad_norm_(lora_params, 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() losses.append(loss.item()) global_step += 1 if global_step % 10 == 0: avg_loss = sum(losses[-10:]) / min(10, len(losses)) logger.info("Step %d/%d | loss=%.4f | lr=%.2e", global_step, steps, avg_loss, scheduler.get_last_lr()[0]) # Save os.makedirs(save_path, exist_ok=True) manager.save_adapter("general", save_path) logger.info("Checkpoint saved: %s", save_path) # Save adapter metadata meta = { "base_model": model_path, "lora_r": lora_r, "lora_alpha": lora_alpha, "steps": steps, "final_loss": sum(losses[-10:]) / min(10, len(losses)), "trainable_params": manager.count_adapter_params("general"), } with open(os.path.join(save_path, "bee_legacy_adapter_config.json"), "w") as f: json.dump(meta, f, indent=2) return model, tokenizer, manager def main(): parser = argparse.ArgumentParser(description="Train Bee LoRA on real instruction data") parser.add_argument("--data", default="./datasets/train_mixed.jsonl", help="Path to instruction JSONL") parser.add_argument("--model", default=DEFAULT_MODEL_PROFILE, help="Model profile, local path, or HF ID") parser.add_argument("--device", default="mps" if torch.backends.mps.is_available() else "cpu") parser.add_argument("--lora_r", type=int, default=16) parser.add_argument("--lora_alpha", type=int, default=32) parser.add_argument("--steps", type=int, default=100) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--save_path", default="./lora_checkpoints/general") args = parser.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) train( data_path=args.data, model_path=args.model, device=args.device, lora_r=args.lora_r, lora_alpha=args.lora_alpha, steps=args.steps, batch_size=args.batch_size, learning_rate=args.lr, save_path=args.save_path, ) if __name__ == "__main__": main()