Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import uuid | |
| import json | |
| import logging | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| from typing import Optional | |
| import torch | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModel, AutoTokenizer | |
| # βββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_NAME = "Dream-org/Dream-v0-Instruct-1B" | |
| API_MODEL_ID = "dream-diffusion-1b" | |
| PORT = int(os.environ.get("PORT", 7860)) | |
| QUANTIZE = os.environ.get("QUANTIZE", "true").lower() == "true" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="[%(asctime)s] %(levelname)s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| log = logging.getLogger("dream-api") | |
| # βββ Global Model References βββββββββββββββββββββββββββββββββ | |
| model = None | |
| tokenizer = None | |
| model_loaded = False | |
| # βββ Model Loading ββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(): | |
| global model, tokenizer, model_loaded | |
| log.info(f"Loading tokenizer: {MODEL_NAME}") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| ) | |
| log.info(f"Loading model: {MODEL_NAME}") | |
| start = time.time() | |
| model = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| # INT8 Dynamic Quantization | |
| if QUANTIZE: | |
| try: | |
| from torch.ao.quantization import quantize_dynamic | |
| model = quantize_dynamic( | |
| model, | |
| {torch.nn.Linear}, | |
| dtype=torch.qint8, | |
| ) | |
| log.info("β INT8 quantization applied") | |
| except Exception as e: | |
| log.warning(f"β οΈ Quantization failed: {e}") | |
| elapsed = time.time() - start | |
| log.info(f"β Model loaded in {elapsed:.1f}s") | |
| model_loaded = True | |
| # βββ Lifespan βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| # Startup: load model in a thread so we don't block | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(None, load_model) | |
| yield | |
| # Shutdown | |
| log.info("Shutting down") | |
| # βββ FastAPI App ββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Dream Diffusion LLM API", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # βββ Pydantic Models βββββββββββββββββββββββββββββββββββββββββ | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = API_MODEL_ID | |
| messages: list[Message] | |
| max_tokens: Optional[int] = Field(default=256, le=1024, ge=1) | |
| temperature: Optional[float] = Field(default=0.35, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0) | |
| stream: Optional[bool] = False | |
| # Diffusion-specific | |
| steps: Optional[int] = Field(default=64, le=256, ge=1) | |
| class ChatCompletionMessage(BaseModel): | |
| role: str = "assistant" | |
| content: str | |
| class Choice(BaseModel): | |
| index: int = 0 | |
| message: ChatCompletionMessage | |
| finish_reason: str = "stop" | |
| class Usage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: list[Choice] | |
| usage: Usage | |
| # βββ Inference Function ββββββββββββββββββββββββββββββββββββββ | |
| def run_inference( | |
| messages: list[Message], | |
| max_tokens: int, | |
| steps: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> tuple[str, float]: | |
| """Run diffusion generation. Returns (text, elapsed_ms).""" | |
| # Build chat prompt | |
| msgs = [{"role": m.role, "content": m.content} for m in messages] | |
| input_text = tokenizer.apply_chat_template( | |
| msgs, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"] | |
| attention_mask = torch.ones_like(input_ids) | |
| prompt_len = input_ids.shape[1] | |
| # Generate | |
| start = time.time() | |
| with torch.no_grad(): | |
| output = model.diffusion_generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| output_history=False, | |
| steps=steps, | |
| temperature=temperature, | |
| top_p=top_p, | |
| alg="entropy", | |
| alg_temp=0.1, | |
| ) | |
| elapsed_ms = (time.time() - start) * 1000 | |
| # Decode | |
| generated_ids = output[0, prompt_len:] | |
| text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| return text, elapsed_ms | |
| # βββ Token Estimator βββββββββββββββββββββββββββββββββββββββββ | |
| def estimate_tokens(text: str) -> int: | |
| words = len(text.split()) | |
| return max(int(words / 0.75), 1) | |
| # βββ Generate Request ID βββββββββββββββββββββββββββββββββββββ | |
| def gen_id() -> str: | |
| return f"chatcmpl-{uuid.uuid4().hex[:12]}" | |
| # βββ SSE Streaming Generator βββββββββββββββββββββββββββββββββ | |
| async def stream_generator(text: str, req_id: str): | |
| """Yield SSE chunks word-by-word from the generated text.""" | |
| now = int(time.time()) | |
| # 1) Role chunk | |
| role_chunk = { | |
| "id": req_id, | |
| "object": "chat.completion.chunk", | |
| "created": now, | |
| "model": API_MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"role": "assistant"}, | |
| "finish_reason": None, | |
| }], | |
| } | |
| yield f"data: {json.dumps(role_chunk)}\n\n" | |
| # 2) Content chunks β word by word | |
| words = text.split() | |
| for i, word in enumerate(words): | |
| content = word + ("" if i == len(words) - 1 else " ") | |
| chunk = { | |
| "id": req_id, | |
| "object": "chat.completion.chunk", | |
| "created": now, | |
| "model": API_MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": content}, | |
| "finish_reason": None, | |
| }], | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| await asyncio.sleep(0.015) # typing effect | |
| # 3) Stop chunk | |
| stop_chunk = { | |
| "id": req_id, | |
| "object": "chat.completion.chunk", | |
| "created": now, | |
| "model": API_MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop", | |
| }], | |
| } | |
| yield f"data: {json.dumps(stop_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # βββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root(): | |
| return { | |
| "name": "Dream Diffusion LLM API", | |
| "model": API_MODEL_ID, | |
| "version": "1.0.0", | |
| "openai_compatible": True, | |
| "endpoints": { | |
| "chat": "POST /v1/chat/completions", | |
| "models": "GET /v1/models", | |
| "health": "GET /health", | |
| }, | |
| } | |
| async def health(): | |
| if not model_loaded: | |
| return JSONResponse( | |
| status_code=503, | |
| content={"status": "loading", "model": MODEL_NAME}, | |
| ) | |
| return {"status": "healthy", "model": MODEL_NAME} | |
| async def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": API_MODEL_ID, | |
| "object": "model", | |
| "created": 1700000000, | |
| "owned_by": "dream-org", | |
| } | |
| ], | |
| } | |
| async def chat_completions(req: ChatCompletionRequest): | |
| if not model_loaded: | |
| raise HTTPException(status_code=503, detail="Model is still loading") | |
| if not req.messages: | |
| raise HTTPException(status_code=400, detail="messages array is required") | |
| log.info( | |
| f"Request: steps={req.steps}, max_tokens={req.max_tokens}, " | |
| f"temp={req.temperature}, stream={req.stream}" | |
| ) | |
| # Run inference in thread pool (blocking call) | |
| loop = asyncio.get_event_loop() | |
| try: | |
| text, elapsed_ms = await loop.run_in_executor( | |
| None, | |
| run_inference, | |
| req.messages, | |
| req.max_tokens, | |
| req.steps, | |
| req.temperature, | |
| req.top_p, | |
| ) | |
| except Exception as e: | |
| log.error(f"Inference error: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") | |
| log.info(f"Generated {len(text)} chars in {elapsed_ms:.0f}ms") | |
| req_id = gen_id() | |
| # ββ Streaming Response ββ | |
| if req.stream: | |
| return StreamingResponse( | |
| stream_generator(text, req_id), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| # ββ Non-Streaming Response ββ | |
| prompt_tokens = sum(estimate_tokens(m.content) for m in req.messages) | |
| completion_tokens = estimate_tokens(text) | |
| return ChatCompletionResponse( | |
| id=req_id, | |
| created=int(time.time()), | |
| model=API_MODEL_ID, | |
| choices=[ | |
| Choice( | |
| message=ChatCompletionMessage(content=text), | |
| ) | |
| ], | |
| usage=Usage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| total_tokens=prompt_tokens + completion_tokens, | |
| ), | |
| ) | |
| # βββ Run ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=PORT, | |
| workers=1, | |
| log_level="info", | |
| ) |