File size: 4,457 Bytes
db82745 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """FastAPI server for Bee inference."""
import argparse
import logging
import os
import sys
import time
import uuid
from pathlib import Path
from contextlib import asynccontextmanager
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
import uvicorn
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.server")
MODEL = None
TOKENIZER = None
DEVICE = None
def load_model(model_path: str, device: str = "auto"):
global MODEL, TOKENIZER, DEVICE
if device == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
else:
DEVICE = device
logger.info("Loading Bee model from %s onto %s", model_path, DEVICE)
TOKENIZER = AutoTokenizer.from_pretrained(model_path)
if TOKENIZER.pad_token is None:
TOKENIZER.pad_token = TOKENIZER.eos_token
MODEL = BeeForCausalLM.from_pretrained(model_path).to(DEVICE)
MODEL.eval()
logger.info("Model loaded. Parameters: %.2fM", sum(p.numel() for p in MODEL.parameters()) / 1e6)
@asynccontextmanager
async def lifespan(app: FastAPI):
model_path = os.environ.get("BEE_MODEL_PATH", "")
device = os.environ.get("BEE_DEVICE", "auto")
if not model_path:
logger.error("BEE_MODEL_PATH not set. Server will fail requests.")
else:
load_model(model_path, device)
yield
logger.info("Shutting down Bee server.")
app = FastAPI(title="Bee LLM API", version="0.1.0", lifespan=lifespan)
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=8192, description="Input prompt")
max_new_tokens: int = Field(default=256, ge=1, le=4096)
temperature: float = Field(default=0.8, ge=0.0, le=2.0)
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0)
class GenerateResponse(BaseModel):
request_id: str
generated_text: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
model: str
duration_ms: float
@app.get("/health")
async def health():
if MODEL is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "ok", "model": "bee", "device": DEVICE}
@app.post("/v1/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest):
if MODEL is None or TOKENIZER is None:
raise HTTPException(status_code=503, detail="Model not loaded")
request_id = str(uuid.uuid4())
start = time.perf_counter()
inputs = TOKENIZER(req.prompt, return_tensors="pt").to(DEVICE)
prompt_tokens = inputs["input_ids"].shape[1]
with torch.no_grad():
outputs = MODEL.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
pad_token_id=TOKENIZER.pad_token_id,
eos_token_id=TOKENIZER.eos_token_id,
)
completion_tokens = outputs.shape[1] - prompt_tokens
generated_text = TOKENIZER.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
duration_ms = (time.perf_counter() - start) * 1000
return GenerateResponse(
request_id=request_id,
generated_text=generated_text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
model="bee",
duration_ms=duration_ms,
)
def get_args():
parser = argparse.ArgumentParser(description="Serve Bee via FastAPI")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--device", type=str, default="auto")
return parser.parse_args()
def main():
args = get_args()
os.environ["BEE_MODEL_PATH"] = args.model_path
os.environ["BEE_DEVICE"] = args.device
uvicorn.run("scripts.server:app", host=args.host, port=args.port, reload=False)
if __name__ == "__main__":
main()
|