"""FastAPI server for Bee inference.""" import argparse import logging import os import sys import time import uuid from pathlib import Path from contextlib import asynccontextmanager import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from transformers import AutoTokenizer import uvicorn sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.register import register from bee.modeling_bee import BeeForCausalLM register() logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") logger = logging.getLogger("bee.server") MODEL = None TOKENIZER = None DEVICE = None def load_model(model_path: str, device: str = "auto"): global MODEL, TOKENIZER, DEVICE if device == "auto": DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" else: DEVICE = device logger.info("Loading Bee model from %s onto %s", model_path, DEVICE) TOKENIZER = AutoTokenizer.from_pretrained(model_path) if TOKENIZER.pad_token is None: TOKENIZER.pad_token = TOKENIZER.eos_token MODEL = BeeForCausalLM.from_pretrained(model_path).to(DEVICE) MODEL.eval() logger.info("Model loaded. Parameters: %.2fM", sum(p.numel() for p in MODEL.parameters()) / 1e6) @asynccontextmanager async def lifespan(app: FastAPI): model_path = os.environ.get("BEE_MODEL_PATH", "") device = os.environ.get("BEE_DEVICE", "auto") if not model_path: logger.error("BEE_MODEL_PATH not set. Server will fail requests.") else: load_model(model_path, device) yield logger.info("Shutting down Bee server.") app = FastAPI(title="Bee LLM API", version="0.1.0", lifespan=lifespan) class GenerateRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=8192, description="Input prompt") max_new_tokens: int = Field(default=256, ge=1, le=4096) temperature: float = Field(default=0.8, ge=0.0, le=2.0) top_p: float = Field(default=0.95, ge=0.0, le=1.0) repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0) class GenerateResponse(BaseModel): request_id: str generated_text: str prompt_tokens: int completion_tokens: int total_tokens: int model: str duration_ms: float @app.get("/health") async def health(): if MODEL is None: raise HTTPException(status_code=503, detail="Model not loaded") return {"status": "ok", "model": "bee", "device": DEVICE} @app.post("/v1/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest): if MODEL is None or TOKENIZER is None: raise HTTPException(status_code=503, detail="Model not loaded") request_id = str(uuid.uuid4()) start = time.perf_counter() inputs = TOKENIZER(req.prompt, return_tensors="pt").to(DEVICE) prompt_tokens = inputs["input_ids"].shape[1] with torch.no_grad(): outputs = MODEL.generate( **inputs, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=req.temperature, top_p=req.top_p, repetition_penalty=req.repetition_penalty, pad_token_id=TOKENIZER.pad_token_id, eos_token_id=TOKENIZER.eos_token_id, ) completion_tokens = outputs.shape[1] - prompt_tokens generated_text = TOKENIZER.decode(outputs[0][prompt_tokens:], skip_special_tokens=True) duration_ms = (time.perf_counter() - start) * 1000 return GenerateResponse( request_id=request_id, generated_text=generated_text, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, model="bee", duration_ms=duration_ms, ) def get_args(): parser = argparse.ArgumentParser(description="Serve Bee via FastAPI") parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--device", type=str, default="auto") return parser.parse_args() def main(): args = get_args() os.environ["BEE_MODEL_PATH"] = args.model_path os.environ["BEE_DEVICE"] = args.device uvicorn.run("scripts.server:app", host=args.host, port=args.port, reload=False) if __name__ == "__main__": main()