| """FastAPI server for Bee inference.""" |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| import time |
| import uuid |
| from pathlib import Path |
| from contextlib import asynccontextmanager |
|
|
| import torch |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, Field |
| from transformers import AutoTokenizer |
| import uvicorn |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.server") |
|
|
| MODEL = None |
| TOKENIZER = None |
| DEVICE = None |
|
|
|
|
| def load_model(model_path: str, device: str = "auto"): |
| global MODEL, TOKENIZER, DEVICE |
| if device == "auto": |
| DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| else: |
| DEVICE = device |
| logger.info("Loading Bee model from %s onto %s", model_path, DEVICE) |
| TOKENIZER = AutoTokenizer.from_pretrained(model_path) |
| if TOKENIZER.pad_token is None: |
| TOKENIZER.pad_token = TOKENIZER.eos_token |
| MODEL = BeeForCausalLM.from_pretrained(model_path).to(DEVICE) |
| MODEL.eval() |
| logger.info("Model loaded. Parameters: %.2fM", sum(p.numel() for p in MODEL.parameters()) / 1e6) |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| model_path = os.environ.get("BEE_MODEL_PATH", "") |
| device = os.environ.get("BEE_DEVICE", "auto") |
| if not model_path: |
| logger.error("BEE_MODEL_PATH not set. Server will fail requests.") |
| else: |
| load_model(model_path, device) |
| yield |
| logger.info("Shutting down Bee server.") |
|
|
|
|
| app = FastAPI(title="Bee LLM API", version="0.1.0", lifespan=lifespan) |
|
|
|
|
| class GenerateRequest(BaseModel): |
| prompt: str = Field(..., min_length=1, max_length=8192, description="Input prompt") |
| max_new_tokens: int = Field(default=256, ge=1, le=4096) |
| temperature: float = Field(default=0.8, ge=0.0, le=2.0) |
| top_p: float = Field(default=0.95, ge=0.0, le=1.0) |
| repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0) |
|
|
|
|
| class GenerateResponse(BaseModel): |
| request_id: str |
| generated_text: str |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
| model: str |
| duration_ms: float |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| if MODEL is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| return {"status": "ok", "model": "bee", "device": DEVICE} |
|
|
|
|
| @app.post("/v1/generate", response_model=GenerateResponse) |
| async def generate(req: GenerateRequest): |
| if MODEL is None or TOKENIZER is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| request_id = str(uuid.uuid4()) |
| start = time.perf_counter() |
|
|
| inputs = TOKENIZER(req.prompt, return_tensors="pt").to(DEVICE) |
| prompt_tokens = inputs["input_ids"].shape[1] |
|
|
| with torch.no_grad(): |
| outputs = MODEL.generate( |
| **inputs, |
| max_new_tokens=req.max_new_tokens, |
| do_sample=True, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| repetition_penalty=req.repetition_penalty, |
| pad_token_id=TOKENIZER.pad_token_id, |
| eos_token_id=TOKENIZER.eos_token_id, |
| ) |
|
|
| completion_tokens = outputs.shape[1] - prompt_tokens |
| generated_text = TOKENIZER.decode(outputs[0][prompt_tokens:], skip_special_tokens=True) |
| duration_ms = (time.perf_counter() - start) * 1000 |
|
|
| return GenerateResponse( |
| request_id=request_id, |
| generated_text=generated_text, |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=prompt_tokens + completion_tokens, |
| model="bee", |
| duration_ms=duration_ms, |
| ) |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description="Serve Bee via FastAPI") |
| parser.add_argument("--model_path", type=str, required=True) |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=8000) |
| parser.add_argument("--device", type=str, default="auto") |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
| os.environ["BEE_MODEL_PATH"] = args.model_path |
| os.environ["BEE_DEVICE"] = args.device |
| uvicorn.run("scripts.server:app", host=args.host, port=args.port, reload=False) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|