Tabular Classification
Scikit-learn
Joblib
postgresql
sql
query-cache
plan-cache
redis
database
tabular-regression
Instructions to use nilenpatel/pg-plan-cache-models with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Scikit-learn
How to use nilenpatel/pg-plan-cache-models with Scikit-learn:
from huggingface_hub import hf_hub_download import joblib model = joblib.load( hf_hub_download("nilenpatel/pg-plan-cache-models", "sklearn_model.joblib") ) # only load pickle files from sources you trust # read more about it here https://skops.readthedocs.io/en/stable/persistence.html - Notebooks
- Google Colab
- Kaggle
| """ | |
| Synthetic training data generator for pg_plan_cache models. | |
| Generates realistic SQL queries across a wide range of complexity levels | |
| with labels for cache benefit, recommended TTL, and complexity score. | |
| """ | |
| import random | |
| # --------------------------------------------------------------------------- | |
| # Building blocks | |
| # --------------------------------------------------------------------------- | |
| TABLES = [ | |
| "users", "orders", "products", "payments", "sessions", | |
| "logs", "events", "accounts", "invoices", "shipments", | |
| "categories", "reviews", "inventory", "notifications", "messages", | |
| "employees", "departments", "projects", "tasks", "comments", | |
| ] | |
| SCHEMAS = ["public", "app", "analytics", "billing"] | |
| COLUMNS = { | |
| "users": ["id", "name", "email", "created_at", "status", "age", "country"], | |
| "orders": ["id", "user_id", "total", "status", "created_at", "shipped_at"], | |
| "products": ["id", "name", "price", "category_id", "stock", "rating"], | |
| "payments": ["id", "order_id", "amount", "method", "paid_at", "status"], | |
| "sessions": ["id", "user_id", "started_at", "ended_at", "ip_address"], | |
| "logs": ["id", "level", "message", "created_at", "source"], | |
| "events": ["id", "type", "user_id", "data", "created_at"], | |
| "accounts": ["id", "owner_id", "balance", "currency", "opened_at"], | |
| "invoices": ["id", "account_id", "amount", "due_date", "status"], | |
| "shipments": ["id", "order_id", "carrier", "tracking", "shipped_at"], | |
| "categories": ["id", "name", "parent_id", "sort_order"], | |
| "reviews": ["id", "product_id", "user_id", "rating", "body", "created_at"], | |
| "inventory": ["id", "product_id", "warehouse_id", "quantity", "updated_at"], | |
| "notifications": ["id", "user_id", "type", "read", "created_at"], | |
| "messages": ["id", "sender_id", "receiver_id", "body", "sent_at"], | |
| "employees": ["id", "name", "department_id", "salary", "hired_at"], | |
| "departments": ["id", "name", "budget", "manager_id"], | |
| "projects": ["id", "name", "department_id", "deadline", "status"], | |
| "tasks": ["id", "project_id", "assignee_id", "title", "status", "due_date"], | |
| "comments": ["id", "task_id", "user_id", "body", "created_at"], | |
| } | |
| AGG_FUNCS = ["COUNT", "SUM", "AVG", "MIN", "MAX"] | |
| COMPARISONS = ["=", ">", "<", ">=", "<=", "!="] | |
| STRING_VALS = ["'active'", "'pending'", "'completed'", "'cancelled'", "'new'", "'shipped'"] | |
| JOIN_TYPES = ["JOIN", "LEFT JOIN", "INNER JOIN", "RIGHT JOIN"] | |
| WINDOW_FUNCS = ["ROW_NUMBER()", "RANK()", "DENSE_RANK()", "LAG(t.id, 1)", "LEAD(t.id, 1)"] | |
| def _rand_table(): | |
| return random.choice(TABLES) | |
| def _rand_cols(table, n=None): | |
| cols = COLUMNS.get(table, ["id", "name"]) | |
| n = n or random.randint(1, min(4, len(cols))) | |
| return random.sample(cols, min(n, len(cols))) | |
| def _rand_where(alias="t"): | |
| col = random.choice(["id", "status", "created_at", "name", "amount", "age"]) | |
| op = random.choice(COMPARISONS) | |
| if col == "status": | |
| return f"{alias}.{col} {op} {random.choice(STRING_VALS)}" | |
| elif col in ("id", "age", "amount"): | |
| return f"{alias}.{col} {op} {random.randint(1, 10000)}" | |
| else: | |
| return f"{alias}.{col} {op} '2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}'" | |
| # --------------------------------------------------------------------------- | |
| # Query generators by complexity tier | |
| # --------------------------------------------------------------------------- | |
| def _simple_select(): | |
| """Tier 1: Simple SELECT with optional WHERE.""" | |
| t = _rand_table() | |
| cols = ", ".join(_rand_cols(t)) | |
| sql = f"SELECT {cols} FROM {t}" | |
| if random.random() > 0.3: | |
| sql += f" WHERE {_rand_where(t[:1])}" | |
| if random.random() > 0.7: | |
| sql += f" LIMIT {random.choice([10, 20, 50, 100])}" | |
| return sql, "low", random.randint(300, 900), random.randint(5, 20) | |
| def _select_with_order(): | |
| """Tier 1.5: SELECT with ORDER BY and LIMIT.""" | |
| t = _rand_table() | |
| cols = ", ".join(_rand_cols(t)) | |
| order_col = random.choice(COLUMNS.get(t, ["id"])) | |
| direction = random.choice(["ASC", "DESC"]) | |
| sql = f"SELECT {cols} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {order_col} {direction} LIMIT {random.choice([10,25,50])}" | |
| return sql, "low", random.randint(600, 1200), random.randint(10, 25) | |
| def _single_join(): | |
| """Tier 2: Single JOIN query.""" | |
| t1, t2 = random.sample(TABLES, 2) | |
| c1 = ", ".join(f"a.{c}" for c in _rand_cols(t1, 2)) | |
| c2 = ", ".join(f"b.{c}" for c in _rand_cols(t2, 2)) | |
| jtype = random.choice(JOIN_TYPES) | |
| sql = ( | |
| f"SELECT {c1}, {c2} FROM {t1} a " | |
| f"{jtype} {t2} b ON a.id = b.{t1[:-1]}_id" | |
| ) | |
| if random.random() > 0.4: | |
| sql += f" WHERE {_rand_where('a')}" | |
| return sql, "medium", random.randint(1800, 3600), random.randint(25, 45) | |
| def _multi_join(): | |
| """Tier 3: Multi-table JOIN.""" | |
| tables = random.sample(TABLES, random.randint(3, 5)) | |
| selects = [] | |
| for i, t in enumerate(tables): | |
| alias = chr(97 + i) | |
| col = random.choice(COLUMNS.get(t, ["id"])) | |
| selects.append(f"{alias}.{col}") | |
| sql = f"SELECT {', '.join(selects)} FROM {tables[0]} a" | |
| for i in range(1, len(tables)): | |
| alias = chr(97 + i) | |
| prev_alias = chr(97 + i - 1) | |
| jtype = random.choice(JOIN_TYPES) | |
| sql += f" {jtype} {tables[i]} {alias} ON {prev_alias}.id = {alias}.{tables[i-1][:-1]}_id" | |
| if random.random() > 0.3: | |
| sql += f" WHERE {_rand_where('a')}" | |
| if random.random() > 0.5: | |
| sql += f" ORDER BY a.id LIMIT {random.choice([50, 100, 200])}" | |
| return sql, "high", random.randint(3600, 7200), random.randint(45, 70) | |
| def _aggregate_query(): | |
| """Tier 3: Aggregation with GROUP BY.""" | |
| t = _rand_table() | |
| group_col = random.choice(COLUMNS.get(t, ["id"])[:3]) | |
| agg = random.choice(AGG_FUNCS) | |
| agg_col = random.choice(["id", "amount", "total", "price", "salary"]) | |
| sql = f"SELECT {group_col}, {agg}({agg_col}) FROM {t}" | |
| if random.random() > 0.4: | |
| sql += f" WHERE {_rand_where(t[:1])}" | |
| sql += f" GROUP BY {group_col}" | |
| if random.random() > 0.6: | |
| sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 1000)}" | |
| if random.random() > 0.5: | |
| sql += f" ORDER BY {agg}({agg_col}) DESC" | |
| return sql, "high", random.randint(3600, 7200), random.randint(40, 65) | |
| def _aggregate_join(): | |
| """Tier 4: JOIN + Aggregation.""" | |
| t1, t2 = random.sample(TABLES, 2) | |
| agg = random.choice(AGG_FUNCS) | |
| group_col = f"a.{random.choice(COLUMNS.get(t1, ['id'])[:2])}" | |
| agg_col = f"b.{random.choice(['id', 'amount', 'total'])}" | |
| jtype = random.choice(JOIN_TYPES) | |
| sql = ( | |
| f"SELECT {group_col}, {agg}({agg_col}) as agg_val " | |
| f"FROM {t1} a {jtype} {t2} b ON a.id = b.{t1[:-1]}_id " | |
| f"WHERE {_rand_where('a')} " | |
| f"GROUP BY {group_col}" | |
| ) | |
| if random.random() > 0.5: | |
| sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 500)}" | |
| sql += f" ORDER BY agg_val DESC LIMIT {random.choice([10, 20, 50])}" | |
| return sql, "high", random.randint(3600, 7200), random.randint(55, 80) | |
| def _subquery(): | |
| """Tier 4: Subquery.""" | |
| t1, t2 = random.sample(TABLES, 2) | |
| cols = ", ".join(_rand_cols(t1, 2)) | |
| sub_agg = random.choice(AGG_FUNCS) | |
| op = random.choice([">", "<", ">="]) | |
| sql = ( | |
| f"SELECT {cols} FROM {t1} " | |
| f"WHERE id IN (SELECT {t1[:-1]}_id FROM {t2} " | |
| f"WHERE {_rand_where(t2[:1])})" | |
| ) | |
| return sql, "high", random.randint(3600, 5400), random.randint(50, 75) | |
| def _correlated_subquery(): | |
| """Tier 5: Correlated subquery.""" | |
| t1, t2 = random.sample(TABLES, 2) | |
| agg = random.choice(AGG_FUNCS) | |
| sql = ( | |
| f"SELECT a.id, a.name, " | |
| f"(SELECT {agg}(b.id) FROM {t2} b WHERE b.{t1[:-1]}_id = a.id) as sub_val " | |
| f"FROM {t1} a WHERE {_rand_where('a')}" | |
| ) | |
| return sql, "high", random.randint(3600, 7200), random.randint(60, 85) | |
| def _cte_query(): | |
| """Tier 5: Common Table Expression (WITH).""" | |
| t1, t2 = random.sample(TABLES, 2) | |
| agg = random.choice(AGG_FUNCS) | |
| sql = ( | |
| f"WITH cte AS (" | |
| f"SELECT {t1[:-1]}_id, {agg}(id) as cnt FROM {t2} GROUP BY {t1[:-1]}_id" | |
| f") SELECT a.id, a.name, c.cnt " | |
| f"FROM {t1} a JOIN cte c ON a.id = c.{t1[:-1]}_id " | |
| f"WHERE c.cnt > {random.randint(1, 50)} " | |
| f"ORDER BY c.cnt DESC" | |
| ) | |
| return sql, "high", random.randint(3600, 7200), random.randint(65, 85) | |
| def _window_query(): | |
| """Tier 5: Window function.""" | |
| t = _rand_table() | |
| wfunc = random.choice(["ROW_NUMBER()", "RANK()", "DENSE_RANK()"]) | |
| partition_col = random.choice(COLUMNS.get(t, ["id"])[:2]) | |
| order_col = random.choice(["id", "created_at"]) | |
| sql = ( | |
| f"SELECT id, {partition_col}, " | |
| f"{wfunc} OVER (PARTITION BY {partition_col} ORDER BY {order_col} DESC) as rn " | |
| f"FROM {t} WHERE {_rand_where(t[:1])}" | |
| ) | |
| return sql, "high", random.randint(3600, 7200), random.randint(55, 80) | |
| def _union_query(): | |
| """Tier 4: UNION query.""" | |
| t1, t2 = random.sample(TABLES, 2) | |
| sql = ( | |
| f"SELECT id, name FROM {t1} WHERE {_rand_where(t1[:1])} " | |
| f"UNION ALL " | |
| f"SELECT id, name FROM {t2} WHERE {_rand_where(t2[:1])}" | |
| ) | |
| return sql, "medium", random.randint(1800, 3600), random.randint(35, 55) | |
| def _complex_analytics(): | |
| """Tier 6: Complex analytics query.""" | |
| t1, t2, t3 = random.sample(TABLES, 3) | |
| agg1 = random.choice(AGG_FUNCS) | |
| agg2 = random.choice(AGG_FUNCS) | |
| sql = ( | |
| f"WITH monthly AS (" | |
| f"SELECT a.id, a.name, {agg1}(b.id) as cnt, {agg2}(c.id) as total " | |
| f"FROM {t1} a " | |
| f"LEFT JOIN {t2} b ON a.id = b.{t1[:-1]}_id " | |
| f"LEFT JOIN {t3} c ON b.id = c.{t2[:-1]}_id " | |
| f"WHERE a.created_at >= '2024-01-01' " | |
| f"GROUP BY a.id, a.name " | |
| f"HAVING {agg1}(b.id) > {random.randint(1, 20)}" | |
| f") SELECT name, cnt, total, " | |
| f"RANK() OVER (ORDER BY cnt DESC) as rank " | |
| f"FROM monthly ORDER BY rank LIMIT 100" | |
| ) | |
| return sql, "high", random.randint(5400, 7200), random.randint(80, 100) | |
| def _insert_query(): | |
| """INSERT — not cacheable.""" | |
| t = _rand_table() | |
| cols = _rand_cols(t, 3) | |
| vals = ", ".join( | |
| f"{random.randint(1, 9999)}" if c in ("id", "age") else f"'val_{random.randint(1,99)}'" | |
| for c in cols | |
| ) | |
| sql = f"INSERT INTO {t} ({', '.join(cols)}) VALUES ({vals})" | |
| return sql, "low", 0, random.randint(5, 15) | |
| def _update_query(): | |
| """UPDATE — not cacheable.""" | |
| t = _rand_table() | |
| col = random.choice(COLUMNS.get(t, ["name"])[1:]) | |
| sql = f"UPDATE {t} SET {col} = 'updated' WHERE {_rand_where(t[:1])}" | |
| return sql, "low", 0, random.randint(5, 15) | |
| def _delete_query(): | |
| """DELETE — not cacheable.""" | |
| t = _rand_table() | |
| sql = f"DELETE FROM {t} WHERE {_rand_where(t[:1])}" | |
| return sql, "low", 0, random.randint(5, 10) | |
| def _exists_query(): | |
| """Tier 4: EXISTS subquery.""" | |
| t1, t2 = random.sample(TABLES, 2) | |
| cols = ", ".join(_rand_cols(t1, 2)) | |
| sql = ( | |
| f"SELECT {cols} FROM {t1} a " | |
| f"WHERE EXISTS (SELECT 1 FROM {t2} b WHERE b.{t1[:-1]}_id = a.id " | |
| f"AND {_rand_where('b')})" | |
| ) | |
| return sql, "high", random.randint(3600, 5400), random.randint(50, 70) | |
| def _case_query(): | |
| """Tier 3: CASE expression.""" | |
| t = _rand_table() | |
| sql = ( | |
| f"SELECT id, " | |
| f"CASE WHEN status = 'active' THEN 'A' " | |
| f"WHEN status = 'pending' THEN 'P' " | |
| f"ELSE 'X' END as status_code, " | |
| f"name FROM {t} WHERE {_rand_where(t[:1])}" | |
| ) | |
| return sql, "medium", random.randint(1800, 3600), random.randint(25, 40) | |
| def _distinct_query(): | |
| """Tier 2: SELECT DISTINCT.""" | |
| t = _rand_table() | |
| col = random.choice(COLUMNS.get(t, ["name"])[:3]) | |
| sql = f"SELECT DISTINCT {col} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {col}" | |
| return sql, "medium", random.randint(1200, 2400), random.randint(20, 35) | |
| # --------------------------------------------------------------------------- | |
| # Generator registry | |
| # --------------------------------------------------------------------------- | |
| GENERATORS = [ | |
| (_simple_select, 15), | |
| (_select_with_order, 10), | |
| (_single_join, 12), | |
| (_multi_join, 8), | |
| (_aggregate_query, 10), | |
| (_aggregate_join, 8), | |
| (_subquery, 7), | |
| (_correlated_subquery, 5), | |
| (_cte_query, 5), | |
| (_window_query, 5), | |
| (_union_query, 4), | |
| (_complex_analytics, 3), | |
| (_insert_query, 8), | |
| (_update_query, 5), | |
| (_delete_query, 4), | |
| (_exists_query, 5), | |
| (_case_query, 4), | |
| (_distinct_query, 4), | |
| ] | |
| # Build weighted list | |
| _WEIGHTED = [] | |
| for gen, weight in GENERATORS: | |
| _WEIGHTED.extend([gen] * weight) | |
| def generate_sample(): | |
| """Generate one (sql, cache_benefit, ttl, complexity) sample.""" | |
| gen = random.choice(_WEIGHTED) | |
| sql, benefit, ttl, complexity = gen() | |
| # Add slight noise to TTL and complexity | |
| ttl = max(0, ttl + random.randint(-60, 60)) | |
| complexity = max(1, min(100, complexity + random.randint(-3, 3))) | |
| return sql, benefit, ttl, complexity | |
| def generate_dataset(n: int = 5000, seed: int = 42): | |
| """ | |
| Generate a training dataset of n samples. | |
| Returns: | |
| queries: list[str] | |
| benefits: list[str] — "low", "medium", "high" | |
| ttls: list[int] — recommended TTL in seconds | |
| complexities: list[int] — 1-100 complexity score | |
| """ | |
| random.seed(seed) | |
| queries, benefits, ttls, complexities = [], [], [], [] | |
| for _ in range(n): | |
| sql, benefit, ttl, complexity = generate_sample() | |
| queries.append(sql) | |
| benefits.append(benefit) | |
| ttls.append(ttl) | |
| complexities.append(complexity) | |
| return queries, benefits, ttls, complexities | |