Spaces:
Sleeping
Sleeping
| import sqlite3 | |
| import json | |
| from contextlib import contextmanager | |
| from typing import List, Dict, Any, Tuple | |
| from config import DB_PATH | |
| def get_db_connection(): | |
| """Context manager for database connections.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| try: | |
| yield conn | |
| finally: | |
| conn.close() | |
| def fetch_all_embeddings(table: str) -> List[Tuple[int, str, List[float]]]: | |
| """Fetch all embeddings from a table.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute(f"SELECT id, full_text, embedding FROM {table}") | |
| rows = cur.fetchall() | |
| parsed = [] | |
| for row in rows: | |
| try: | |
| parsed.append((row['id'], row['full_text'], json.loads(row['embedding']))) | |
| except (json.JSONDecodeError, TypeError): | |
| continue | |
| return parsed | |
| def fetch_row_by_id(table: str, row_id: int) -> Dict[str, Any]: | |
| """Fetch a single row by ID.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute(f"SELECT * FROM {table} WHERE id = ?", (row_id,)) | |
| row = cur.fetchone() | |
| return dict(row) if row else {} | |
| def fetch_all_faq_embeddings() -> List[Tuple[int, str, str, List[float]]]: | |
| """Fetch all FAQ embeddings.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute("SELECT id, question, answer, embedding FROM faq_entries") | |
| rows = cur.fetchall() | |
| parsed = [] | |
| for row in rows: | |
| try: | |
| parsed.append((row['id'], row['question'], row['answer'], json.loads(row['embedding']))) | |
| except (json.JSONDecodeError, TypeError): | |
| continue | |
| return parsed | |
| def log_question( | |
| question: str, | |
| session_id: str = None, | |
| category: str = None, | |
| answer: str = None, | |
| detected_mode: str = None, | |
| routing_question: str = None, | |
| rule_triggered: str = None, | |
| link_provided: bool = False | |
| ): | |
| """Log a user question to the database with comprehensive observability metadata. | |
| Args: | |
| question: The user's question | |
| session_id: Session identifier | |
| category: Question category (e.g., 'faq_match', 'llm_generated', 'policy_violation') | |
| answer: The bot's response | |
| detected_mode: Operating mode ('Mode A' or 'Mode B') | |
| routing_question: The routing question asked (if any) | |
| rule_triggered: Business rule that was triggered (e.g., 'audit_rule', 'free_class_first') | |
| link_provided: Whether a direct link was included in the response | |
| """ | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| try: | |
| cur.execute(""" | |
| INSERT INTO question_logs ( | |
| session_id, question, category, answer, | |
| detected_mode, routing_question, rule_triggered, link_provided | |
| ) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session_id, question, category, answer, | |
| detected_mode, routing_question, rule_triggered, | |
| 1 if link_provided else 0 | |
| )) | |
| except sqlite3.OperationalError as e: | |
| # Fallback for older schema versions (shouldn't happen after migration) | |
| print(f"⚠️ Logging error: {e}. Falling back to basic logging.") | |
| cur.execute("INSERT INTO question_logs (question) VALUES (?)", (question,)) | |
| conn.commit() | |
| def get_session_state(session_id: str) -> Dict[str, Any]: | |
| """Get session state from DB""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute("SELECT * FROM user_sessions WHERE session_id = ?", (session_id,)) | |
| row = cur.fetchone() | |
| if row: | |
| return dict(row) | |
| return {"preference": None, "msg_count": 0, "clarification_count": 0, "knowledge_context": "{}"} | |
| def update_session_state(session_id: str, preference: str = None, increment_count: bool = True, increment_clarification: bool = False, reset_clarification: bool = False, knowledge_update: Dict = None): | |
| """Update session state with Knowledge Dictionary support""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| # Check if exists | |
| cur.execute("SELECT preference, msg_count, clarification_count, knowledge_context FROM user_sessions WHERE session_id = ?", (session_id,)) | |
| row = cur.fetchone() | |
| current_knowledge = {} | |
| if row: | |
| curr_pref, curr_count, curr_clarification, curr_knowledge_json = row | |
| try: | |
| current_knowledge = json.loads(curr_knowledge_json) | |
| except: | |
| current_knowledge = {} | |
| new_pref = preference if preference else curr_pref | |
| new_count = curr_count + 1 if increment_count else curr_count | |
| # 10-Message Memory Rule: Reset if we hit the limit | |
| if new_count > 10: | |
| print(f"🔄 Session {session_id} reached 10 messages. Resetting memory context.") | |
| new_count = 1 | |
| new_pref = None | |
| current_knowledge = {} | |
| new_clarification = 0 | |
| else: | |
| new_clarification = curr_clarification | |
| if reset_clarification: | |
| new_clarification = 0 | |
| elif increment_clarification: | |
| new_clarification = curr_clarification + 1 | |
| # Merge knowledge updates | |
| if knowledge_update: | |
| current_knowledge.update(knowledge_update) | |
| new_knowledge_json = json.dumps(current_knowledge) | |
| cur.execute(""" | |
| UPDATE user_sessions | |
| SET preference = ?, msg_count = ?, clarification_count = ?, knowledge_context = ?, last_updated = CURRENT_TIMESTAMP | |
| WHERE session_id = ? | |
| """, (new_pref, new_count, new_clarification, new_knowledge_json, session_id)) | |
| else: | |
| new_pref = preference | |
| new_count = 1 if increment_count else 0 | |
| new_clarification = 1 if increment_clarification else 0 | |
| if knowledge_update: | |
| current_knowledge.update(knowledge_update) | |
| new_knowledge_json = json.dumps(current_knowledge) | |
| cur.execute(""" | |
| INSERT INTO user_sessions (session_id, preference, msg_count, clarification_count, knowledge_context) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (session_id, new_pref, new_count, new_clarification, new_knowledge_json)) | |
| conn.commit() | |
| def update_faq_entry(faq_id: int, question: str, answer: str): | |
| """Update an existing FAQ entry.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute( | |
| "UPDATE faq_entries SET question = ?, answer = ?, embedding = NULL WHERE id = ?", | |
| (question, answer, faq_id) | |
| ) | |
| conn.commit() | |
| def delete_faq_entry(faq_id: int): | |
| """Delete an FAQ entry.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute("DELETE FROM faq_entries WHERE id = ?", (faq_id,)) | |
| conn.commit() | |
| def add_faq_entry(question: str, answer: str): | |
| """Add a new FAQ entry.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute( | |
| "INSERT INTO faq_entries (question, answer) VALUES (?, ?)", | |
| (question, answer) | |
| ) | |
| conn.commit() | |
| def bulk_update_faqs(entries: List[Dict[str, str]]): | |
| """Bulk update FAQs from a list of dictionaries.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| for entry in entries: | |
| question = entry.get('Question') or entry.get('question') | |
| answer = entry.get('Answer') or entry.get('answer') | |
| if question and answer: | |
| cur.execute( | |
| "INSERT INTO faq_entries (question, answer) VALUES (?, ?)", | |
| (question, answer) | |
| ) | |
| conn.commit() | |
| def bulk_update_podcasts(entries: List[Dict[str, str]]): | |
| """Bulk update Podcasts from a list of dictionaries.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| for entry in entries: | |
| guest = entry.get('Guest Name') or entry.get('guest_name') | |
| url = entry.get('YouTube URL') or entry.get('youtube_url') | |
| summary = entry.get('Summary') or entry.get('summary') | |
| if guest and url and summary: | |
| # Format full_text as required by the existing logic | |
| full_text = f"Guest: {guest}. Summary: {summary}" | |
| # Store summary in highlight_json as a simple list for compatibility | |
| h_json = json.dumps([{"summary": summary}]) | |
| cur.execute( | |
| "INSERT INTO podcast_episodes (guest_name, youtube_url, highlight_json, full_text) VALUES (?, ?, ?, ?)", | |
| (guest, url, h_json, full_text) | |
| ) | |
| conn.commit() | |
| def fetch_all_podcast_metadata() -> List[Dict[str, Any]]: | |
| """Fetch all podcast metadata for the admin table.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute("SELECT id, guest_name, youtube_url, highlight_json FROM podcast_episodes") | |
| rows = cur.fetchall() | |
| results = [] | |
| for row in rows: | |
| d = dict(row) | |
| # Try to extract plain summary from JSON for the table | |
| try: | |
| h = json.loads(d['highlight_json']) | |
| d['summary'] = h[0]['summary'] if h and isinstance(h, list) else d['highlight_json'] | |
| except: | |
| d['summary'] = d['highlight_json'] | |
| results.append(d) | |
| return results | |
| def fetch_all_faq_metadata() -> List[Dict[str, Any]]: | |
| """Fetch all FAQ metadata for the admin table.""" | |
| with get_db_connection() as conn: | |
| cur = conn.cursor() | |
| cur.execute("SELECT id, question, answer FROM faq_entries") | |
| return [dict(row) for row in cur.fetchall()] |