phishnet-api / api.py
IRBXrocket's picture
Add PhishNet API
dd272ec
"""
PhishNet β€” FastAPI Backend
Serves phishing detection as a REST API.
Downloads BERT from HuggingFace on first startup.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import joblib
import numpy as np
import pandas as pd
import re
import os
import time
from urllib.parse import urlparse
app = FastAPI(title="PhishNet API", version="1.0.0")
# ── CORS β€” allow Chrome extension to call this API ────────────────────────────
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # extension can call from any origin
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ── DEVICE ────────────────────────────────────────────────────────────────────
device = torch.device("cpu")
# ── MODEL LOADING ─────────────────────────────────────────────────────────────
# BERT is downloaded from HuggingFace Hub at startup if not cached locally.
# .joblib files are loaded from disk (included in the repo).
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "YOUR_HF_USERNAME/phishnet-bert")
MODEL_CACHE = "./bert_phishing_5k_benchmark"
print(f"Loading BERT from: {HF_MODEL_ID if not os.path.isdir(MODEL_CACHE) else MODEL_CACHE}")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID
)
dl_model = AutoModelForSequenceClassification.from_pretrained(
MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID
).to(device)
dl_model.eval()
rf_model = joblib.load("random_forest_model.joblib")
pca = joblib.load("pca_compressor.joblib")
print("βœ… All models loaded")
# ── ATTACK TIPS ───────────────────────────────────────────────────────────────
ATTACK_TIPS = {
"url_length" : ("Long URL Obfuscation", "Attackers pad URLs to hide the real domain."),
"count_at" : ("@ Symbol Trick", "Browsers ignore everything before @ in a URL."),
"count_hyphen" : ("Hyphen Stuffing", "Hyphens make fake domains look legitimate."),
"count_double_slash" : ("Double-Slash Redirect", "Extra slashes bypass simple URL filters."),
"count_percent" : ("Percent Encoding", "Encoding hides the true URL from scanners."),
"count_digits" : ("Digit Substitution", "Replacing letters with digits β€” paypa1 = paypal."),
"count_dots" : ("Subdomain Abuse", "Extra dots create deep subdomains to fool users."),
"digit_letter_ratio" : ("High Digit Density", "Legitimate domains rarely have many digits."),
"has_ip" : ("Raw IP Address", "Legitimate sites use domain names, not raw IPs."),
"is_shortened" : ("URL Shortener", "Shorteners hide the real malicious destination."),
}
FEATURE_META = {
"url_length" : {"label": "URL Length", "threshold": 75, "high_is_bad": True},
"count_at" : {"label": "@ Symbol", "threshold": 0, "high_is_bad": True},
"count_hyphen" : {"label": "Hyphens", "threshold": 3, "high_is_bad": True},
"count_double_slash" : {"label": "Double Slashes", "threshold": 0, "high_is_bad": True},
"count_percent" : {"label": "% Encoding", "threshold": 1, "high_is_bad": True},
"count_digits" : {"label": "Digit Count", "threshold": 8, "high_is_bad": True},
"count_dots" : {"label": "Dot Count", "threshold": 4, "high_is_bad": True},
"digit_letter_ratio" : {"label": "Digit/Letter Ratio", "threshold": 0.15, "high_is_bad": True},
"has_ip" : {"label": "Raw IP Address", "threshold": 0, "high_is_bad": True},
"is_shortened" : {"label": "URL Shortener", "threshold": 0, "high_is_bad": True},
}
# ── LEXICAL EXTRACTOR ─────────────────────────────────────────────────────────
def extract_url_math(url: str) -> dict:
url_for_parse = url if url.startswith(("http://", "https://")) else "http://" + url
domain = urlparse(url_for_parse).netloc
num_digits = sum(c.isdigit() for c in url)
num_letters = sum(c.isalpha() for c in url)
return {
"url_length" : len(url),
"count_at" : url.count("@"),
"count_hyphen" : url.count("-"),
"count_double_slash" : max(0, url.count("//") - 1),
"count_percent" : url.count("%"),
"count_digits" : num_digits,
"count_dots" : url.count("."),
"digit_letter_ratio" : num_digits / num_letters if num_letters > 0 else num_digits,
"has_ip" : 1 if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", domain) else 0,
"is_shortened" : 1 if any(s in domain for s in [
"bit.ly", "qrco.de", "t.co", "tinyurl.com",
"l.ead.me", "goo.gl", "ow.ly"]) else 0,
}
# ── REQUEST / RESPONSE MODELS ─────────────────────────────────────────────────
class ScanRequest(BaseModel):
url: str
class FeatureResult(BaseModel):
label: str
value: float
suspicious: bool
tip_title: str
tip_body: str
class ScanResponse(BaseModel):
url: str
verdict: str # "phishing" | "safe"
threat_score: float # 0.0 – 1.0
bert_score: float
rf_score: float
models_agree: bool
triggered_count: int
features: list[FeatureResult]
inference_ms: float
# ── ENDPOINTS ─────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {"status": "PhishNet API running", "version": "1.0.0"}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/scan", response_model=ScanResponse)
def scan(req: ScanRequest):
t0 = time.time()
url = req.url.strip()
if not url.startswith(("http://", "https://")):
url = "https://" + url
# ── Tokenise ──────────────────────────────────────────────────────────────
inputs = tokenizer(
url, return_tensors="pt",
truncation=True, padding=True, max_length=128
).to(device)
# ── BERT inference ────────────────────────────────────────────────────────
with torch.no_grad():
outputs = dl_model(**inputs, output_hidden_states=True)
bert_probs = F.softmax(outputs.logits, dim=1).squeeze().tolist()
bert_score = float(bert_probs[1])
# ── Hybrid: CLS β†’ PCA β†’ RF ────────────────────────────────────────────────
cls_emb = outputs.hidden_states[-1][:, 0, :].cpu().numpy()
pca_feats = pca.transform(cls_emb)
lex = extract_url_math(url)
feat_dict = {f"pca_feature_{i}": pca_feats[0][i] for i in range(20)}
feat_dict.update(lex)
df = pd.DataFrame([feat_dict])
df_aligned = df[rf_model.feature_names_in_]
rf_proba = rf_model.predict_proba(df_aligned)[0]
rf_score = float(rf_proba[1])
# ── Primary verdict: average of both ─────────────────────────────────────
threat_score = (bert_score + rf_score) / 2
verdict = "phishing" if threat_score > 0.5 else "safe"
models_agree = (bert_score > 0.5) == (rf_score > 0.5)
bert_triggered = bert_score > 0.7
# ── Feature breakdown ─────────────────────────────────────────────────────
features = []
triggered = 0
for key, meta in FEATURE_META.items():
val = lex[key]
is_sus = val > meta["threshold"] if meta["high_is_bad"] else val < meta["threshold"]
if is_sus:
triggered += 1
tip_title, tip_body = ATTACK_TIPS.get(key, ("", ""))
features.append(FeatureResult(
label = meta["label"],
value = float(val),
suspicious = is_sus,
tip_title = tip_title if is_sus else "",
tip_body = tip_body if is_sus else "",
))
# If no lexical features fired but BERT is highly confident,
# add a semantic trigger so the extension shows an explanation
if triggered == 0 and bert_triggered:
features.append(FeatureResult(
label = "BERT Semantic Pattern",
value = float(bert_score),
suspicious = True,
tip_title = "Semantic Phishing Pattern",
tip_body = "BERT detected token patterns in this URL that "
"strongly match known phishing page structures, "
"even though individual lexical features appear normal.",
))
triggered = 1
elapsed = (time.time() - t0) * 1000
return ScanResponse(
url = url,
verdict = verdict,
threat_score = round(threat_score, 4),
bert_score = round(bert_score, 4),
rf_score = round(rf_score, 4),
models_agree = models_agree,
triggered_count = triggered,
features = features,
inference_ms = round(elapsed, 1),
)