File size: 4,362 Bytes
406cec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Inference API for pg_plan_cache models.

Loads trained models and provides prediction functions for:
  1. Cache benefit   (high / medium / low)
  2. Recommended TTL (seconds)
  3. Complexity score (1-100)
"""

import os
import json
import joblib
import numpy as np
from features import extract_features, FEATURE_NAMES

MODEL_DIR = os.path.join(os.path.dirname(__file__), "trained")

_cache_advisor = None
_ttl_recommender = None
_complexity_estimator = None
_label_encoder = None
_loaded = False


def _load_models():
    """Lazy-load all models from disk."""
    global _cache_advisor, _ttl_recommender, _complexity_estimator, _label_encoder, _loaded
    if _loaded:
        return

    _cache_advisor = joblib.load(os.path.join(MODEL_DIR, "cache_advisor.joblib"))
    _ttl_recommender = joblib.load(os.path.join(MODEL_DIR, "ttl_recommender.joblib"))
    _complexity_estimator = joblib.load(os.path.join(MODEL_DIR, "complexity_estimator.joblib"))
    _label_encoder = joblib.load(os.path.join(MODEL_DIR, "label_encoder.joblib"))
    _loaded = True


def predict(sql: str) -> dict:
    """
    Run all three models on a SQL query.

    Returns:
        {
            "query": str,
            "cache_benefit": "high" | "medium" | "low",
            "cache_benefit_probabilities": {"high": 0.8, "medium": 0.15, "low": 0.05},
            "recommended_ttl": int,          # seconds
            "ttl_human": str,                # e.g. "1h 0m"
            "complexity_score": int,          # 1-100
            "complexity_label": str,          # "simple" | "moderate" | "complex" | "very complex"
            "features": {name: value, ...},
        }
    """
    _load_models()

    features = extract_features(sql)
    X = np.array([features])

    # Cache advisor
    benefit_idx = _cache_advisor.predict(X)[0]
    benefit_label = _label_encoder.inverse_transform([benefit_idx])[0]
    benefit_probs = _cache_advisor.predict_proba(X)[0]
    prob_dict = {
        _label_encoder.inverse_transform([i])[0]: round(float(p), 4)
        for i, p in enumerate(benefit_probs)
    }

    # TTL recommender
    ttl_raw = _ttl_recommender.predict(X)[0]
    ttl = max(0, int(round(ttl_raw)))
    hours, mins = divmod(ttl // 60, 60)
    ttl_human = f"{hours}h {mins}m" if hours else f"{mins}m"

    # Complexity estimator
    cplx_raw = _complexity_estimator.predict(X)[0]
    cplx = max(1, min(100, int(round(cplx_raw))))
    if cplx <= 20:
        cplx_label = "simple"
    elif cplx <= 45:
        cplx_label = "moderate"
    elif cplx <= 75:
        cplx_label = "complex"
    else:
        cplx_label = "very complex"

    return {
        "query": sql,
        "cache_benefit": benefit_label,
        "cache_benefit_probabilities": prob_dict,
        "recommended_ttl": ttl,
        "ttl_human": ttl_human,
        "complexity_score": cplx,
        "complexity_label": cplx_label,
        "features": dict(zip(FEATURE_NAMES, features)),
    }


def predict_batch(queries: list[str]) -> list[dict]:
    """Run predictions on multiple queries."""
    return [predict(q) for q in queries]


def format_prediction(result: dict) -> str:
    """Format a prediction result as a readable string."""
    lines = [
        f"  Query:       {result['query'][:100]}{'...' if len(result['query']) > 100 else ''}",
        f"  Cache Benefit: {result['cache_benefit'].upper()}",
        f"    Probabilities: {result['cache_benefit_probabilities']}",
        f"  Recommended TTL: {result['recommended_ttl']}s ({result['ttl_human']})",
        f"  Complexity:  {result['complexity_score']}/100 ({result['complexity_label']})",
    ]
    return "\n".join(lines)


def get_model_info() -> dict:
    """Return model metadata."""
    meta_path = os.path.join(MODEL_DIR, "metadata.json")
    if os.path.exists(meta_path):
        with open(meta_path) as f:
            return json.load(f)
    return {"error": "metadata.json not found. Run train.py first."}


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("Usage: python predict.py \"SELECT * FROM users WHERE id = 42\"")
        sys.exit(1)

    sql = " ".join(sys.argv[1:])
    result = predict(sql)
    print(format_prediction(result))