"""Simple CLI inference for Bee.""" import argparse import logging import sys from pathlib import Path import torch from transformers import AutoTokenizer sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.register import register from bee.modeling_bee import BeeForCausalLM register() logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") logger = logging.getLogger("bee.inference") def get_args(): parser = argparse.ArgumentParser(description="Run inference with Bee") parser.add_argument("--model_path", type=str, required=True, help="Path to Bee checkpoint") parser.add_argument("--prompt", type=str, default="Once upon a time, ") parser.add_argument("--max_new_tokens", type=int, default=100) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top_p", type=float, default=0.95) parser.add_argument("--repetition_penalty", type=float, default=1.1) parser.add_argument("--device", type=str, default="auto") return parser.parse_args() def main(): args = get_args() logger.info("Loading model from %s", args.model_path) model = BeeForCausalLM.from_pretrained(args.model_path) tokenizer = AutoTokenizer.from_pretrained(args.model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" else: device = args.device model = model.to(device) model.eval() inputs = tokenizer(args.prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=True, temperature=args.temperature, top_p=args.top_p, repetition_penalty=args.repetition_penalty, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) print("\n=== Generated Text ===\n") print(decoded) print("\n======================\n") if __name__ == "__main__": main()