diffusionLLM / app.py
triflix's picture
Create app.py
55fb776 verified
import os
import time
import uuid
import json
import logging
import asyncio
from contextlib import asynccontextmanager
from typing import Optional
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModel, AutoTokenizer
# ─── Config ──────────────────────────────────────────────────
MODEL_NAME = "Dream-org/Dream-v0-Instruct-1B"
API_MODEL_ID = "dream-diffusion-1b"
PORT = int(os.environ.get("PORT", 7860))
QUANTIZE = os.environ.get("QUANTIZE", "true").lower() == "true"
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger("dream-api")
# ─── Global Model References ─────────────────────────────────
model = None
tokenizer = None
model_loaded = False
# ─── Model Loading ────────────────────────────────────────────
def load_model():
global model, tokenizer, model_loaded
log.info(f"Loading tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
)
log.info(f"Loading model: {MODEL_NAME}")
start = time.time()
model = AutoModel.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
trust_remote_code=True,
)
model.eval()
# INT8 Dynamic Quantization
if QUANTIZE:
try:
from torch.ao.quantization import quantize_dynamic
model = quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8,
)
log.info("βœ… INT8 quantization applied")
except Exception as e:
log.warning(f"⚠️ Quantization failed: {e}")
elapsed = time.time() - start
log.info(f"βœ… Model loaded in {elapsed:.1f}s")
model_loaded = True
# ─── Lifespan ─────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: load model in a thread so we don't block
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, load_model)
yield
# Shutdown
log.info("Shutting down")
# ─── FastAPI App ──────────────────────────────────────────────
app = FastAPI(
title="Dream Diffusion LLM API",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ─── Pydantic Models ─────────────────────────────────────────
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = API_MODEL_ID
messages: list[Message]
max_tokens: Optional[int] = Field(default=256, le=1024, ge=1)
temperature: Optional[float] = Field(default=0.35, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0)
stream: Optional[bool] = False
# Diffusion-specific
steps: Optional[int] = Field(default=64, le=256, ge=1)
class ChatCompletionMessage(BaseModel):
role: str = "assistant"
content: str
class Choice(BaseModel):
index: int = 0
message: ChatCompletionMessage
finish_reason: str = "stop"
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: list[Choice]
usage: Usage
# ─── Inference Function ──────────────────────────────────────
def run_inference(
messages: list[Message],
max_tokens: int,
steps: int,
temperature: float,
top_p: float,
) -> tuple[str, float]:
"""Run diffusion generation. Returns (text, elapsed_ms)."""
# Build chat prompt
msgs = [{"role": m.role, "content": m.content} for m in messages]
input_text = tokenizer.apply_chat_template(
msgs,
tokenize=False,
add_generation_prompt=True,
)
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"]
attention_mask = torch.ones_like(input_ids)
prompt_len = input_ids.shape[1]
# Generate
start = time.time()
with torch.no_grad():
output = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
output_history=False,
steps=steps,
temperature=temperature,
top_p=top_p,
alg="entropy",
alg_temp=0.1,
)
elapsed_ms = (time.time() - start) * 1000
# Decode
generated_ids = output[0, prompt_len:]
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return text, elapsed_ms
# ─── Token Estimator ─────────────────────────────────────────
def estimate_tokens(text: str) -> int:
words = len(text.split())
return max(int(words / 0.75), 1)
# ─── Generate Request ID ─────────────────────────────────────
def gen_id() -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
# ─── SSE Streaming Generator ─────────────────────────────────
async def stream_generator(text: str, req_id: str):
"""Yield SSE chunks word-by-word from the generated text."""
now = int(time.time())
# 1) Role chunk
role_chunk = {
"id": req_id,
"object": "chat.completion.chunk",
"created": now,
"model": API_MODEL_ID,
"choices": [{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(role_chunk)}\n\n"
# 2) Content chunks β€” word by word
words = text.split()
for i, word in enumerate(words):
content = word + ("" if i == len(words) - 1 else " ")
chunk = {
"id": req_id,
"object": "chat.completion.chunk",
"created": now,
"model": API_MODEL_ID,
"choices": [{
"index": 0,
"delta": {"content": content},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.015) # typing effect
# 3) Stop chunk
stop_chunk = {
"id": req_id,
"object": "chat.completion.chunk",
"created": now,
"model": API_MODEL_ID,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
}
yield f"data: {json.dumps(stop_chunk)}\n\n"
yield "data: [DONE]\n\n"
# ─── Routes ──────────────────────────────────────────────────
@app.get("/")
async def root():
return {
"name": "Dream Diffusion LLM API",
"model": API_MODEL_ID,
"version": "1.0.0",
"openai_compatible": True,
"endpoints": {
"chat": "POST /v1/chat/completions",
"models": "GET /v1/models",
"health": "GET /health",
},
}
@app.get("/health")
async def health():
if not model_loaded:
return JSONResponse(
status_code=503,
content={"status": "loading", "model": MODEL_NAME},
)
return {"status": "healthy", "model": MODEL_NAME}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": API_MODEL_ID,
"object": "model",
"created": 1700000000,
"owned_by": "dream-org",
}
],
}
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionRequest):
if not model_loaded:
raise HTTPException(status_code=503, detail="Model is still loading")
if not req.messages:
raise HTTPException(status_code=400, detail="messages array is required")
log.info(
f"Request: steps={req.steps}, max_tokens={req.max_tokens}, "
f"temp={req.temperature}, stream={req.stream}"
)
# Run inference in thread pool (blocking call)
loop = asyncio.get_event_loop()
try:
text, elapsed_ms = await loop.run_in_executor(
None,
run_inference,
req.messages,
req.max_tokens,
req.steps,
req.temperature,
req.top_p,
)
except Exception as e:
log.error(f"Inference error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
log.info(f"Generated {len(text)} chars in {elapsed_ms:.0f}ms")
req_id = gen_id()
# ── Streaming Response ──
if req.stream:
return StreamingResponse(
stream_generator(text, req_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# ── Non-Streaming Response ──
prompt_tokens = sum(estimate_tokens(m.content) for m in req.messages)
completion_tokens = estimate_tokens(text)
return ChatCompletionResponse(
id=req_id,
created=int(time.time()),
model=API_MODEL_ID,
choices=[
Choice(
message=ChatCompletionMessage(content=text),
)
],
usage=Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
# ─── Run ──────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app",
host="0.0.0.0",
port=PORT,
workers=1,
log_level="info",
)