| from fastapi import FastAPI, HTTPException
|
| from pydantic import BaseModel
|
| import sqlite3
|
| import pandas as pd
|
| import os
|
| from dotenv import load_dotenv
|
| import google.generativeai as genai
|
|
|
| app = FastAPI()
|
|
|
|
|
| load_dotenv()
|
| genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
|
|
|
| class Query(BaseModel):
|
| question: str
|
| data_source: str
|
|
|
| def get_gemini_response(question, prompt):
|
| model = genai.GenerativeModel('gemini-pro')
|
| response = model.generate_content([prompt, question])
|
| return response.text
|
|
|
| def get_csv_columns():
|
| df = pd.read_csv('employee.csv')
|
| return df.columns.tolist()
|
|
|
| csv_columns = get_csv_columns()
|
|
|
| sql_prompt = """
|
| You are an expert in converting English questions to SQL code!
|
| The SQL database has the name STUDENT and has the following Columns - NAME, CLASS, SECTION
|
|
|
| For example:
|
| - How many entries of records are present? SQL command: SELECT COUNT(*) FROM STUDENT;
|
| - Tell me all the students studying in Data Science class? SQL command: SELECT * FROM STUDENT where CLASS="Data Science";
|
|
|
| Also, the SQL code should not have ''' in the beginning or at the end, and SQL word in output.
|
| Ensure that you only generate valid SQL queries, not pandas or Python code.
|
| """
|
|
|
| csv_prompt = f"""
|
| You are an expert in analyzing CSV data and converting English questions to pandas query syntax.
|
| The CSV file is named 'employee.csv' and contains employee information.
|
| The available columns in the CSV file are: {', '.join(csv_columns)}
|
|
|
| For example:
|
| - How many employees are there? Pandas query: len(df)
|
| - List all employees in the Sales department. Pandas query: df[df['Department'] == 'Sales']
|
| - Show employees with a specific ID. Pandas query: df[df['ID'] == specific_id]
|
|
|
| Provide only the pandas query syntax without any additional explanation or markdown formatting.
|
| Do not include 'df = ' or any variable assignment in your response.
|
| Make sure to use only the columns that are available in the CSV file.
|
| Ensure that you only generate valid pandas queries, not SQL or other types of code.
|
| """
|
|
|
| def execute_sql_query(query):
|
| conn = sqlite3.connect('student.db')
|
| try:
|
| cursor = conn.cursor()
|
| cursor.execute(query)
|
| result = cursor.fetchall()
|
| return result
|
| except sqlite3.Error as e:
|
| raise HTTPException(status_code=400, detail=f"SQL Error: {str(e)}")
|
| finally:
|
| conn.close()
|
|
|
| def execute_pandas_query(query):
|
| df = pd.read_csv('employee.csv')
|
| try:
|
| result = eval(query, {'df': df, 'pd': pd})
|
| if isinstance(result, pd.DataFrame):
|
| return result.to_dict(orient='records')
|
| elif isinstance(result, pd.Series):
|
| return result.to_dict()
|
| else:
|
| return result
|
| except Exception as e:
|
| raise HTTPException(status_code=400, detail=f"Pandas Error: {str(e)}")
|
|
|
| @app.post("/query")
|
| async def process_query(query: Query):
|
| if query.data_source == "SQL Database":
|
| ai_response = get_gemini_response(query.question, sql_prompt)
|
| try:
|
| result = execute_sql_query(ai_response)
|
| return {"query": ai_response, "result": result}
|
| except HTTPException as e:
|
| raise HTTPException(status_code=400, detail=f"Error in SQL query: {e.detail}")
|
| else:
|
| ai_response = get_gemini_response(query.question, csv_prompt)
|
| try:
|
| result = execute_pandas_query(ai_response)
|
| return {"query": ai_response, "result": result, "columns": csv_columns}
|
| except HTTPException as e:
|
| raise HTTPException(status_code=400, detail=f"Error in pandas query: {e.detail}") |