""" Python port of query_normalize.c from pg_plan_cache. Replicates the exact normalization and SHA-256 hashing logic so the agent can compute cache keys without needing PostgreSQL. """ import hashlib from enum import Enum, auto class _State(Enum): DEFAULT = auto() IN_SINGLE_QUOTE = auto() IN_DOUBLE_QUOTE = auto() IN_NUMBER = auto() IN_LINE_COMMENT = auto() IN_BLOCK_COMMENT = auto() def normalize_query(query: str) -> str: """ Normalize a SQL query string: - Replace string literals with positional parameters ($N) - Replace numeric literals with positional parameters ($N) - Collapse whitespace to single spaces - Lowercase non-quoted identifiers - Strip comments This mirrors the C implementation in query_normalize.c. """ if not query: return "" buf: list[str] = [] param_num = 0 i = 0 n = len(query) state = _State.DEFAULT last_was_space = False block_comment_depth = 0 while i < n: c = query[i] if state == _State.DEFAULT: # Line comment -- if c == '-' and i + 1 < n and query[i + 1] == '-': state = _State.IN_LINE_COMMENT i += 2 continue # Block comment /* if c == '/' and i + 1 < n and query[i + 1] == '*': state = _State.IN_BLOCK_COMMENT block_comment_depth = 1 i += 2 continue # String literal if c == "'": param_num += 1 buf.append(f"${param_num}") state = _State.IN_SINGLE_QUOTE i += 1 last_was_space = False continue # Quoted identifier if c == '"': state = _State.IN_DOUBLE_QUOTE buf.append(c) i += 1 last_was_space = False continue # Numeric literal if c.isdigit() or (c == '.' and i + 1 < n and query[i + 1].isdigit()): param_num += 1 buf.append(f"${param_num}") state = _State.IN_NUMBER i += 1 last_was_space = False continue # Whitespace if c.isspace(): if not last_was_space and buf: buf.append(' ') last_was_space = True i += 1 continue # Regular character if c.isalpha(): buf.append(c.lower()) else: buf.append(c) last_was_space = False i += 1 elif state == _State.IN_SINGLE_QUOTE: if c == "'" and i + 1 < n and query[i + 1] == "'": i += 2 # escaped quote elif c == "'": state = _State.DEFAULT i += 1 else: i += 1 elif state == _State.IN_DOUBLE_QUOTE: buf.append(c) if c == '"' and not (i + 1 < n and query[i + 1] == '"'): state = _State.DEFAULT elif c == '"': i += 1 if i < n: buf.append(query[i]) i += 1 elif state == _State.IN_NUMBER: if c.isdigit() or c in '.eE+-': i += 1 else: state = _State.DEFAULT elif state == _State.IN_LINE_COMMENT: if c == '\n': state = _State.DEFAULT i += 1 elif state == _State.IN_BLOCK_COMMENT: if c == '/' and i + 1 < n and query[i + 1] == '*': block_comment_depth += 1 i += 2 elif c == '*' and i + 1 < n and query[i + 1] == '/': block_comment_depth -= 1 i += 2 if block_comment_depth == 0: state = _State.DEFAULT else: i += 1 result = ''.join(buf) return result.rstrip(' ') def compute_query_hash(normalized_query: str) -> str: """ Compute SHA-256 hash of the normalized query. Returns a 64-character hex string, matching compute_query_hash() in C. """ if not normalized_query: return "0" * 64 return hashlib.sha256(normalized_query.encode("utf-8")).hexdigest()