fabricationworkshop commited on
Commit
a4c1dcd
·
1 Parent(s): b32b43e

Deploy StreamDiffusion real-time visual engine

Browse files
Files changed (4) hide show
  1. Dockerfile +20 -0
  2. README.md +36 -4
  3. app.py +230 -0
  4. requirements.txt +11 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-runtime-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ RUN apt-get update && apt-get install -y \
7
+ python3 python3-pip git \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ WORKDIR /app
11
+
12
+ COPY requirements.txt .
13
+ RUN pip3 install --no-cache-dir -r requirements.txt
14
+
15
+ COPY app.py .
16
+
17
+ # HF Spaces expects port 7860
18
+ EXPOSE 7860
19
+
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,42 @@
1
  ---
2
- title: Streamdiffusion
3
- emoji: 🦀
4
  colorFrom: purple
5
- colorTo: pink
6
  sdk: docker
 
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: StreamDiffusion Visual Engine
3
+ emoji: 🎨
4
  colorFrom: purple
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 7860
8
+ hardware: a10g-small
9
  pinned: false
10
  ---
11
 
12
+ # StreamDiffusion Real-Time Visual Engine
13
+
14
+ WebSocket server for real-time AI visual generation during immersive healing sessions.
15
+
16
+ ## Protocol
17
+
18
+ **WebSocket endpoint:** `wss://<space-url>/ws`
19
+
20
+ **Client → Server (JSON):**
21
+ ```json
22
+ {
23
+ "prompt": "sacred geometry, ethereal light",
24
+ "amplitude": 0.3,
25
+ "beat": false,
26
+ "phase": 1
27
+ }
28
+ ```
29
+
30
+ **Server → Client:** Binary JPEG frame
31
+
32
+ ## Configuration
33
+
34
+ | Env Var | Default | Description |
35
+ |---------|---------|-------------|
36
+ | `MODEL_ID` | `stabilityai/sd-turbo` | Diffusion model |
37
+ | `TINY_VAE_ID` | `madebyollin/taesd` | Fast VAE decoder |
38
+ | `WIDTH` | `512` | Output width |
39
+ | `HEIGHT` | `512` | Output height |
40
+ | `NUM_STEPS` | `1` | Inference steps (1 = fastest) |
41
+ | `GUIDANCE_SCALE` | `0.0` | CFG scale (0 for SD-Turbo) |
42
+ | `JPEG_QUALITY` | `75` | Output JPEG quality |
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ StreamDiffusion Real-Time Visual Engine
3
+ ========================================
4
+ A WebSocket server that accepts prompt+audio data and returns
5
+ JPEG frames generated via StreamDiffusion's img2img pipeline.
6
+
7
+ Protocol (matches VisualCanvas.tsx):
8
+ Client -> Server: JSON { prompt, amplitude, beat, phase }
9
+ Server -> Client: Binary JPEG frame
10
+ """
11
+
12
+ import asyncio
13
+ import io
14
+ import json
15
+ import logging
16
+ import os
17
+ import time
18
+ from contextlib import asynccontextmanager
19
+
20
+ import numpy as np
21
+ import torch
22
+ from diffusers import AutoencoderTiny, StableDiffusionPipeline
23
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
24
+ from fastapi.responses import HTMLResponse
25
+ from PIL import Image
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger("streamdiffusion")
29
+
30
+ # ── Configuration ───────────────────────────────────────────────────────────────
31
+ MODEL_ID = os.environ.get("MODEL_ID", "stabilityai/sd-turbo")
32
+ TINY_VAE_ID = os.environ.get("TINY_VAE_ID", "madebyollin/taesd")
33
+ WIDTH = int(os.environ.get("WIDTH", "512"))
34
+ HEIGHT = int(os.environ.get("HEIGHT", "512"))
35
+ NUM_INFERENCE_STEPS = int(os.environ.get("NUM_STEPS", "1"))
36
+ GUIDANCE_SCALE = float(os.environ.get("GUIDANCE_SCALE", "0.0"))
37
+ JPEG_QUALITY = int(os.environ.get("JPEG_QUALITY", "75"))
38
+
39
+ # ── Phase color palettes for seed images ────────────────────────────────────────
40
+ PHASE_COLORS = {
41
+ 1: [(10, 10, 30), (26, 16, 64)], # Deep indigo (Shadow)
42
+ 2: [(26, 16, 64), (42, 24, 96)], # Rising violet (Gift)
43
+ 3: [(42, 24, 96), (80, 40, 140)], # Bright purple (Siddhi)
44
+ }
45
+
46
+ # ── Global pipeline reference ───────────────────────────────────────────────────
47
+ pipeline = None
48
+ generator = None
49
+
50
+
51
+ def create_seed_image(phase: int = 1, amplitude: float = 0.3) -> Image.Image:
52
+ """Create a gradient seed image based on the current phase and audio amplitude."""
53
+ colors = PHASE_COLORS.get(phase, PHASE_COLORS[1])
54
+ img = Image.new("RGB", (WIDTH, HEIGHT))
55
+ pixels = img.load()
56
+
57
+ # Add some variation based on amplitude
58
+ noise_strength = int(amplitude * 40)
59
+
60
+ for y in range(HEIGHT):
61
+ t = y / HEIGHT
62
+ r = int(colors[0][0] * (1 - t) + colors[1][0] * t)
63
+ g = int(colors[0][1] * (1 - t) + colors[1][1] * t)
64
+ b = int(colors[0][2] * (1 - t) + colors[1][2] * t)
65
+
66
+ for x in range(WIDTH):
67
+ # Add subtle noise based on amplitude
68
+ nr = max(0, min(255, r + np.random.randint(-noise_strength, noise_strength + 1)))
69
+ ng = max(0, min(255, g + np.random.randint(-noise_strength, noise_strength + 1)))
70
+ nb = max(0, min(255, b + np.random.randint(-noise_strength, noise_strength + 1)))
71
+ pixels[x, y] = (nr, ng, nb)
72
+
73
+ return img
74
+
75
+
76
+ def load_pipeline():
77
+ """Load the SD-Turbo pipeline with TinyVAE for maximum speed."""
78
+ global pipeline, generator
79
+
80
+ logger.info(f"Loading model: {MODEL_ID}")
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+ dtype = torch.float16 if device == "cuda" else torch.float32
83
+
84
+ pipe = StableDiffusionPipeline.from_pretrained(
85
+ MODEL_ID,
86
+ torch_dtype=dtype,
87
+ safety_checker=None,
88
+ requires_safety_checker=False,
89
+ )
90
+
91
+ # Swap in TinyVAE for speed
92
+ try:
93
+ tiny_vae = AutoencoderTiny.from_pretrained(TINY_VAE_ID, torch_dtype=dtype)
94
+ pipe.vae = tiny_vae
95
+ logger.info("TinyVAE loaded successfully")
96
+ except Exception as e:
97
+ logger.warning(f"TinyVAE load failed, using default VAE: {e}")
98
+
99
+ pipe = pipe.to(device)
100
+
101
+ # Enable memory optimizations
102
+ if device == "cuda":
103
+ try:
104
+ pipe.enable_xformers_memory_efficient_attention()
105
+ logger.info("xformers enabled")
106
+ except Exception:
107
+ logger.info("xformers not available, using default attention")
108
+
109
+ # Warmup
110
+ logger.info("Warming up pipeline...")
111
+ _ = pipe(
112
+ prompt="warmup",
113
+ num_inference_steps=NUM_INFERENCE_STEPS,
114
+ guidance_scale=GUIDANCE_SCALE,
115
+ width=WIDTH,
116
+ height=HEIGHT,
117
+ )
118
+ logger.info("Pipeline ready!")
119
+
120
+ pipeline = pipe
121
+ generator = torch.Generator(device=device)
122
+
123
+
124
+ @asynccontextmanager
125
+ async def lifespan(app: FastAPI):
126
+ """Load the model on startup."""
127
+ load_pipeline()
128
+ yield
129
+
130
+
131
+ app = FastAPI(lifespan=lifespan)
132
+
133
+
134
+ @app.get("/")
135
+ async def root():
136
+ """Health check endpoint — also used by useWebSocket.ts to detect if the space is awake."""
137
+ return HTMLResponse(
138
+ content="<h1>StreamDiffusion Visual Engine</h1><p>Status: Running</p>",
139
+ status_code=200,
140
+ )
141
+
142
+
143
+ @app.websocket("/ws")
144
+ async def websocket_endpoint(ws: WebSocket):
145
+ await ws.accept()
146
+ logger.info("Client connected")
147
+
148
+ last_prompt = ""
149
+ last_phase = 1
150
+ frame_count = 0
151
+ start_time = time.time()
152
+
153
+ try:
154
+ while True:
155
+ # Receive prompt data from client
156
+ try:
157
+ raw = await asyncio.wait_for(ws.receive_text(), timeout=5.0)
158
+ except asyncio.TimeoutError:
159
+ # No data from client — generate with last known prompt
160
+ raw = json.dumps({
161
+ "prompt": last_prompt or "ethereal sacred geometry, cosmic energy",
162
+ "amplitude": 0.3,
163
+ "beat": False,
164
+ "phase": last_phase,
165
+ })
166
+
167
+ try:
168
+ data = json.loads(raw)
169
+ except json.JSONDecodeError:
170
+ continue
171
+
172
+ prompt = data.get("prompt", last_prompt or "ethereal sacred geometry")
173
+ amplitude = data.get("amplitude", 0.3)
174
+ beat = data.get("beat", False)
175
+ phase = data.get("phase", 1)
176
+
177
+ last_prompt = prompt
178
+ last_phase = phase
179
+
180
+ # Enhance prompt with phase and energy context
181
+ phase_labels = {1: "shadow, dense, contained", 2: "gift, expanding, radiant", 3: "siddhi, transcendent, luminous"}
182
+ energy = "pulsing, rhythmic beat" if beat else "flowing, gentle"
183
+ intensity = "high energy, vibrant" if amplitude > 0.5 else "calm, subtle"
184
+
185
+ full_prompt = (
186
+ f"{prompt}, {phase_labels.get(phase, '')}, "
187
+ f"{energy}, {intensity}, "
188
+ f"abstract art, sacred geometry, ethereal, 4k, high quality"
189
+ )
190
+
191
+ # Generate frame
192
+ if pipeline is not None:
193
+ try:
194
+ # Use a beat-influenced seed for visual coherence with slight variation
195
+ seed = int(phase * 1000 + (amplitude * 100))
196
+ generator.manual_seed(seed)
197
+
198
+ result = pipeline(
199
+ prompt=full_prompt,
200
+ num_inference_steps=NUM_INFERENCE_STEPS,
201
+ guidance_scale=GUIDANCE_SCALE,
202
+ width=WIDTH,
203
+ height=HEIGHT,
204
+ generator=generator,
205
+ )
206
+
207
+ frame = result.images[0]
208
+ except Exception as e:
209
+ logger.error(f"Generation error: {e}")
210
+ frame = create_seed_image(phase, amplitude)
211
+ else:
212
+ frame = create_seed_image(phase, amplitude)
213
+
214
+ # Encode to JPEG and send as binary
215
+ buf = io.BytesIO()
216
+ frame.save(buf, format="JPEG", quality=JPEG_QUALITY)
217
+ await ws.send_bytes(buf.getvalue())
218
+
219
+ frame_count += 1
220
+ if frame_count % 30 == 0:
221
+ elapsed = time.time() - start_time
222
+ fps = frame_count / elapsed if elapsed > 0 else 0
223
+ logger.info(f"FPS: {fps:.1f} | Frames: {frame_count} | Prompt: {prompt[:50]}...")
224
+
225
+ except WebSocketDisconnect:
226
+ logger.info("Client disconnected")
227
+ except Exception as e:
228
+ logger.error(f"WebSocket error: {e}")
229
+ finally:
230
+ logger.info(f"Session ended. Total frames: {frame_count}")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ diffusers>=0.25.0
3
+ transformers>=4.36.0
4
+ accelerate>=0.25.0
5
+ safetensors>=0.4.0
6
+ Pillow>=10.0.0
7
+ numpy>=1.24.0
8
+ fastapi>=0.104.0
9
+ uvicorn[standard]>=0.24.0
10
+ websockets>=12.0
11
+ xformers>=0.0.23; sys_platform != "darwin"