bee / scripts /train_lora.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
#!/usr/bin/env python3
"""Train Bee LoRA adapters on real instruction data.
Loads pretrained model + instruction datasets, trains LoRA adapters,
saves checkpoint, optionally evaluates before/after.
Usage (MacBook, slow):
python scripts/train_lora.py --data ./datasets/train_mixed.jsonl --steps 100 --device mps
Usage (GPU cloud):
python scripts/train_lora.py --data ./datasets/train_mixed.jsonl --steps 1000 --batch_size 4 --device cuda
"""
import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.lora_adapter import DomainLoRAManager, LoRAConfig
from bee.model_profiles import DEFAULT_MODEL_PROFILE, resolve_model_id
logger = logging.getLogger("bee.train")
class InstructionDataset(Dataset):
"""Simple instruction-following dataset from JSONL."""
def __init__(self, data_path: str, tokenizer, max_length: int = 512):
self.samples = []
self.tokenizer = tokenizer
self.max_length = max_length
with open(data_path) as f:
for line in f:
ex = json.loads(line)
instruction = ex.get("instruction", "")
input_text = ex.get("input", "")
output = ex.get("output", "")
# Use chat template if available
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
user_msg = instruction
if input_text:
user_msg += f"\n\n{input_text}"
chat = [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": output},
]
text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
else:
text = f"### Instruction:\n{instruction}\n### Input:\n{input_text}\n### Response:\n{output}"
self.samples.append(text)
logger.info("Loaded %d instruction samples from %s", len(self.samples), data_path)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
text = self.samples[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# Labels = input_ids for causal LM (shifted internally)
labels = input_ids.clone()
labels[attention_mask == 0] = -100
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
def train(
data_path: str,
model_path: str = DEFAULT_MODEL_PROFILE,
device: str = "mps",
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
steps: int = 100,
batch_size: int = 1,
learning_rate: float = 5e-4,
warmup_steps: int = 10,
max_length: int = 512,
save_path: str = "./lora_checkpoints/general",
eval_before: bool = True,
):
model_path = resolve_model_id(model_path)
# Load model
logger.info("Loading model: %s", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Use float32 for training (float16 causes NaN on MPS with LoRA)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
).to(device)
# Setup LoRA
lora_cfg = LoRAConfig(r=lora_r, alpha=lora_alpha, dropout=lora_dropout)
manager = DomainLoRAManager(model, lora_cfg)
manager.add_adapter("general")
manager.activate_domain("general")
logger.info("LoRA adapters: %d trainable params", manager.count_adapter_params("general"))
# Load data
if not os.path.exists(data_path):
logger.error("Dataset not found: %s", data_path)
logger.info("Run: python scripts/download_datasets.py")
return
dataset = InstructionDataset(data_path, tokenizer, max_length)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Optimizer: only LoRA params
lora_params = []
for name, module in model.named_modules():
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
lora_params.extend([module.lora_A, module.lora_B])
optimizer = torch.optim.AdamW(lora_params, lr=learning_rate)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=steps
)
# Training loop
logger.info("Starting training: %d steps, batch_size=%d, lr=%.1e", steps, batch_size, learning_rate)
model.train()
global_step = 0
epoch = 0
losses = []
while global_step < steps:
epoch += 1
for batch in loader:
if global_step >= steps:
break
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(lora_params, 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
losses.append(loss.item())
global_step += 1
if global_step % 10 == 0:
avg_loss = sum(losses[-10:]) / min(10, len(losses))
logger.info("Step %d/%d | loss=%.4f | lr=%.2e", global_step, steps, avg_loss, scheduler.get_last_lr()[0])
# Save
os.makedirs(save_path, exist_ok=True)
manager.save_adapter("general", save_path)
logger.info("Checkpoint saved: %s", save_path)
# Save adapter metadata
meta = {
"base_model": model_path,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"steps": steps,
"final_loss": sum(losses[-10:]) / min(10, len(losses)),
"trainable_params": manager.count_adapter_params("general"),
}
with open(os.path.join(save_path, "bee_legacy_adapter_config.json"), "w") as f:
json.dump(meta, f, indent=2)
return model, tokenizer, manager
def main():
parser = argparse.ArgumentParser(description="Train Bee LoRA on real instruction data")
parser.add_argument("--data", default="./datasets/train_mixed.jsonl", help="Path to instruction JSONL")
parser.add_argument("--model", default=DEFAULT_MODEL_PROFILE, help="Model profile, local path, or HF ID")
parser.add_argument("--device", default="mps" if torch.backends.mps.is_available() else "cpu")
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--save_path", default="./lora_checkpoints/general")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
train(
data_path=args.data,
model_path=args.model,
device=args.device,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
steps=args.steps,
batch_size=args.batch_size,
learning_rate=args.lr,
save_path=args.save_path,
)
if __name__ == "__main__":
main()