| |
| """Train Bee LoRA adapters on real instruction data. |
| |
| Loads pretrained model + instruction datasets, trains LoRA adapters, |
| saves checkpoint, optionally evaluates before/after. |
| |
| Usage (MacBook, slow): |
| python scripts/train_lora.py --data ./datasets/train_mixed.jsonl --steps 100 --device mps |
| |
| Usage (GPU cloud): |
| python scripts/train_lora.py --data ./datasets/train_mixed.jsonl --steps 1000 --batch_size 4 --device cuda |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.lora_adapter import DomainLoRAManager, LoRAConfig |
| from bee.model_profiles import DEFAULT_MODEL_PROFILE, resolve_model_id |
|
|
| logger = logging.getLogger("bee.train") |
|
|
|
|
| class InstructionDataset(Dataset): |
| """Simple instruction-following dataset from JSONL.""" |
|
|
| def __init__(self, data_path: str, tokenizer, max_length: int = 512): |
| self.samples = [] |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| with open(data_path) as f: |
| for line in f: |
| ex = json.loads(line) |
| instruction = ex.get("instruction", "") |
| input_text = ex.get("input", "") |
| output = ex.get("output", "") |
|
|
| |
| if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: |
| user_msg = instruction |
| if input_text: |
| user_msg += f"\n\n{input_text}" |
| chat = [ |
| {"role": "user", "content": user_msg}, |
| {"role": "assistant", "content": output}, |
| ] |
| text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) |
| else: |
| text = f"### Instruction:\n{instruction}\n### Input:\n{input_text}\n### Response:\n{output}" |
|
|
| self.samples.append(text) |
|
|
| logger.info("Loaded %d instruction samples from %s", len(self.samples), data_path) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| text = self.samples[idx] |
| encoding = self.tokenizer( |
| text, |
| truncation=True, |
| max_length=self.max_length, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
| input_ids = encoding["input_ids"].squeeze(0) |
| attention_mask = encoding["attention_mask"].squeeze(0) |
| |
| labels = input_ids.clone() |
| labels[attention_mask == 0] = -100 |
| return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
|
|
|
|
| def train( |
| data_path: str, |
| model_path: str = DEFAULT_MODEL_PROFILE, |
| device: str = "mps", |
| lora_r: int = 16, |
| lora_alpha: int = 32, |
| lora_dropout: float = 0.05, |
| steps: int = 100, |
| batch_size: int = 1, |
| learning_rate: float = 5e-4, |
| warmup_steps: int = 10, |
| max_length: int = 512, |
| save_path: str = "./lora_checkpoints/general", |
| eval_before: bool = True, |
| ): |
| model_path = resolve_model_id(model_path) |
|
|
| |
| logger.info("Loading model: %s", model_path) |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| ).to(device) |
|
|
| |
| lora_cfg = LoRAConfig(r=lora_r, alpha=lora_alpha, dropout=lora_dropout) |
| manager = DomainLoRAManager(model, lora_cfg) |
| manager.add_adapter("general") |
| manager.activate_domain("general") |
| logger.info("LoRA adapters: %d trainable params", manager.count_adapter_params("general")) |
|
|
| |
| if not os.path.exists(data_path): |
| logger.error("Dataset not found: %s", data_path) |
| logger.info("Run: python scripts/download_datasets.py") |
| return |
|
|
| dataset = InstructionDataset(data_path, tokenizer, max_length) |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
| |
| lora_params = [] |
| for name, module in model.named_modules(): |
| if hasattr(module, "lora_A") and hasattr(module, "lora_B"): |
| lora_params.extend([module.lora_A, module.lora_B]) |
|
|
| optimizer = torch.optim.AdamW(lora_params, lr=learning_rate) |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, num_warmup_steps=warmup_steps, num_training_steps=steps |
| ) |
|
|
| |
| logger.info("Starting training: %d steps, batch_size=%d, lr=%.1e", steps, batch_size, learning_rate) |
| model.train() |
| global_step = 0 |
| epoch = 0 |
| losses = [] |
|
|
| while global_step < steps: |
| epoch += 1 |
| for batch in loader: |
| if global_step >= steps: |
| break |
|
|
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| loss = outputs.loss |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(lora_params, 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
|
|
| losses.append(loss.item()) |
| global_step += 1 |
|
|
| if global_step % 10 == 0: |
| avg_loss = sum(losses[-10:]) / min(10, len(losses)) |
| logger.info("Step %d/%d | loss=%.4f | lr=%.2e", global_step, steps, avg_loss, scheduler.get_last_lr()[0]) |
|
|
| |
| os.makedirs(save_path, exist_ok=True) |
| manager.save_adapter("general", save_path) |
| logger.info("Checkpoint saved: %s", save_path) |
|
|
| |
| meta = { |
| "base_model": model_path, |
| "lora_r": lora_r, |
| "lora_alpha": lora_alpha, |
| "steps": steps, |
| "final_loss": sum(losses[-10:]) / min(10, len(losses)), |
| "trainable_params": manager.count_adapter_params("general"), |
| } |
| with open(os.path.join(save_path, "bee_legacy_adapter_config.json"), "w") as f: |
| json.dump(meta, f, indent=2) |
|
|
| return model, tokenizer, manager |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train Bee LoRA on real instruction data") |
| parser.add_argument("--data", default="./datasets/train_mixed.jsonl", help="Path to instruction JSONL") |
| parser.add_argument("--model", default=DEFAULT_MODEL_PROFILE, help="Model profile, local path, or HF ID") |
| parser.add_argument("--device", default="mps" if torch.backends.mps.is_available() else "cpu") |
| parser.add_argument("--lora_r", type=int, default=16) |
| parser.add_argument("--lora_alpha", type=int, default=32) |
| parser.add_argument("--steps", type=int, default=100) |
| parser.add_argument("--batch_size", type=int, default=1) |
| parser.add_argument("--lr", type=float, default=2e-4) |
| parser.add_argument("--save_path", default="./lora_checkpoints/general") |
| args = parser.parse_args() |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| ) |
|
|
| train( |
| data_path=args.data, |
| model_path=args.model, |
| device=args.device, |
| lora_r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| steps=args.steps, |
| batch_size=args.batch_size, |
| learning_rate=args.lr, |
| save_path=args.save_path, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|