MarneMorgan's picture
Create app.py
723d5e0 verified
import os
import time
import threading
import socket
import subprocess
import asyncio
import json
import uuid
import requests
from fastapi import FastAPI, Request, Header
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
app = FastAPI()
# ----------------------------
# Config
# ----------------------------
API_KEY = os.getenv("API_KEY", "821274")
MODEL = os.getenv("MODEL", "llama3.2:1b")
OLLAMA_BASE = os.getenv("OLLAMA_BASE", "http://127.0.0.1:11434")
SYSTEM_PROMPT = """
You are a helpful, friendly AI assistant.
Rules:
- If the user asks casual conversation, respond normally in plain text.
- If the user asks for code/scripts/automation or says “just do it” / “write something”, choose a reasonable task and output a complete useful script.
- Don’t refuse normal conversation.
- Only output code when appropriate.
- When you write code: output code first, then a short explanation.
Always be helpful. Never say you cannot fulfill a request unless it is unsafe.
""".strip()
# Free CPU: serialize requests to avoid overload / timeouts
GEN_SEM = asyncio.Semaphore(1)
# ----------------------------
# Ollama helpers
# ----------------------------
def is_port_open(host="127.0.0.1", port=11434) -> bool:
try:
with socket.create_connection((host, port), timeout=0.5):
return True
except OSError:
return False
def ollama_healthy() -> bool:
try:
r = requests.get(f"{OLLAMA_BASE}/api/tags", timeout=1.5)
return r.status_code == 200
except Exception:
return False
def ensure_ollama_running():
# Only start if not reachable
if not is_port_open("127.0.0.1", 11434):
subprocess.Popen(["ollama", "serve"])
def wait_for_ollama(timeout_s=120) -> bool:
start = time.time()
while time.time() - start < timeout_s:
if ollama_healthy():
return True
time.sleep(1)
return False
def pull_and_warm_model():
"""
Best-effort: pull the model (may be slow on free CPU) and warm it once.
Safe to fail (space still boots).
"""
try:
ensure_ollama_running()
if not wait_for_ollama(120):
print("Ollama not ready yet; skipping model pull.")
return
print(f"Pulling model: {MODEL}")
r = requests.post(f"{OLLAMA_BASE}/api/pull", json={"name": MODEL}, timeout=60 * 30)
if r.status_code != 200:
print("Pull failed:", r.text[:2000])
return
# Warmup: avoids first real user request being extra flaky/slow
print("Warming up…")
requests.post(
f"{OLLAMA_BASE}/api/generate",
json={"model": MODEL, "system": SYSTEM_PROMPT, "prompt": "Say: ready.", "stream": False},
timeout=180,
)
print("Warmup done.")
except Exception as e:
print("Boot task error (non-fatal):", str(e))
threading.Thread(target=pull_and_warm_model, daemon=True).start()
def generate_with_recovery(prompt: str, attempts: int = 3):
last_err = None
for i in range(1, attempts + 1):
try:
if not ollama_healthy():
ensure_ollama_running()
wait_for_ollama(60)
r = requests.post(
f"{OLLAMA_BASE}/api/generate",
json={
"model": MODEL,
"system": SYSTEM_PROMPT,
"prompt": prompt,
"stream": False,
},
timeout=600,
)
r.raise_for_status()
data = r.json()
return data.get("response", ""), None
except Exception as e:
last_err = str(e)
time.sleep(min(2 ** (i - 1), 4))
return (
"⚠️ Backend hiccup while generating. Retrying usually works.\n\n"
"Debug error:\n" + (last_err or "unknown"),
last_err,
)
def messages_to_prompt(messages):
"""
Convert OpenAI-style messages into a single prompt string for Ollama /api/generate.
"""
parts = []
for m in messages or []:
role = (m.get("role") or "user").strip().upper()
content = (m.get("content") or "").strip()
if content:
parts.append(f"{role}:\n{content}")
return "\n\n".join(parts).strip()
# ----------------------------
# Health
# ----------------------------
@app.get("/health")
def health():
return {"ok": ollama_healthy(), "model": MODEL}
# ----------------------------
# OpenAI compatibility endpoints
# ----------------------------
@app.get("/v1/models")
def openai_models(authorization: str = Header(default="")):
# Optional auth check (recommended)
if authorization and authorization != f"Bearer {API_KEY}":
return JSONResponse({"error": {"message": "Invalid API key"}}, status_code=401)
return {
"object": "list",
"data": [
{
"id": MODEL,
"object": "model",
"created": int(time.time()),
"owned_by": "private-ai"
}
]
}
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: Request, authorization: str = Header(default="")):
# OpenAI-style auth
if authorization != f"Bearer {API_KEY}":
return JSONResponse(
{"error": {"message": "Invalid API key", "type": "auth_error"}},
status_code=401
)
body = await request.json()
model = body.get("model") or MODEL
messages = body.get("messages") or []
# Convert into a single prompt for Ollama generate
prompt = messages_to_prompt(messages)
if not prompt:
prompt = "USER:\nHello"
async with GEN_SEM:
text, err = generate_with_recovery(prompt, attempts=3)
# If error, still return valid OpenAI shaped response
if err:
text += f"\n\n---\nBackend error:\n{err}"
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop"
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
}
# ----------------------------
# UI (REAL HTML, not escaped)
# ----------------------------
@app.get("/", response_class=HTMLResponse)
def ui():
return f"""<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Private AI</title>
<style>
:root {{
--bg:#0b0f17;
--panel:rgba(255,255,255,.06);
--border:rgba(255,255,255,.10);
--text:rgba(255,255,255,.92);
--muted:rgba(255,255,255,.55);
--radius:18px;
--shadow:0 12px 40px rgba(0,0,0,.35);
}}
*{{box-sizing:border-box}}
body{{
margin:0;height:100vh;overflow:hidden;
font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto;
background:
radial-gradient(900px 600px at 20% 0%, rgba(59,130,246,.18), transparent 60%),
radial-gradient(800px 600px at 80% 0%, rgba(168,85,247,.16), transparent 60%),
var(--bg);
color:var(--text);
}}
.app{{display:grid;grid-template-rows:auto 1fr auto;height:100vh}}
header{{
padding:14px 18px;border-bottom:1px solid var(--border);
backdrop-filter:blur(14px);background:rgba(10,14,22,.6);
}}
header .inner{{max-width:980px;margin:0 auto;display:flex;align-items:center;justify-content:space-between}}
header h1{{font-size:14px;margin:0;font-weight:650}}
.status{{font-size:12px;color:var(--muted);display:flex;gap:8px;align-items:center}}
.dot{{width:8px;height:8px;border-radius:50%;background:#555;box-shadow:0 0 0 6px rgba(255,255,255,.06)}}
.dot.online{{background:#22c55e;box-shadow:0 0 0 6px rgba(34,197,94,.12)}}
.dot.busy{{background:#60a5fa;box-shadow:0 0 0 6px rgba(96,165,250,.14)}}
.dot.degraded{{background:#f59e0b;box-shadow:0 0 0 6px rgba(245,158,11,.14)}}
main{{overflow-y:auto;padding:20px 18px}}
.chat{{max-width:980px;margin:0 auto}}
.msg{{display:flex;gap:12px;margin:12px 0}}
.avatar{{width:36px;height:36px;border-radius:10px;background:var(--panel);border:1px solid var(--border);display:grid;place-items:center;font-size:13px}}
.bubble{{padding:12px 14px;border-radius:var(--radius);background:var(--panel);border:1px solid var(--border);box-shadow:var(--shadow);font-size:14px;line-height:1.45;white-space:pre-wrap;overflow-wrap:anywhere}}
.me .bubble{{background:linear-gradient(180deg,rgba(59,130,246,.22),rgba(255,255,255,.06));border-color:rgba(59,130,246,.25)}}
footer{{padding:14px 18px 18px;border-top:1px solid var(--border);backdrop-filter:blur(14px);background:rgba(10,14,22,.6)}}
.composer{{max-width:980px;margin:0 auto;display:grid;grid-template-columns:1fr auto;gap:10px}}
textarea{{resize:none;min-height:44px;max-height:180px;padding:12px 14px;border-radius:16px;border:1px solid var(--border);background:var(--panel);color:var(--text);font-size:14px;outline:none}}
button{{height:44px;padding:0 16px;border-radius:16px;border:1px solid var(--border);background:rgba(255,255,255,.08);color:var(--text);font-weight:650;cursor:pointer}}
button:disabled{{opacity:.6}}
.hint{{max-width:980px;margin:10px auto 0;color:rgba(255,255,255,.55);font-size:12px}}
kbd{{font-family:ui-monospace,SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",monospace;font-size:12px;padding:2px 6px;border-radius:8px;background:rgba(255,255,255,.08);border:1px solid rgba(255,255,255,.12)}}
</style>
</head>
<body>
<div class="app">
<header>
<div class="inner">
<h1>Private AI</h1>
<div class="status">
<span class="dot" id="dot"></span>
<span id="status">Connecting…</span>
</div>
</div>
</header>
<main><div class="chat" id="chat"></div></main>
<footer>
<div class="composer">
<textarea id="input" placeholder="Message Private AI…"></textarea>
<button id="send">Send</button>
</div>
<div class="hint" id="hint"></div>
</footer>
</div>
<script>
const chat = document.getElementById("chat");
const input = document.getElementById("input");
const sendBtn = document.getElementById("send");
const dot = document.getElementById("dot");
const statusEl = document.getElementById("status");
const hint = document.getElementById("hint");
let failCount = 0;
function setStatus(mode, txt) {{
dot.classList.remove("online","busy","degraded");
if (mode === "online") dot.classList.add("online");
if (mode === "busy") dot.classList.add("busy");
if (mode === "degraded") dot.classList.add("degraded");
statusEl.textContent = txt;
}}
function add(role, text) {{
const el = document.createElement("div");
el.className = "msg " + role;
el.innerHTML = `
<div class="avatar">${{role === "me" ? "You" : "AI"}}</div>
<div class="bubble"></div>
`;
el.querySelector(".bubble").textContent = text || "";
chat.appendChild(el);
chat.scrollTop = chat.scrollHeight;
return el;
}}
async function healthCheck() {{
try {{
const r = await fetch("/health", {{ cache: "no-store" }});
const d = await r.json();
if (!r.ok || !d.ok) throw new Error("unhealthy");
failCount = 0;
setStatus("online", "Online");
return true;
}} catch {{
failCount++;
setStatus("degraded", failCount >= 5 ? "Recovering…" : "Reconnecting…");
return false;
}}
}}
function getKey() {{
let k = localStorage.getItem("API_KEY");
if (!k) {{
k = prompt("Enter API key:");
if (k) localStorage.setItem("API_KEY", k);
}}
return k || "";
}}
async function sendStream(msg, bubbleEl) {{
const key = getKey();
const r = await fetch("/v1/chat/stream", {{
method: "POST",
headers: {{
"Authorization": "Bearer " + key,
"Content-Type": "application/json"
}},
body: JSON.stringify({{ prompt: msg }})
}});
if (!r.ok || !r.body) throw new Error("No stream");
const reader = r.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {{
const {{ value, done }} = await reader.read();
if (done) break;
buffer += decoder.decode(value, {{ stream: true }});
const parts = buffer.split("\\n\\n");
buffer = parts.pop() || "";
for (const part of parts) {{
const lines = part.split("\\n");
let event = "message";
let dataLine = "";
for (const line of lines) {{
if (line.startsWith("event:")) event = line.slice(6).trim();
if (line.startsWith("data:")) dataLine += line.slice(5).trim();
}}
if (!dataLine) continue;
const payload = JSON.parse(dataLine);
if (event === "error") {{
bubbleEl.textContent += "\\n\\n---\\nError:\\n" + (payload.error || "Unknown error");
setStatus("degraded", "Recovering…");
}} else if (event === "done") {{
setStatus("online", "Online");
}} else {{
bubbleEl.textContent += payload.delta || "";
}}
}}
}}
}}
async function send() {{
const msg = input.value.trim();
if (!msg) return;
input.value = "";
add("me", msg);
setStatus("busy", "Thinking…");
const aiMsg = add("ai", "");
const bubble = aiMsg.querySelector(".bubble");
sendBtn.disabled = true;
try {{
await sendStream(msg, bubble);
}} catch (e) {{
bubble.textContent =
"Temporary error. Try again in a moment.\\n\\nTip: verify your API key (stored in browser localStorage).";
setStatus("degraded", "Reconnecting…");
}} finally {{
sendBtn.disabled = false;
setTimeout(healthCheck, 800);
}}
}}
sendBtn.onclick = send;
input.addEventListener("keydown", (e) => {{
if (e.key === "Enter" && !e.shiftKey) {{
e.preventDefault();
send();
}}
}});
hint.innerHTML = "Press <kbd>Enter</kbd> to send, <kbd>Shift</kbd>+<kbd>Enter</kbd> for newline.";
add("ai", "Hey 👋 I’m ready when you are.");
healthCheck();
setInterval(healthCheck, 5000);
</script>
</body>
</html>
"""
# ----------------------------
# Streaming SSE endpoint
# ----------------------------
@app.post("/v1/chat/stream")
async def chat_stream(request: Request):
auth = request.headers.get("authorization", "")
if auth != f"Bearer {API_KEY}":
async def deny():
yield "event: error\ndata: " + json.dumps({"error": "403: Invalid API key"}) + "\n\n"
yield "event: done\ndata: {}\n\n"
return StreamingResponse(deny(), media_type="text/event-stream")
body = await request.json()
prompt = (body.get("prompt") or "").strip()
if not prompt:
async def empty():
yield "data: " + json.dumps({"delta": "Send a message and I’ll respond."}) + "\n\n"
yield "event: done\ndata: {}\n\n"
return StreamingResponse(empty(), media_type="text/event-stream")
async def event_gen():
async with GEN_SEM:
try:
if not ollama_healthy():
ensure_ollama_running()
wait_for_ollama(60)
r = requests.post(
f"{OLLAMA_BASE}/api/generate",
json={
"model": MODEL,
"system": SYSTEM_PROMPT,
"prompt": prompt,
"stream": True,
},
stream=True,
timeout=600,
)
if r.status_code != 200:
yield "event: error\ndata: " + json.dumps({"error": r.text[:2000]}) + "\n\n"
yield "event: done\ndata: {}\n\n"
return
for line in r.iter_lines(decode_unicode=True):
if not line:
continue
try:
obj = json.loads(line)
except Exception:
continue
delta = obj.get("response", "")
if delta:
yield "data: " + json.dumps({"delta": delta}) + "\n\n"
if obj.get("done"):
break
yield "event: done\ndata: {}\n\n"
except Exception as e:
yield "event: error\ndata: " + json.dumps({"error": str(e)}) + "\n\n"
yield "event: done\ndata: {}\n\n"
return StreamingResponse(event_gen(), media_type="text/event-stream")
# ----------------------------
# Non-stream fallback
# ----------------------------
@app.post("/v1/chat")
async def chat_api(request: Request):
auth = request.headers.get("authorization", "")
if auth != f"Bearer {API_KEY}":
return JSONResponse({"response": "", "error": "403: Invalid API key"}, status_code=200)
body = await request.json()
prompt = (body.get("prompt") or "").strip()
if not prompt:
return {"response": "Send a message and I’ll respond.", "error": None}
async with GEN_SEM:
text, err = generate_with_recovery(prompt, attempts=3)
return {"response": text, "error": err}