"""Cross-Model Learning — Bee learns from multiple teacher LLMs simultaneously. Queries OpenAI, Anthropic, and local models for the same prompt, distills their consensus into Bee through multi-teacher distillation. This is how Bee learns from Claude, GPT-4, Gemini, etc. without needing their weights. Requires OPENAI_API_KEY and/or ANTHROPIC_API_KEY env vars. Falls back to local models if APIs unavailable. """ import argparse import json import logging import os import sys import time from pathlib import Path import torch import torch.nn.functional as F from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.register import register from bee.config import BeeConfig from bee.modeling_bee import BeeForCausalLM register() logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") logger = logging.getLogger("bee.cross_model") def query_openai(prompt, model="gpt-3.5-turbo"): api_key = os.environ.get("OPENAI_API_KEY") if not api_key: return None try: import openai client = openai.OpenAI(api_key=api_key) resp = client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], temperature=0.7, max_tokens=256, ) return resp.choices[0].message.content except Exception as e: logger.warning("OpenAI query failed: %s", e) return None def query_anthropic(prompt, model="claude-3-haiku-20240307"): api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: return None try: import anthropic client = anthropic.Anthropic(api_key=api_key) resp = client.messages.create( model=model, max_tokens=256, messages=[{"role": "user", "content": prompt}], ) return resp.content[0].text except Exception as e: logger.warning("Anthropic query failed: %s", e) return None def query_local(prompt, model_id="HuggingFaceTB/SmolLM2-135M", device="cpu"): """Query a local model as a teacher.""" try: tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device).eval() inputs = tok(prompt, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7) return tok.decode(out[0], skip_special_tokens=True) except Exception as e: logger.warning("Local model query failed: %s", e) return None def distill_from_texts(student, tokenizer, texts, device, learning_rate=5e-4, steps_per_text=5): """Distill from teacher-generated text strings into student.""" optimizer = torch.optim.AdamW(student.parameters(), lr=learning_rate) student.train() total_loss = 0.0 n = 0 for text in texts: if not text: continue inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device) if inputs["input_ids"].shape[1] < 4: continue for _ in range(steps_per_text): optimizer.zero_grad() out = student(**inputs) logits = out.logits if hasattr(out, "logits") else out[0] shift_logits = logits[:, :-1, :].contiguous().view(-1, logits.size(-1)) shift_labels = inputs["input_ids"][:, 1:].contiguous().view(-1) loss = F.cross_entropy(shift_logits, shift_labels) loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) optimizer.step() total_loss += loss.item() n += 1 return total_loss / max(n, 1) def main(): parser = argparse.ArgumentParser() parser.add_argument("--student_config", type=str, default="nano", choices=["nano", "tiny"], help="Student size") parser.add_argument("--num_queries", type=int, default=20) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--device", type=str, default="mps" if torch.backends.mps.is_available() else "cpu") parser.add_argument("--local_teacher", type=str, default="HuggingFaceTB/SmolLM2-135M") parser.add_argument("--use_openai", action="store_true") parser.add_argument("--use_anthropic", action="store_true") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) # Init student if args.student_config == "nano": cfg = BeeConfig(vocab_size=49152, hidden_size=512, num_hidden_layers=8, num_attention_heads=8, intermediate_size=1024, max_position_embeddings=2048) else: cfg = BeeConfig(vocab_size=49152, hidden_size=1024, num_hidden_layers=16, num_attention_heads=16, intermediate_size=2816, max_position_embeddings=4096) student = BeeForCausalLM(cfg).to(args.device) n_params = sum(p.numel() for p in student.parameters()) logger.info("Student params: %.2fM", n_params / 1e6) # Use SmolLM tokenizer (vocab compatible) tok = AutoTokenizer.from_pretrained(args.local_teacher, trust_remote_code=True) if tok.pad_token is None: tok.pad_token = tok.eos_token # Load prompts from TinyStories ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True) ds = ds.take(args.num_queries) results = [] all_teacher_texts = [] for i, ex in enumerate(ds): prompt = ex["text"][:128] # Use first 128 chars as prompt logger.info("Query %d/%d: prompt='%s...'", i + 1, args.num_queries, prompt[:40]) responses = {} if args.use_openai: r = query_openai(prompt) if r: responses["openai"] = r if args.use_anthropic: r = query_anthropic(prompt) if r: responses["anthropic"] = r # Always query local teacher r = query_local(prompt, args.local_teacher, args.device) if r: responses["local"] = r logger.info(" Got %d teacher responses", len(responses)) for src, txt in responses.items(): all_teacher_texts.append(txt) results.append({"step": i, "source": src, "prompt": prompt, "response": txt}) # Incremental distillation every 5 queries if (i + 1) % 5 == 0 and all_teacher_texts: logger.info(" Distilling from %d teacher texts...", len(all_teacher_texts)) avg_loss = distill_from_texts(student, tok, all_teacher_texts, args.device) logger.info(" Avg loss: %.4f", avg_loss) all_teacher_texts = [] # Clear to avoid re-distilling # Final save student.save_pretrained(args.output_dir) tok.save_pretrained(args.output_dir) with open(os.path.join(args.output_dir, "cross_model_log.json"), "w") as f: json.dump(results, f, indent=2) logger.info("Cross-model learning complete. Model saved to %s", args.output_dir) logger.info("Total teacher responses collected: %d", len(results)) if __name__ == "__main__": main()