| """Cross-Model Learning — Bee learns from multiple teacher LLMs simultaneously. |
| |
| Queries OpenAI, Anthropic, and local models for the same prompt, |
| distills their consensus into Bee through multi-teacher distillation. |
| This is how Bee learns from Claude, GPT-4, Gemini, etc. without |
| needing their weights. |
| |
| Requires OPENAI_API_KEY and/or ANTHROPIC_API_KEY env vars. |
| Falls back to local models if APIs unavailable. |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.config import BeeConfig |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.cross_model") |
|
|
|
|
| def query_openai(prompt, model="gpt-3.5-turbo"): |
| api_key = os.environ.get("OPENAI_API_KEY") |
| if not api_key: |
| return None |
| try: |
| import openai |
| client = openai.OpenAI(api_key=api_key) |
| resp = client.chat.completions.create( |
| model=model, |
| messages=[{"role": "user", "content": prompt}], |
| temperature=0.7, |
| max_tokens=256, |
| ) |
| return resp.choices[0].message.content |
| except Exception as e: |
| logger.warning("OpenAI query failed: %s", e) |
| return None |
|
|
|
|
| def query_anthropic(prompt, model="claude-3-haiku-20240307"): |
| api_key = os.environ.get("ANTHROPIC_API_KEY") |
| if not api_key: |
| return None |
| try: |
| import anthropic |
| client = anthropic.Anthropic(api_key=api_key) |
| resp = client.messages.create( |
| model=model, |
| max_tokens=256, |
| messages=[{"role": "user", "content": prompt}], |
| ) |
| return resp.content[0].text |
| except Exception as e: |
| logger.warning("Anthropic query failed: %s", e) |
| return None |
|
|
|
|
| def query_local(prompt, model_id="HuggingFaceTB/SmolLM2-135M", device="cpu"): |
| """Query a local model as a teacher.""" |
| try: |
| tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device).eval() |
| inputs = tok(prompt, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| out = model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7) |
| return tok.decode(out[0], skip_special_tokens=True) |
| except Exception as e: |
| logger.warning("Local model query failed: %s", e) |
| return None |
|
|
|
|
| def distill_from_texts(student, tokenizer, texts, device, learning_rate=5e-4, steps_per_text=5): |
| """Distill from teacher-generated text strings into student.""" |
| optimizer = torch.optim.AdamW(student.parameters(), lr=learning_rate) |
| student.train() |
| total_loss = 0.0 |
| n = 0 |
|
|
| for text in texts: |
| if not text: |
| continue |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device) |
| if inputs["input_ids"].shape[1] < 4: |
| continue |
|
|
| for _ in range(steps_per_text): |
| optimizer.zero_grad() |
| out = student(**inputs) |
| logits = out.logits if hasattr(out, "logits") else out[0] |
| shift_logits = logits[:, :-1, :].contiguous().view(-1, logits.size(-1)) |
| shift_labels = inputs["input_ids"][:, 1:].contiguous().view(-1) |
| loss = F.cross_entropy(shift_logits, shift_labels) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) |
| optimizer.step() |
| total_loss += loss.item() |
| n += 1 |
|
|
| return total_loss / max(n, 1) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--student_config", type=str, default="nano", |
| choices=["nano", "tiny"], help="Student size") |
| parser.add_argument("--num_queries", type=int, default=20) |
| parser.add_argument("--output_dir", type=str, required=True) |
| parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu") |
| parser.add_argument("--local_teacher", type=str, default="HuggingFaceTB/SmolLM2-135M") |
| parser.add_argument("--use_openai", action="store_true") |
| parser.add_argument("--use_anthropic", action="store_true") |
| args = parser.parse_args() |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| if args.student_config == "nano": |
| cfg = BeeConfig(vocab_size=49152, hidden_size=512, num_hidden_layers=8, |
| num_attention_heads=8, intermediate_size=1024, max_position_embeddings=2048) |
| else: |
| cfg = BeeConfig(vocab_size=49152, hidden_size=1024, num_hidden_layers=16, |
| num_attention_heads=16, intermediate_size=2816, max_position_embeddings=4096) |
|
|
| student = BeeForCausalLM(cfg).to(args.device) |
| n_params = sum(p.numel() for p in student.parameters()) |
| logger.info("Student params: %.2fM", n_params / 1e6) |
|
|
| |
| tok = AutoTokenizer.from_pretrained(args.local_teacher, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
|
|
| |
| ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True) |
| ds = ds.take(args.num_queries) |
|
|
| results = [] |
| all_teacher_texts = [] |
|
|
| for i, ex in enumerate(ds): |
| prompt = ex["text"][:128] |
| logger.info("Query %d/%d: prompt='%s...'", i + 1, args.num_queries, prompt[:40]) |
|
|
| responses = {} |
| if args.use_openai: |
| r = query_openai(prompt) |
| if r: |
| responses["openai"] = r |
| if args.use_anthropic: |
| r = query_anthropic(prompt) |
| if r: |
| responses["anthropic"] = r |
|
|
| |
| r = query_local(prompt, args.local_teacher, args.device) |
| if r: |
| responses["local"] = r |
|
|
| logger.info(" Got %d teacher responses", len(responses)) |
| for src, txt in responses.items(): |
| all_teacher_texts.append(txt) |
| results.append({"step": i, "source": src, "prompt": prompt, "response": txt}) |
|
|
| |
| if (i + 1) % 5 == 0 and all_teacher_texts: |
| logger.info(" Distilling from %d teacher texts...", len(all_teacher_texts)) |
| avg_loss = distill_from_texts(student, tok, all_teacher_texts, args.device) |
| logger.info(" Avg loss: %.4f", avg_loss) |
| all_teacher_texts = [] |
|
|
| |
| student.save_pretrained(args.output_dir) |
| tok.save_pretrained(args.output_dir) |
| with open(os.path.join(args.output_dir, "cross_model_log.json"), "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| logger.info("Cross-model learning complete. Model saved to %s", args.output_dir) |
| logger.info("Total teacher responses collected: %d", len(results)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|