Embedding-Calculation / database.py
saim1309's picture
Upload 4 files
42db8ec verified
import sqlite3
import json
from contextlib import contextmanager
from typing import List, Dict, Any, Tuple
from config import DB_PATH
@contextmanager
def get_db_connection():
"""Context manager for database connections."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def fetch_all_embeddings(table: str) -> List[Tuple[int, str, List[float]]]:
"""Fetch all embeddings from a table."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute(f"SELECT id, full_text, embedding FROM {table}")
rows = cur.fetchall()
parsed = []
for row in rows:
try:
parsed.append((row['id'], row['full_text'], json.loads(row['embedding'])))
except (json.JSONDecodeError, TypeError):
continue
return parsed
def fetch_row_by_id(table: str, row_id: int) -> Dict[str, Any]:
"""Fetch a single row by ID."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute(f"SELECT * FROM {table} WHERE id = ?", (row_id,))
row = cur.fetchone()
return dict(row) if row else {}
def fetch_all_faq_embeddings() -> List[Tuple[int, str, str, List[float]]]:
"""Fetch all FAQ embeddings."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute("SELECT id, question, answer, embedding FROM faq_entries")
rows = cur.fetchall()
parsed = []
for row in rows:
try:
parsed.append((row['id'], row['question'], row['answer'], json.loads(row['embedding'])))
except (json.JSONDecodeError, TypeError):
continue
return parsed
def log_question(
question: str,
session_id: str = None,
category: str = None,
answer: str = None,
detected_mode: str = None,
routing_question: str = None,
rule_triggered: str = None,
link_provided: bool = False
):
"""Log a user question to the database with comprehensive observability metadata.
Args:
question: The user's question
session_id: Session identifier
category: Question category (e.g., 'faq_match', 'llm_generated', 'policy_violation')
answer: The bot's response
detected_mode: Operating mode ('Mode A' or 'Mode B')
routing_question: The routing question asked (if any)
rule_triggered: Business rule that was triggered (e.g., 'audit_rule', 'free_class_first')
link_provided: Whether a direct link was included in the response
"""
with get_db_connection() as conn:
cur = conn.cursor()
try:
cur.execute("""
INSERT INTO question_logs (
session_id, question, category, answer,
detected_mode, routing_question, rule_triggered, link_provided
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
session_id, question, category, answer,
detected_mode, routing_question, rule_triggered,
1 if link_provided else 0
))
except sqlite3.OperationalError as e:
# Fallback for older schema versions (shouldn't happen after migration)
print(f"⚠️ Logging error: {e}. Falling back to basic logging.")
cur.execute("INSERT INTO question_logs (question) VALUES (?)", (question,))
conn.commit()
def get_session_state(session_id: str) -> Dict[str, Any]:
"""Get session state from DB"""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute("SELECT * FROM user_sessions WHERE session_id = ?", (session_id,))
row = cur.fetchone()
if row:
return dict(row)
return {"preference": None, "msg_count": 0, "clarification_count": 0, "knowledge_context": "{}"}
def update_session_state(session_id: str, preference: str = None, increment_count: bool = True, increment_clarification: bool = False, reset_clarification: bool = False, knowledge_update: Dict = None):
"""Update session state with Knowledge Dictionary support"""
with get_db_connection() as conn:
cur = conn.cursor()
# Check if exists
cur.execute("SELECT preference, msg_count, clarification_count, knowledge_context FROM user_sessions WHERE session_id = ?", (session_id,))
row = cur.fetchone()
current_knowledge = {}
if row:
curr_pref, curr_count, curr_clarification, curr_knowledge_json = row
try:
current_knowledge = json.loads(curr_knowledge_json)
except:
current_knowledge = {}
new_pref = preference if preference else curr_pref
new_count = curr_count + 1 if increment_count else curr_count
# 10-Message Memory Rule: Reset if we hit the limit
if new_count > 10:
print(f"🔄 Session {session_id} reached 10 messages. Resetting memory context.")
new_count = 1
new_pref = None
current_knowledge = {}
new_clarification = 0
else:
new_clarification = curr_clarification
if reset_clarification:
new_clarification = 0
elif increment_clarification:
new_clarification = curr_clarification + 1
# Merge knowledge updates
if knowledge_update:
current_knowledge.update(knowledge_update)
new_knowledge_json = json.dumps(current_knowledge)
cur.execute("""
UPDATE user_sessions
SET preference = ?, msg_count = ?, clarification_count = ?, knowledge_context = ?, last_updated = CURRENT_TIMESTAMP
WHERE session_id = ?
""", (new_pref, new_count, new_clarification, new_knowledge_json, session_id))
else:
new_pref = preference
new_count = 1 if increment_count else 0
new_clarification = 1 if increment_clarification else 0
if knowledge_update:
current_knowledge.update(knowledge_update)
new_knowledge_json = json.dumps(current_knowledge)
cur.execute("""
INSERT INTO user_sessions (session_id, preference, msg_count, clarification_count, knowledge_context)
VALUES (?, ?, ?, ?, ?)
""", (session_id, new_pref, new_count, new_clarification, new_knowledge_json))
conn.commit()
def update_faq_entry(faq_id: int, question: str, answer: str):
"""Update an existing FAQ entry."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute(
"UPDATE faq_entries SET question = ?, answer = ?, embedding = NULL WHERE id = ?",
(question, answer, faq_id)
)
conn.commit()
def delete_faq_entry(faq_id: int):
"""Delete an FAQ entry."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute("DELETE FROM faq_entries WHERE id = ?", (faq_id,))
conn.commit()
def add_faq_entry(question: str, answer: str):
"""Add a new FAQ entry."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute(
"INSERT INTO faq_entries (question, answer) VALUES (?, ?)",
(question, answer)
)
conn.commit()
def bulk_update_faqs(entries: List[Dict[str, str]]):
"""Bulk update FAQs from a list of dictionaries."""
with get_db_connection() as conn:
cur = conn.cursor()
for entry in entries:
question = entry.get('Question') or entry.get('question')
answer = entry.get('Answer') or entry.get('answer')
if question and answer:
cur.execute(
"INSERT INTO faq_entries (question, answer) VALUES (?, ?)",
(question, answer)
)
conn.commit()
def bulk_update_podcasts(entries: List[Dict[str, str]]):
"""Bulk update Podcasts from a list of dictionaries."""
with get_db_connection() as conn:
cur = conn.cursor()
for entry in entries:
guest = entry.get('Guest Name') or entry.get('guest_name')
url = entry.get('YouTube URL') or entry.get('youtube_url')
summary = entry.get('Summary') or entry.get('summary')
if guest and url and summary:
# Format full_text as required by the existing logic
full_text = f"Guest: {guest}. Summary: {summary}"
# Store summary in highlight_json as a simple list for compatibility
h_json = json.dumps([{"summary": summary}])
cur.execute(
"INSERT INTO podcast_episodes (guest_name, youtube_url, highlight_json, full_text) VALUES (?, ?, ?, ?)",
(guest, url, h_json, full_text)
)
conn.commit()
def fetch_all_podcast_metadata() -> List[Dict[str, Any]]:
"""Fetch all podcast metadata for the admin table."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute("SELECT id, guest_name, youtube_url, highlight_json FROM podcast_episodes")
rows = cur.fetchall()
results = []
for row in rows:
d = dict(row)
# Try to extract plain summary from JSON for the table
try:
h = json.loads(d['highlight_json'])
d['summary'] = h[0]['summary'] if h and isinstance(h, list) else d['highlight_json']
except:
d['summary'] = d['highlight_json']
results.append(d)
return results
def fetch_all_faq_metadata() -> List[Dict[str, Any]]:
"""Fetch all FAQ metadata for the admin table."""
with get_db_connection() as conn:
cur = conn.cursor()
cur.execute("SELECT id, question, answer FROM faq_entries")
return [dict(row) for row in cur.fetchall()]