| """Simple CLI inference for Bee.""" |
|
|
| import argparse |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.inference") |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description="Run inference with Bee") |
| parser.add_argument("--model_path", type=str, required=True, help="Path to Bee checkpoint") |
| parser.add_argument("--prompt", type=str, default="Once upon a time, ") |
| parser.add_argument("--max_new_tokens", type=int, default=100) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top_p", type=float, default=0.95) |
| parser.add_argument("--repetition_penalty", type=float, default=1.1) |
| parser.add_argument("--device", type=str, default="auto") |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
| logger.info("Loading model from %s", args.model_path) |
|
|
| model = BeeForCausalLM.from_pretrained(args.model_path) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| else: |
| device = args.device |
| model = model.to(device) |
| model.eval() |
|
|
| inputs = tokenizer(args.prompt, return_tensors="pt").to(device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=True, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| print("\n=== Generated Text ===\n") |
| print(decoded) |
| print("\n======================\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|