File size: 4,457 Bytes
db82745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""FastAPI server for Bee inference."""

import argparse
import logging
import os
import sys
import time
import uuid
from pathlib import Path
from contextlib import asynccontextmanager

import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
import uvicorn

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.modeling_bee import BeeForCausalLM

register()

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.server")

MODEL = None
TOKENIZER = None
DEVICE = None


def load_model(model_path: str, device: str = "auto"):
    global MODEL, TOKENIZER, DEVICE
    if device == "auto":
        DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    else:
        DEVICE = device
    logger.info("Loading Bee model from %s onto %s", model_path, DEVICE)
    TOKENIZER = AutoTokenizer.from_pretrained(model_path)
    if TOKENIZER.pad_token is None:
        TOKENIZER.pad_token = TOKENIZER.eos_token
    MODEL = BeeForCausalLM.from_pretrained(model_path).to(DEVICE)
    MODEL.eval()
    logger.info("Model loaded. Parameters: %.2fM", sum(p.numel() for p in MODEL.parameters()) / 1e6)


@asynccontextmanager
async def lifespan(app: FastAPI):
    model_path = os.environ.get("BEE_MODEL_PATH", "")
    device = os.environ.get("BEE_DEVICE", "auto")
    if not model_path:
        logger.error("BEE_MODEL_PATH not set. Server will fail requests.")
    else:
        load_model(model_path, device)
    yield
    logger.info("Shutting down Bee server.")


app = FastAPI(title="Bee LLM API", version="0.1.0", lifespan=lifespan)


class GenerateRequest(BaseModel):
    prompt: str = Field(..., min_length=1, max_length=8192, description="Input prompt")
    max_new_tokens: int = Field(default=256, ge=1, le=4096)
    temperature: float = Field(default=0.8, ge=0.0, le=2.0)
    top_p: float = Field(default=0.95, ge=0.0, le=1.0)
    repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0)


class GenerateResponse(BaseModel):
    request_id: str
    generated_text: str
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int
    model: str
    duration_ms: float


@app.get("/health")
async def health():
    if MODEL is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    return {"status": "ok", "model": "bee", "device": DEVICE}


@app.post("/v1/generate", response_model=GenerateResponse)
async def generate(req: GenerateRequest):
    if MODEL is None or TOKENIZER is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    request_id = str(uuid.uuid4())
    start = time.perf_counter()

    inputs = TOKENIZER(req.prompt, return_tensors="pt").to(DEVICE)
    prompt_tokens = inputs["input_ids"].shape[1]

    with torch.no_grad():
        outputs = MODEL.generate(
            **inputs,
            max_new_tokens=req.max_new_tokens,
            do_sample=True,
            temperature=req.temperature,
            top_p=req.top_p,
            repetition_penalty=req.repetition_penalty,
            pad_token_id=TOKENIZER.pad_token_id,
            eos_token_id=TOKENIZER.eos_token_id,
        )

    completion_tokens = outputs.shape[1] - prompt_tokens
    generated_text = TOKENIZER.decode(outputs[0][prompt_tokens:], skip_special_tokens=True)
    duration_ms = (time.perf_counter() - start) * 1000

    return GenerateResponse(
        request_id=request_id,
        generated_text=generated_text,
        prompt_tokens=prompt_tokens,
        completion_tokens=completion_tokens,
        total_tokens=prompt_tokens + completion_tokens,
        model="bee",
        duration_ms=duration_ms,
    )


def get_args():
    parser = argparse.ArgumentParser(description="Serve Bee via FastAPI")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--device", type=str, default="auto")
    return parser.parse_args()


def main():
    args = get_args()
    os.environ["BEE_MODEL_PATH"] = args.model_path
    os.environ["BEE_DEVICE"] = args.device
    uvicorn.run("scripts.server:app", host=args.host, port=args.port, reload=False)


if __name__ == "__main__":
    main()