import os import time import uuid import json import logging import asyncio from contextlib import asynccontextmanager from typing import Optional import torch from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoModel, AutoTokenizer # ─── Config ────────────────────────────────────────────────── MODEL_NAME = "Dream-org/Dream-v0-Instruct-1B" API_MODEL_ID = "dream-diffusion-1b" PORT = int(os.environ.get("PORT", 7860)) QUANTIZE = os.environ.get("QUANTIZE", "true").lower() == "true" logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("dream-api") # ─── Global Model References ───────────────────────────────── model = None tokenizer = None model_loaded = False # ─── Model Loading ──────────────────────────────────────────── def load_model(): global model, tokenizer, model_loaded log.info(f"Loading tokenizer: {MODEL_NAME}") tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True, ) log.info(f"Loading model: {MODEL_NAME}") start = time.time() model = AutoModel.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, trust_remote_code=True, ) model.eval() # INT8 Dynamic Quantization if QUANTIZE: try: from torch.ao.quantization import quantize_dynamic model = quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8, ) log.info("✅ INT8 quantization applied") except Exception as e: log.warning(f"⚠️ Quantization failed: {e}") elapsed = time.time() - start log.info(f"✅ Model loaded in {elapsed:.1f}s") model_loaded = True # ─── Lifespan ───────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): # Startup: load model in a thread so we don't block loop = asyncio.get_event_loop() await loop.run_in_executor(None, load_model) yield # Shutdown log.info("Shutting down") # ─── FastAPI App ────────────────────────────────────────────── app = FastAPI( title="Dream Diffusion LLM API", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ─── Pydantic Models ───────────────────────────────────────── class Message(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str = API_MODEL_ID messages: list[Message] max_tokens: Optional[int] = Field(default=256, le=1024, ge=1) temperature: Optional[float] = Field(default=0.35, ge=0.0, le=2.0) top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0) stream: Optional[bool] = False # Diffusion-specific steps: Optional[int] = Field(default=64, le=256, ge=1) class ChatCompletionMessage(BaseModel): role: str = "assistant" content: str class Choice(BaseModel): index: int = 0 message: ChatCompletionMessage finish_reason: str = "stop" class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: list[Choice] usage: Usage # ─── Inference Function ────────────────────────────────────── def run_inference( messages: list[Message], max_tokens: int, steps: int, temperature: float, top_p: float, ) -> tuple[str, float]: """Run diffusion generation. Returns (text, elapsed_ms).""" # Build chat prompt msgs = [{"role": m.role, "content": m.content} for m in messages] input_text = tokenizer.apply_chat_template( msgs, tokenize=False, add_generation_prompt=True, ) input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"] attention_mask = torch.ones_like(input_ids) prompt_len = input_ids.shape[1] # Generate start = time.time() with torch.no_grad(): output = model.diffusion_generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_tokens, output_history=False, steps=steps, temperature=temperature, top_p=top_p, alg="entropy", alg_temp=0.1, ) elapsed_ms = (time.time() - start) * 1000 # Decode generated_ids = output[0, prompt_len:] text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() return text, elapsed_ms # ─── Token Estimator ───────────────────────────────────────── def estimate_tokens(text: str) -> int: words = len(text.split()) return max(int(words / 0.75), 1) # ─── Generate Request ID ───────────────────────────────────── def gen_id() -> str: return f"chatcmpl-{uuid.uuid4().hex[:12]}" # ─── SSE Streaming Generator ───────────────────────────────── async def stream_generator(text: str, req_id: str): """Yield SSE chunks word-by-word from the generated text.""" now = int(time.time()) # 1) Role chunk role_chunk = { "id": req_id, "object": "chat.completion.chunk", "created": now, "model": API_MODEL_ID, "choices": [{ "index": 0, "delta": {"role": "assistant"}, "finish_reason": None, }], } yield f"data: {json.dumps(role_chunk)}\n\n" # 2) Content chunks — word by word words = text.split() for i, word in enumerate(words): content = word + ("" if i == len(words) - 1 else " ") chunk = { "id": req_id, "object": "chat.completion.chunk", "created": now, "model": API_MODEL_ID, "choices": [{ "index": 0, "delta": {"content": content}, "finish_reason": None, }], } yield f"data: {json.dumps(chunk)}\n\n" await asyncio.sleep(0.015) # typing effect # 3) Stop chunk stop_chunk = { "id": req_id, "object": "chat.completion.chunk", "created": now, "model": API_MODEL_ID, "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop", }], } yield f"data: {json.dumps(stop_chunk)}\n\n" yield "data: [DONE]\n\n" # ─── Routes ────────────────────────────────────────────────── @app.get("/") async def root(): return { "name": "Dream Diffusion LLM API", "model": API_MODEL_ID, "version": "1.0.0", "openai_compatible": True, "endpoints": { "chat": "POST /v1/chat/completions", "models": "GET /v1/models", "health": "GET /health", }, } @app.get("/health") async def health(): if not model_loaded: return JSONResponse( status_code=503, content={"status": "loading", "model": MODEL_NAME}, ) return {"status": "healthy", "model": MODEL_NAME} @app.get("/v1/models") async def list_models(): return { "object": "list", "data": [ { "id": API_MODEL_ID, "object": "model", "created": 1700000000, "owned_by": "dream-org", } ], } @app.post("/v1/chat/completions") async def chat_completions(req: ChatCompletionRequest): if not model_loaded: raise HTTPException(status_code=503, detail="Model is still loading") if not req.messages: raise HTTPException(status_code=400, detail="messages array is required") log.info( f"Request: steps={req.steps}, max_tokens={req.max_tokens}, " f"temp={req.temperature}, stream={req.stream}" ) # Run inference in thread pool (blocking call) loop = asyncio.get_event_loop() try: text, elapsed_ms = await loop.run_in_executor( None, run_inference, req.messages, req.max_tokens, req.steps, req.temperature, req.top_p, ) except Exception as e: log.error(f"Inference error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") log.info(f"Generated {len(text)} chars in {elapsed_ms:.0f}ms") req_id = gen_id() # ── Streaming Response ── if req.stream: return StreamingResponse( stream_generator(text, req_id), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # ── Non-Streaming Response ── prompt_tokens = sum(estimate_tokens(m.content) for m in req.messages) completion_tokens = estimate_tokens(text) return ChatCompletionResponse( id=req_id, created=int(time.time()), model=API_MODEL_ID, choices=[ Choice( message=ChatCompletionMessage(content=text), ) ], usage=Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), ) # ─── Run ────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run( "app:app", host="0.0.0.0", port=PORT, workers=1, log_level="info", )