from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles import time import os import re import asyncio import base64 from datetime import datetime from typing import List, Optional, Any from pydantic import BaseModel from dotenv import load_dotenv # LangChain / Google GenAI imports from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.messages import HumanMessage, SystemMessage, AIMessage load_dotenv() app = FastAPI(title="Socratic Sentiment Chatbot API") # Enable CORS for frontend integration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Pydantic Schemas class ChatMessage(BaseModel): role: str # "user" or "assistant" content: str class ChatRequest(BaseModel): message: str gemini_api_key: Optional[str] = None history: Optional[List[ChatMessage]] = None class ChatResponse(BaseModel): sentiment: str response: str latency: float prompt_context: str tokens: int cost: float # Token estimation helper (using standard ~4 characters per token multiplier for English) def estimate_tokens(text: str) -> int: return max(1, int(len(text) / 4.0)) # Cost calculation helper def calculate_cost(input_tokens: int, output_tokens: int) -> float: # Gemini 3.1 Flash Lite pricing ($0.075/1M input tokens, $0.30/1M output tokens) input_cost = (input_tokens / 1_000_000.0) * 0.075 output_cost = (output_tokens / 1_000_000.0) * 0.30 return input_cost + output_cost # Helper to extract text from LangChain message content def get_text_content(content: Any) -> str: if isinstance(content, str): return content elif isinstance(content, list): text_parts = [] for part in content: if isinstance(part, dict) and part.get("type") == "text": text_parts.append(part.get("text", "")) elif isinstance(part, str): text_parts.append(part) return "".join(text_parts) return str(content) # Regex PII scrubbing helper def scrub_pii(text: str) -> str: if not text: return text # Email addresses text = re.sub(r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+', '[EMAIL]', text) # Phone numbers text = re.sub(r'\b(?:\+?\d{1,3}[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b', '[PHONE]', text) # IP Addresses text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '[IP_ADDRESS]', text) # SSNs text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text) return text # Markdown Logging helper MD_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sentiment_log.md") def log_to_md(question: str, sentiment: str, latency: float, cost: float, tokens_in: int, tokens_out: int, reply: str): file_exists = os.path.exists(MD_FILE) try: with open(MD_FILE, mode="a", encoding="utf-8") as f: if not file_exists: f.write("# Socratic Chatbot Sentiment & Response Log\n\n") f.write("This file tracks detected user sentiments, response latencies, costs, and Socratic replies.\n\n") timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") f.write(f"## [{timestamp}] Query: \"{question}\"\n\n") f.write("\n") f.write(" \n") f.write(" \n") f.write(" \n") f.write(" \n") f.write(f" \n") f.write(f" \n") f.write(f" \n") f.write(f" \n") f.write(" \n") f.write("
MetricValue
Detected Sentiment{sentiment}
Latency{round(latency, 3)}s
Estimated Cost${cost:.7f}
Tokens{tokens_in + tokens_out} ({tokens_in} in / {tokens_out} out)
\n\n") f.write(f"### Socratic Tutor Reply\n{reply}\n\n") f.write("---\n\n") except Exception as e: print(f"Error writing to MD log: {e}") # Option B response helper doing both sentiment detection and response generation in one pass def run_flow_b(message: str, api_key: str, history: Optional[List[ChatMessage]] = None): import json # Enforce structural JSON natively. llm = ChatGoogleGenerativeAI( model="gemini-3.1-flash-lite", google_api_key=api_key, temperature=0.5, max_tokens=450, generation_config={"response_mime_type": "application/json"} ) num_user_turns = sum(1 for m in history if m.role == "user") if history else 0 custom_system = ( f"Socratic tutor: guide with clear, substantial hints. " f"Current conversation: {num_user_turns} user turns so far. " "CRITICAL RULES:\n" "- If the user has answered correctly, solved the problem, or if hints exceed topic complexity " "(e.g., 2 user turns for simple topics, 4 user turns for complex topics): DO NOT ask any more math/science/concept questions. " "Immediately confirm their success (or provide the direct solution) and ask exactly: 'Do you want to learn something else?'\n" "- If the user is close, highly frustrated, or asks directly: give the solution and ask exactly: 'Do you want to learn something else?'\n" "- Otherwise, guide with a hint and ask exactly 1 Socratic question." ) tone_instruction = ( "JSON: {\"s\":\"sentiment\",\"r\":\"reply\"}\n" "s values: confusion|frustration|confused_but_engaged|confused_and_frustrated|starting_to_get_bored|confident_and_engaged|neutral\n" "Rules:\n" "- Sympathize with s implicitly (tone/style); never name or mention the sentiment/emotion itself.\n" "- NEVER use 'if you' (use direct phrasing: 'think about', 'imagine').\n" "- Ask 1 question max.\n" "Responses (unless wrapping up / giving final answer):\n" "- frustration: acknowledge sentiment but not explicitly + simplify + question.\n" "- starting_to_get_bored: acknowledge the specific source of boredom (e.g. repetition, dry theory) + puzzle/analogy + question.\n" "- other: hint + question." ) messages = [SystemMessage(content=f"{custom_system}\n\n{tone_instruction}")] # Minimize tokens: slice history to last 4 messages and truncate to 60 characters if history: compact_history = history[-4:] for msg in compact_history: content = msg.content if len(content) > 60: content = content[:60] + "..." if msg.role == "user": messages.append(HumanMessage(content=content)) else: messages.append(AIMessage(content=content)) messages.append(HumanMessage(content=message)) res = llm.invoke(messages) raw_response = get_text_content(res.content) cleaned_json = raw_response.strip() try: parsed = json.loads(cleaned_json) state_val = parsed.get("s", "neutral") reply_val = parsed.get("r", "") except Exception as e: print(f"Failed to parse LLM JSON response: {e}. Raw response: {raw_response}") state_val = "neutral" reply_val = "Let's take a look at this concept step by step. What do you think is the first part?" prompt_context = f"{custom_system}\n{tone_instruction}\nUser Query: {message}" est_in = estimate_tokens(prompt_context) est_out = estimate_tokens(raw_response) return state_val, reply_val, prompt_context, est_in, est_out # API Routes @app.get("/api/status") def get_status(): return { "status": "ready", "gemini_api_key_configured": bool(os.environ.get("GEMINI_API_KEY")) } @app.post("/api/chat", response_model=ChatResponse) def chat_endpoint(request: ChatRequest): # Retrieve Gemini API Key api_key = request.gemini_api_key or os.environ.get("GEMINI_API_KEY") if not api_key: raise HTTPException( status_code=400, detail="Gemini API Key is missing. Please provide it in the Settings panel or environment." ) start_time = time.time() # Scrub PII scrubbed_message = scrub_pii(request.message) try: sentiment, reply, prompt_context, est_in, est_out = run_flow_b( message=scrubbed_message, api_key=api_key, history=request.history ) latency = time.time() - start_time cost = calculate_cost(est_in, est_out) tokens = est_in + est_out # Log to Markdown log_to_md( question=request.message, sentiment=sentiment, latency=latency, cost=cost, tokens_in=est_in, tokens_out=est_out, reply=reply ) return ChatResponse( sentiment=sentiment, response=reply, latency=round(latency, 3), prompt_context=prompt_context, tokens=tokens, cost=cost ) except Exception as e: print(f"Chat endpoint error: {e}") raise HTTPException( status_code=500, detail=f"An error occurred: {str(e)}" ) @app.websocket("/api/live-ws") async def websocket_live_endpoint(websocket: WebSocket): await websocket.accept() # Retrieve Gemini API Key from query params or environment api_key = websocket.query_params.get("api_key") or os.environ.get("GEMINI_API_KEY") if not api_key: await websocket.close(code=4000, reason="GEMINI_API_KEY is missing.") return def ws_log(msg: str): log_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ws_debug.log") try: with open(log_file, "a", encoding="utf-8") as f: f.write(f"[{datetime.now().strftime('%H:%M:%S.%f')[:-3]}] {msg}\n") except Exception: pass ws_log("Client WebSocket connected. Initializing Live connection...") try: from google import genai from google.genai import types except ImportError as e: ws_log(f"ImportError: google-genai not installed. {e}") await websocket.close(code=4001, reason="google-genai SDK not installed.") return client = genai.Client(api_key=api_key) # Configure Socratic Tutor instruction for Gemini Live API config = types.LiveConnectConfig( response_modalities=["AUDIO"], # Audio modality system_instruction=types.Content( parts=[types.Part.from_text( text="Socratic tutor: guide with clear, substantial hints (no tiny nudges) to solve faster. " "Confidence is not mastery—continue Socratic hints unless they are close. " "Only when close to the solution, give the final answer & ask: 'Do you want to learn something else?' " "NEVER use the phrase 'if you' anywhere in your response (e.g. do not say 'if you think', 'if you were', etc.). Instead, frame instructions or scenarios directly (e.g., say 'think about', 'imagine', 'when looking at', or 'sometimes'). " "Only ask one question at a time to avoid overwhelming the user. " "Keep replies extremely concise (maximum 3 brief sentences) and conversational." )] ) ) try: # Establish async WebSocket connection to Gemini Live using the Gemini 3.1 Flash Live model async with client.aio.live.connect(model="gemini-3.1-flash-live-preview", config=config) as session: ws_log("Successfully connected to Gemini Live session.") async def receive_from_client(): try: audio_chunk_count = 0 while True: # Receive JSON from browser client message = await websocket.receive_json() msg_type = message.get("type") if msg_type == "audio": audio_chunk_count += 1 if audio_chunk_count % 50 == 1: ws_log(f"Received audio chunk {audio_chunk_count} from client.") # Decode base64 PCM audio chunk sent from frontend audio_bytes = base64.b64decode(message["data"]) # Stream real-time audio (using 'audio' instead of deprecated 'media') await session.send_realtime_input( audio=types.Blob(data=audio_bytes, mime_type="audio/pcm;rate=16000") ) elif msg_type == "text": ws_log(f"Received text query from client: {message['data']}") # Send real-time text input await session.send_realtime_input(text=message["data"]) except WebSocketDisconnect: ws_log("Client WebSocket disconnected (WebSocketDisconnect).") except Exception as e: ws_log(f"[WebSocket Proxy Client -> Gemini] Error: {e}") finally: ws_log("receive_from_client loop exited.") async def send_to_client(): try: while True: async for response in session.receive(): server_content = response.server_content if server_content is not None: model_turn = server_content.model_turn if model_turn is not None: for part in model_turn.parts: if part.inline_data is not None: # Stream PCM audio output back to client as Base64 audio_b64 = base64.b64encode(part.inline_data.data).decode('utf-8') await websocket.send_json({ "type": "audio", "data": audio_b64 }) elif part.text is not None: # Stream text transcription back to client ws_log(f"Streaming text chunk from Gemini: {part.text}") await websocket.send_json({ "type": "text", "data": part.text }) # Handle turn completion (model finished speaking) if server_content.turn_complete: ws_log("Gemini sent turn_complete.") await websocket.send_json({"type": "turn_complete"}) # Avoid tight loop if iterator finishes instantly await asyncio.sleep(0.1) except Exception as e: ws_log(f"[WebSocket Proxy Gemini -> Client] Error: {e}") finally: ws_log("send_to_client loop exited.") # Run both tasks concurrently and terminate when the first one finishes done, pending = await asyncio.wait( [ asyncio.create_task(receive_from_client()), asyncio.create_task(send_to_client()) ], return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() except Exception as e: ws_log(f"WebSocket Gemini Live connection failed: {e}") finally: ws_log("Closing WebSocket and cleaning up.") try: await websocket.close() except Exception: pass # Mount frontend static files in production if dist folder is built frontend_dist_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "frontend", "dist") if os.path.exists(frontend_dist_path): app.mount("/", StaticFiles(directory=frontend_dist_path, html=True), name="frontend")