Database_Agent / app.py
prthm11's picture
Update app.py
8a66204 verified
from flask import Flask, request, jsonify, render_template
from flask_socketio import SocketIO, emit
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import initialize_agent, AgentType
from langchain_community.agent_toolkits import create_sql_agent, SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain.tools import Tool
from langchain.memory import ConversationBufferMemory
from pymongo import MongoClient
import threading
import os, re, traceback, ast
from bson import json_util
from dotenv import load_dotenv
from werkzeug.exceptions import HTTPException
from langchain.prompts import ChatPromptTemplate
from tabulate import tabulate
from fuzzywuzzy import fuzz
import urllib
import logging
from urllib.parse import urlparse
from langchain_groq import ChatGroq
# --------------------------
# BASIC CONFIG
# --------------------------
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Static SQL ODBC connection (hard-coded as requested)
ODBC_CONN = (
"DRIVER={ODBC Driver 17 for SQL Server};"
f"SERVER={os.getenv('DB_SERVER','192.168.1.37')},"
f"{os.getenv('DB_PORT','1433')};"
f"DATABASE={os.getenv('DB_NAME','TunisSyncV1')};"
f"UID={os.getenv('DB_USER','sa')};"
f"PWD={os.getenv('DB_PASS','sa123')}"
)
# params = urllib.parse.quote_plus(ODBC_CONN)
# DB_URI = f"mssql+pyodbc:///?odbc_connect={params}"
# mssql+pyodbc:///?odbc_connect=DRIVER%3D%7BODBC+Driver+17+for+SQL+Server%7D%3BSERVER%3D192.168.1.37%2C1433%3BDATABASE%3DTunisSyncV1%3BUID%3Dsa%3BPWD%3Dsa123
# # Static MongoDB URI (Atlas)
# # MONGO_URI = os.getenv('MONGO_URI', 'mongodb+srv://dixitmwa:DixitWa%40123!@cluster0.qiysaz9.mongodb.net/shopdb')
# MONGO_URI = 'mongodb+srv://dixitmwa:DixitWa%40123!@cluster0.qiysaz9.mongodb.net/shopdb'
# --------------------------
# Flask + SocketIO + LLM
# --------------------------
app = Flask(__name__)
app.config['SECRET_KEY'] = os.urandom(32)
app.config['UPLOAD_FOLDER'] = 'uploads'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
socketio = SocketIO(app, cors_allowed_origins='*')
llm = ChatGoogleGenerativeAI(
temperature=0.2,
model="gemini-2.0-flash",
max_retries=50,
api_key=os.getenv('GEMINI_API_KEY')
)
# llm = ChatGroq(
# # model="meta-llama/llama-4-scout-17b-16e-instruct",
# # model="deepseek-r1-distill-llama-70b",
# model= "meta-llama/llama-4-maverick-17b-128e-instruct",
# # model="openai/gpt-oss-120b",
# temperature=0,
# max_tokens=None,
# max_retries=50,
# api_key=os.getenv('GROQ_API_KEY')
# )
# Globals
agent_executor = None
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, input_key='input')
mongo_client = None
mongo_db = None
db_mode = None # 'sql' or 'mongo'
# --------------------------
# Helpers / Safety checks
# --------------------------
def error_safe(f):
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except HTTPException as he:
return jsonify({"status": "error", "message": he.description}), he.code
except Exception as e:
print('[ERROR] Uncaught Exception in', f.__name__)
traceback.print_exc()
return jsonify({"status": "error", "message": str(e)}), 500
wrapper.__name__ = f.__name__
return wrapper
# def is_schema_request(prompt: str) -> bool:
# pattern = re.compile(r'\b(schema|table names|tables|columns|structure|column names|collections?|field names|metadata|describe|show)\b', re.IGNORECASE)
# return bool(pattern.search(prompt))
# def is_sensitive_request(prompt: str) -> bool:
# sensitive_keywords = [
# "password", "token", "credential", "secret", "api key", "schema", "structure",
# "collection name", "field name", "user_id", "order_id", "payment_id",
# "internal", "database structure", "table structure", "email", "phone", "contact", "ssn"
# ]
# lowered = prompt.lower()
# return any(keyword in lowered for keyword in sensitive_keywords)
# intent_prompt = ChatPromptTemplate.from_messages([
# ("system", "Classify if the user is asking schema/structure/sensitive info (tables, columns, schema): YES or NO."),
# ("human", "{prompt}")
# ])
# try:
# intent_checker = intent_prompt | llm
# except Exception:
# intent_checker = None
# def is_schema_leak_request(prompt):
# if intent_checker is None:
# return False
# try:
# classification = intent_checker.invoke({"prompt": prompt})
# text = ''
# if hasattr(classification, 'content'):
# text = classification.content
# elif hasattr(classification, 'text'):
# text = classification.text
# else:
# text = str(classification)
# return 'yes' in text.strip().lower()
# except Exception as e:
# logger.warning('Schema intent classifier failed: %s', e)
# return False
# --------------------------
# SQL agent initialization
# --------------------------
def init_sql_agent_from_uri(sql_uri: str):
global agent_executor, db_mode
try:
# # Detect dialect from URI prefix
# if sql_uri.startswith("postgresql://"):
# dialect = "PostgreSQL"
# elif sql_uri.startswith("mysql://") or sql_uri.startswith("mysql+pymysql://"):
# dialect = "MySQL"
# elif sql_uri.startswith("sqlite:///") or sql_uri.startswith("sqlite://"):
# dialect = "SQLite"
# else:
# dialect = "Generic SQL"
sql_db = SQLDatabase.from_uri(sql_uri)
toolkit = SQLDatabaseToolkit(db=sql_db, llm=llm)
prefix = '''You are a helpful SQL expert agent that ALWAYS returns natural language answers using the tools.
Always format your responses in Markdown. For example:
- Use bullet points
- Use bold for headers
- Wrap code in triple backticks
- Tables should use Markdown table syntax
You must NEVER:
- Show or mention SQL syntax.
- Reveal table names, column names, or database schema.
- Respond with any technical details or structure of the database.
- Return code or tool names.
- Give wrong Answers.
'''
agent = create_sql_agent(
llm=llm,
toolkit=toolkit,
verbose=True,
prefix=prefix,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
memory=memory,
agent_executor_kwargs={"handle_parsing_errors": True},
)
agent_executor = agent
db_mode = 'sql'
logger.info('SQL agent initialized using URI')
except Exception as e:
logger.error('Failed to initialize SQL agent: %s', e)
traceback.print_exc()
# --------------------------
# Mongo agent initialization
# --------------------------
def find_docs_tool_func(query: str) -> str:
try:
parts = dict(part.strip().split('=', 1) for part in query.split(',') if '=' in part)
collection = parts.get('collection')
key = parts.get('key')
value = parts.get('value')
if not collection:
return "❌ 'collection' is required."
def query_collection(coll_name):
if key and value:
return list(mongo_db[coll_name].find({key: value}, {'_id': 0}))
elif value:
return [doc for doc in mongo_db[coll_name].find({}, {'_id': 0}) if any(str(v).lower() == value.lower() for v in doc.values())]
else:
return list(mongo_db[coll_name].find({}, {'_id': 0}))
docs = query_collection(collection)
if docs:
return "\n markdown\n" + tabulate(docs, headers='keys', tablefmt='github') + "\n"
for coll in mongo_db.list_collection_names():
if coll == collection:
continue
docs = query_collection(coll)
if docs:
return "\n markdown\n" + tabulate(docs, headers='keys', tablefmt='github') + "\n"
return "**No documents found.**"
except Exception as e:
return f"Invalid input format or error: {str(e)}"
def aggregate_group_by(_input: str):
try:
if _input.strip().startswith('{'):
args = ast.literal_eval(_input)
collection = args.get('collection_name') or args.get('collection')
field = args.get('group_by') or args.get('field')
else:
args = dict(x.split('=') for x in _input.split(','))
collection = args['collection']
field = args['field']
pipeline = [
{'$group': {'_id': f"${field}", 'count': {'$sum': 1}}},
{'$project': {'_id': 0, field: '$_id', 'count': 1}}
]
result = list(mongo_db[collection].aggregate(pipeline))
if not result:
return "**No data found.**"
return "\n markdown\n" + tabulate(result, headers='keys', tablefmt='github') + "\n"
except Exception as e:
return f"Aggregation failed: {e}"
def get_all_documents(collection: str):
try:
docs = list(mongo_db[collection].find({}, {'_id': 0}))
if not docs:
return "**No documents found.**"
return "\n markdown\n" + tabulate(docs, headers='keys', tablefmt='github') + "\n"
except Exception as e:
return f"Error fetching documents: {e}"
def fuzzy_find_documents(query: str):
try:
parts = dict(part.strip().split('=', 1) for part in query.split(','))
collection = parts['collection']
value = parts['value']
threshold = int(parts.get('threshold', 80))
matches = []
for doc in mongo_db[collection].find({}, {'_id': 0}):
if any(fuzz.partial_ratio(str(v).lower(), value.lower()) >= threshold for v in doc.values()):
matches.append(doc)
if not matches:
return "**No fuzzy matches found.**"
return "\n markdown\n" + tabulate(matches, headers='keys', tablefmt='github') + "\n"
except Exception as e:
return f"Fuzzy match error: {e}"
def join_collections_tool_func(_input: str):
try:
args = dict(x.strip().split('=', 1) for x in _input.split(','))
from_coll = args['from']
key = args['key']
to_coll = args['to']
match = args['match']
return_field = args['return']
next_key = args.get('next_key')
next_to = args.get('next_to')
next_match = args.get('next_match')
to_docs = {doc[match]: doc for doc in mongo_db[to_coll].find() if match in doc}
joined = []
for doc in mongo_db[from_coll].find({}, {'_id': 0}):
foreign_doc = to_docs.get(doc.get(key))
if not foreign_doc:
continue
merged = {**doc, **foreign_doc}
joined.append(merged)
if next_key and next_to and next_match:
next_docs = {doc[next_match]: doc for doc in mongo_db[next_to].find() if next_match in doc}
for doc in joined:
user_doc = next_docs.get(doc.get(next_key))
if user_doc:
doc[return_field] = user_doc.get(return_field, 'Unknown')
else:
doc[return_field] = 'Unknown'
if not joined:
return "**No documents found.**"
final = [{return_field: doc.get(return_field)} for doc in joined if return_field in doc]
return "\n markdown\n" + tabulate(final, headers='keys', tablefmt='github') + "\n"
except Exception as e:
return f"Join failed: {e}"
def smart_join_router(prompt: str) -> str:
prompt_lower = prompt.lower()
if 'payment' in prompt_lower and any(term in prompt_lower for term in ['who', 'name', 'user', 'person']):
return 'from=Payments, key=order_id, to=Orders, match=order_id, next_key=user_id, next_to=Users, next_match=user_id, return=name'
elif 'order' in prompt_lower and 'name' in prompt_lower:
return 'from=Orders, key=user_id, to=Users, match=user_id, return=name'
return 'Unable to auto-generate join path. Please provide more context.'
def init_mongo_agent_from_uri(mongo_uri: str, database_name: str = None):
"""Initialize global mongo_client, mongo_db and build the LangChain agent tools for MongoDB access."""
global mongo_client, mongo_db, agent_executor, db_mode
try:
mongo_client = MongoClient(mongo_uri, serverSelectionTimeoutMS=5000)
# Trigger a ping to validate connection
mongo_client.admin.command('ping')
# Try to get DB name from URI if not provided
if database_name is None:
parsed = urlparse(mongo_uri)
path = parsed.path.lstrip('/')
if path:
database_name = path
else:
database_name = "test" # fallback if URI has no db
mongo_db = mongo_client[database_name]
logger.info('Connected to MongoDB Atlas database: %s', database_name)
tools = [
Tool(name='FindDocuments', func=find_docs_tool_func, description='Flexible MongoDB search...'),
Tool(name='ListCollections', func=lambda x: mongo_db.list_collection_names(), description='List all collections...'),
Tool(name='AggregateGroupBy', func=aggregate_group_by, description='Group and count by any field...'),
Tool(name='GetAllDocuments', func=get_all_documents, description='Fetch all documents from a collection...'),
Tool(name='FuzzyFindDocuments', func=fuzzy_find_documents, description='Fuzzy match documents across all fields...'),
Tool(name='JoinCollections', func=join_collections_tool_func, description='Join related collections to return names instead of IDs...'),
Tool(name='SmartJoinCollections', func=smart_join_router, description='Suggest join formats...')
]
prefix = f"""
You are MongoDBQueryBot, an intelligent assistant for interacting with a MongoDB database.
You have read-only access to the database and can answer questions using the provided tools.
"""
# Guidelines for all queries:
# 1. Always answer in clear, natural language. Use Markdown formatting, bullet points, and tables when helpful.
# 2. Explain the content of collections and fields based on the summary.
# 3. If asked about the purpose or meaning of the database, synthesize a complete description from collections and sample data.
# 4. For factual questions, query the database using the available tools:
# - FindDocuments: Flexible search by key/value
# - AggregateGroupBy: Summarize counts by fields
# - FuzzyFindDocuments: Approximate text search
# - GetAllDocuments: Retrieve all documents from a collection
# - JoinCollections / SmartJoinCollections: Combine related collections for meaningful answers
# 5. NEVER expose raw database connection info, credentials, or sensitive information.
# 6. NEVER provide raw schema details unless explicitly requested.
# 7. If the user query is vague, ambiguous, or general, make the **best effort explanation** using collection names, field names, and sample documents.
# 8. When presenting query results, format them as human-readable tables or bullet lists.
# 9. When a user asks a question you cannot answer confidently, politely explain that the answer may be limited.
# Examples:
# - User: "What is this database about?"
# Assistant: "This database contains Users, Orders, and Payments collections. It stores e-commerce information including user profiles, order histories, and payment records."
# - User: "Show me all orders for user John Doe"
# Assistant: Use FindDocuments or JoinCollections to fetch relevant results, and present in a table format.
# - User: "How many users registered this month?"
# Assistant: Use AggregateGroupBy and summarize results in a clear sentence.
# """
agent = initialize_agent(
tools=tools,
llm=llm,
agent_type=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
memory=memory,
verbose=True,
prefix=prefix,
handle_parsing_errors=True
)
agent_executor = agent
db_mode = 'mongo'
logger.info('Mongo agent initialized')
except Exception as e:
logger.error('Failed to initialize Mongo agent: %s', e)
traceback.print_exc()
# --------------------------
# Routes
# --------------------------
@app.route('/')
def index():
return render_template('index_db_json.html')
# Upload endpoint intentionally disabled (dynamic upload removed)
@app.route('/upload_db', methods=['POST'])
@error_safe
def upload_db():
return jsonify({'success': False, 'message': 'Dynamic DB upload is disabled. This server uses static configured DB URIs.'}), 403
@app.route('/connect_db', methods=['POST'])
@error_safe
def connect_db():
global agent_executor, db_mode
data = request.get_json(force=True)
uri = data.get('uri', '').strip()
if not uri:
return jsonify({"success": False, "message": "❌ No connection string provided."}), 400
try:
# --- MongoDB ---
if uri.startswith("mongodb://") or uri.startswith("mongodb+srv://"):
init_mongo_agent_from_uri(uri)
return jsonify({"success": True, "message": "✅ Connected to MongoDB agent."})
# --- PostgreSQL ---
elif uri.startswith("postgresql://"):
init_sql_agent_from_uri(uri, dialect="postgresql")
return jsonify({"success": True, "message": "✅ Connected to PostgreSQL agent."})
# --- MySQL ---
elif uri.startswith("mysql://") or uri.startswith("mysql+pymysql://"):
init_sql_agent_from_uri(uri, dialect="mysql")
return jsonify({"success": True, "message": "✅ Connected to MySQL agent."})
# --- SQLite ---
elif uri.startswith("sqlite:///") or uri.startswith("sqlite://"):
init_sql_agent_from_uri(uri, dialect="sqlite")
return jsonify({"success": True, "message": "✅ Connected to SQLite agent."})
# --- SQL Server ---
else:
init_sql_agent_from_uri(uri)
return jsonify({"success": True, "message": "✅ Connected to SQL agent."})
except Exception as e:
logger.error("Failed to connect DB: %s", e)
return jsonify({"success": False, "message": f"❌ Connection failed: {e}"}), 500
# @app.route('/generate', methods=['POST'])
# @error_safe
# def generate():
# try:
# data = request.get_json(force=True)
# prompt = data.get('prompt', '').strip()
# if not prompt:
# return jsonify({'status': 'error', 'message': 'Prompt is required'}), 400
# # if is_schema_leak_request(prompt) or is_schema_request(prompt):
# # msg = '⛔ Sorry, you\'re not allowed to access structure/schema information.'
# # socketio.emit('final', {'message': msg})
# # return jsonify({'status': 'blocked', 'message': msg}), 403
# # if is_sensitive_request(prompt):
# # msg = '⛔ This query may involve sensitive or protected information. Please rephrase your question.'
# # socketio.emit('final', {'message': msg})
# # return jsonify({'status': 'blocked', 'message': msg}), 403
# except Exception as e:
# traceback.print_exc()
# return jsonify({'status': 'error', 'message': 'Invalid input'}), 400
# def run_agent():
# try:
# result = agent_executor.invoke({'input': prompt})
# final_answer = result.get('output', '')
# if not final_answer.strip():
# final_answer = "Please, rephrase your query because I can't exactly understand, what you want !"
# socketio.emit('final', {'message': final_answer})
# except Exception as e:
# error_message = str(e)
# if '429' in error_message and 'quota' in error_message.lower():
# user_friendly_msg = '🚦 Agent is busy, try again after sometime.'
# else:
# user_friendly_msg = f'Agent failed: {error_message}'
# socketio.emit('final', {'message': user_friendly_msg})
# traceback.print_exc()
# threading.Thread(target=run_agent).start()
# return jsonify({'status': 'ok'}), 200
@app.route('/generate', methods=['POST'])
@error_safe
def generate():
try:
data = request.get_json(force=True)
prompt = data.get('prompt', '').strip()
if not prompt:
return jsonify({'status': 'error', 'message': 'Prompt is required'}), 400
# Optional safety checks (commented out in your snippet)
# if is_schema_leak_request(prompt) or is_schema_request(prompt):
# msg = "⛔ Sorry, you're not allowed to access structure/schema information."
# socketio.emit('final', {'message': msg})
# return jsonify({'status': 'blocked', 'message': msg}), 403
#
# if is_sensitive_request(prompt):
# msg = "⛔ This query may involve sensitive or protected information. Please rephrase your question."
# socketio.emit('final', {'message': msg})
# return jsonify({'status': 'blocked', 'message': msg}), 403
except Exception:
traceback.print_exc()
return jsonify({'status': 'error', 'message': 'Invalid input'}), 400
try:
# Run the agent synchronously and normalize the response
result = agent_executor.invoke({'input': prompt})
if isinstance(result, dict):
final_answer = (
result.get('final_answer')
or result.get('final')
or result.get('output')
or result.get('answer')
or result.get('text')
or ""
)
else:
final_answer = str(result or "")
if final_answer is None:
final_answer = ""
final_answer = final_answer.strip()
if not final_answer:
final_answer = "Please, rephrase your query because I can't exactly understand, what you want !"
# Emit via socketio (best-effort)
try:
socketio.emit('final', {'message': final_answer})
except Exception:
app.logger.debug("socket emit failed, continuing")
# Return final_answer in the HTTP response
return jsonify({'status': 'ok', 'prompt': prompt, 'final_answer': final_answer}), 200
except Exception as e:
app.logger.exception("Agent invocation failed")
# Friendly message for certain common failures (example: quota/429)
err_text = str(e)
if '429' in err_text and 'quota' in err_text.lower():
user_msg = '🚦 Agent is busy, try again after sometime.'
else:
user_msg = f'Agent error: {err_text[:500]}'
# Still emit to clients so UIs listening get notified
try:
socketio.emit('final', {'message': user_msg})
except Exception:
app.logger.debug("socket emit failed during error handling")
return jsonify({'status': 'error', 'prompt': prompt, 'final_answer': '', 'message': user_msg}), 500
# --------------------------
# Error handlers
# --------------------------
@app.errorhandler(Exception)
def handle_all_errors(e):
print(f"[ERROR] Global handler caught an exception: {str(e)}")
traceback.print_exc()
return jsonify({'status': 'error', 'message': 'An unexpected error occurred'}), 500
from werkzeug.exceptions import TooManyRequests
@app.errorhandler(TooManyRequests)
def handle_429_error(e):
return jsonify({
'status': 'error',
'message': '🚦 Agent is busy, try again after sometime.'
}), 429
# --------------------------
# Startup: initialize both agents using static URIs
# --------------------------
if __name__ == '__main__':
# Initialize SQL agent (static)
# init_sql_agent_from_uri(DB_URI)
# # Initialize Mongo agent (static)
# init_mongo_agent_from_uri(MONGO_URI, database_name='shopdb')
socketio.run(app, host="0.0.0.0", port=7860, allow_unsafe_werkzeug=True)