File size: 6,300 Bytes
406cec4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
SQL feature extraction for pg_plan_cache models.
Extracts structural features from raw SQL query text to feed into
the Cache Advisor, TTL Recommender, and Complexity Estimator models.
"""
import re
AGGREGATE_FUNCS = re.compile(
r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(",
re.IGNORECASE,
)
WINDOW_FUNCS = re.compile(
r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(",
re.IGNORECASE,
)
JOIN_PATTERN = re.compile(
r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b",
re.IGNORECASE,
)
SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE)
CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE)
UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE)
CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE)
IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE)
LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE)
BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE)
EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE)
HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE)
CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE)
FEATURE_NAMES = [
"query_length",
"query_type", # 0=SELECT, 1=INSERT, 2=UPDATE, 3=DELETE, 4=OTHER
"num_tables",
"num_joins",
"num_conditions",
"num_aggregates",
"num_subqueries",
"num_columns",
"has_distinct",
"has_order_by",
"has_group_by",
"has_having",
"has_limit",
"has_offset",
"has_where",
"has_like",
"has_in_clause",
"has_between",
"has_exists",
"has_window_func",
"has_cte",
"has_union",
"has_case",
"has_cast",
"nesting_depth",
"num_and_or",
"num_string_literals",
"num_numeric_literals",
]
def _count_tables(sql: str) -> int:
"""Estimate the number of tables referenced."""
count = 0
# FROM clause tables
from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL)
if from_match:
from_clause = from_match.group(1)
count += len(re.split(r",", from_clause))
# JOIN tables
count += len(JOIN_PATTERN.findall(sql))
return max(count, 0)
def _count_columns(sql: str) -> int:
"""Estimate the number of columns in SELECT clause."""
match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL)
if not match:
return 0
select_clause = match.group(1).strip()
if select_clause == "*":
return 1
# Split by commas not inside parentheses
depth = 0
count = 1
for ch in select_clause:
if ch == '(':
depth += 1
elif ch == ')':
depth -= 1
elif ch == ',' and depth == 0:
count += 1
return count
def _nesting_depth(sql: str) -> int:
"""Calculate maximum parenthesis nesting depth."""
max_depth = 0
depth = 0
for ch in sql:
if ch == '(':
depth += 1
max_depth = max(max_depth, depth)
elif ch == ')':
depth -= 1
return max_depth
def extract_features(sql: str) -> list[float]:
"""
Extract a fixed-length feature vector from a SQL query string.
Returns a list of floats matching FEATURE_NAMES ordering.
"""
sql = sql.strip()
upper = sql.upper().lstrip()
# Query type
if upper.startswith("SELECT"):
qtype = 0
elif upper.startswith("INSERT"):
qtype = 1
elif upper.startswith("UPDATE"):
qtype = 2
elif upper.startswith("DELETE"):
qtype = 3
else:
qtype = 4
num_joins = len(JOIN_PATTERN.findall(sql))
num_aggs = len(AGGREGATE_FUNCS.findall(sql))
num_subqueries = len(SUBQUERY_PATTERN.findall(sql))
num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE))
num_string_lits = len(re.findall(r"'[^']*'", sql))
num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql))
features = [
float(len(sql)), # query_length
float(qtype), # query_type
float(_count_tables(sql)), # num_tables
float(num_joins), # num_joins
float(num_conditions), # num_conditions
float(num_aggs), # num_aggregates
float(num_subqueries), # num_subqueries
float(_count_columns(sql)), # num_columns
float(bool(re.search(r"\bdistinct\b", sql, re.I))), # has_distinct
float(bool(re.search(r"\border\s+by\b", sql, re.I))), # has_order_by
float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))), # has_group_by
float(bool(HAVING_PATTERN.search(sql))), # has_having
float(bool(re.search(r"\blimit\b", sql, re.I))), # has_limit
float(bool(re.search(r"\boffset\b", sql, re.I))), # has_offset
float(bool(re.search(r"\bwhere\b", sql, re.I))), # has_where
float(bool(LIKE_PATTERN.search(sql))), # has_like
float(bool(IN_PATTERN.search(sql))), # has_in_clause
float(bool(BETWEEN_PATTERN.search(sql))), # has_between
float(bool(EXISTS_PATTERN.search(sql))), # has_exists
float(bool(WINDOW_FUNCS.search(sql))), # has_window_func
float(bool(CTE_PATTERN.search(sql))), # has_cte
float(bool(UNION_PATTERN.search(sql))), # has_union
float(bool(CASE_PATTERN.search(sql))), # has_case
float(bool(CAST_PATTERN.search(sql))), # has_cast
float(_nesting_depth(sql)), # nesting_depth
float(num_conditions), # num_and_or
float(num_string_lits), # num_string_literals
float(num_numeric_lits), # num_numeric_literals
]
return features
|