File size: 6,300 Bytes
406cec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
SQL feature extraction for pg_plan_cache models.

Extracts structural features from raw SQL query text to feed into
the Cache Advisor, TTL Recommender, and Complexity Estimator models.
"""

import re


AGGREGATE_FUNCS = re.compile(
    r"\b(count|sum|avg|min|max|array_agg|string_agg|bool_and|bool_or|jsonb_agg)\s*\(",
    re.IGNORECASE,
)
WINDOW_FUNCS = re.compile(
    r"\b(row_number|rank|dense_rank|ntile|lag|lead|first_value|last_value|nth_value)\s*\(",
    re.IGNORECASE,
)
JOIN_PATTERN = re.compile(
    r"\b(inner\s+join|left\s+join|right\s+join|full\s+join|cross\s+join|join)\b",
    re.IGNORECASE,
)
SUBQUERY_PATTERN = re.compile(r"\(\s*select\b", re.IGNORECASE)
CTE_PATTERN = re.compile(r"\bwith\s+\w+\s+as\s*\(", re.IGNORECASE)
UNION_PATTERN = re.compile(r"\b(union|intersect|except)\b", re.IGNORECASE)
CASE_PATTERN = re.compile(r"\bcase\b", re.IGNORECASE)
IN_PATTERN = re.compile(r"\bin\s*\(", re.IGNORECASE)
LIKE_PATTERN = re.compile(r"\b(like|ilike)\b", re.IGNORECASE)
BETWEEN_PATTERN = re.compile(r"\bbetween\b", re.IGNORECASE)
EXISTS_PATTERN = re.compile(r"\bexists\s*\(", re.IGNORECASE)
HAVING_PATTERN = re.compile(r"\bhaving\b", re.IGNORECASE)
CAST_PATTERN = re.compile(r"\b(cast|::)\b", re.IGNORECASE)

FEATURE_NAMES = [
    "query_length",
    "query_type",           # 0=SELECT, 1=INSERT, 2=UPDATE, 3=DELETE, 4=OTHER
    "num_tables",
    "num_joins",
    "num_conditions",
    "num_aggregates",
    "num_subqueries",
    "num_columns",
    "has_distinct",
    "has_order_by",
    "has_group_by",
    "has_having",
    "has_limit",
    "has_offset",
    "has_where",
    "has_like",
    "has_in_clause",
    "has_between",
    "has_exists",
    "has_window_func",
    "has_cte",
    "has_union",
    "has_case",
    "has_cast",
    "nesting_depth",
    "num_and_or",
    "num_string_literals",
    "num_numeric_literals",
]


def _count_tables(sql: str) -> int:
    """Estimate the number of tables referenced."""
    count = 0
    # FROM clause tables
    from_match = re.search(r"\bfrom\s+(.+?)(?:\bwhere\b|\bjoin\b|\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)", sql, re.IGNORECASE | re.DOTALL)
    if from_match:
        from_clause = from_match.group(1)
        count += len(re.split(r",", from_clause))
    # JOIN tables
    count += len(JOIN_PATTERN.findall(sql))
    return max(count, 0)


def _count_columns(sql: str) -> int:
    """Estimate the number of columns in SELECT clause."""
    match = re.search(r"\bselect\s+(.*?)\bfrom\b", sql, re.IGNORECASE | re.DOTALL)
    if not match:
        return 0
    select_clause = match.group(1).strip()
    if select_clause == "*":
        return 1
    # Split by commas not inside parentheses
    depth = 0
    count = 1
    for ch in select_clause:
        if ch == '(':
            depth += 1
        elif ch == ')':
            depth -= 1
        elif ch == ',' and depth == 0:
            count += 1
    return count


def _nesting_depth(sql: str) -> int:
    """Calculate maximum parenthesis nesting depth."""
    max_depth = 0
    depth = 0
    for ch in sql:
        if ch == '(':
            depth += 1
            max_depth = max(max_depth, depth)
        elif ch == ')':
            depth -= 1
    return max_depth


def extract_features(sql: str) -> list[float]:
    """
    Extract a fixed-length feature vector from a SQL query string.

    Returns a list of floats matching FEATURE_NAMES ordering.
    """
    sql = sql.strip()
    upper = sql.upper().lstrip()

    # Query type
    if upper.startswith("SELECT"):
        qtype = 0
    elif upper.startswith("INSERT"):
        qtype = 1
    elif upper.startswith("UPDATE"):
        qtype = 2
    elif upper.startswith("DELETE"):
        qtype = 3
    else:
        qtype = 4

    num_joins = len(JOIN_PATTERN.findall(sql))
    num_aggs = len(AGGREGATE_FUNCS.findall(sql))
    num_subqueries = len(SUBQUERY_PATTERN.findall(sql))
    num_conditions = len(re.findall(r"\b(and|or)\b", sql, re.IGNORECASE))
    num_string_lits = len(re.findall(r"'[^']*'", sql))
    num_numeric_lits = len(re.findall(r"\b\d+(?:\.\d+)?\b", sql))

    features = [
        float(len(sql)),                                        # query_length
        float(qtype),                                           # query_type
        float(_count_tables(sql)),                              # num_tables
        float(num_joins),                                       # num_joins
        float(num_conditions),                                  # num_conditions
        float(num_aggs),                                        # num_aggregates
        float(num_subqueries),                                  # num_subqueries
        float(_count_columns(sql)),                             # num_columns
        float(bool(re.search(r"\bdistinct\b", sql, re.I))),    # has_distinct
        float(bool(re.search(r"\border\s+by\b", sql, re.I))),  # has_order_by
        float(bool(re.search(r"\bgroup\s+by\b", sql, re.I))),  # has_group_by
        float(bool(HAVING_PATTERN.search(sql))),                # has_having
        float(bool(re.search(r"\blimit\b", sql, re.I))),       # has_limit
        float(bool(re.search(r"\boffset\b", sql, re.I))),      # has_offset
        float(bool(re.search(r"\bwhere\b", sql, re.I))),       # has_where
        float(bool(LIKE_PATTERN.search(sql))),                  # has_like
        float(bool(IN_PATTERN.search(sql))),                    # has_in_clause
        float(bool(BETWEEN_PATTERN.search(sql))),               # has_between
        float(bool(EXISTS_PATTERN.search(sql))),                # has_exists
        float(bool(WINDOW_FUNCS.search(sql))),                  # has_window_func
        float(bool(CTE_PATTERN.search(sql))),                   # has_cte
        float(bool(UNION_PATTERN.search(sql))),                 # has_union
        float(bool(CASE_PATTERN.search(sql))),                  # has_case
        float(bool(CAST_PATTERN.search(sql))),                  # has_cast
        float(_nesting_depth(sql)),                             # nesting_depth
        float(num_conditions),                                  # num_and_or
        float(num_string_lits),                                 # num_string_literals
        float(num_numeric_lits),                                # num_numeric_literals
    ]

    return features