Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import threading | |
| import socket | |
| import subprocess | |
| import asyncio | |
| import json | |
| import uuid | |
| import requests | |
| from fastapi import FastAPI, Request, Header | |
| from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse | |
| app = FastAPI() | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| API_KEY = os.getenv("API_KEY", "821274") | |
| MODEL = os.getenv("MODEL", "llama3.2:1b") | |
| OLLAMA_BASE = os.getenv("OLLAMA_BASE", "http://127.0.0.1:11434") | |
| SYSTEM_PROMPT = """ | |
| You are a helpful, friendly AI assistant. | |
| Rules: | |
| - If the user asks casual conversation, respond normally in plain text. | |
| - If the user asks for code/scripts/automation or says “just do it” / “write something”, choose a reasonable task and output a complete useful script. | |
| - Don’t refuse normal conversation. | |
| - Only output code when appropriate. | |
| - When you write code: output code first, then a short explanation. | |
| Always be helpful. Never say you cannot fulfill a request unless it is unsafe. | |
| """.strip() | |
| # Free CPU: serialize requests to avoid overload / timeouts | |
| GEN_SEM = asyncio.Semaphore(1) | |
| # ---------------------------- | |
| # Ollama helpers | |
| # ---------------------------- | |
| def is_port_open(host="127.0.0.1", port=11434) -> bool: | |
| try: | |
| with socket.create_connection((host, port), timeout=0.5): | |
| return True | |
| except OSError: | |
| return False | |
| def ollama_healthy() -> bool: | |
| try: | |
| r = requests.get(f"{OLLAMA_BASE}/api/tags", timeout=1.5) | |
| return r.status_code == 200 | |
| except Exception: | |
| return False | |
| def ensure_ollama_running(): | |
| # Only start if not reachable | |
| if not is_port_open("127.0.0.1", 11434): | |
| subprocess.Popen(["ollama", "serve"]) | |
| def wait_for_ollama(timeout_s=120) -> bool: | |
| start = time.time() | |
| while time.time() - start < timeout_s: | |
| if ollama_healthy(): | |
| return True | |
| time.sleep(1) | |
| return False | |
| def pull_and_warm_model(): | |
| """ | |
| Best-effort: pull the model (may be slow on free CPU) and warm it once. | |
| Safe to fail (space still boots). | |
| """ | |
| try: | |
| ensure_ollama_running() | |
| if not wait_for_ollama(120): | |
| print("Ollama not ready yet; skipping model pull.") | |
| return | |
| print(f"Pulling model: {MODEL}") | |
| r = requests.post(f"{OLLAMA_BASE}/api/pull", json={"name": MODEL}, timeout=60 * 30) | |
| if r.status_code != 200: | |
| print("Pull failed:", r.text[:2000]) | |
| return | |
| # Warmup: avoids first real user request being extra flaky/slow | |
| print("Warming up…") | |
| requests.post( | |
| f"{OLLAMA_BASE}/api/generate", | |
| json={"model": MODEL, "system": SYSTEM_PROMPT, "prompt": "Say: ready.", "stream": False}, | |
| timeout=180, | |
| ) | |
| print("Warmup done.") | |
| except Exception as e: | |
| print("Boot task error (non-fatal):", str(e)) | |
| threading.Thread(target=pull_and_warm_model, daemon=True).start() | |
| def generate_with_recovery(prompt: str, attempts: int = 3): | |
| last_err = None | |
| for i in range(1, attempts + 1): | |
| try: | |
| if not ollama_healthy(): | |
| ensure_ollama_running() | |
| wait_for_ollama(60) | |
| r = requests.post( | |
| f"{OLLAMA_BASE}/api/generate", | |
| json={ | |
| "model": MODEL, | |
| "system": SYSTEM_PROMPT, | |
| "prompt": prompt, | |
| "stream": False, | |
| }, | |
| timeout=600, | |
| ) | |
| r.raise_for_status() | |
| data = r.json() | |
| return data.get("response", ""), None | |
| except Exception as e: | |
| last_err = str(e) | |
| time.sleep(min(2 ** (i - 1), 4)) | |
| return ( | |
| "⚠️ Backend hiccup while generating. Retrying usually works.\n\n" | |
| "Debug error:\n" + (last_err or "unknown"), | |
| last_err, | |
| ) | |
| def messages_to_prompt(messages): | |
| """ | |
| Convert OpenAI-style messages into a single prompt string for Ollama /api/generate. | |
| """ | |
| parts = [] | |
| for m in messages or []: | |
| role = (m.get("role") or "user").strip().upper() | |
| content = (m.get("content") or "").strip() | |
| if content: | |
| parts.append(f"{role}:\n{content}") | |
| return "\n\n".join(parts).strip() | |
| # ---------------------------- | |
| # Health | |
| # ---------------------------- | |
| def health(): | |
| return {"ok": ollama_healthy(), "model": MODEL} | |
| # ---------------------------- | |
| # OpenAI compatibility endpoints | |
| # ---------------------------- | |
| def openai_models(authorization: str = Header(default="")): | |
| # Optional auth check (recommended) | |
| if authorization and authorization != f"Bearer {API_KEY}": | |
| return JSONResponse({"error": {"message": "Invalid API key"}}, status_code=401) | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": MODEL, | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": "private-ai" | |
| } | |
| ] | |
| } | |
| async def openai_chat_completions(request: Request, authorization: str = Header(default="")): | |
| # OpenAI-style auth | |
| if authorization != f"Bearer {API_KEY}": | |
| return JSONResponse( | |
| {"error": {"message": "Invalid API key", "type": "auth_error"}}, | |
| status_code=401 | |
| ) | |
| body = await request.json() | |
| model = body.get("model") or MODEL | |
| messages = body.get("messages") or [] | |
| # Convert into a single prompt for Ollama generate | |
| prompt = messages_to_prompt(messages) | |
| if not prompt: | |
| prompt = "USER:\nHello" | |
| async with GEN_SEM: | |
| text, err = generate_with_recovery(prompt, attempts=3) | |
| # If error, still return valid OpenAI shaped response | |
| if err: | |
| text += f"\n\n---\nBackend error:\n{err}" | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4().hex}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": {"role": "assistant", "content": text}, | |
| "finish_reason": "stop" | |
| } | |
| ], | |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
| } | |
| # ---------------------------- | |
| # UI (REAL HTML, not escaped) | |
| # ---------------------------- | |
| def ui(): | |
| return f"""<!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> | |
| <title>Private AI</title> | |
| <style> | |
| :root {{ | |
| --bg:#0b0f17; | |
| --panel:rgba(255,255,255,.06); | |
| --border:rgba(255,255,255,.10); | |
| --text:rgba(255,255,255,.92); | |
| --muted:rgba(255,255,255,.55); | |
| --radius:18px; | |
| --shadow:0 12px 40px rgba(0,0,0,.35); | |
| }} | |
| *{{box-sizing:border-box}} | |
| body{{ | |
| margin:0;height:100vh;overflow:hidden; | |
| font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto; | |
| background: | |
| radial-gradient(900px 600px at 20% 0%, rgba(59,130,246,.18), transparent 60%), | |
| radial-gradient(800px 600px at 80% 0%, rgba(168,85,247,.16), transparent 60%), | |
| var(--bg); | |
| color:var(--text); | |
| }} | |
| .app{{display:grid;grid-template-rows:auto 1fr auto;height:100vh}} | |
| header{{ | |
| padding:14px 18px;border-bottom:1px solid var(--border); | |
| backdrop-filter:blur(14px);background:rgba(10,14,22,.6); | |
| }} | |
| header .inner{{max-width:980px;margin:0 auto;display:flex;align-items:center;justify-content:space-between}} | |
| header h1{{font-size:14px;margin:0;font-weight:650}} | |
| .status{{font-size:12px;color:var(--muted);display:flex;gap:8px;align-items:center}} | |
| .dot{{width:8px;height:8px;border-radius:50%;background:#555;box-shadow:0 0 0 6px rgba(255,255,255,.06)}} | |
| .dot.online{{background:#22c55e;box-shadow:0 0 0 6px rgba(34,197,94,.12)}} | |
| .dot.busy{{background:#60a5fa;box-shadow:0 0 0 6px rgba(96,165,250,.14)}} | |
| .dot.degraded{{background:#f59e0b;box-shadow:0 0 0 6px rgba(245,158,11,.14)}} | |
| main{{overflow-y:auto;padding:20px 18px}} | |
| .chat{{max-width:980px;margin:0 auto}} | |
| .msg{{display:flex;gap:12px;margin:12px 0}} | |
| .avatar{{width:36px;height:36px;border-radius:10px;background:var(--panel);border:1px solid var(--border);display:grid;place-items:center;font-size:13px}} | |
| .bubble{{padding:12px 14px;border-radius:var(--radius);background:var(--panel);border:1px solid var(--border);box-shadow:var(--shadow);font-size:14px;line-height:1.45;white-space:pre-wrap;overflow-wrap:anywhere}} | |
| .me .bubble{{background:linear-gradient(180deg,rgba(59,130,246,.22),rgba(255,255,255,.06));border-color:rgba(59,130,246,.25)}} | |
| footer{{padding:14px 18px 18px;border-top:1px solid var(--border);backdrop-filter:blur(14px);background:rgba(10,14,22,.6)}} | |
| .composer{{max-width:980px;margin:0 auto;display:grid;grid-template-columns:1fr auto;gap:10px}} | |
| textarea{{resize:none;min-height:44px;max-height:180px;padding:12px 14px;border-radius:16px;border:1px solid var(--border);background:var(--panel);color:var(--text);font-size:14px;outline:none}} | |
| button{{height:44px;padding:0 16px;border-radius:16px;border:1px solid var(--border);background:rgba(255,255,255,.08);color:var(--text);font-weight:650;cursor:pointer}} | |
| button:disabled{{opacity:.6}} | |
| .hint{{max-width:980px;margin:10px auto 0;color:rgba(255,255,255,.55);font-size:12px}} | |
| kbd{{font-family:ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",monospace;font-size:12px;padding:2px 6px;border-radius:8px;background:rgba(255,255,255,.08);border:1px solid rgba(255,255,255,.12)}} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="app"> | |
| <header> | |
| <div class="inner"> | |
| <h1>Private AI</h1> | |
| <div class="status"> | |
| <span class="dot" id="dot"></span> | |
| <span id="status">Connecting…</span> | |
| </div> | |
| </div> | |
| </header> | |
| <main><div class="chat" id="chat"></div></main> | |
| <footer> | |
| <div class="composer"> | |
| <textarea id="input" placeholder="Message Private AI…"></textarea> | |
| <button id="send">Send</button> | |
| </div> | |
| <div class="hint" id="hint"></div> | |
| </footer> | |
| </div> | |
| <script> | |
| const chat = document.getElementById("chat"); | |
| const input = document.getElementById("input"); | |
| const sendBtn = document.getElementById("send"); | |
| const dot = document.getElementById("dot"); | |
| const statusEl = document.getElementById("status"); | |
| const hint = document.getElementById("hint"); | |
| let failCount = 0; | |
| function setStatus(mode, txt) {{ | |
| dot.classList.remove("online","busy","degraded"); | |
| if (mode === "online") dot.classList.add("online"); | |
| if (mode === "busy") dot.classList.add("busy"); | |
| if (mode === "degraded") dot.classList.add("degraded"); | |
| statusEl.textContent = txt; | |
| }} | |
| function add(role, text) {{ | |
| const el = document.createElement("div"); | |
| el.className = "msg " + role; | |
| el.innerHTML = ` | |
| <div class="avatar">${{role === "me" ? "You" : "AI"}}</div> | |
| <div class="bubble"></div> | |
| `; | |
| el.querySelector(".bubble").textContent = text || ""; | |
| chat.appendChild(el); | |
| chat.scrollTop = chat.scrollHeight; | |
| return el; | |
| }} | |
| async function healthCheck() {{ | |
| try {{ | |
| const r = await fetch("/health", {{ cache: "no-store" }}); | |
| const d = await r.json(); | |
| if (!r.ok || !d.ok) throw new Error("unhealthy"); | |
| failCount = 0; | |
| setStatus("online", "Online"); | |
| return true; | |
| }} catch {{ | |
| failCount++; | |
| setStatus("degraded", failCount >= 5 ? "Recovering…" : "Reconnecting…"); | |
| return false; | |
| }} | |
| }} | |
| function getKey() {{ | |
| let k = localStorage.getItem("API_KEY"); | |
| if (!k) {{ | |
| k = prompt("Enter API key:"); | |
| if (k) localStorage.setItem("API_KEY", k); | |
| }} | |
| return k || ""; | |
| }} | |
| async function sendStream(msg, bubbleEl) {{ | |
| const key = getKey(); | |
| const r = await fetch("/v1/chat/stream", {{ | |
| method: "POST", | |
| headers: {{ | |
| "Authorization": "Bearer " + key, | |
| "Content-Type": "application/json" | |
| }}, | |
| body: JSON.stringify({{ prompt: msg }}) | |
| }}); | |
| if (!r.ok || !r.body) throw new Error("No stream"); | |
| const reader = r.body.getReader(); | |
| const decoder = new TextDecoder(); | |
| let buffer = ""; | |
| while (true) {{ | |
| const {{ value, done }} = await reader.read(); | |
| if (done) break; | |
| buffer += decoder.decode(value, {{ stream: true }}); | |
| const parts = buffer.split("\\n\\n"); | |
| buffer = parts.pop() || ""; | |
| for (const part of parts) {{ | |
| const lines = part.split("\\n"); | |
| let event = "message"; | |
| let dataLine = ""; | |
| for (const line of lines) {{ | |
| if (line.startsWith("event:")) event = line.slice(6).trim(); | |
| if (line.startsWith("data:")) dataLine += line.slice(5).trim(); | |
| }} | |
| if (!dataLine) continue; | |
| const payload = JSON.parse(dataLine); | |
| if (event === "error") {{ | |
| bubbleEl.textContent += "\\n\\n---\\nError:\\n" + (payload.error || "Unknown error"); | |
| setStatus("degraded", "Recovering…"); | |
| }} else if (event === "done") {{ | |
| setStatus("online", "Online"); | |
| }} else {{ | |
| bubbleEl.textContent += payload.delta || ""; | |
| }} | |
| }} | |
| }} | |
| }} | |
| async function send() {{ | |
| const msg = input.value.trim(); | |
| if (!msg) return; | |
| input.value = ""; | |
| add("me", msg); | |
| setStatus("busy", "Thinking…"); | |
| const aiMsg = add("ai", ""); | |
| const bubble = aiMsg.querySelector(".bubble"); | |
| sendBtn.disabled = true; | |
| try {{ | |
| await sendStream(msg, bubble); | |
| }} catch (e) {{ | |
| bubble.textContent = | |
| "Temporary error. Try again in a moment.\\n\\nTip: verify your API key (stored in browser localStorage)."; | |
| setStatus("degraded", "Reconnecting…"); | |
| }} finally {{ | |
| sendBtn.disabled = false; | |
| setTimeout(healthCheck, 800); | |
| }} | |
| }} | |
| sendBtn.onclick = send; | |
| input.addEventListener("keydown", (e) => {{ | |
| if (e.key === "Enter" && !e.shiftKey) {{ | |
| e.preventDefault(); | |
| send(); | |
| }} | |
| }}); | |
| hint.innerHTML = "Press <kbd>Enter</kbd> to send, <kbd>Shift</kbd>+<kbd>Enter</kbd> for newline."; | |
| add("ai", "Hey 👋 I’m ready when you are."); | |
| healthCheck(); | |
| setInterval(healthCheck, 5000); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # ---------------------------- | |
| # Streaming SSE endpoint | |
| # ---------------------------- | |
| async def chat_stream(request: Request): | |
| auth = request.headers.get("authorization", "") | |
| if auth != f"Bearer {API_KEY}": | |
| async def deny(): | |
| yield "event: error\ndata: " + json.dumps({"error": "403: Invalid API key"}) + "\n\n" | |
| yield "event: done\ndata: {}\n\n" | |
| return StreamingResponse(deny(), media_type="text/event-stream") | |
| body = await request.json() | |
| prompt = (body.get("prompt") or "").strip() | |
| if not prompt: | |
| async def empty(): | |
| yield "data: " + json.dumps({"delta": "Send a message and I’ll respond."}) + "\n\n" | |
| yield "event: done\ndata: {}\n\n" | |
| return StreamingResponse(empty(), media_type="text/event-stream") | |
| async def event_gen(): | |
| async with GEN_SEM: | |
| try: | |
| if not ollama_healthy(): | |
| ensure_ollama_running() | |
| wait_for_ollama(60) | |
| r = requests.post( | |
| f"{OLLAMA_BASE}/api/generate", | |
| json={ | |
| "model": MODEL, | |
| "system": SYSTEM_PROMPT, | |
| "prompt": prompt, | |
| "stream": True, | |
| }, | |
| stream=True, | |
| timeout=600, | |
| ) | |
| if r.status_code != 200: | |
| yield "event: error\ndata: " + json.dumps({"error": r.text[:2000]}) + "\n\n" | |
| yield "event: done\ndata: {}\n\n" | |
| return | |
| for line in r.iter_lines(decode_unicode=True): | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| except Exception: | |
| continue | |
| delta = obj.get("response", "") | |
| if delta: | |
| yield "data: " + json.dumps({"delta": delta}) + "\n\n" | |
| if obj.get("done"): | |
| break | |
| yield "event: done\ndata: {}\n\n" | |
| except Exception as e: | |
| yield "event: error\ndata: " + json.dumps({"error": str(e)}) + "\n\n" | |
| yield "event: done\ndata: {}\n\n" | |
| return StreamingResponse(event_gen(), media_type="text/event-stream") | |
| # ---------------------------- | |
| # Non-stream fallback | |
| # ---------------------------- | |
| async def chat_api(request: Request): | |
| auth = request.headers.get("authorization", "") | |
| if auth != f"Bearer {API_KEY}": | |
| return JSONResponse({"response": "", "error": "403: Invalid API key"}, status_code=200) | |
| body = await request.json() | |
| prompt = (body.get("prompt") or "").strip() | |
| if not prompt: | |
| return {"response": "Send a message and I’ll respond.", "error": None} | |
| async with GEN_SEM: | |
| text, err = generate_with_recovery(prompt, attempts=3) | |
| return {"response": text, "error": err} |