bee / scripts /cross_model_learn.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""Cross-Model Learning — Bee learns from multiple teacher LLMs simultaneously.
Queries OpenAI, Anthropic, and local models for the same prompt,
distills their consensus into Bee through multi-teacher distillation.
This is how Bee learns from Claude, GPT-4, Gemini, etc. without
needing their weights.
Requires OPENAI_API_KEY and/or ANTHROPIC_API_KEY env vars.
Falls back to local models if APIs unavailable.
"""
import argparse
import json
import logging
import os
import sys
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.cross_model")
def query_openai(prompt, model="gpt-3.5-turbo"):
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return None
try:
import openai
client = openai.OpenAI(api_key=api_key)
resp = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=256,
)
return resp.choices[0].message.content
except Exception as e:
logger.warning("OpenAI query failed: %s", e)
return None
def query_anthropic(prompt, model="claude-3-haiku-20240307"):
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
return None
try:
import anthropic
client = anthropic.Anthropic(api_key=api_key)
resp = client.messages.create(
model=model,
max_tokens=256,
messages=[{"role": "user", "content": prompt}],
)
return resp.content[0].text
except Exception as e:
logger.warning("Anthropic query failed: %s", e)
return None
def query_local(prompt, model_id="HuggingFaceTB/SmolLM2-135M", device="cpu"):
"""Query a local model as a teacher."""
try:
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
inputs = tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7)
return tok.decode(out[0], skip_special_tokens=True)
except Exception as e:
logger.warning("Local model query failed: %s", e)
return None
def distill_from_texts(student, tokenizer, texts, device, learning_rate=5e-4, steps_per_text=5):
"""Distill from teacher-generated text strings into student."""
optimizer = torch.optim.AdamW(student.parameters(), lr=learning_rate)
student.train()
total_loss = 0.0
n = 0
for text in texts:
if not text:
continue
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
if inputs["input_ids"].shape[1] < 4:
continue
for _ in range(steps_per_text):
optimizer.zero_grad()
out = student(**inputs)
logits = out.logits if hasattr(out, "logits") else out[0]
shift_logits = logits[:, :-1, :].contiguous().view(-1, logits.size(-1))
shift_labels = inputs["input_ids"][:, 1:].contiguous().view(-1)
loss = F.cross_entropy(shift_logits, shift_labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
n += 1
return total_loss / max(n, 1)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--student_config", type=str, default="nano",
choices=["nano", "tiny"], help="Student size")
parser.add_argument("--num_queries", type=int, default=20)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu")
parser.add_argument("--local_teacher", type=str, default="HuggingFaceTB/SmolLM2-135M")
parser.add_argument("--use_openai", action="store_true")
parser.add_argument("--use_anthropic", action="store_true")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Init student
if args.student_config == "nano":
cfg = BeeConfig(vocab_size=49152, hidden_size=512, num_hidden_layers=8,
num_attention_heads=8, intermediate_size=1024, max_position_embeddings=2048)
else:
cfg = BeeConfig(vocab_size=49152, hidden_size=1024, num_hidden_layers=16,
num_attention_heads=16, intermediate_size=2816, max_position_embeddings=4096)
student = BeeForCausalLM(cfg).to(args.device)
n_params = sum(p.numel() for p in student.parameters())
logger.info("Student params: %.2fM", n_params / 1e6)
# Use SmolLM tokenizer (vocab compatible)
tok = AutoTokenizer.from_pretrained(args.local_teacher, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# Load prompts from TinyStories
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
ds = ds.take(args.num_queries)
results = []
all_teacher_texts = []
for i, ex in enumerate(ds):
prompt = ex["text"][:128] # Use first 128 chars as prompt
logger.info("Query %d/%d: prompt='%s...'", i + 1, args.num_queries, prompt[:40])
responses = {}
if args.use_openai:
r = query_openai(prompt)
if r:
responses["openai"] = r
if args.use_anthropic:
r = query_anthropic(prompt)
if r:
responses["anthropic"] = r
# Always query local teacher
r = query_local(prompt, args.local_teacher, args.device)
if r:
responses["local"] = r
logger.info(" Got %d teacher responses", len(responses))
for src, txt in responses.items():
all_teacher_texts.append(txt)
results.append({"step": i, "source": src, "prompt": prompt, "response": txt})
# Incremental distillation every 5 queries
if (i + 1) % 5 == 0 and all_teacher_texts:
logger.info(" Distilling from %d teacher texts...", len(all_teacher_texts))
avg_loss = distill_from_texts(student, tok, all_teacher_texts, args.device)
logger.info(" Avg loss: %.4f", avg_loss)
all_teacher_texts = [] # Clear to avoid re-distilling
# Final save
student.save_pretrained(args.output_dir)
tok.save_pretrained(args.output_dir)
with open(os.path.join(args.output_dir, "cross_model_log.json"), "w") as f:
json.dump(results, f, indent=2)
logger.info("Cross-model learning complete. Model saved to %s", args.output_dir)
logger.info("Total teacher responses collected: %d", len(results))
if __name__ == "__main__":
main()