Spaces:
Sleeping
Sleeping
| """ | |
| PhishNet β FastAPI Backend | |
| Serves phishing detection as a REST API. | |
| Downloads BERT from HuggingFace on first startup. | |
| """ | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| import re | |
| import os | |
| import time | |
| from urllib.parse import urlparse | |
| app = FastAPI(title="PhishNet API", version="1.0.0") | |
| # ββ CORS β allow Chrome extension to call this API ββββββββββββββββββββββββββββ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # extension can call from any origin | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ DEVICE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cpu") | |
| # ββ MODEL LOADING βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BERT is downloaded from HuggingFace Hub at startup if not cached locally. | |
| # .joblib files are loaded from disk (included in the repo). | |
| HF_MODEL_ID = os.getenv("HF_MODEL_ID", "YOUR_HF_USERNAME/phishnet-bert") | |
| MODEL_CACHE = "./bert_phishing_5k_benchmark" | |
| print(f"Loading BERT from: {HF_MODEL_ID if not os.path.isdir(MODEL_CACHE) else MODEL_CACHE}") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID | |
| ) | |
| dl_model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_CACHE if os.path.isdir(MODEL_CACHE) else HF_MODEL_ID | |
| ).to(device) | |
| dl_model.eval() | |
| rf_model = joblib.load("random_forest_model.joblib") | |
| pca = joblib.load("pca_compressor.joblib") | |
| print("β All models loaded") | |
| # ββ ATTACK TIPS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ATTACK_TIPS = { | |
| "url_length" : ("Long URL Obfuscation", "Attackers pad URLs to hide the real domain."), | |
| "count_at" : ("@ Symbol Trick", "Browsers ignore everything before @ in a URL."), | |
| "count_hyphen" : ("Hyphen Stuffing", "Hyphens make fake domains look legitimate."), | |
| "count_double_slash" : ("Double-Slash Redirect", "Extra slashes bypass simple URL filters."), | |
| "count_percent" : ("Percent Encoding", "Encoding hides the true URL from scanners."), | |
| "count_digits" : ("Digit Substitution", "Replacing letters with digits β paypa1 = paypal."), | |
| "count_dots" : ("Subdomain Abuse", "Extra dots create deep subdomains to fool users."), | |
| "digit_letter_ratio" : ("High Digit Density", "Legitimate domains rarely have many digits."), | |
| "has_ip" : ("Raw IP Address", "Legitimate sites use domain names, not raw IPs."), | |
| "is_shortened" : ("URL Shortener", "Shorteners hide the real malicious destination."), | |
| } | |
| FEATURE_META = { | |
| "url_length" : {"label": "URL Length", "threshold": 75, "high_is_bad": True}, | |
| "count_at" : {"label": "@ Symbol", "threshold": 0, "high_is_bad": True}, | |
| "count_hyphen" : {"label": "Hyphens", "threshold": 3, "high_is_bad": True}, | |
| "count_double_slash" : {"label": "Double Slashes", "threshold": 0, "high_is_bad": True}, | |
| "count_percent" : {"label": "% Encoding", "threshold": 1, "high_is_bad": True}, | |
| "count_digits" : {"label": "Digit Count", "threshold": 8, "high_is_bad": True}, | |
| "count_dots" : {"label": "Dot Count", "threshold": 4, "high_is_bad": True}, | |
| "digit_letter_ratio" : {"label": "Digit/Letter Ratio", "threshold": 0.15, "high_is_bad": True}, | |
| "has_ip" : {"label": "Raw IP Address", "threshold": 0, "high_is_bad": True}, | |
| "is_shortened" : {"label": "URL Shortener", "threshold": 0, "high_is_bad": True}, | |
| } | |
| # ββ LEXICAL EXTRACTOR βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_url_math(url: str) -> dict: | |
| url_for_parse = url if url.startswith(("http://", "https://")) else "http://" + url | |
| domain = urlparse(url_for_parse).netloc | |
| num_digits = sum(c.isdigit() for c in url) | |
| num_letters = sum(c.isalpha() for c in url) | |
| return { | |
| "url_length" : len(url), | |
| "count_at" : url.count("@"), | |
| "count_hyphen" : url.count("-"), | |
| "count_double_slash" : max(0, url.count("//") - 1), | |
| "count_percent" : url.count("%"), | |
| "count_digits" : num_digits, | |
| "count_dots" : url.count("."), | |
| "digit_letter_ratio" : num_digits / num_letters if num_letters > 0 else num_digits, | |
| "has_ip" : 1 if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", domain) else 0, | |
| "is_shortened" : 1 if any(s in domain for s in [ | |
| "bit.ly", "qrco.de", "t.co", "tinyurl.com", | |
| "l.ead.me", "goo.gl", "ow.ly"]) else 0, | |
| } | |
| # ββ REQUEST / RESPONSE MODELS βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ScanRequest(BaseModel): | |
| url: str | |
| class FeatureResult(BaseModel): | |
| label: str | |
| value: float | |
| suspicious: bool | |
| tip_title: str | |
| tip_body: str | |
| class ScanResponse(BaseModel): | |
| url: str | |
| verdict: str # "phishing" | "safe" | |
| threat_score: float # 0.0 β 1.0 | |
| bert_score: float | |
| rf_score: float | |
| models_agree: bool | |
| triggered_count: int | |
| features: list[FeatureResult] | |
| inference_ms: float | |
| # ββ ENDPOINTS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return {"status": "PhishNet API running", "version": "1.0.0"} | |
| def health(): | |
| return {"status": "ok"} | |
| def scan(req: ScanRequest): | |
| t0 = time.time() | |
| url = req.url.strip() | |
| if not url.startswith(("http://", "https://")): | |
| url = "https://" + url | |
| # ββ Tokenise ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| inputs = tokenizer( | |
| url, return_tensors="pt", | |
| truncation=True, padding=True, max_length=128 | |
| ).to(device) | |
| # ββ BERT inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with torch.no_grad(): | |
| outputs = dl_model(**inputs, output_hidden_states=True) | |
| bert_probs = F.softmax(outputs.logits, dim=1).squeeze().tolist() | |
| bert_score = float(bert_probs[1]) | |
| # ββ Hybrid: CLS β PCA β RF ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cls_emb = outputs.hidden_states[-1][:, 0, :].cpu().numpy() | |
| pca_feats = pca.transform(cls_emb) | |
| lex = extract_url_math(url) | |
| feat_dict = {f"pca_feature_{i}": pca_feats[0][i] for i in range(20)} | |
| feat_dict.update(lex) | |
| df = pd.DataFrame([feat_dict]) | |
| df_aligned = df[rf_model.feature_names_in_] | |
| rf_proba = rf_model.predict_proba(df_aligned)[0] | |
| rf_score = float(rf_proba[1]) | |
| # ββ Primary verdict: average of both βββββββββββββββββββββββββββββββββββββ | |
| threat_score = (bert_score + rf_score) / 2 | |
| verdict = "phishing" if threat_score > 0.5 else "safe" | |
| models_agree = (bert_score > 0.5) == (rf_score > 0.5) | |
| bert_triggered = bert_score > 0.7 | |
| # ββ Feature breakdown βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| features = [] | |
| triggered = 0 | |
| for key, meta in FEATURE_META.items(): | |
| val = lex[key] | |
| is_sus = val > meta["threshold"] if meta["high_is_bad"] else val < meta["threshold"] | |
| if is_sus: | |
| triggered += 1 | |
| tip_title, tip_body = ATTACK_TIPS.get(key, ("", "")) | |
| features.append(FeatureResult( | |
| label = meta["label"], | |
| value = float(val), | |
| suspicious = is_sus, | |
| tip_title = tip_title if is_sus else "", | |
| tip_body = tip_body if is_sus else "", | |
| )) | |
| # If no lexical features fired but BERT is highly confident, | |
| # add a semantic trigger so the extension shows an explanation | |
| if triggered == 0 and bert_triggered: | |
| features.append(FeatureResult( | |
| label = "BERT Semantic Pattern", | |
| value = float(bert_score), | |
| suspicious = True, | |
| tip_title = "Semantic Phishing Pattern", | |
| tip_body = "BERT detected token patterns in this URL that " | |
| "strongly match known phishing page structures, " | |
| "even though individual lexical features appear normal.", | |
| )) | |
| triggered = 1 | |
| elapsed = (time.time() - t0) * 1000 | |
| return ScanResponse( | |
| url = url, | |
| verdict = verdict, | |
| threat_score = round(threat_score, 4), | |
| bert_score = round(bert_score, 4), | |
| rf_score = round(rf_score, 4), | |
| models_agree = models_agree, | |
| triggered_count = triggered, | |
| features = features, | |
| inference_ms = round(elapsed, 1), | |
| ) |