| | import requests |
| | from langchain_community.utilities import SQLDatabase |
| | from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool |
| | from sqlalchemy import ( |
| | create_engine, |
| | MetaData, |
| | inspect, |
| | Table, |
| | select, |
| | distinct |
| | ) |
| | from sqlalchemy.schema import CreateTable |
| | from sqlalchemy.exc import ProgrammingError |
| | from sqlalchemy.engine import Engine |
| | import re |
| |
|
| | def get_all_groq_model(api_key:str=None) -> list: |
| | """Uses Groq API to fetch all the available models.""" |
| | if api_key is None: |
| | raise ValueError("API key is required") |
| | url = "https://api.groq.com/openai/v1/models" |
| |
|
| | headers = { |
| | "Authorization": f"Bearer {api_key}", |
| | "Content-Type": "application/json" |
| | } |
| |
|
| | response = requests.get(url, headers=headers) |
| |
|
| | data = response.json()['data'] |
| | model_ids = [model['id'] for model in data] |
| |
|
| | return model_ids |
| |
|
| | def validate_api_key(api_key:str) -> bool: |
| | """Validates the Groq API key using the get_all_groq_model function.""" |
| | if len(api_key) == 0: |
| | return False |
| | try: |
| | get_all_groq_model(api_key=api_key) |
| | return True |
| | except Exception as e: |
| | return False |
| |
|
| | def validate_uri(uri:str) -> bool: |
| | """Validates the SQL Database URI using the SQLDatabase.from_uri function.""" |
| | try: |
| | SQLDatabase.from_uri(uri) |
| | return True |
| | except Exception as e: |
| | return False |
| |
|
| | def get_info(uri:str) -> dict[str, str] | None: |
| | """Gets the dialect name, accessible tables and table schemas using the SQLDatabase toolkit""" |
| | db = SQLDatabase.from_uri(uri) |
| | dialect = db.dialect |
| | |
| | access_tables = ListSQLDatabaseTool(db=db).invoke("") |
| | |
| | tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables) |
| | return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas} |
| |
|
| | def get_sample_rows(engine:Engine, table:Table, row_count: int = 3) -> str: |
| | """Gets the sample rows of a table using the SQLAlchemy engine""" |
| | |
| | command = select(table).limit(row_count) |
| |
|
| | |
| | columns_str = "\t".join([col.name for col in table.columns]) |
| |
|
| | try: |
| | |
| | with engine.connect() as connection: |
| | sample_rows_result = connection.execute(command) |
| | |
| | sample_rows = list( |
| | map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) |
| | ) |
| |
|
| | |
| | sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) |
| |
|
| | |
| | |
| | except ProgrammingError: |
| | sample_rows_str = "" |
| |
|
| | return ( |
| | f"{row_count} rows from {table.name} table:\n" |
| | f"{columns_str}\n" |
| | f"{sample_rows_str}" |
| | ) |
| |
|
| | def get_unique_values(engine:Engine, table:Table) -> str: |
| | """Gets the unique values of each column in a table using the SQLAlchemy engine""" |
| | unique_values = {} |
| | for column in table.c: |
| | command = select(distinct(column)) |
| |
|
| | try: |
| | |
| | with engine.connect() as connection: |
| | result = connection.execute(command) |
| | |
| | unique_values[column.name] = [str(u) for u in result] |
| |
|
| | |
| | |
| | |
| | |
| | except ProgrammingError: |
| | sample_rows_str = "" |
| |
|
| | output_str = f"Unique values of each column in {table.name}: \n" |
| | for column, values in unique_values.items(): |
| | output_str += f"{column} has {len(values)} unique values: {' '.join(values[:20])}" |
| | if len(values) > 20: |
| | output_str += ", ...." |
| | output_str += "\n" |
| |
|
| | return output_str |
| |
|
| | def get_info_sqlalchemy(uri:str) -> dict[str, str] | None: |
| | """Gets the dialect name, accessible tables and table schemas using the SQLAlchemy engine""" |
| | engine = create_engine(uri) |
| | |
| | inspector = inspect(engine) |
| | dialect = inspector.dialect.name |
| | |
| | m = MetaData() |
| | m.reflect(engine) |
| |
|
| | tables = {} |
| | for table in m.tables.values(): |
| | tables[table.name] = str(CreateTable(table).compile(engine)).rstrip() |
| | tables[table.name] += "\n\n/*" |
| | tables[table.name] += "\n" + get_sample_rows(engine, table)+"\n" |
| | tables[table.name] += "\n" + get_unique_values(engine, table)+"\n" |
| | tables[table.name] += "*/" |
| |
|
| | return {'sql_dialect': dialect, 'tables': ", ".join(tables.keys()), 'tables_schema': "\n\n".join(tables.values())} |
| |
|
| | def extract_code_blocks(text): |
| | pattern = r"```(?:\w+)?\n(.*?)\n```" |
| | matches = re.findall(pattern, text, re.DOTALL) |
| | return matches |
| |
|
| | if __name__ == "__main__": |
| | from dotenv import load_dotenv |
| | import os |
| | load_dotenv() |
| |
|
| | uri = os.getenv("POSTGRES_URI") |
| | print(get_info_sqlalchemy(uri)) |
| |
|