pg-plan-cache-models / dataset.py
nilenpatel's picture
Upload pg_plan_cache models
406cec4 verified
"""
Synthetic training data generator for pg_plan_cache models.
Generates realistic SQL queries across a wide range of complexity levels
with labels for cache benefit, recommended TTL, and complexity score.
"""
import random
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
TABLES = [
"users", "orders", "products", "payments", "sessions",
"logs", "events", "accounts", "invoices", "shipments",
"categories", "reviews", "inventory", "notifications", "messages",
"employees", "departments", "projects", "tasks", "comments",
]
SCHEMAS = ["public", "app", "analytics", "billing"]
COLUMNS = {
"users": ["id", "name", "email", "created_at", "status", "age", "country"],
"orders": ["id", "user_id", "total", "status", "created_at", "shipped_at"],
"products": ["id", "name", "price", "category_id", "stock", "rating"],
"payments": ["id", "order_id", "amount", "method", "paid_at", "status"],
"sessions": ["id", "user_id", "started_at", "ended_at", "ip_address"],
"logs": ["id", "level", "message", "created_at", "source"],
"events": ["id", "type", "user_id", "data", "created_at"],
"accounts": ["id", "owner_id", "balance", "currency", "opened_at"],
"invoices": ["id", "account_id", "amount", "due_date", "status"],
"shipments": ["id", "order_id", "carrier", "tracking", "shipped_at"],
"categories": ["id", "name", "parent_id", "sort_order"],
"reviews": ["id", "product_id", "user_id", "rating", "body", "created_at"],
"inventory": ["id", "product_id", "warehouse_id", "quantity", "updated_at"],
"notifications": ["id", "user_id", "type", "read", "created_at"],
"messages": ["id", "sender_id", "receiver_id", "body", "sent_at"],
"employees": ["id", "name", "department_id", "salary", "hired_at"],
"departments": ["id", "name", "budget", "manager_id"],
"projects": ["id", "name", "department_id", "deadline", "status"],
"tasks": ["id", "project_id", "assignee_id", "title", "status", "due_date"],
"comments": ["id", "task_id", "user_id", "body", "created_at"],
}
AGG_FUNCS = ["COUNT", "SUM", "AVG", "MIN", "MAX"]
COMPARISONS = ["=", ">", "<", ">=", "<=", "!="]
STRING_VALS = ["'active'", "'pending'", "'completed'", "'cancelled'", "'new'", "'shipped'"]
JOIN_TYPES = ["JOIN", "LEFT JOIN", "INNER JOIN", "RIGHT JOIN"]
WINDOW_FUNCS = ["ROW_NUMBER()", "RANK()", "DENSE_RANK()", "LAG(t.id, 1)", "LEAD(t.id, 1)"]
def _rand_table():
return random.choice(TABLES)
def _rand_cols(table, n=None):
cols = COLUMNS.get(table, ["id", "name"])
n = n or random.randint(1, min(4, len(cols)))
return random.sample(cols, min(n, len(cols)))
def _rand_where(alias="t"):
col = random.choice(["id", "status", "created_at", "name", "amount", "age"])
op = random.choice(COMPARISONS)
if col == "status":
return f"{alias}.{col} {op} {random.choice(STRING_VALS)}"
elif col in ("id", "age", "amount"):
return f"{alias}.{col} {op} {random.randint(1, 10000)}"
else:
return f"{alias}.{col} {op} '2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}'"
# ---------------------------------------------------------------------------
# Query generators by complexity tier
# ---------------------------------------------------------------------------
def _simple_select():
"""Tier 1: Simple SELECT with optional WHERE."""
t = _rand_table()
cols = ", ".join(_rand_cols(t))
sql = f"SELECT {cols} FROM {t}"
if random.random() > 0.3:
sql += f" WHERE {_rand_where(t[:1])}"
if random.random() > 0.7:
sql += f" LIMIT {random.choice([10, 20, 50, 100])}"
return sql, "low", random.randint(300, 900), random.randint(5, 20)
def _select_with_order():
"""Tier 1.5: SELECT with ORDER BY and LIMIT."""
t = _rand_table()
cols = ", ".join(_rand_cols(t))
order_col = random.choice(COLUMNS.get(t, ["id"]))
direction = random.choice(["ASC", "DESC"])
sql = f"SELECT {cols} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {order_col} {direction} LIMIT {random.choice([10,25,50])}"
return sql, "low", random.randint(600, 1200), random.randint(10, 25)
def _single_join():
"""Tier 2: Single JOIN query."""
t1, t2 = random.sample(TABLES, 2)
c1 = ", ".join(f"a.{c}" for c in _rand_cols(t1, 2))
c2 = ", ".join(f"b.{c}" for c in _rand_cols(t2, 2))
jtype = random.choice(JOIN_TYPES)
sql = (
f"SELECT {c1}, {c2} FROM {t1} a "
f"{jtype} {t2} b ON a.id = b.{t1[:-1]}_id"
)
if random.random() > 0.4:
sql += f" WHERE {_rand_where('a')}"
return sql, "medium", random.randint(1800, 3600), random.randint(25, 45)
def _multi_join():
"""Tier 3: Multi-table JOIN."""
tables = random.sample(TABLES, random.randint(3, 5))
selects = []
for i, t in enumerate(tables):
alias = chr(97 + i)
col = random.choice(COLUMNS.get(t, ["id"]))
selects.append(f"{alias}.{col}")
sql = f"SELECT {', '.join(selects)} FROM {tables[0]} a"
for i in range(1, len(tables)):
alias = chr(97 + i)
prev_alias = chr(97 + i - 1)
jtype = random.choice(JOIN_TYPES)
sql += f" {jtype} {tables[i]} {alias} ON {prev_alias}.id = {alias}.{tables[i-1][:-1]}_id"
if random.random() > 0.3:
sql += f" WHERE {_rand_where('a')}"
if random.random() > 0.5:
sql += f" ORDER BY a.id LIMIT {random.choice([50, 100, 200])}"
return sql, "high", random.randint(3600, 7200), random.randint(45, 70)
def _aggregate_query():
"""Tier 3: Aggregation with GROUP BY."""
t = _rand_table()
group_col = random.choice(COLUMNS.get(t, ["id"])[:3])
agg = random.choice(AGG_FUNCS)
agg_col = random.choice(["id", "amount", "total", "price", "salary"])
sql = f"SELECT {group_col}, {agg}({agg_col}) FROM {t}"
if random.random() > 0.4:
sql += f" WHERE {_rand_where(t[:1])}"
sql += f" GROUP BY {group_col}"
if random.random() > 0.6:
sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 1000)}"
if random.random() > 0.5:
sql += f" ORDER BY {agg}({agg_col}) DESC"
return sql, "high", random.randint(3600, 7200), random.randint(40, 65)
def _aggregate_join():
"""Tier 4: JOIN + Aggregation."""
t1, t2 = random.sample(TABLES, 2)
agg = random.choice(AGG_FUNCS)
group_col = f"a.{random.choice(COLUMNS.get(t1, ['id'])[:2])}"
agg_col = f"b.{random.choice(['id', 'amount', 'total'])}"
jtype = random.choice(JOIN_TYPES)
sql = (
f"SELECT {group_col}, {agg}({agg_col}) as agg_val "
f"FROM {t1} a {jtype} {t2} b ON a.id = b.{t1[:-1]}_id "
f"WHERE {_rand_where('a')} "
f"GROUP BY {group_col}"
)
if random.random() > 0.5:
sql += f" HAVING {agg}({agg_col}) > {random.randint(1, 500)}"
sql += f" ORDER BY agg_val DESC LIMIT {random.choice([10, 20, 50])}"
return sql, "high", random.randint(3600, 7200), random.randint(55, 80)
def _subquery():
"""Tier 4: Subquery."""
t1, t2 = random.sample(TABLES, 2)
cols = ", ".join(_rand_cols(t1, 2))
sub_agg = random.choice(AGG_FUNCS)
op = random.choice([">", "<", ">="])
sql = (
f"SELECT {cols} FROM {t1} "
f"WHERE id IN (SELECT {t1[:-1]}_id FROM {t2} "
f"WHERE {_rand_where(t2[:1])})"
)
return sql, "high", random.randint(3600, 5400), random.randint(50, 75)
def _correlated_subquery():
"""Tier 5: Correlated subquery."""
t1, t2 = random.sample(TABLES, 2)
agg = random.choice(AGG_FUNCS)
sql = (
f"SELECT a.id, a.name, "
f"(SELECT {agg}(b.id) FROM {t2} b WHERE b.{t1[:-1]}_id = a.id) as sub_val "
f"FROM {t1} a WHERE {_rand_where('a')}"
)
return sql, "high", random.randint(3600, 7200), random.randint(60, 85)
def _cte_query():
"""Tier 5: Common Table Expression (WITH)."""
t1, t2 = random.sample(TABLES, 2)
agg = random.choice(AGG_FUNCS)
sql = (
f"WITH cte AS ("
f"SELECT {t1[:-1]}_id, {agg}(id) as cnt FROM {t2} GROUP BY {t1[:-1]}_id"
f") SELECT a.id, a.name, c.cnt "
f"FROM {t1} a JOIN cte c ON a.id = c.{t1[:-1]}_id "
f"WHERE c.cnt > {random.randint(1, 50)} "
f"ORDER BY c.cnt DESC"
)
return sql, "high", random.randint(3600, 7200), random.randint(65, 85)
def _window_query():
"""Tier 5: Window function."""
t = _rand_table()
wfunc = random.choice(["ROW_NUMBER()", "RANK()", "DENSE_RANK()"])
partition_col = random.choice(COLUMNS.get(t, ["id"])[:2])
order_col = random.choice(["id", "created_at"])
sql = (
f"SELECT id, {partition_col}, "
f"{wfunc} OVER (PARTITION BY {partition_col} ORDER BY {order_col} DESC) as rn "
f"FROM {t} WHERE {_rand_where(t[:1])}"
)
return sql, "high", random.randint(3600, 7200), random.randint(55, 80)
def _union_query():
"""Tier 4: UNION query."""
t1, t2 = random.sample(TABLES, 2)
sql = (
f"SELECT id, name FROM {t1} WHERE {_rand_where(t1[:1])} "
f"UNION ALL "
f"SELECT id, name FROM {t2} WHERE {_rand_where(t2[:1])}"
)
return sql, "medium", random.randint(1800, 3600), random.randint(35, 55)
def _complex_analytics():
"""Tier 6: Complex analytics query."""
t1, t2, t3 = random.sample(TABLES, 3)
agg1 = random.choice(AGG_FUNCS)
agg2 = random.choice(AGG_FUNCS)
sql = (
f"WITH monthly AS ("
f"SELECT a.id, a.name, {agg1}(b.id) as cnt, {agg2}(c.id) as total "
f"FROM {t1} a "
f"LEFT JOIN {t2} b ON a.id = b.{t1[:-1]}_id "
f"LEFT JOIN {t3} c ON b.id = c.{t2[:-1]}_id "
f"WHERE a.created_at >= '2024-01-01' "
f"GROUP BY a.id, a.name "
f"HAVING {agg1}(b.id) > {random.randint(1, 20)}"
f") SELECT name, cnt, total, "
f"RANK() OVER (ORDER BY cnt DESC) as rank "
f"FROM monthly ORDER BY rank LIMIT 100"
)
return sql, "high", random.randint(5400, 7200), random.randint(80, 100)
def _insert_query():
"""INSERT — not cacheable."""
t = _rand_table()
cols = _rand_cols(t, 3)
vals = ", ".join(
f"{random.randint(1, 9999)}" if c in ("id", "age") else f"'val_{random.randint(1,99)}'"
for c in cols
)
sql = f"INSERT INTO {t} ({', '.join(cols)}) VALUES ({vals})"
return sql, "low", 0, random.randint(5, 15)
def _update_query():
"""UPDATE — not cacheable."""
t = _rand_table()
col = random.choice(COLUMNS.get(t, ["name"])[1:])
sql = f"UPDATE {t} SET {col} = 'updated' WHERE {_rand_where(t[:1])}"
return sql, "low", 0, random.randint(5, 15)
def _delete_query():
"""DELETE — not cacheable."""
t = _rand_table()
sql = f"DELETE FROM {t} WHERE {_rand_where(t[:1])}"
return sql, "low", 0, random.randint(5, 10)
def _exists_query():
"""Tier 4: EXISTS subquery."""
t1, t2 = random.sample(TABLES, 2)
cols = ", ".join(_rand_cols(t1, 2))
sql = (
f"SELECT {cols} FROM {t1} a "
f"WHERE EXISTS (SELECT 1 FROM {t2} b WHERE b.{t1[:-1]}_id = a.id "
f"AND {_rand_where('b')})"
)
return sql, "high", random.randint(3600, 5400), random.randint(50, 70)
def _case_query():
"""Tier 3: CASE expression."""
t = _rand_table()
sql = (
f"SELECT id, "
f"CASE WHEN status = 'active' THEN 'A' "
f"WHEN status = 'pending' THEN 'P' "
f"ELSE 'X' END as status_code, "
f"name FROM {t} WHERE {_rand_where(t[:1])}"
)
return sql, "medium", random.randint(1800, 3600), random.randint(25, 40)
def _distinct_query():
"""Tier 2: SELECT DISTINCT."""
t = _rand_table()
col = random.choice(COLUMNS.get(t, ["name"])[:3])
sql = f"SELECT DISTINCT {col} FROM {t} WHERE {_rand_where(t[:1])} ORDER BY {col}"
return sql, "medium", random.randint(1200, 2400), random.randint(20, 35)
# ---------------------------------------------------------------------------
# Generator registry
# ---------------------------------------------------------------------------
GENERATORS = [
(_simple_select, 15),
(_select_with_order, 10),
(_single_join, 12),
(_multi_join, 8),
(_aggregate_query, 10),
(_aggregate_join, 8),
(_subquery, 7),
(_correlated_subquery, 5),
(_cte_query, 5),
(_window_query, 5),
(_union_query, 4),
(_complex_analytics, 3),
(_insert_query, 8),
(_update_query, 5),
(_delete_query, 4),
(_exists_query, 5),
(_case_query, 4),
(_distinct_query, 4),
]
# Build weighted list
_WEIGHTED = []
for gen, weight in GENERATORS:
_WEIGHTED.extend([gen] * weight)
def generate_sample():
"""Generate one (sql, cache_benefit, ttl, complexity) sample."""
gen = random.choice(_WEIGHTED)
sql, benefit, ttl, complexity = gen()
# Add slight noise to TTL and complexity
ttl = max(0, ttl + random.randint(-60, 60))
complexity = max(1, min(100, complexity + random.randint(-3, 3)))
return sql, benefit, ttl, complexity
def generate_dataset(n: int = 5000, seed: int = 42):
"""
Generate a training dataset of n samples.
Returns:
queries: list[str]
benefits: list[str] — "low", "medium", "high"
ttls: list[int] — recommended TTL in seconds
complexities: list[int] — 1-100 complexity score
"""
random.seed(seed)
queries, benefits, ttls, complexities = [], [], [], []
for _ in range(n):
sql, benefit, ttl, complexity = generate_sample()
queries.append(sql)
benefits.append(benefit)
ttls.append(ttl)
complexities.append(complexity)
return queries, benefits, ttls, complexities