bee / scripts /inference.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""Simple CLI inference for Bee."""
import argparse
import logging
import sys
from pathlib import Path
import torch
from transformers import AutoTokenizer
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.inference")
def get_args():
parser = argparse.ArgumentParser(description="Run inference with Bee")
parser.add_argument("--model_path", type=str, required=True, help="Path to Bee checkpoint")
parser.add_argument("--prompt", type=str, default="Once upon a time, ")
parser.add_argument("--max_new_tokens", type=int, default=100)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--repetition_penalty", type=float, default=1.1)
parser.add_argument("--device", type=str, default="auto")
return parser.parse_args()
def main():
args = get_args()
logger.info("Loading model from %s", args.model_path)
model = BeeForCausalLM.from_pretrained(args.model_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
else:
device = args.device
model = model.to(device)
model.eval()
inputs = tokenizer(args.prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=True,
temperature=args.temperature,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n=== Generated Text ===\n")
print(decoded)
print("\n======================\n")
if __name__ == "__main__":
main()