from flask import Flask, request, jsonify, render_template from flask_socketio import SocketIO, emit from langchain_google_genai import ChatGoogleGenerativeAI from langchain.agents import initialize_agent, AgentType from langchain_community.agent_toolkits import create_sql_agent, SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain.tools import Tool from langchain.memory import ConversationBufferMemory from pymongo import MongoClient import threading import os, re, traceback, ast from bson import json_util from dotenv import load_dotenv from werkzeug.exceptions import HTTPException from langchain.prompts import ChatPromptTemplate from tabulate import tabulate from fuzzywuzzy import fuzz import urllib import logging from urllib.parse import urlparse from langchain_groq import ChatGroq # -------------------------- # BASIC CONFIG # -------------------------- load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Static SQL ODBC connection (hard-coded as requested) ODBC_CONN = ( "DRIVER={ODBC Driver 17 for SQL Server};" f"SERVER={os.getenv('DB_SERVER','192.168.1.37')}," f"{os.getenv('DB_PORT','1433')};" f"DATABASE={os.getenv('DB_NAME','TunisSyncV1')};" f"UID={os.getenv('DB_USER','sa')};" f"PWD={os.getenv('DB_PASS','sa123')}" ) # params = urllib.parse.quote_plus(ODBC_CONN) # DB_URI = f"mssql+pyodbc:///?odbc_connect={params}" # mssql+pyodbc:///?odbc_connect=DRIVER%3D%7BODBC+Driver+17+for+SQL+Server%7D%3BSERVER%3D192.168.1.37%2C1433%3BDATABASE%3DTunisSyncV1%3BUID%3Dsa%3BPWD%3Dsa123 # # Static MongoDB URI (Atlas) # # MONGO_URI = os.getenv('MONGO_URI', 'mongodb+srv://dixitmwa:DixitWa%40123!@cluster0.qiysaz9.mongodb.net/shopdb') # MONGO_URI = 'mongodb+srv://dixitmwa:DixitWa%40123!@cluster0.qiysaz9.mongodb.net/shopdb' # -------------------------- # Flask + SocketIO + LLM # -------------------------- app = Flask(__name__) app.config['SECRET_KEY'] = os.urandom(32) app.config['UPLOAD_FOLDER'] = 'uploads' os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) socketio = SocketIO(app, cors_allowed_origins='*') llm = ChatGoogleGenerativeAI( temperature=0.2, model="gemini-2.0-flash", max_retries=50, api_key=os.getenv('GEMINI_API_KEY') ) # llm = ChatGroq( # # model="meta-llama/llama-4-scout-17b-16e-instruct", # # model="deepseek-r1-distill-llama-70b", # model= "meta-llama/llama-4-maverick-17b-128e-instruct", # # model="openai/gpt-oss-120b", # temperature=0, # max_tokens=None, # max_retries=50, # api_key=os.getenv('GROQ_API_KEY') # ) # Globals agent_executor = None memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, input_key='input') mongo_client = None mongo_db = None db_mode = None # 'sql' or 'mongo' # -------------------------- # Helpers / Safety checks # -------------------------- def error_safe(f): def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except HTTPException as he: return jsonify({"status": "error", "message": he.description}), he.code except Exception as e: print('[ERROR] Uncaught Exception in', f.__name__) traceback.print_exc() return jsonify({"status": "error", "message": str(e)}), 500 wrapper.__name__ = f.__name__ return wrapper # def is_schema_request(prompt: str) -> bool: # pattern = re.compile(r'\b(schema|table names|tables|columns|structure|column names|collections?|field names|metadata|describe|show)\b', re.IGNORECASE) # return bool(pattern.search(prompt)) # def is_sensitive_request(prompt: str) -> bool: # sensitive_keywords = [ # "password", "token", "credential", "secret", "api key", "schema", "structure", # "collection name", "field name", "user_id", "order_id", "payment_id", # "internal", "database structure", "table structure", "email", "phone", "contact", "ssn" # ] # lowered = prompt.lower() # return any(keyword in lowered for keyword in sensitive_keywords) # intent_prompt = ChatPromptTemplate.from_messages([ # ("system", "Classify if the user is asking schema/structure/sensitive info (tables, columns, schema): YES or NO."), # ("human", "{prompt}") # ]) # try: # intent_checker = intent_prompt | llm # except Exception: # intent_checker = None # def is_schema_leak_request(prompt): # if intent_checker is None: # return False # try: # classification = intent_checker.invoke({"prompt": prompt}) # text = '' # if hasattr(classification, 'content'): # text = classification.content # elif hasattr(classification, 'text'): # text = classification.text # else: # text = str(classification) # return 'yes' in text.strip().lower() # except Exception as e: # logger.warning('Schema intent classifier failed: %s', e) # return False # -------------------------- # SQL agent initialization # -------------------------- def init_sql_agent_from_uri(sql_uri: str): global agent_executor, db_mode try: # # Detect dialect from URI prefix # if sql_uri.startswith("postgresql://"): # dialect = "PostgreSQL" # elif sql_uri.startswith("mysql://") or sql_uri.startswith("mysql+pymysql://"): # dialect = "MySQL" # elif sql_uri.startswith("sqlite:///") or sql_uri.startswith("sqlite://"): # dialect = "SQLite" # else: # dialect = "Generic SQL" sql_db = SQLDatabase.from_uri(sql_uri) toolkit = SQLDatabaseToolkit(db=sql_db, llm=llm) prefix = '''You are a helpful SQL expert agent that ALWAYS returns natural language answers using the tools. Always format your responses in Markdown. For example: - Use bullet points - Use bold for headers - Wrap code in triple backticks - Tables should use Markdown table syntax You must NEVER: - Show or mention SQL syntax. - Reveal table names, column names, or database schema. - Respond with any technical details or structure of the database. - Return code or tool names. - Give wrong Answers. ''' agent = create_sql_agent( llm=llm, toolkit=toolkit, verbose=True, prefix=prefix, agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, memory=memory, agent_executor_kwargs={"handle_parsing_errors": True}, ) agent_executor = agent db_mode = 'sql' logger.info('SQL agent initialized using URI') except Exception as e: logger.error('Failed to initialize SQL agent: %s', e) traceback.print_exc() # -------------------------- # Mongo agent initialization # -------------------------- def find_docs_tool_func(query: str) -> str: try: parts = dict(part.strip().split('=', 1) for part in query.split(',') if '=' in part) collection = parts.get('collection') key = parts.get('key') value = parts.get('value') if not collection: return "❌ 'collection' is required." def query_collection(coll_name): if key and value: return list(mongo_db[coll_name].find({key: value}, {'_id': 0})) elif value: return [doc for doc in mongo_db[coll_name].find({}, {'_id': 0}) if any(str(v).lower() == value.lower() for v in doc.values())] else: return list(mongo_db[coll_name].find({}, {'_id': 0})) docs = query_collection(collection) if docs: return "\n markdown\n" + tabulate(docs, headers='keys', tablefmt='github') + "\n" for coll in mongo_db.list_collection_names(): if coll == collection: continue docs = query_collection(coll) if docs: return "\n markdown\n" + tabulate(docs, headers='keys', tablefmt='github') + "\n" return "**No documents found.**" except Exception as e: return f"Invalid input format or error: {str(e)}" def aggregate_group_by(_input: str): try: if _input.strip().startswith('{'): args = ast.literal_eval(_input) collection = args.get('collection_name') or args.get('collection') field = args.get('group_by') or args.get('field') else: args = dict(x.split('=') for x in _input.split(',')) collection = args['collection'] field = args['field'] pipeline = [ {'$group': {'_id': f"${field}", 'count': {'$sum': 1}}}, {'$project': {'_id': 0, field: '$_id', 'count': 1}} ] result = list(mongo_db[collection].aggregate(pipeline)) if not result: return "**No data found.**" return "\n markdown\n" + tabulate(result, headers='keys', tablefmt='github') + "\n" except Exception as e: return f"Aggregation failed: {e}" def get_all_documents(collection: str): try: docs = list(mongo_db[collection].find({}, {'_id': 0})) if not docs: return "**No documents found.**" return "\n markdown\n" + tabulate(docs, headers='keys', tablefmt='github') + "\n" except Exception as e: return f"Error fetching documents: {e}" def fuzzy_find_documents(query: str): try: parts = dict(part.strip().split('=', 1) for part in query.split(',')) collection = parts['collection'] value = parts['value'] threshold = int(parts.get('threshold', 80)) matches = [] for doc in mongo_db[collection].find({}, {'_id': 0}): if any(fuzz.partial_ratio(str(v).lower(), value.lower()) >= threshold for v in doc.values()): matches.append(doc) if not matches: return "**No fuzzy matches found.**" return "\n markdown\n" + tabulate(matches, headers='keys', tablefmt='github') + "\n" except Exception as e: return f"Fuzzy match error: {e}" def join_collections_tool_func(_input: str): try: args = dict(x.strip().split('=', 1) for x in _input.split(',')) from_coll = args['from'] key = args['key'] to_coll = args['to'] match = args['match'] return_field = args['return'] next_key = args.get('next_key') next_to = args.get('next_to') next_match = args.get('next_match') to_docs = {doc[match]: doc for doc in mongo_db[to_coll].find() if match in doc} joined = [] for doc in mongo_db[from_coll].find({}, {'_id': 0}): foreign_doc = to_docs.get(doc.get(key)) if not foreign_doc: continue merged = {**doc, **foreign_doc} joined.append(merged) if next_key and next_to and next_match: next_docs = {doc[next_match]: doc for doc in mongo_db[next_to].find() if next_match in doc} for doc in joined: user_doc = next_docs.get(doc.get(next_key)) if user_doc: doc[return_field] = user_doc.get(return_field, 'Unknown') else: doc[return_field] = 'Unknown' if not joined: return "**No documents found.**" final = [{return_field: doc.get(return_field)} for doc in joined if return_field in doc] return "\n markdown\n" + tabulate(final, headers='keys', tablefmt='github') + "\n" except Exception as e: return f"Join failed: {e}" def smart_join_router(prompt: str) -> str: prompt_lower = prompt.lower() if 'payment' in prompt_lower and any(term in prompt_lower for term in ['who', 'name', 'user', 'person']): return 'from=Payments, key=order_id, to=Orders, match=order_id, next_key=user_id, next_to=Users, next_match=user_id, return=name' elif 'order' in prompt_lower and 'name' in prompt_lower: return 'from=Orders, key=user_id, to=Users, match=user_id, return=name' return 'Unable to auto-generate join path. Please provide more context.' def init_mongo_agent_from_uri(mongo_uri: str, database_name: str = None): """Initialize global mongo_client, mongo_db and build the LangChain agent tools for MongoDB access.""" global mongo_client, mongo_db, agent_executor, db_mode try: mongo_client = MongoClient(mongo_uri, serverSelectionTimeoutMS=5000) # Trigger a ping to validate connection mongo_client.admin.command('ping') # Try to get DB name from URI if not provided if database_name is None: parsed = urlparse(mongo_uri) path = parsed.path.lstrip('/') if path: database_name = path else: database_name = "test" # fallback if URI has no db mongo_db = mongo_client[database_name] logger.info('Connected to MongoDB Atlas database: %s', database_name) tools = [ Tool(name='FindDocuments', func=find_docs_tool_func, description='Flexible MongoDB search...'), Tool(name='ListCollections', func=lambda x: mongo_db.list_collection_names(), description='List all collections...'), Tool(name='AggregateGroupBy', func=aggregate_group_by, description='Group and count by any field...'), Tool(name='GetAllDocuments', func=get_all_documents, description='Fetch all documents from a collection...'), Tool(name='FuzzyFindDocuments', func=fuzzy_find_documents, description='Fuzzy match documents across all fields...'), Tool(name='JoinCollections', func=join_collections_tool_func, description='Join related collections to return names instead of IDs...'), Tool(name='SmartJoinCollections', func=smart_join_router, description='Suggest join formats...') ] prefix = f""" You are MongoDBQueryBot, an intelligent assistant for interacting with a MongoDB database. You have read-only access to the database and can answer questions using the provided tools. """ # Guidelines for all queries: # 1. Always answer in clear, natural language. Use Markdown formatting, bullet points, and tables when helpful. # 2. Explain the content of collections and fields based on the summary. # 3. If asked about the purpose or meaning of the database, synthesize a complete description from collections and sample data. # 4. For factual questions, query the database using the available tools: # - FindDocuments: Flexible search by key/value # - AggregateGroupBy: Summarize counts by fields # - FuzzyFindDocuments: Approximate text search # - GetAllDocuments: Retrieve all documents from a collection # - JoinCollections / SmartJoinCollections: Combine related collections for meaningful answers # 5. NEVER expose raw database connection info, credentials, or sensitive information. # 6. NEVER provide raw schema details unless explicitly requested. # 7. If the user query is vague, ambiguous, or general, make the **best effort explanation** using collection names, field names, and sample documents. # 8. When presenting query results, format them as human-readable tables or bullet lists. # 9. When a user asks a question you cannot answer confidently, politely explain that the answer may be limited. # Examples: # - User: "What is this database about?" # Assistant: "This database contains Users, Orders, and Payments collections. It stores e-commerce information including user profiles, order histories, and payment records." # - User: "Show me all orders for user John Doe" # Assistant: Use FindDocuments or JoinCollections to fetch relevant results, and present in a table format. # - User: "How many users registered this month?" # Assistant: Use AggregateGroupBy and summarize results in a clear sentence. # """ agent = initialize_agent( tools=tools, llm=llm, agent_type=AgentType.CONVERSATIONAL_REACT_DESCRIPTION, memory=memory, verbose=True, prefix=prefix, handle_parsing_errors=True ) agent_executor = agent db_mode = 'mongo' logger.info('Mongo agent initialized') except Exception as e: logger.error('Failed to initialize Mongo agent: %s', e) traceback.print_exc() # -------------------------- # Routes # -------------------------- @app.route('/') def index(): return render_template('index_db_json.html') # Upload endpoint intentionally disabled (dynamic upload removed) @app.route('/upload_db', methods=['POST']) @error_safe def upload_db(): return jsonify({'success': False, 'message': 'Dynamic DB upload is disabled. This server uses static configured DB URIs.'}), 403 @app.route('/connect_db', methods=['POST']) @error_safe def connect_db(): global agent_executor, db_mode data = request.get_json(force=True) uri = data.get('uri', '').strip() if not uri: return jsonify({"success": False, "message": "❌ No connection string provided."}), 400 try: # --- MongoDB --- if uri.startswith("mongodb://") or uri.startswith("mongodb+srv://"): init_mongo_agent_from_uri(uri) return jsonify({"success": True, "message": "✅ Connected to MongoDB agent."}) # --- PostgreSQL --- elif uri.startswith("postgresql://"): init_sql_agent_from_uri(uri, dialect="postgresql") return jsonify({"success": True, "message": "✅ Connected to PostgreSQL agent."}) # --- MySQL --- elif uri.startswith("mysql://") or uri.startswith("mysql+pymysql://"): init_sql_agent_from_uri(uri, dialect="mysql") return jsonify({"success": True, "message": "✅ Connected to MySQL agent."}) # --- SQLite --- elif uri.startswith("sqlite:///") or uri.startswith("sqlite://"): init_sql_agent_from_uri(uri, dialect="sqlite") return jsonify({"success": True, "message": "✅ Connected to SQLite agent."}) # --- SQL Server --- else: init_sql_agent_from_uri(uri) return jsonify({"success": True, "message": "✅ Connected to SQL agent."}) except Exception as e: logger.error("Failed to connect DB: %s", e) return jsonify({"success": False, "message": f"❌ Connection failed: {e}"}), 500 # @app.route('/generate', methods=['POST']) # @error_safe # def generate(): # try: # data = request.get_json(force=True) # prompt = data.get('prompt', '').strip() # if not prompt: # return jsonify({'status': 'error', 'message': 'Prompt is required'}), 400 # # if is_schema_leak_request(prompt) or is_schema_request(prompt): # # msg = '⛔ Sorry, you\'re not allowed to access structure/schema information.' # # socketio.emit('final', {'message': msg}) # # return jsonify({'status': 'blocked', 'message': msg}), 403 # # if is_sensitive_request(prompt): # # msg = '⛔ This query may involve sensitive or protected information. Please rephrase your question.' # # socketio.emit('final', {'message': msg}) # # return jsonify({'status': 'blocked', 'message': msg}), 403 # except Exception as e: # traceback.print_exc() # return jsonify({'status': 'error', 'message': 'Invalid input'}), 400 # def run_agent(): # try: # result = agent_executor.invoke({'input': prompt}) # final_answer = result.get('output', '') # if not final_answer.strip(): # final_answer = "Please, rephrase your query because I can't exactly understand, what you want !" # socketio.emit('final', {'message': final_answer}) # except Exception as e: # error_message = str(e) # if '429' in error_message and 'quota' in error_message.lower(): # user_friendly_msg = '🚦 Agent is busy, try again after sometime.' # else: # user_friendly_msg = f'Agent failed: {error_message}' # socketio.emit('final', {'message': user_friendly_msg}) # traceback.print_exc() # threading.Thread(target=run_agent).start() # return jsonify({'status': 'ok'}), 200 @app.route('/generate', methods=['POST']) @error_safe def generate(): try: data = request.get_json(force=True) prompt = data.get('prompt', '').strip() if not prompt: return jsonify({'status': 'error', 'message': 'Prompt is required'}), 400 # Optional safety checks (commented out in your snippet) # if is_schema_leak_request(prompt) or is_schema_request(prompt): # msg = "⛔ Sorry, you're not allowed to access structure/schema information." # socketio.emit('final', {'message': msg}) # return jsonify({'status': 'blocked', 'message': msg}), 403 # # if is_sensitive_request(prompt): # msg = "⛔ This query may involve sensitive or protected information. Please rephrase your question." # socketio.emit('final', {'message': msg}) # return jsonify({'status': 'blocked', 'message': msg}), 403 except Exception: traceback.print_exc() return jsonify({'status': 'error', 'message': 'Invalid input'}), 400 try: # Run the agent synchronously and normalize the response result = agent_executor.invoke({'input': prompt}) if isinstance(result, dict): final_answer = ( result.get('final_answer') or result.get('final') or result.get('output') or result.get('answer') or result.get('text') or "" ) else: final_answer = str(result or "") if final_answer is None: final_answer = "" final_answer = final_answer.strip() if not final_answer: final_answer = "Please, rephrase your query because I can't exactly understand, what you want !" # Emit via socketio (best-effort) try: socketio.emit('final', {'message': final_answer}) except Exception: app.logger.debug("socket emit failed, continuing") # Return final_answer in the HTTP response return jsonify({'status': 'ok', 'prompt': prompt, 'final_answer': final_answer}), 200 except Exception as e: app.logger.exception("Agent invocation failed") # Friendly message for certain common failures (example: quota/429) err_text = str(e) if '429' in err_text and 'quota' in err_text.lower(): user_msg = '🚦 Agent is busy, try again after sometime.' else: user_msg = f'Agent error: {err_text[:500]}' # Still emit to clients so UIs listening get notified try: socketio.emit('final', {'message': user_msg}) except Exception: app.logger.debug("socket emit failed during error handling") return jsonify({'status': 'error', 'prompt': prompt, 'final_answer': '', 'message': user_msg}), 500 # -------------------------- # Error handlers # -------------------------- @app.errorhandler(Exception) def handle_all_errors(e): print(f"[ERROR] Global handler caught an exception: {str(e)}") traceback.print_exc() return jsonify({'status': 'error', 'message': 'An unexpected error occurred'}), 500 from werkzeug.exceptions import TooManyRequests @app.errorhandler(TooManyRequests) def handle_429_error(e): return jsonify({ 'status': 'error', 'message': '🚦 Agent is busy, try again after sometime.' }), 429 # -------------------------- # Startup: initialize both agents using static URIs # -------------------------- if __name__ == '__main__': # Initialize SQL agent (static) # init_sql_agent_from_uri(DB_URI) # # Initialize Mongo agent (static) # init_mongo_agent_from_uri(MONGO_URI, database_name='shopdb') socketio.run(app, host="0.0.0.0", port=7860, allow_unsafe_werkzeug=True)