Spaces:
Sleeping
Sleeping
File size: 10,474 Bytes
dd272ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | """
PhishNet β FastAPI Backend
Serves phishing detection as a REST API.
Downloads BERT from HuggingFace on first startup.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import joblib
import numpy as np
import pandas as pd
import re
import os
import time
from urllib.parse import urlparse
app = FastAPI(title="PhishNet API", version="1.0.0")
# ββ CORS β allow Chrome extension to call this API ββββββββββββββββββββββββββββ
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # extension can call from any origin
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ββ DEVICE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
device = torch.device("cpu")
# ββ MODEL LOADING βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# BERT is downloaded from HuggingFace Hub at startup if not cached locally.
# .joblib files are loaded from disk (included in the repo).
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "YOUR_HF_USERNAME/phishnet-bert")
MODEL_CACHE = "./bert_phishing_5k_benchmark"
print(f"Loading BERT from: {HF_MODEL_ID if not os.path.isdir(MODEL_CACHE) else MODEL_CACHE}")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID
)
dl_model = AutoModelForSequenceClassification.from_pretrained(
MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID
).to(device)
dl_model.eval()
rf_model = joblib.load("random_forest_model.joblib")
pca = joblib.load("pca_compressor.joblib")
print("β
All models loaded")
# ββ ATTACK TIPS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ATTACK_TIPS = {
"url_length" : ("Long URL Obfuscation", "Attackers pad URLs to hide the real domain."),
"count_at" : ("@ Symbol Trick", "Browsers ignore everything before @ in a URL."),
"count_hyphen" : ("Hyphen Stuffing", "Hyphens make fake domains look legitimate."),
"count_double_slash" : ("Double-Slash Redirect", "Extra slashes bypass simple URL filters."),
"count_percent" : ("Percent Encoding", "Encoding hides the true URL from scanners."),
"count_digits" : ("Digit Substitution", "Replacing letters with digits β paypa1 = paypal."),
"count_dots" : ("Subdomain Abuse", "Extra dots create deep subdomains to fool users."),
"digit_letter_ratio" : ("High Digit Density", "Legitimate domains rarely have many digits."),
"has_ip" : ("Raw IP Address", "Legitimate sites use domain names, not raw IPs."),
"is_shortened" : ("URL Shortener", "Shorteners hide the real malicious destination."),
}
FEATURE_META = {
"url_length" : {"label": "URL Length", "threshold": 75, "high_is_bad": True},
"count_at" : {"label": "@ Symbol", "threshold": 0, "high_is_bad": True},
"count_hyphen" : {"label": "Hyphens", "threshold": 3, "high_is_bad": True},
"count_double_slash" : {"label": "Double Slashes", "threshold": 0, "high_is_bad": True},
"count_percent" : {"label": "% Encoding", "threshold": 1, "high_is_bad": True},
"count_digits" : {"label": "Digit Count", "threshold": 8, "high_is_bad": True},
"count_dots" : {"label": "Dot Count", "threshold": 4, "high_is_bad": True},
"digit_letter_ratio" : {"label": "Digit/Letter Ratio", "threshold": 0.15, "high_is_bad": True},
"has_ip" : {"label": "Raw IP Address", "threshold": 0, "high_is_bad": True},
"is_shortened" : {"label": "URL Shortener", "threshold": 0, "high_is_bad": True},
}
# ββ LEXICAL EXTRACTOR βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def extract_url_math(url: str) -> dict:
url_for_parse = url if url.startswith(("http://", "https://")) else "http://" + url
domain = urlparse(url_for_parse).netloc
num_digits = sum(c.isdigit() for c in url)
num_letters = sum(c.isalpha() for c in url)
return {
"url_length" : len(url),
"count_at" : url.count("@"),
"count_hyphen" : url.count("-"),
"count_double_slash" : max(0, url.count("//") - 1),
"count_percent" : url.count("%"),
"count_digits" : num_digits,
"count_dots" : url.count("."),
"digit_letter_ratio" : num_digits / num_letters if num_letters > 0 else num_digits,
"has_ip" : 1 if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", domain) else 0,
"is_shortened" : 1 if any(s in domain for s in [
"bit.ly", "qrco.de", "t.co", "tinyurl.com",
"l.ead.me", "goo.gl", "ow.ly"]) else 0,
}
# ββ REQUEST / RESPONSE MODELS βββββββββββββββββββββββββββββββββββββββββββββββββ
class ScanRequest(BaseModel):
url: str
class FeatureResult(BaseModel):
label: str
value: float
suspicious: bool
tip_title: str
tip_body: str
class ScanResponse(BaseModel):
url: str
verdict: str # "phishing" | "safe"
threat_score: float # 0.0 β 1.0
bert_score: float
rf_score: float
models_agree: bool
triggered_count: int
features: list[FeatureResult]
inference_ms: float
# ββ ENDPOINTS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/")
def root():
return {"status": "PhishNet API running", "version": "1.0.0"}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/scan", response_model=ScanResponse)
def scan(req: ScanRequest):
t0 = time.time()
url = req.url.strip()
if not url.startswith(("http://", "https://")):
url = "https://" + url
# ββ Tokenise ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
inputs = tokenizer(
url, return_tensors="pt",
truncation=True, padding=True, max_length=128
).to(device)
# ββ BERT inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with torch.no_grad():
outputs = dl_model(**inputs, output_hidden_states=True)
bert_probs = F.softmax(outputs.logits, dim=1).squeeze().tolist()
bert_score = float(bert_probs[1])
# ββ Hybrid: CLS β PCA β RF ββββββββββββββββββββββββββββββββββββββββββββββββ
cls_emb = outputs.hidden_states[-1][:, 0, :].cpu().numpy()
pca_feats = pca.transform(cls_emb)
lex = extract_url_math(url)
feat_dict = {f"pca_feature_{i}": pca_feats[0][i] for i in range(20)}
feat_dict.update(lex)
df = pd.DataFrame([feat_dict])
df_aligned = df[rf_model.feature_names_in_]
rf_proba = rf_model.predict_proba(df_aligned)[0]
rf_score = float(rf_proba[1])
# ββ Primary verdict: average of both βββββββββββββββββββββββββββββββββββββ
threat_score = (bert_score + rf_score) / 2
verdict = "phishing" if threat_score > 0.5 else "safe"
models_agree = (bert_score > 0.5) == (rf_score > 0.5)
bert_triggered = bert_score > 0.7
# ββ Feature breakdown βββββββββββββββββββββββββββββββββββββββββββββββββββββ
features = []
triggered = 0
for key, meta in FEATURE_META.items():
val = lex[key]
is_sus = val > meta["threshold"] if meta["high_is_bad"] else val < meta["threshold"]
if is_sus:
triggered += 1
tip_title, tip_body = ATTACK_TIPS.get(key, ("", ""))
features.append(FeatureResult(
label = meta["label"],
value = float(val),
suspicious = is_sus,
tip_title = tip_title if is_sus else "",
tip_body = tip_body if is_sus else "",
))
# If no lexical features fired but BERT is highly confident,
# add a semantic trigger so the extension shows an explanation
if triggered == 0 and bert_triggered:
features.append(FeatureResult(
label = "BERT Semantic Pattern",
value = float(bert_score),
suspicious = True,
tip_title = "Semantic Phishing Pattern",
tip_body = "BERT detected token patterns in this URL that "
"strongly match known phishing page structures, "
"even though individual lexical features appear normal.",
))
triggered = 1
elapsed = (time.time() - t0) * 1000
return ScanResponse(
url = url,
verdict = verdict,
threat_score = round(threat_score, 4),
bert_score = round(bert_score, 4),
rf_score = round(rf_score, 4),
models_agree = models_agree,
triggered_count = triggered,
features = features,
inference_ms = round(elapsed, 1),
) |