bee / scripts /train_remote.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
#!/usr/bin/env python3
"""Remote training script for Bee — runs on GPU cloud (RunPod, Vast.ai, Lambda, Colab).
Downloads autopilot checkpoints from your MacBook via HuggingFace Hub,
trains LoRA adapters on GPU, uploads results back.
Usage on GPU instance:
pip install -r requirements.txt
export HF_TOKEN=your_huggingface_token
python train_remote.py --model_id your-username/bee-checkpoint --iterations 1000
Environment:
HF_TOKEN HuggingFace token for push/pull
BEE_HUB_ID HF Hub repo ID (e.g., "cfrost/bee")
WANDB_PROJECT Optional Weights & Biases project
"""
import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
import torch
from huggingface_hub import HfApi, hf_hub_download, upload_file
from transformers import AutoTokenizer
sys.path.insert(0, str(Path(__file__).resolve().parent))
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM
from bee.lora_adapter import LoRAConfig
from bee.model_profiles import DEFAULT_MODEL_PROFILE, resolve_model_id
from scripts.autopilot import Autopilot
logger = logging.getLogger("bee.remote_train")
def download_checkpoint(hub_id: str, local_dir: str = "./checkpoint_in") -> str:
"""Pull latest checkpoint from HuggingFace Hub."""
api = HfApi()
files = api.list_repo_files(hub_id)
os.makedirs(local_dir, exist_ok=True)
for f in files:
if f.endswith(('.bin', '.safetensors', '.json', '.pt')):
logger.info("Downloading %s", f)
hf_hub_download(repo_id=hub_id, filename=f, local_dir=local_dir)
return local_dir
def upload_checkpoint(hub_id: str, checkpoint_dir: str):
"""Push trained checkpoint to HuggingFace Hub."""
api = HfApi()
for f in Path(checkpoint_dir).rglob("*"):
if f.is_file():
rel = f.relative_to(checkpoint_dir).as_posix()
logger.info("Uploading %s", rel)
upload_file(path_or_fileobj=str(f), path_in_repo=rel, repo_id=hub_id)
logger.info("Checkpoint uploaded to %s", hub_id)
def train(
hub_id: str,
iterations: int = 1000,
device: str = "cuda",
batch_size: int = 4,
learning_rate: float = 5e-4,
push_every: int = 50,
):
device = device if torch.cuda.is_available() else "cpu"
logger.info("Training on %s", device)
# Load model
model_path = resolve_model_id(os.getenv("BEE_MODEL_PROFILE") or os.getenv("BEE_MODEL_PATH") or DEFAULT_MODEL_PROFILE)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Exact architecture match
cfg = BeeConfig(
vocab_size=49152,
hidden_size=960,
num_hidden_layers=32,
num_attention_heads=15,
num_key_value_heads=5,
intermediate_size=2560,
max_position_embeddings=8192,
rms_norm_eps=1e-05,
tie_word_embeddings=False,
)
model = BeeForCausalLM(cfg).to(device)
# Transfer weights from pretrained
from bee.weight_transfer import transfer_weights
model = transfer_weights(model_path, cfg, device)
logger.info("Model loaded: %.1fM params", sum(p.numel() for p in model.parameters()) / 1e6)
# Autopilot
autopilot = Autopilot(
model=model,
tokenizer=tokenizer,
device=device,
domains=["general", "programming", "quantum", "cybersecurity", "fintech"],
lora_config=LoRAConfig(r=16, alpha=32, dropout=0.05),
checkpoint_dir="./remote_checkpoints",
use_quantum=False,
)
# Try loading previous checkpoint from Hub
try:
local_ckpt = download_checkpoint(hub_id)
autopilot.load_checkpoint(local_ckpt)
logger.info("Resumed from Hub checkpoint")
except Exception as e:
logger.warning("No checkpoint on Hub, starting fresh: %s", e)
# Training loop
start_iter = autopilot.step_count
for i in range(start_iter, start_iter + iterations):
domain = autopilot.domains[i % len(autopilot.domains)]
loss = autopilot.train_domain_adapter(
domain=domain,
num_steps=10,
batch_size=batch_size,
learning_rate=learning_rate,
use_synthetic=True,
)
logger.info("Iter %d | domain=%s | loss=%.4f", i, domain, loss)
# Save + push every N iterations
if i % push_every == 0 and i > 0:
ckpt_dir = f"./remote_checkpoints/iter_{i}"
autopilot.save_checkpoint(ckpt_dir)
upload_checkpoint(hub_id, ckpt_dir)
# Final save
final_dir = "./remote_checkpoints/iter_final"
autopilot.save_checkpoint(final_dir)
upload_checkpoint(hub_id, final_dir)
logger.info("Training complete. Final checkpoint: %s", final_dir)
def main():
parser = argparse.ArgumentParser(description="Bee Remote GPU Training")
parser.add_argument("--hub_id", default=os.getenv("BEE_HUB_ID", "cfrost/bee"), help="HF Hub repo ID")
parser.add_argument("--iterations", type=int, default=1000, help="Training iterations")
parser.add_argument("--device", default="cuda", help="Device (cuda/cpu)")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
parser.add_argument("--push_every", type=int, default=50, help="Push to Hub every N iterations")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
train(
hub_id=args.hub_id,
iterations=args.iterations,
device=args.device,
batch_size=args.batch_size,
learning_rate=args.lr,
push_every=args.push_every,
)
if __name__ == "__main__":
main()