from __future__ import annotations import asyncio import json import os import uuid from typing import Any, Dict, Optional import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse from fastrtc import Stream, ReplyOnPause, get_stt_model, get_tts_model from .gemini_text import ( gemini_chat_turn, get_session, deliver_function_result, ) app = FastAPI() # ---------------------------- # FastRTC Voice Chat (VAD + STT + TTS) # ---------------------------- # These are CPU-friendly, but still heavy on Spaces. Keep them global. STT_MODEL_NAME = os.getenv("FASTRTC_STT_MODEL", "moonshine/tiny") TTS_MODEL_NAME = os.getenv("FASTRTC_TTS_MODEL", "kokoro") stt = get_stt_model(model=STT_MODEL_NAME) tts = get_tts_model(model=TTS_MODEL_NAME) def _voice_reply_fn(audio: tuple[int, np.ndarray]): """ Called when the user pauses (VAD). Returns streamed audio frames (TTS). """ # audio is (sample_rate, int16 mono ndarray) # FastRTC STT expects "audio" in the same tuple form per docs examples. user_text = stt.stt(audio).strip() if not user_text: return # For voice sessions we create a synthetic session_id (not Scratch ws session) # because FastRTC’s ReplyOnPause fn signature doesn’t expose the RTC session id. # This keeps a stable conversation state per-process, but not per-user. # # If you need per-user memory for voice, we can switch to a stateful StreamHandler later. voice_session_id = "voice-global" async def run(): # No tool bounce for voice by default (still supported via same session registry if you want) async def noop_emit(_evt: dict): return text = await gemini_chat_turn( session_id=voice_session_id, user_text=user_text, emit_event=noop_emit, model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"), ) return text text = asyncio.get_event_loop().run_until_complete(run()) # Stream TTS back for chunk in tts.stream_tts_sync(text): # chunk is already an audio frame compatible with FastRTC yield chunk voice_stream = Stream( modality="audio", mode="send-receive", handler=ReplyOnPause(_voice_reply_fn), ) # Mount FastRTC endpoints (WebRTC + WebSocket) under /rtc voice_stream.mount(app, path="/rtc") # ---------------------------- # Scratch-friendly WebSocket API (text + function calling) # ---------------------------- @app.get("/") async def root(): return JSONResponse( { "ok": True, "service": "salexai-api", "ws": "/ws", "fastrtc": "/rtc", "notes": [ "Use /ws for Scratch JSON chat + function calling.", "Use /rtc for FastRTC voice chat endpoints (VAD/STT/TTS handled by FastRTC).", ], } ) @app.websocket("/ws") async def ws_endpoint(ws: WebSocket): await ws.accept() session_id: Optional[str] = None async def emit(evt: dict): await ws.send_text(json.dumps(evt)) try: while True: raw = await ws.receive_text() msg = json.loads(raw) if raw else {} mtype = msg.get("type") if mtype == "connect": session_id = msg.get("session_id") or str(uuid.uuid4()) get_session(session_id) # ensure exists await emit({"type": "ready", "session_id": session_id}) continue if not session_id: await emit({"type": "error", "message": "Not connected. Send {type:'connect'} first."}) continue # -------- function registry -------- if mtype == "add_function": name = str(msg.get("name") or "").strip() schema = msg.get("schema") or {} if not name: await emit({"type": "error", "message": "add_function missing name"}) continue s = get_session(session_id) s.functions[name] = schema await emit({"type": "function_added", "name": name}) continue if mtype == "remove_function": name = str(msg.get("name") or "").strip() s = get_session(session_id) if name in s.functions: s.functions.pop(name, None) await emit({"type": "function_removed", "name": name}) else: await emit({"type": "warning", "message": f"Function not found: {name}"}) continue if mtype == "list_functions": s = get_session(session_id) await emit({"type": "functions", "items": list(s.functions.keys())}) continue # Client returns tool results if mtype == "function_result": call_id = msg.get("call_id") result = msg.get("result") if not call_id: await emit({"type": "error", "message": "function_result missing call_id"}) continue ok = deliver_function_result(session_id, call_id, result) if not ok: await emit({"type": "warning", "message": f"No pending call_id: {call_id}"}) else: await emit({"type": "function_result_ack", "call_id": call_id}) continue # -------- chat -------- if mtype == "send": text = str(msg.get("text") or "") if not text.strip(): await emit({"type": "error", "message": "Empty text"}) continue try: assistant_text = await gemini_chat_turn( session_id=session_id, user_text=text, emit_event=emit, # this is where tool calls get emitted model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"), ) await emit({"type": "assistant", "text": assistant_text}) except Exception as e: await emit({"type": "error", "message": f"Gemini error: {e}"}) continue await emit({"type": "error", "message": f"Unknown type: {mtype}"}) except WebSocketDisconnect: return except Exception as e: try: await emit({"type": "error", "message": f"WS crashed: {e}"}) except Exception: pass