bee / scripts /server.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""FastAPI server for Bee inference."""
import argparse
import logging
import os
import sys
import time
import uuid
from pathlib import Path
from contextlib import asynccontextmanager
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
import uvicorn
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.server")
MODEL = None
TOKENIZER = None
DEVICE = None
def load_model(model_path: str, device: str = "auto"):
global MODEL, TOKENIZER, DEVICE
if device == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
else:
DEVICE = device
logger.info("Loading Bee model from %s onto %s", model_path, DEVICE)
TOKENIZER = AutoTokenizer.from_pretrained(model_path)
if TOKENIZER.pad_token is None:
TOKENIZER.pad_token = TOKENIZER.eos_token
MODEL = BeeForCausalLM.from_pretrained(model_path).to(DEVICE)
MODEL.eval()
logger.info("Model loaded. Parameters: %.2fM", sum(p.numel() for p in MODEL.parameters()) / 1e6)
@asynccontextmanager
async def lifespan(app: FastAPI):
model_path = os.environ.get("BEE_MODEL_PATH", "")
device = os.environ.get("BEE_DEVICE", "auto")
if not model_path:
logger.error("BEE_MODEL_PATH not set. Server will fail requests.")
else:
load_model(model_path, device)
yield
logger.info("Shutting down Bee server.")
app = FastAPI(title="Bee LLM API", version="0.1.0", lifespan=lifespan)
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=8192, description="Input prompt")
max_new_tokens: int = Field(default=256, ge=1, le=4096)
temperature: float = Field(default=0.8, ge=0.0, le=2.0)
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0)
class GenerateResponse(BaseModel):
request_id: str
generated_text: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
model: str
duration_ms: float
@app.get("/health")
async def health():
if MODEL is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "ok", "model": "bee", "device": DEVICE}
@app.post("/v1/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest):
if MODEL is None or TOKENIZER is None:
raise HTTPException(status_code=503, detail="Model not loaded")
request_id = str(uuid.uuid4())
start = time.perf_counter()
inputs = TOKENIZER(req.prompt, return_tensors="pt").to(DEVICE)
prompt_tokens = inputs["input_ids"].shape[1]
with torch.no_grad():
outputs = MODEL.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
pad_token_id=TOKENIZER.pad_token_id,
eos_token_id=TOKENIZER.eos_token_id,
)
completion_tokens = outputs.shape[1] - prompt_tokens
generated_text = TOKENIZER.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
duration_ms = (time.perf_counter() - start) * 1000
return GenerateResponse(
request_id=request_id,
generated_text=generated_text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
model="bee",
duration_ms=duration_ms,
)
def get_args():
parser = argparse.ArgumentParser(description="Serve Bee via FastAPI")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--device", type=str, default="auto")
return parser.parse_args()
def main():
args = get_args()
os.environ["BEE_MODEL_PATH"] = args.model_path
os.environ["BEE_DEVICE"] = args.device
uvicorn.run("scripts.server:app", host=args.host, port=args.port, reload=False)
if __name__ == "__main__":
main()