IRBXrocket commited on
Commit
dd272ec
Β·
1 Parent(s): ee68788

Add PhishNet API

Browse files
Files changed (5) hide show
  1. Dockerfile +23 -0
  2. api.py +215 -0
  3. pca_compressor.joblib +3 -0
  4. random_forest_model.joblib +3 -0
  5. requirements.txt +10 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ gcc \
8
+ g++ \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first for layer caching
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy app files
16
+ COPY api.py .
17
+ COPY random_forest_model.joblib .
18
+ COPY pca_compressor.joblib .
19
+
20
+ # HuggingFace Spaces runs on port 7860
21
+ EXPOSE 7860
22
+
23
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
api.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PhishNet β€” FastAPI Backend
3
+ Serves phishing detection as a REST API.
4
+ Downloads BERT from HuggingFace on first startup.
5
+ """
6
+
7
+ from fastapi import FastAPI
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
+ import joblib
14
+ import numpy as np
15
+ import pandas as pd
16
+ import re
17
+ import os
18
+ import time
19
+ from urllib.parse import urlparse
20
+
21
+ app = FastAPI(title="PhishNet API", version="1.0.0")
22
+
23
+ # ── CORS β€” allow Chrome extension to call this API ────────────────────────────
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"], # extension can call from any origin
27
+ allow_methods=["GET", "POST"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # ── DEVICE ────────────────────────────────────────────────────────────────────
32
+ device = torch.device("cpu")
33
+
34
+ # ── MODEL LOADING ─────────────────────────────────────────────────────────────
35
+ # BERT is downloaded from HuggingFace Hub at startup if not cached locally.
36
+ # .joblib files are loaded from disk (included in the repo).
37
+
38
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "YOUR_HF_USERNAME/phishnet-bert")
39
+ MODEL_CACHE = "./bert_phishing_5k_benchmark"
40
+
41
+ print(f"Loading BERT from: {HF_MODEL_ID if not os.path.isdir(MODEL_CACHE) else MODEL_CACHE}")
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(
44
+ MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID
45
+ )
46
+ dl_model = AutoModelForSequenceClassification.from_pretrained(
47
+ MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID
48
+ ).to(device)
49
+ dl_model.eval()
50
+
51
+ rf_model = joblib.load("random_forest_model.joblib")
52
+ pca = joblib.load("pca_compressor.joblib")
53
+
54
+ print("βœ… All models loaded")
55
+
56
+ # ── ATTACK TIPS ───────────────────────────────────────────────────────────────
57
+ ATTACK_TIPS = {
58
+ "url_length" : ("Long URL Obfuscation", "Attackers pad URLs to hide the real domain."),
59
+ "count_at" : ("@ Symbol Trick", "Browsers ignore everything before @ in a URL."),
60
+ "count_hyphen" : ("Hyphen Stuffing", "Hyphens make fake domains look legitimate."),
61
+ "count_double_slash" : ("Double-Slash Redirect", "Extra slashes bypass simple URL filters."),
62
+ "count_percent" : ("Percent Encoding", "Encoding hides the true URL from scanners."),
63
+ "count_digits" : ("Digit Substitution", "Replacing letters with digits β€” paypa1 = paypal."),
64
+ "count_dots" : ("Subdomain Abuse", "Extra dots create deep subdomains to fool users."),
65
+ "digit_letter_ratio" : ("High Digit Density", "Legitimate domains rarely have many digits."),
66
+ "has_ip" : ("Raw IP Address", "Legitimate sites use domain names, not raw IPs."),
67
+ "is_shortened" : ("URL Shortener", "Shorteners hide the real malicious destination."),
68
+ }
69
+
70
+ FEATURE_META = {
71
+ "url_length" : {"label": "URL Length", "threshold": 75, "high_is_bad": True},
72
+ "count_at" : {"label": "@ Symbol", "threshold": 0, "high_is_bad": True},
73
+ "count_hyphen" : {"label": "Hyphens", "threshold": 3, "high_is_bad": True},
74
+ "count_double_slash" : {"label": "Double Slashes", "threshold": 0, "high_is_bad": True},
75
+ "count_percent" : {"label": "% Encoding", "threshold": 1, "high_is_bad": True},
76
+ "count_digits" : {"label": "Digit Count", "threshold": 8, "high_is_bad": True},
77
+ "count_dots" : {"label": "Dot Count", "threshold": 4, "high_is_bad": True},
78
+ "digit_letter_ratio" : {"label": "Digit/Letter Ratio", "threshold": 0.15, "high_is_bad": True},
79
+ "has_ip" : {"label": "Raw IP Address", "threshold": 0, "high_is_bad": True},
80
+ "is_shortened" : {"label": "URL Shortener", "threshold": 0, "high_is_bad": True},
81
+ }
82
+
83
+ # ── LEXICAL EXTRACTOR ─────────────────────────────────────────────────────────
84
+ def extract_url_math(url: str) -> dict:
85
+ url_for_parse = url if url.startswith(("http://", "https://")) else "http://" + url
86
+ domain = urlparse(url_for_parse).netloc
87
+ num_digits = sum(c.isdigit() for c in url)
88
+ num_letters = sum(c.isalpha() for c in url)
89
+ return {
90
+ "url_length" : len(url),
91
+ "count_at" : url.count("@"),
92
+ "count_hyphen" : url.count("-"),
93
+ "count_double_slash" : max(0, url.count("//") - 1),
94
+ "count_percent" : url.count("%"),
95
+ "count_digits" : num_digits,
96
+ "count_dots" : url.count("."),
97
+ "digit_letter_ratio" : num_digits / num_letters if num_letters > 0 else num_digits,
98
+ "has_ip" : 1 if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", domain) else 0,
99
+ "is_shortened" : 1 if any(s in domain for s in [
100
+ "bit.ly", "qrco.de", "t.co", "tinyurl.com",
101
+ "l.ead.me", "goo.gl", "ow.ly"]) else 0,
102
+ }
103
+
104
+ # ── REQUEST / RESPONSE MODELS ─────────────────────────────────────────────────
105
+ class ScanRequest(BaseModel):
106
+ url: str
107
+
108
+ class FeatureResult(BaseModel):
109
+ label: str
110
+ value: float
111
+ suspicious: bool
112
+ tip_title: str
113
+ tip_body: str
114
+
115
+ class ScanResponse(BaseModel):
116
+ url: str
117
+ verdict: str # "phishing" | "safe"
118
+ threat_score: float # 0.0 – 1.0
119
+ bert_score: float
120
+ rf_score: float
121
+ models_agree: bool
122
+ triggered_count: int
123
+ features: list[FeatureResult]
124
+ inference_ms: float
125
+
126
+ # ── ENDPOINTS ─────────────────────────────────────────────────────────────────
127
+ @app.get("/")
128
+ def root():
129
+ return {"status": "PhishNet API running", "version": "1.0.0"}
130
+
131
+ @app.get("/health")
132
+ def health():
133
+ return {"status": "ok"}
134
+
135
+ @app.post("/scan", response_model=ScanResponse)
136
+ def scan(req: ScanRequest):
137
+ t0 = time.time()
138
+ url = req.url.strip()
139
+ if not url.startswith(("http://", "https://")):
140
+ url = "https://" + url
141
+
142
+ # ── Tokenise ──────────────────────────────────────────────────────────────
143
+ inputs = tokenizer(
144
+ url, return_tensors="pt",
145
+ truncation=True, padding=True, max_length=128
146
+ ).to(device)
147
+
148
+ # ── BERT inference ────────────────────────────────────────────────────────
149
+ with torch.no_grad():
150
+ outputs = dl_model(**inputs, output_hidden_states=True)
151
+
152
+ bert_probs = F.softmax(outputs.logits, dim=1).squeeze().tolist()
153
+ bert_score = float(bert_probs[1])
154
+
155
+ # ── Hybrid: CLS β†’ PCA β†’ RF ────────────────────────────────────────────────
156
+ cls_emb = outputs.hidden_states[-1][:, 0, :].cpu().numpy()
157
+ pca_feats = pca.transform(cls_emb)
158
+ lex = extract_url_math(url)
159
+ feat_dict = {f"pca_feature_{i}": pca_feats[0][i] for i in range(20)}
160
+ feat_dict.update(lex)
161
+ df = pd.DataFrame([feat_dict])
162
+ df_aligned = df[rf_model.feature_names_in_]
163
+ rf_proba = rf_model.predict_proba(df_aligned)[0]
164
+ rf_score = float(rf_proba[1])
165
+
166
+ # ── Primary verdict: average of both ─────────────────────────────────────
167
+ threat_score = (bert_score + rf_score) / 2
168
+ verdict = "phishing" if threat_score > 0.5 else "safe"
169
+ models_agree = (bert_score > 0.5) == (rf_score > 0.5)
170
+ bert_triggered = bert_score > 0.7
171
+
172
+ # ── Feature breakdown ─────────────────────────────────────────────────────
173
+ features = []
174
+ triggered = 0
175
+ for key, meta in FEATURE_META.items():
176
+ val = lex[key]
177
+ is_sus = val > meta["threshold"] if meta["high_is_bad"] else val < meta["threshold"]
178
+ if is_sus:
179
+ triggered += 1
180
+ tip_title, tip_body = ATTACK_TIPS.get(key, ("", ""))
181
+ features.append(FeatureResult(
182
+ label = meta["label"],
183
+ value = float(val),
184
+ suspicious = is_sus,
185
+ tip_title = tip_title if is_sus else "",
186
+ tip_body = tip_body if is_sus else "",
187
+ ))
188
+
189
+ # If no lexical features fired but BERT is highly confident,
190
+ # add a semantic trigger so the extension shows an explanation
191
+ if triggered == 0 and bert_triggered:
192
+ features.append(FeatureResult(
193
+ label = "BERT Semantic Pattern",
194
+ value = float(bert_score),
195
+ suspicious = True,
196
+ tip_title = "Semantic Phishing Pattern",
197
+ tip_body = "BERT detected token patterns in this URL that "
198
+ "strongly match known phishing page structures, "
199
+ "even though individual lexical features appear normal.",
200
+ ))
201
+ triggered = 1
202
+
203
+ elapsed = (time.time() - t0) * 1000
204
+
205
+ return ScanResponse(
206
+ url = url,
207
+ verdict = verdict,
208
+ threat_score = round(threat_score, 4),
209
+ bert_score = round(bert_score, 4),
210
+ rf_score = round(rf_score, 4),
211
+ models_agree = models_agree,
212
+ triggered_count = triggered,
213
+ features = features,
214
+ inference_ms = round(elapsed, 1),
215
+ )
pca_compressor.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:498deb69c28308b9edb8c5be99660b2ea456b13472cefbb8594e5bf68fada6d1
3
+ size 130455
random_forest_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa4b5d4d2b86a24be8d1fe9acf0029b1d9a0c86db1b29012679032194e7ef073
3
+ size 1689065
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ joblib
6
+ pandas
7
+ numpy
8
+ scikit-learn
9
+ huggingface_hub
10
+ pydantic