Spaces:
Runtime error
Runtime error
File size: 5,013 Bytes
330b6e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
"""Database utilities and initialization for the chat agent."""
import os
from flask import Flask
from chat_agent.models import db, ChatSession, Message, LanguageContext
def init_database(app: Flask):
"""Initialize the database with the Flask app."""
db.init_app(app)
with app.app_context():
# Create all tables
db.create_all()
print("Database tables created successfully")
def create_tables():
"""Create all database tables."""
db.create_all()
print("All database tables created")
def drop_tables():
"""Drop all database tables."""
db.drop_all()
print("All database tables dropped")
def reset_database():
"""Reset the database by dropping and recreating all tables."""
print("Resetting database...")
drop_tables()
create_tables()
print("Database reset completed")
def get_database_info():
"""Get information about the current database."""
try:
# Get table names
inspector = db.inspect(db.engine)
tables = inspector.get_table_names()
info = {
'database_url': str(db.engine.url),
'tables': tables,
'table_count': len(tables)
}
# Get row counts for each table
table_counts = {}
for table in tables:
try:
result = db.session.execute(f"SELECT COUNT(*) FROM {table}")
count = result.scalar()
table_counts[table] = count
except Exception as e:
table_counts[table] = f"Error: {e}"
info['table_counts'] = table_counts
return info
except Exception as e:
return {'error': str(e)}
def check_database_connection():
"""Check if database connection is working."""
try:
# Try to execute a simple query
db.session.execute('SELECT 1')
return True
except Exception as e:
print(f"Database connection failed: {e}")
return False
class DatabaseManager:
"""Database management utilities."""
def __init__(self, app=None):
"""Initialize database manager."""
self.app = app
if app:
self.init_app(app)
def init_app(self, app):
"""Initialize with Flask app."""
self.app = app
init_database(app)
def create_sample_data(self):
"""Create sample data for testing."""
from uuid import uuid4
from datetime import datetime
# Create a sample user session
user_id = uuid4()
session = ChatSession.create_session(user_id=user_id, language='python')
# Create sample messages
user_message = Message.create_user_message(
session_id=session.id,
content="Hello! Can you help me with Python?",
language='python'
)
assistant_message = Message.create_assistant_message(
session_id=session.id,
content="Hello! I'd be happy to help you with Python programming. What would you like to learn about?",
language='python',
metadata={'response_time': 0.5}
)
# Create language context
context = LanguageContext.create_context(
session_id=session.id,
language='python'
)
# Add to database
db.session.add_all([user_message, assistant_message])
db.session.commit()
print(f"Sample data created:")
print(f"- Session ID: {session.id}")
print(f"- User ID: {user_id}")
print(f"- Messages: 2")
print(f"- Language Context: Python")
return {
'session_id': session.id,
'user_id': user_id,
'message_count': 2
}
def cleanup_old_sessions(self, hours=24):
"""Clean up old inactive sessions."""
count = ChatSession.cleanup_expired_sessions(timeout_seconds=hours * 3600)
print(f"Cleaned up {count} expired sessions")
return count
def get_stats(self):
"""Get database statistics."""
stats = {
'total_sessions': db.session.query(ChatSession).count(),
'active_sessions': db.session.query(ChatSession).filter(ChatSession.is_active == True).count(),
'total_messages': db.session.query(Message).count(),
'total_contexts': db.session.query(LanguageContext).count(),
}
# Get language distribution
from sqlalchemy import func
language_stats = (db.session.query(ChatSession.language, func.count(ChatSession.id))
.group_by(ChatSession.language)
.all())
stats['languages'] = dict(language_stats)
return stats |