| """ |
| StreamDiffusion Real-Time Visual Engine |
| ======================================== |
| A WebSocket server that accepts prompt+audio data and returns |
| JPEG frames generated via StreamDiffusion's img2img pipeline. |
| |
| Protocol (matches VisualCanvas.tsx): |
| Client -> Server: JSON { prompt, amplitude, beat, phase } |
| Server -> Client: Binary JPEG frame |
| """ |
|
|
| import asyncio |
| import io |
| import json |
| import logging |
| import os |
| import time |
| from contextlib import asynccontextmanager |
|
|
| import numpy as np |
| import torch |
| from diffusers import AutoencoderTiny, StableDiffusionPipeline |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.responses import HTMLResponse |
| from PIL import Image |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("streamdiffusion") |
|
|
| |
| MODEL_ID = os.environ.get("MODEL_ID", "stabilityai/sd-turbo") |
| TINY_VAE_ID = os.environ.get("TINY_VAE_ID", "madebyollin/taesd") |
| WIDTH = int(os.environ.get("WIDTH", "512")) |
| HEIGHT = int(os.environ.get("HEIGHT", "512")) |
| NUM_INFERENCE_STEPS = int(os.environ.get("NUM_STEPS", "1")) |
| GUIDANCE_SCALE = float(os.environ.get("GUIDANCE_SCALE", "0.0")) |
| JPEG_QUALITY = int(os.environ.get("JPEG_QUALITY", "75")) |
|
|
| |
| PHASE_COLORS = { |
| 1: [(10, 10, 30), (26, 16, 64)], |
| 2: [(26, 16, 64), (42, 24, 96)], |
| 3: [(42, 24, 96), (80, 40, 140)], |
| } |
|
|
| |
| pipeline = None |
| generator = None |
|
|
|
|
| def create_seed_image(phase: int = 1, amplitude: float = 0.3) -> Image.Image: |
| """Create a gradient seed image based on the current phase and audio amplitude.""" |
| colors = PHASE_COLORS.get(phase, PHASE_COLORS[1]) |
| img = Image.new("RGB", (WIDTH, HEIGHT)) |
| pixels = img.load() |
|
|
| |
| noise_strength = int(amplitude * 40) |
|
|
| for y in range(HEIGHT): |
| t = y / HEIGHT |
| r = int(colors[0][0] * (1 - t) + colors[1][0] * t) |
| g = int(colors[0][1] * (1 - t) + colors[1][1] * t) |
| b = int(colors[0][2] * (1 - t) + colors[1][2] * t) |
|
|
| for x in range(WIDTH): |
| |
| nr = max(0, min(255, r + np.random.randint(-noise_strength, noise_strength + 1))) |
| ng = max(0, min(255, g + np.random.randint(-noise_strength, noise_strength + 1))) |
| nb = max(0, min(255, b + np.random.randint(-noise_strength, noise_strength + 1))) |
| pixels[x, y] = (nr, ng, nb) |
|
|
| return img |
|
|
|
|
| def load_pipeline(): |
| """Load the SD-Turbo pipeline with TinyVAE for maximum speed.""" |
| global pipeline, generator |
|
|
| logger.info(f"Loading model: {MODEL_ID}") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
| pipe = StableDiffusionPipeline.from_pretrained( |
| MODEL_ID, |
| torch_dtype=dtype, |
| safety_checker=None, |
| requires_safety_checker=False, |
| ) |
|
|
| |
| try: |
| tiny_vae = AutoencoderTiny.from_pretrained(TINY_VAE_ID, torch_dtype=dtype) |
| pipe.vae = tiny_vae |
| logger.info("TinyVAE loaded successfully") |
| except Exception as e: |
| logger.warning(f"TinyVAE load failed, using default VAE: {e}") |
|
|
| pipe = pipe.to(device) |
|
|
| |
| if device == "cuda": |
| try: |
| pipe.enable_xformers_memory_efficient_attention() |
| logger.info("xformers enabled") |
| except Exception: |
| logger.info("xformers not available, using default attention") |
|
|
| |
| logger.info("Warming up pipeline...") |
| _ = pipe( |
| prompt="warmup", |
| num_inference_steps=NUM_INFERENCE_STEPS, |
| guidance_scale=GUIDANCE_SCALE, |
| width=WIDTH, |
| height=HEIGHT, |
| ) |
| logger.info("Pipeline ready!") |
|
|
| pipeline = pipe |
| generator = torch.Generator(device=device) |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Load the model on startup.""" |
| load_pipeline() |
| yield |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
|
|
| @app.get("/") |
| async def root(): |
| """Health check endpoint β also used by useWebSocket.ts to detect if the space is awake.""" |
| return HTMLResponse( |
| content="<h1>StreamDiffusion Visual Engine</h1><p>Status: Running</p>", |
| status_code=200, |
| ) |
|
|
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(ws: WebSocket): |
| await ws.accept() |
| logger.info("Client connected") |
|
|
| last_prompt = "" |
| last_phase = 1 |
| frame_count = 0 |
| start_time = time.time() |
|
|
| try: |
| while True: |
| |
| try: |
| raw = await asyncio.wait_for(ws.receive_text(), timeout=5.0) |
| except asyncio.TimeoutError: |
| |
| raw = json.dumps({ |
| "prompt": last_prompt or "ethereal sacred geometry, cosmic energy", |
| "amplitude": 0.3, |
| "beat": False, |
| "phase": last_phase, |
| }) |
|
|
| try: |
| data = json.loads(raw) |
| except json.JSONDecodeError: |
| continue |
|
|
| prompt = data.get("prompt", last_prompt or "ethereal sacred geometry") |
| amplitude = data.get("amplitude", 0.3) |
| beat = data.get("beat", False) |
| phase = data.get("phase", 1) |
|
|
| last_prompt = prompt |
| last_phase = phase |
|
|
| |
| phase_labels = {1: "shadow, dense, contained", 2: "gift, expanding, radiant", 3: "siddhi, transcendent, luminous"} |
| energy = "pulsing, rhythmic beat" if beat else "flowing, gentle" |
| intensity = "high energy, vibrant" if amplitude > 0.5 else "calm, subtle" |
|
|
| full_prompt = ( |
| f"{prompt}, {phase_labels.get(phase, '')}, " |
| f"{energy}, {intensity}, " |
| f"abstract art, sacred geometry, ethereal, 4k, high quality" |
| ) |
|
|
| |
| if pipeline is not None: |
| try: |
| |
| seed = int(phase * 1000 + (amplitude * 100)) |
| generator.manual_seed(seed) |
|
|
| result = pipeline( |
| prompt=full_prompt, |
| num_inference_steps=NUM_INFERENCE_STEPS, |
| guidance_scale=GUIDANCE_SCALE, |
| width=WIDTH, |
| height=HEIGHT, |
| generator=generator, |
| ) |
|
|
| frame = result.images[0] |
| except Exception as e: |
| logger.error(f"Generation error: {e}") |
| frame = create_seed_image(phase, amplitude) |
| else: |
| frame = create_seed_image(phase, amplitude) |
|
|
| |
| buf = io.BytesIO() |
| frame.save(buf, format="JPEG", quality=JPEG_QUALITY) |
| await ws.send_bytes(buf.getvalue()) |
|
|
| frame_count += 1 |
| if frame_count % 30 == 0: |
| elapsed = time.time() - start_time |
| fps = frame_count / elapsed if elapsed > 0 else 0 |
| logger.info(f"FPS: {fps:.1f} | Frames: {frame_count} | Prompt: {prompt[:50]}...") |
|
|
| except WebSocketDisconnect: |
| logger.info("Client disconnected") |
| except Exception as e: |
| logger.error(f"WebSocket error: {e}") |
| finally: |
| logger.info(f"Session ended. Total frames: {frame_count}") |
|
|