Spaces:
Running
Running
Apurv commited on
Commit ·
b8630cb
0
Parent(s):
Deploying AegisAI Hackathon Backend
Browse files- .gitattributes +1 -0
- agents/__init__.py +5 -0
- agents/agent1_external.py +168 -0
- agents/agent2_content.py +281 -0
- agents/agent3_synthesizer.py +236 -0
- agents/agent4_prompt.py +117 -0
- app.py +120 -0
- models/model_new.pkl +3 -0
- models/phishing_new.pkl +3 -0
- models/vectorizer_new.pkl +3 -0
- models/vectorizerurl_new.pkl +3 -0
- requirements.txt +8 -0
- utils/__init__.py +3 -0
- utils/preprocessor.py +40 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .agent1_external import ExternalAnalysisAgent
|
| 2 |
+
from .agent2_content import ContentAnalysisAgent
|
| 3 |
+
from .agent3_synthesizer import SynthesizerAgent
|
| 4 |
+
|
| 5 |
+
__all__ = ['ExternalAnalysisAgent', 'ContentAnalysisAgent', 'SynthesizerAgent']
|
agents/agent1_external.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
+
from urllib.parse import urlparse
|
| 6 |
+
from difflib import SequenceMatcher
|
| 7 |
+
import os
|
| 8 |
+
import pickle
|
| 9 |
+
|
| 10 |
+
class ExternalAnalysisAgent:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
print("Loading External Analysis Agent...")
|
| 13 |
+
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 14 |
+
|
| 15 |
+
# Load pickle models for URL analysis
|
| 16 |
+
model_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'models')
|
| 17 |
+
try:
|
| 18 |
+
with open(os.path.join(model_dir, 'phishing_new.pkl'), 'rb') as f:
|
| 19 |
+
self.url_ml_model = pickle.load(f)
|
| 20 |
+
with open(os.path.join(model_dir, 'vectorizerurl_new.pkl'), 'rb') as f:
|
| 21 |
+
self.url_vectorizer = pickle.load(f)
|
| 22 |
+
self.has_url_ml = True
|
| 23 |
+
print("Successfully loaded URL ML models.")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Failed to load URL ML models: {e}")
|
| 26 |
+
self.has_url_ml = False
|
| 27 |
+
|
| 28 |
+
self.phishing_patterns = [
|
| 29 |
+
"verify your account immediately",
|
| 30 |
+
"suspicious activity detected",
|
| 31 |
+
"click here to confirm",
|
| 32 |
+
"your account will be suspended",
|
| 33 |
+
"update your payment information",
|
| 34 |
+
"unusual sign-in attempt",
|
| 35 |
+
"secure your account now",
|
| 36 |
+
"limited time offer",
|
| 37 |
+
"you have won a prize",
|
| 38 |
+
"inheritance money transfer"
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
self.suspicious_tlds = ['.xyz', '.top', '.club', '.online', '.site', '.win', '.bid']
|
| 42 |
+
|
| 43 |
+
self.legitimate_domains = ['google.com', 'microsoft.com', 'amazon.com', 'paypal.com', 'apple.com']
|
| 44 |
+
|
| 45 |
+
self.pattern_embeddings = self.model.encode(self.phishing_patterns)
|
| 46 |
+
print("External Analysis Agent loaded successfully!")
|
| 47 |
+
|
| 48 |
+
def analyze_url_risk(self, url):
|
| 49 |
+
"""Analyze URL for suspicious patterns"""
|
| 50 |
+
risk_score = 0.0
|
| 51 |
+
reasons = []
|
| 52 |
+
|
| 53 |
+
for tld in self.suspicious_tlds:
|
| 54 |
+
if url.lower().endswith(tld) or tld in url.lower():
|
| 55 |
+
risk_score += 0.3
|
| 56 |
+
reasons.append(f"Suspicious TLD: {tld}")
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
if re.search(r'\d+\.\d+\.\d+\.\d+', url):
|
| 60 |
+
risk_score += 0.4
|
| 61 |
+
reasons.append("IP address used instead of domain name")
|
| 62 |
+
|
| 63 |
+
if url.count('.') > 3:
|
| 64 |
+
risk_score += 0.2
|
| 65 |
+
reasons.append("Excessive subdomains")
|
| 66 |
+
|
| 67 |
+
shortening_services = ['bit.ly', 'tinyurl', 'goo.gl', 'ow.ly', 'tiny.cc']
|
| 68 |
+
for service in shortening_services:
|
| 69 |
+
if service in url.lower():
|
| 70 |
+
risk_score += 0.3
|
| 71 |
+
reasons.append(f"URL shortening service detected: {service}")
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
suspicious_keywords = ['login', 'signin', 'verify', 'account', 'secure', 'update', 'confirm']
|
| 75 |
+
for keyword in suspicious_keywords:
|
| 76 |
+
if keyword in url.lower():
|
| 77 |
+
risk_score += 0.1
|
| 78 |
+
reasons.append(f"Suspicious keyword in URL: '{keyword}'")
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
domain_similarity = self.check_domain_similarity(url)
|
| 82 |
+
if domain_similarity > 0.7:
|
| 83 |
+
risk_score += 0.3
|
| 84 |
+
reasons.append("Domain similar to legitimate brand")
|
| 85 |
+
|
| 86 |
+
url_ml_prob = 0.0
|
| 87 |
+
if self.has_url_ml:
|
| 88 |
+
try:
|
| 89 |
+
features = self.url_vectorizer.transform([url])
|
| 90 |
+
# phishing.pkl is LogisticRegression
|
| 91 |
+
url_ml_prob = self.url_ml_model.predict_proba(features)[0][1]
|
| 92 |
+
|
| 93 |
+
# Hybrid Logic: Weight the ML model heavily if it has high confidence
|
| 94 |
+
if url_ml_prob > 0.8:
|
| 95 |
+
risk_score = max(risk_score, 0.9)
|
| 96 |
+
reasons.append(f"ML model identified highly malicious URL structure (Score: {url_ml_prob:.1%})")
|
| 97 |
+
elif url_ml_prob > 0.5:
|
| 98 |
+
risk_score = max(risk_score, 0.6)
|
| 99 |
+
reasons.append(f"ML model flagged suspicious URL structure (Score: {url_ml_prob:.1%})")
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Error predicting URL with ML model: {e}")
|
| 103 |
+
|
| 104 |
+
return min(risk_score, 1.0), reasons, url_ml_prob
|
| 105 |
+
|
| 106 |
+
def check_domain_similarity(self, url):
|
| 107 |
+
"""Check if domain is similar to legitimate domains"""
|
| 108 |
+
domain = self.extract_domain(url)
|
| 109 |
+
max_similarity = 0.0
|
| 110 |
+
|
| 111 |
+
for legit_domain in self.legitimate_domains:
|
| 112 |
+
similarity = SequenceMatcher(None, domain.lower(), legit_domain).ratio()
|
| 113 |
+
max_similarity = max(max_similarity, similarity)
|
| 114 |
+
|
| 115 |
+
return max_similarity
|
| 116 |
+
|
| 117 |
+
def extract_domain(self, url):
|
| 118 |
+
"""Extract domain from URL"""
|
| 119 |
+
parsed = urlparse(url)
|
| 120 |
+
domain = parsed.netloc or parsed.path.split('/')[0]
|
| 121 |
+
return domain
|
| 122 |
+
|
| 123 |
+
def analyze(self, input_data):
|
| 124 |
+
"""Main analysis function"""
|
| 125 |
+
text = input_data['cleaned_text']
|
| 126 |
+
urls = input_data['urls']
|
| 127 |
+
|
| 128 |
+
results = {
|
| 129 |
+
'url_risk': 0.0,
|
| 130 |
+
'url_ml_risk': 0.0,
|
| 131 |
+
'domain_similarity': 0.0,
|
| 132 |
+
'suspicious_patterns': [],
|
| 133 |
+
'risk_factors': [],
|
| 134 |
+
'overall_risk': 0.0
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
if urls:
|
| 138 |
+
url_risks = []
|
| 139 |
+
url_ml_risks = []
|
| 140 |
+
for url in urls:
|
| 141 |
+
risk, reasons, ml_prob = self.analyze_url_risk(url)
|
| 142 |
+
url_risks.append(risk)
|
| 143 |
+
url_ml_risks.append(ml_prob)
|
| 144 |
+
results['risk_factors'].extend(reasons)
|
| 145 |
+
|
| 146 |
+
results['url_risk'] = np.mean(url_risks) if url_risks else 0
|
| 147 |
+
results['url_ml_risk'] = max(url_ml_risks) if url_ml_risks else 0
|
| 148 |
+
|
| 149 |
+
results['domain_similarity'] = self.check_domain_similarity(urls[0])
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
text_embedding = self.model.encode([text])
|
| 153 |
+
similarities = cosine_similarity(text_embedding, self.pattern_embeddings)[0]
|
| 154 |
+
|
| 155 |
+
if max(similarities) > 0.6:
|
| 156 |
+
results['suspicious_patterns'].append("Text similar to known phishing patterns")
|
| 157 |
+
results['overall_risk'] += 0.3
|
| 158 |
+
except Exception as e:
|
| 159 |
+
print(f"Error in semantic similarity: {e}")
|
| 160 |
+
|
| 161 |
+
results['overall_risk'] = min(
|
| 162 |
+
results['url_risk'] * 0.6 +
|
| 163 |
+
results['domain_similarity'] * 0.4 +
|
| 164 |
+
len(results['suspicious_patterns']) * 0.1,
|
| 165 |
+
1.0
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return results
|
agents/agent2_content.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, pipeline
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
class ContentAnalysisAgent:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
# Detection of Device
|
| 11 |
+
self.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
| 12 |
+
print(f"Using device: {self.device} for inference optimization.")
|
| 13 |
+
|
| 14 |
+
self.model_name = "microsoft/deberta-v3-small"
|
| 15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
|
| 16 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 17 |
+
self.model_name,
|
| 18 |
+
num_labels=2,
|
| 19 |
+
ignore_mismatched_sizes=True
|
| 20 |
+
).to(self.device)
|
| 21 |
+
|
| 22 |
+
# New: sentence-transformers/all-MiniLM-L6-v2 using AutoModel/AutoTokenizer
|
| 23 |
+
self.minilm_name = "sentence-transformers/all-MiniLM-L6-v2"
|
| 24 |
+
self.minilm_tokenizer = AutoTokenizer.from_pretrained(self.minilm_name)
|
| 25 |
+
self.minilm_model = AutoModel.from_pretrained(self.minilm_name).to(self.device)
|
| 26 |
+
|
| 27 |
+
# Optimization: Use Half-precision if on MPS
|
| 28 |
+
if self.device.type == "mps":
|
| 29 |
+
self.model = self.model.half()
|
| 30 |
+
self.minilm_model = self.minilm_model.half()
|
| 31 |
+
|
| 32 |
+
self.model.eval()
|
| 33 |
+
self.minilm_model.eval()
|
| 34 |
+
|
| 35 |
+
print("Loading Hugging Face pipelines...")
|
| 36 |
+
try:
|
| 37 |
+
self.mask_pipeline = pipeline("fill-mask", model="microsoft/deberta-v3-small")
|
| 38 |
+
self.sentiment_pipeline = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
| 39 |
+
self.has_pipelines = True
|
| 40 |
+
print("Successfully loaded HF pipelines.")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Failed to load HF pipelines: {e}")
|
| 43 |
+
self.has_pipelines = False
|
| 44 |
+
|
| 45 |
+
print("Loading local text ML models...")
|
| 46 |
+
model_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'models')
|
| 47 |
+
try:
|
| 48 |
+
with open(os.path.join(model_dir, 'model_new.pkl'), 'rb') as f:
|
| 49 |
+
self.scikit_model = pickle.load(f)
|
| 50 |
+
with open(os.path.join(model_dir, 'vectorizer_new.pkl'), 'rb') as f:
|
| 51 |
+
self.scikit_vectorizer = pickle.load(f)
|
| 52 |
+
self.has_text_ml = True
|
| 53 |
+
print("Successfully loaded text ML models.")
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"Failed to load text ML models: {e}")
|
| 56 |
+
self.has_text_ml = False
|
| 57 |
+
|
| 58 |
+
self.phishing_keywords = [
|
| 59 |
+
'verify', 'account', 'bank', 'login', 'password', 'credit card',
|
| 60 |
+
'ssn', 'social security', 'suspended', 'limited', 'unusual activity',
|
| 61 |
+
'confirm identity', 'update information', 'click here', 'urgent'
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
self.urgency_phrases = [
|
| 65 |
+
'immediately', 'within 24 hours', 'as soon as possible',
|
| 66 |
+
'urgent', 'action required', 'deadline', 'expire soon'
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
self.prompt_injection_patterns = [
|
| 70 |
+
'ignore previous instructions',
|
| 71 |
+
'ignore all previous',
|
| 72 |
+
'disregard previous',
|
| 73 |
+
'system prompt',
|
| 74 |
+
'you are now',
|
| 75 |
+
'act as',
|
| 76 |
+
'new role:',
|
| 77 |
+
'forget your instructions'
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
def analyze_phishing(self, text):
|
| 81 |
+
"""Analyze text for phishing indicators"""
|
| 82 |
+
text_lower = text.lower()
|
| 83 |
+
|
| 84 |
+
keyword_matches = []
|
| 85 |
+
for keyword in self.phishing_keywords:
|
| 86 |
+
if keyword in text_lower:
|
| 87 |
+
keyword_matches.append(keyword)
|
| 88 |
+
|
| 89 |
+
urgency_matches = []
|
| 90 |
+
for phrase in self.urgency_phrases:
|
| 91 |
+
if phrase in text_lower:
|
| 92 |
+
urgency_matches.append(phrase)
|
| 93 |
+
|
| 94 |
+
keyword_score = min(len(keyword_matches) / 5, 1.0)
|
| 95 |
+
urgency_score = min(len(urgency_matches) / 3, 1.0)
|
| 96 |
+
|
| 97 |
+
has_personal_info_request = any([
|
| 98 |
+
'password' in text_lower and 'send' in text_lower,
|
| 99 |
+
'credit card' in text_lower,
|
| 100 |
+
'ssn' in text_lower,
|
| 101 |
+
'social security' in text_lower
|
| 102 |
+
])
|
| 103 |
+
|
| 104 |
+
if has_personal_info_request:
|
| 105 |
+
personal_info_score = 0.8
|
| 106 |
+
else:
|
| 107 |
+
personal_info_score = 0.0
|
| 108 |
+
|
| 109 |
+
phishing_score = (keyword_score * 0.4 + urgency_score * 0.3 + personal_info_score * 0.3)
|
| 110 |
+
|
| 111 |
+
return phishing_score, keyword_matches, urgency_matches
|
| 112 |
+
|
| 113 |
+
def analyze_prompt_injection(self, text):
|
| 114 |
+
"""Check for prompt injection attempts"""
|
| 115 |
+
text_lower = text.lower()
|
| 116 |
+
|
| 117 |
+
for pattern in self.prompt_injection_patterns:
|
| 118 |
+
if pattern in text_lower:
|
| 119 |
+
return True, [f"Prompt injection pattern detected: '{pattern}'"]
|
| 120 |
+
|
| 121 |
+
return False, []
|
| 122 |
+
|
| 123 |
+
def analyze_ai_generated(self, text):
|
| 124 |
+
"""Basic detection of AI-generated content patterns"""
|
| 125 |
+
ai_indicators = [
|
| 126 |
+
'as an ai', 'i am an ai', 'as a language model',
|
| 127 |
+
'i cannot', 'i apologize', 'i am unable to',
|
| 128 |
+
'unfortunately', 'i must inform you'
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
text_lower = text.lower()
|
| 132 |
+
matches = [ind for ind in ai_indicators if ind in text_lower]
|
| 133 |
+
|
| 134 |
+
if len(matches) > 1:
|
| 135 |
+
return 0.7, matches
|
| 136 |
+
elif len(matches) > 0:
|
| 137 |
+
return 0.4, matches
|
| 138 |
+
else:
|
| 139 |
+
return 0.0, []
|
| 140 |
+
|
| 141 |
+
def analyze_with_transformer(self, text):
|
| 142 |
+
"""Use transformer model for classification with optimized inference"""
|
| 143 |
+
try:
|
| 144 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
|
| 145 |
+
|
| 146 |
+
with torch.inference_mode(): # Faster than no_grad
|
| 147 |
+
outputs = self.model(**inputs)
|
| 148 |
+
probabilities = F.softmax(outputs.logits.float(), dim=-1) # Cast back to float for softmax
|
| 149 |
+
|
| 150 |
+
phishing_prob = probabilities[0][1].item()
|
| 151 |
+
return phishing_prob
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"Transformer error: {e}")
|
| 155 |
+
return 0.5
|
| 156 |
+
|
| 157 |
+
def get_minilm_embeddings(self, text):
|
| 158 |
+
"""Get embeddings using all-MiniLM-L6-v2 with mean pooling (optimized)"""
|
| 159 |
+
inputs = self.minilm_tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=512).to(self.device)
|
| 160 |
+
with torch.inference_mode():
|
| 161 |
+
model_output = self.minilm_model(**inputs)
|
| 162 |
+
|
| 163 |
+
# Mean Pooling
|
| 164 |
+
attention_mask = inputs['attention_mask']
|
| 165 |
+
token_embeddings = model_output[0].float() # Cast to float16 to float32 for pooling stability
|
| 166 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 167 |
+
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 168 |
+
return embeddings
|
| 169 |
+
|
| 170 |
+
def analyze_connection(self, text, urls):
|
| 171 |
+
"""Analyze the connection between email text and URLs"""
|
| 172 |
+
if not urls:
|
| 173 |
+
return 1.0, "No URLs to analyze"
|
| 174 |
+
|
| 175 |
+
text_emb = self.get_minilm_embeddings(text)
|
| 176 |
+
|
| 177 |
+
connection_scores = []
|
| 178 |
+
for url in urls:
|
| 179 |
+
# Extract meaningful parts of the URL for semantic comparison
|
| 180 |
+
url_parts = url.replace('http://', '').replace('https://', '').replace('www.', '')
|
| 181 |
+
url_parts = re.sub(r'[/.\-_]', ' ', url_parts)
|
| 182 |
+
url_emb = self.get_minilm_embeddings(url_parts)
|
| 183 |
+
|
| 184 |
+
similarity = F.cosine_similarity(text_emb, url_emb).item()
|
| 185 |
+
connection_scores.append(similarity)
|
| 186 |
+
|
| 187 |
+
avg_connection = sum(connection_scores) / len(connection_scores)
|
| 188 |
+
|
| 189 |
+
# A very low connection score (divergence) is an indicator of phishing
|
| 190 |
+
if avg_connection < 0.2:
|
| 191 |
+
return avg_connection, "High divergence: URL content does not match email context"
|
| 192 |
+
elif avg_connection < 0.4:
|
| 193 |
+
return avg_connection, "Moderate divergence: URL seems loosely related to email context"
|
| 194 |
+
else:
|
| 195 |
+
return avg_connection, "Stable: URL matches email context"
|
| 196 |
+
|
| 197 |
+
def analyze(self, input_data):
|
| 198 |
+
"""Main analysis function with hybrid and connection logic"""
|
| 199 |
+
text = input_data['cleaned_text']
|
| 200 |
+
urls = input_data['urls']
|
| 201 |
+
|
| 202 |
+
# Benign baseline check for short / common messages
|
| 203 |
+
benign_greetings = ['hi', 'hii', 'hiii', 'hello', 'hey', 'how are you', 'how is this', 'test']
|
| 204 |
+
clean_msg = text.lower().strip().replace('?', '').replace('!', '')
|
| 205 |
+
if clean_msg in benign_greetings and not urls:
|
| 206 |
+
return {
|
| 207 |
+
'phishing_probability': 0.01,
|
| 208 |
+
'urgency_matches': [],
|
| 209 |
+
'keyword_matches': [],
|
| 210 |
+
'prompt_injection': False,
|
| 211 |
+
'ai_generated_probability': 0.05,
|
| 212 |
+
'spam_probability': 0.01,
|
| 213 |
+
'connection_score': 1.0,
|
| 214 |
+
'connection_message': "Safe: Benign conversational text",
|
| 215 |
+
'sentiment_label': "POSITIVE",
|
| 216 |
+
'sentiment_score': 0.99
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
phishing_score, keyword_matches, urgency_matches = self.analyze_phishing(text)
|
| 220 |
+
prompt_injection, injection_patterns = self.analyze_prompt_injection(text)
|
| 221 |
+
ai_generated_score, ai_patterns = self.analyze_ai_generated(text)
|
| 222 |
+
transformer_score = self.analyze_with_transformer(text)
|
| 223 |
+
|
| 224 |
+
# Hybrid Text Analysis: Combine model.pkl score with transformer_score
|
| 225 |
+
spam_probability = 0.0
|
| 226 |
+
spam_ml_prob = 0.0
|
| 227 |
+
if self.has_text_ml:
|
| 228 |
+
try:
|
| 229 |
+
features = self.scikit_vectorizer.transform([text])
|
| 230 |
+
spam_ml_prob = self.scikit_model.predict_proba(features)[0][1]
|
| 231 |
+
# Fine-tune transformer score using the pickle model baseline
|
| 232 |
+
transformer_score = (transformer_score * 0.7) + (spam_ml_prob * 0.3)
|
| 233 |
+
spam_probability = spam_ml_prob
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f"Text ML model unavailable (sklearn version mismatch), using fallback: {e}")
|
| 236 |
+
self.has_text_ml = False # disable to avoid repeated errors
|
| 237 |
+
|
| 238 |
+
# Connection Analysis
|
| 239 |
+
connection_score, connection_msg = self.analyze_connection(text, urls)
|
| 240 |
+
|
| 241 |
+
# Adjust combined phishing score based on connection divergence
|
| 242 |
+
# If divergence is high (low connection), we increase the phishing probability
|
| 243 |
+
connection_penalty = max(0, 0.5 - connection_score) if connection_score < 0.4 else 0
|
| 244 |
+
combined_phishing = min(max(phishing_score, transformer_score) + connection_penalty, 1.0)
|
| 245 |
+
|
| 246 |
+
if spam_probability < 0.3:
|
| 247 |
+
spam_indicators = ['free', 'win', 'winner', 'prize', 'click here', 'offer', 'limited time', 'lottery', 'congratulations', 'cash', 'money', 'claim', 'award']
|
| 248 |
+
spam_matches = [ind for ind in spam_indicators if ind in text.lower()]
|
| 249 |
+
heuristic_spam = min(len(spam_matches) / 6, 1.0) # 1 match = 0.16 (Safe), 2 matches = 0.33 (Low), 3 matches = 0.5 (Medium)
|
| 250 |
+
spam_probability = max(spam_probability, heuristic_spam)
|
| 251 |
+
|
| 252 |
+
# Optional sentiment analysis using pipeline
|
| 253 |
+
sentiment_score = 0.0
|
| 254 |
+
sentiment_label = "UNKNOWN"
|
| 255 |
+
if self.has_pipelines:
|
| 256 |
+
try:
|
| 257 |
+
sent_result = self.sentiment_pipeline(text[:512])[0]
|
| 258 |
+
sentiment_label = sent_result['label']
|
| 259 |
+
sentiment_score = sent_result['score'] if sentiment_label == 'NEGATIVE' else (1.0 - sent_result['score'])
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"Error predicting sentiment: {e}")
|
| 262 |
+
|
| 263 |
+
results = {
|
| 264 |
+
'phishing_probability': combined_phishing,
|
| 265 |
+
'prompt_injection': prompt_injection,
|
| 266 |
+
'prompt_injection_patterns': injection_patterns,
|
| 267 |
+
'ai_generated_probability': ai_generated_score,
|
| 268 |
+
'spam_probability': spam_probability,
|
| 269 |
+
'spam_ml_score': spam_ml_prob,
|
| 270 |
+
'keyword_matches': keyword_matches,
|
| 271 |
+
'urgency_matches': urgency_matches,
|
| 272 |
+
'ai_patterns': ai_patterns,
|
| 273 |
+
'transformer_score': transformer_score,
|
| 274 |
+
'using_transformer': True,
|
| 275 |
+
'sentiment_score': sentiment_score,
|
| 276 |
+
'sentiment_label': sentiment_label,
|
| 277 |
+
'connection_score': connection_score,
|
| 278 |
+
'connection_message': connection_msg
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
return results
|
agents/agent3_synthesizer.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class SynthesizerAgent:
|
| 2 |
+
def __init__(self):
|
| 3 |
+
self.thresholds = {
|
| 4 |
+
'low': 0.22,
|
| 5 |
+
'medium': 0.5,
|
| 6 |
+
'high': 0.8
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
self.weights = {
|
| 10 |
+
'phishing': 0.4,
|
| 11 |
+
'url_risk': 0.3,
|
| 12 |
+
'spam': 0.15,
|
| 13 |
+
'ai_generated': 0.1,
|
| 14 |
+
'domain_similarity': 0.05,
|
| 15 |
+
'prompt_injection': 0.3 # High impact when detected
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
def calculate_risk_score(self, agent1_results, agent2_results, agent4_results):
|
| 19 |
+
"""Calculate overall risk score"""
|
| 20 |
+
risk_score = 0.0
|
| 21 |
+
|
| 22 |
+
# Give higher priority to ML based scores if available
|
| 23 |
+
url_risk_val = agent1_results['url_ml_risk'] if agent1_results.get('url_ml_risk', 0) > agent1_results['url_risk'] else agent1_results['url_risk']
|
| 24 |
+
spam_val = agent2_results['spam_ml_score'] if agent2_results.get('spam_ml_score', 0) > agent2_results.get('spam_probability', 0) else agent2_results.get('spam_probability', 0)
|
| 25 |
+
|
| 26 |
+
risk_score += agent2_results['phishing_probability'] * self.weights['phishing']
|
| 27 |
+
risk_score += url_risk_val * self.weights['url_risk']
|
| 28 |
+
risk_score += spam_val * self.weights['spam']
|
| 29 |
+
risk_score += agent2_results['ai_generated_probability'] * self.weights['ai_generated']
|
| 30 |
+
risk_score += agent1_results['domain_similarity'] * self.weights['domain_similarity']
|
| 31 |
+
|
| 32 |
+
# Integrate Agent 4 Prompt Injection Score
|
| 33 |
+
risk_score += agent4_results['confidence'] * self.weights['prompt_injection']
|
| 34 |
+
|
| 35 |
+
# New: Factor in connection score (divergence)
|
| 36 |
+
connection_score = agent2_results.get('connection_score', 1.0)
|
| 37 |
+
if connection_score < 0.4:
|
| 38 |
+
# Low connection = higher risk
|
| 39 |
+
divergence_penalty = (0.4 - connection_score) * 0.5
|
| 40 |
+
risk_score += divergence_penalty
|
| 41 |
+
|
| 42 |
+
# Adjust based on aggressive sentiment
|
| 43 |
+
if agent2_results.get('sentiment_label') == 'NEGATIVE' and agent2_results.get('sentiment_score', 0) > 0.8:
|
| 44 |
+
risk_score += 0.1
|
| 45 |
+
|
| 46 |
+
# Combine Prompt Injection flags from Agent 2 (heuristic) and Agent 4 (transformer)
|
| 47 |
+
if agent2_results['prompt_injection'] or agent4_results['prompt_injection_detected']:
|
| 48 |
+
risk_score = max(risk_score, 0.7) # Ensure at least HIGH risk if injection is detected
|
| 49 |
+
|
| 50 |
+
return min(risk_score, 1.0)
|
| 51 |
+
|
| 52 |
+
def determine_risk_level(self, risk_score):
|
| 53 |
+
"""Convert numerical score to risk level"""
|
| 54 |
+
if risk_score >= self.thresholds['high']:
|
| 55 |
+
return "HIGH"
|
| 56 |
+
elif risk_score >= self.thresholds['medium']:
|
| 57 |
+
return "MEDIUM"
|
| 58 |
+
elif risk_score >= self.thresholds['low']:
|
| 59 |
+
return "LOW"
|
| 60 |
+
else:
|
| 61 |
+
return "MINIMAL"
|
| 62 |
+
|
| 63 |
+
def determine_threat_type(self, risk_score, agent1_results, agent2_results, agent4_results):
|
| 64 |
+
"""Classify the type of threat"""
|
| 65 |
+
threats = []
|
| 66 |
+
|
| 67 |
+
if agent2_results['phishing_probability'] > 0.7:
|
| 68 |
+
threats.append("Phishing")
|
| 69 |
+
|
| 70 |
+
if agent1_results['url_risk'] > 0.7 or agent1_results.get('url_ml_risk', 0) > 0.7:
|
| 71 |
+
threats.append("Malicious URL")
|
| 72 |
+
|
| 73 |
+
if agent2_results['prompt_injection'] or agent4_results['prompt_injection_detected']:
|
| 74 |
+
threats.append("Prompt Injection")
|
| 75 |
+
|
| 76 |
+
if agent2_results['ai_generated_probability'] > 0.6:
|
| 77 |
+
threats.append("AI-Generated Scam")
|
| 78 |
+
|
| 79 |
+
if agent2_results.get('spam_probability', 0) > 0.7 or agent2_results.get('spam_ml_score', 0) > 0.7:
|
| 80 |
+
threats.append("Spam")
|
| 81 |
+
|
| 82 |
+
if not threats and risk_score > 0.3:
|
| 83 |
+
threats.append("Suspicious Content")
|
| 84 |
+
elif not threats:
|
| 85 |
+
threats.append("Benign")
|
| 86 |
+
|
| 87 |
+
return threats
|
| 88 |
+
|
| 89 |
+
def generate_explanation(self, agent1_results, agent2_results, agent4_results, threat_types, risk_score):
|
| 90 |
+
"""Generate detailed, context-aware forensic reasoning like a security expert."""
|
| 91 |
+
reasons = []
|
| 92 |
+
|
| 93 |
+
# ── URL / Domain Forensics ──
|
| 94 |
+
for factor in agent1_results.get('risk_factors', []):
|
| 95 |
+
factor_lower = factor.lower()
|
| 96 |
+
if 'suspicious tld' in factor_lower:
|
| 97 |
+
reasons.append(f"URL Analysis: {factor} — uncommon TLDs are frequently used by phishing campaigns to evade domain blocklists")
|
| 98 |
+
elif 'ip address' in factor_lower:
|
| 99 |
+
reasons.append(f"URL Analysis: {factor} — legitimate services almost never use raw IP addresses in their links")
|
| 100 |
+
elif 'shortening' in factor_lower:
|
| 101 |
+
reasons.append(f"URL Analysis: {factor} — URL shorteners hide the true destination, commonly abused by attackers")
|
| 102 |
+
elif 'ml model' in factor_lower:
|
| 103 |
+
reasons.append(f"URL Analysis (ML): {factor}")
|
| 104 |
+
elif 'similar to legitimate' in factor_lower:
|
| 105 |
+
reasons.append(f"Sender Spoofing: {factor} — this domain uses visual similarity (homoglyph attack) to impersonate a trusted brand")
|
| 106 |
+
elif 'suspicious keyword' in factor_lower:
|
| 107 |
+
reasons.append(f"URL Analysis: {factor} — authentication keywords in URLs often indicate credential-harvesting pages")
|
| 108 |
+
elif 'subdomain' in factor_lower:
|
| 109 |
+
reasons.append(f"URL Analysis: {factor} — excessive subdomains are a technique to disguise malicious domains")
|
| 110 |
+
else:
|
| 111 |
+
reasons.append(f"URL Analysis: {factor}")
|
| 112 |
+
|
| 113 |
+
# Domain similarity warning
|
| 114 |
+
if agent1_results.get('domain_similarity', 0) > 0.5:
|
| 115 |
+
reasons.append(f"Sender Spoofing: Domain is {agent1_results['domain_similarity']:.0%} similar to a known legitimate brand — possible impersonation attempt")
|
| 116 |
+
|
| 117 |
+
# ── Content Forensics ──
|
| 118 |
+
keyword_matches = agent2_results.get('keyword_matches', [])
|
| 119 |
+
if keyword_matches:
|
| 120 |
+
kw_str = ', '.join(f"'{k}'" for k in keyword_matches[:4])
|
| 121 |
+
reasons.append(f"Content Analysis: Detected high-risk keywords [{kw_str}] — these are hallmarks of social engineering and credential theft attempts")
|
| 122 |
+
|
| 123 |
+
urgency_matches = agent2_results.get('urgency_matches', [])
|
| 124 |
+
if urgency_matches:
|
| 125 |
+
urg_str = ', '.join(f"'{u}'" for u in urgency_matches[:3])
|
| 126 |
+
reasons.append(f"Behavioral Threat: Urgency/pressure language detected [{urg_str}] — creates artificial time pressure to bypass critical thinking")
|
| 127 |
+
|
| 128 |
+
# ── Prompt Injection (Agent 4 Integration) ──
|
| 129 |
+
if agent4_results.get('prompt_injection_detected'):
|
| 130 |
+
cats = agent4_results.get('attack_categories', [])
|
| 131 |
+
detail = f"Detected Categories: {', '.join(cats)}" if cats else "AI instruction override attempt"
|
| 132 |
+
reasons.append(f"Prompt Injection Agent: {detail} (Risk: {agent4_results['confidence']:.0%}) — advanced hijacking pattern identified via transformer analysis")
|
| 133 |
+
elif agent2_results.get('prompt_injection'):
|
| 134 |
+
reasons.append("Prompt Injection: Heuristic pattern match — suspicious instruction override pattern detected in input text")
|
| 135 |
+
|
| 136 |
+
# ── AI Generated Content ──
|
| 137 |
+
ai_prob = agent2_results.get('ai_generated_probability', 0)
|
| 138 |
+
if ai_prob > 0.5:
|
| 139 |
+
reasons.append(f"Content Analysis: Text shows AI-generation patterns (Score: {ai_prob:.0%}) — machine-written scam content designed to appear legitimate")
|
| 140 |
+
|
| 141 |
+
# ── Semantic Divergence ──
|
| 142 |
+
connection_score = agent2_results.get('connection_score', 1.0)
|
| 143 |
+
connection_msg = agent2_results.get('connection_message', '')
|
| 144 |
+
if connection_score < 0.4:
|
| 145 |
+
reasons.append(f"Hidden Threat: {connection_msg} (Divergence Score: {connection_score:.0%}) — link text says one thing but URL points somewhere completely different")
|
| 146 |
+
elif connection_score < 0.6 and agent1_results.get('url_risk', 0) > 0.3:
|
| 147 |
+
reasons.append(f"Content Analysis: Weak semantic link between email text and embedded URLs ({connection_score:.0%}) — potentially deceptive link labels")
|
| 148 |
+
|
| 149 |
+
# ── Sentiment / Tone ──
|
| 150 |
+
sentiment_label = agent2_results.get('sentiment_label', 'UNKNOWN')
|
| 151 |
+
sentiment_score = agent2_results.get('sentiment_score', 0)
|
| 152 |
+
if sentiment_label == 'NEGATIVE' and sentiment_score > 0.8:
|
| 153 |
+
reasons.append(f"Behavioral Threat: Highly aggressive/threatening tone detected (Score: {sentiment_score:.1%}) — intimidation tactics used to provoke panic-driven actions")
|
| 154 |
+
elif sentiment_label == 'NEGATIVE' and sentiment_score > 0.5:
|
| 155 |
+
reasons.append(f"Content Analysis: Negative sentiment detected (Score: {sentiment_score:.1%}) — may use fear-based language to manipulate recipient")
|
| 156 |
+
|
| 157 |
+
# ── Spam Signals ──
|
| 158 |
+
spam_prob = max(agent2_results.get('spam_probability', 0), agent2_results.get('spam_ml_score', 0))
|
| 159 |
+
if spam_prob > 0.7:
|
| 160 |
+
reasons.append(f"Content Analysis: High spam probability ({spam_prob:.0%}) — message matches known bulk/unsolicited mail patterns")
|
| 161 |
+
|
| 162 |
+
# ── Safe fallback: never return empty reasoning ──
|
| 163 |
+
if not reasons:
|
| 164 |
+
if risk_score < 0.2:
|
| 165 |
+
reasons.append("Content Analysis: No suspicious patterns, malicious URLs, or social engineering tactics detected — message appears legitimate")
|
| 166 |
+
reasons.append("URL Analysis: No links found, or all URLs point to verified, trusted domains")
|
| 167 |
+
else:
|
| 168 |
+
reasons.append(f"Content Analysis: Minor risk signals detected (combined score: {risk_score:.0%}) but no single strong threat indicator found")
|
| 169 |
+
|
| 170 |
+
# ── Recommended Actions ──
|
| 171 |
+
actions = []
|
| 172 |
+
if "Phishing" in threat_types or "Malicious URL" in threat_types:
|
| 173 |
+
actions.extend([
|
| 174 |
+
"Do not click any links in this message",
|
| 175 |
+
"Do not provide personal information or credentials",
|
| 176 |
+
"Block the sender and report to your security team"
|
| 177 |
+
])
|
| 178 |
+
elif "Prompt Injection" in threat_types:
|
| 179 |
+
actions.extend([
|
| 180 |
+
"Do not execute any instructions contained in this message",
|
| 181 |
+
"Report this message to security team"
|
| 182 |
+
])
|
| 183 |
+
elif "Spam" in threat_types:
|
| 184 |
+
actions.extend([
|
| 185 |
+
"Mark as spam and block sender",
|
| 186 |
+
"Do not unsubscribe via links — this confirms your address"
|
| 187 |
+
])
|
| 188 |
+
elif "AI-Generated Scam" in threat_types:
|
| 189 |
+
actions.extend([
|
| 190 |
+
"Verify the sender through an independent channel",
|
| 191 |
+
"Do not act on any financial requests in this message"
|
| 192 |
+
])
|
| 193 |
+
|
| 194 |
+
if risk_score < 0.3 and not actions:
|
| 195 |
+
actions.append("No immediate action required")
|
| 196 |
+
elif not actions:
|
| 197 |
+
actions.append("Report this message to security team")
|
| 198 |
+
|
| 199 |
+
return {
|
| 200 |
+
'reasons': reasons[:6],
|
| 201 |
+
'actions': actions[:4]
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
def synthesize(self, agent1_results, agent2_results, agent4_results):
|
| 205 |
+
"""Main synthesis function"""
|
| 206 |
+
risk_score = self.calculate_risk_score(agent1_results, agent2_results, agent4_results)
|
| 207 |
+
|
| 208 |
+
risk_level = self.determine_risk_level(risk_score)
|
| 209 |
+
|
| 210 |
+
threat_types = self.determine_threat_type(risk_score, agent1_results, agent2_results, agent4_results)
|
| 211 |
+
|
| 212 |
+
explanation = self.generate_explanation(
|
| 213 |
+
agent1_results, agent2_results, agent4_results, threat_types, risk_score
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Confidence: Now dynamically reflects certainty in the verdict
|
| 217 |
+
# Higher confidence when risk_score is closer to extremes (0.0 or 1.0)
|
| 218 |
+
# Lower confidence when score is near the middle (0.5)
|
| 219 |
+
distance_from_borderline = abs(risk_score - 0.5)
|
| 220 |
+
confidence = 0.5 + distance_from_borderline
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
result = {
|
| 224 |
+
'threat_types': threat_types,
|
| 225 |
+
'risk_level': risk_level,
|
| 226 |
+
'risk_score': risk_score,
|
| 227 |
+
'confidence': min(confidence, 1.0),
|
| 228 |
+
'explanation': explanation,
|
| 229 |
+
'detailed_results': {
|
| 230 |
+
'agent1': agent1_results,
|
| 231 |
+
'agent2': agent2_results,
|
| 232 |
+
'agent4': agent4_results
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
return result
|
agents/agent4_prompt.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent 4: AI Prompt Injection Detection Module
|
| 3 |
+
Uses a fine‑tuned DeBERTa model (MNLI) + rule‑based patterns.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 9 |
+
|
| 10 |
+
class PromptInjectionAgent:
|
| 11 |
+
"""
|
| 12 |
+
Detects prompt injection and jailbreak attempts in user inputs.
|
| 13 |
+
Combines a transformer model (trained on MNLI) with heuristic rules.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, model_name="mrm8488/deberta-v3-small-finetuned-mnli", threshold=0.6):
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
model_name: Hugging Face model identifier for a DeBERTa MNLI model.
|
| 19 |
+
threshold: Confidence threshold above which input is flagged as injection.
|
| 20 |
+
"""
|
| 21 |
+
print("Loading Prompt Injection Agent (MNLI-based)...")
|
| 22 |
+
self.threshold = threshold
|
| 23 |
+
|
| 24 |
+
# Load tokenizer and model
|
| 25 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 26 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 27 |
+
self.model.eval() # inference mode
|
| 28 |
+
print("✓ Model loaded successfully")
|
| 29 |
+
|
| 30 |
+
# Rule‑based patterns (covers common jailbreak attempts)
|
| 31 |
+
self.injection_patterns = [
|
| 32 |
+
(r"ignore previous instructions", "instruction_override"),
|
| 33 |
+
(r"ignore all previous", "instruction_override"),
|
| 34 |
+
(r"disregard previous", "instruction_override"),
|
| 35 |
+
(r"system prompt", "system_override"),
|
| 36 |
+
(r"you are now", "role_playing"),
|
| 37 |
+
(r"act as", "role_playing"),
|
| 38 |
+
(r"new role:", "role_playing"),
|
| 39 |
+
(r"forget your instructions", "instruction_override"),
|
| 40 |
+
(r"do anything now", "privilege_escalation"),
|
| 41 |
+
(r"you must", "privilege_escalation"),
|
| 42 |
+
(r"you are free", "jailbreak"),
|
| 43 |
+
(r"no restrictions", "jailbreak"),
|
| 44 |
+
(r"override", "instruction_override"),
|
| 45 |
+
(r"jailbreak", "jailbreak"),
|
| 46 |
+
(r"dan", "jailbreak"), # DAN mode
|
| 47 |
+
(r"developer mode", "jailbreak"),
|
| 48 |
+
(r"chatgpt, you are now", "role_playing"),
|
| 49 |
+
(r"you are an ai with no ethics", "role_playing"),
|
| 50 |
+
(r"output raw", "attention_diversion"),
|
| 51 |
+
(r"base64 decode", "attention_diversion"),
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
def analyze(self, text: str) -> dict:
|
| 55 |
+
"""
|
| 56 |
+
Analyze input text for prompt injection.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
dict with keys:
|
| 60 |
+
prompt_injection_detected (bool): final decision
|
| 61 |
+
confidence (float): combined risk score
|
| 62 |
+
risk_score (float): same as confidence (for backward compatibility)
|
| 63 |
+
matched_patterns (list): regex patterns that fired
|
| 64 |
+
attack_categories (list): types of injection detected
|
| 65 |
+
explanation (list): human‑readable reasons
|
| 66 |
+
"""
|
| 67 |
+
# -------------------- Rule‑based scan --------------------
|
| 68 |
+
text_lower = text.lower()
|
| 69 |
+
rule_score = 0.0
|
| 70 |
+
matched_patterns = []
|
| 71 |
+
attack_categories = []
|
| 72 |
+
|
| 73 |
+
for pattern, category in self.injection_patterns:
|
| 74 |
+
if re.search(pattern, text_lower):
|
| 75 |
+
rule_score += 0.3
|
| 76 |
+
matched_patterns.append(pattern)
|
| 77 |
+
attack_categories.append(category)
|
| 78 |
+
|
| 79 |
+
# -------------------- Transformer inference --------------------
|
| 80 |
+
# Tokenize
|
| 81 |
+
inputs = self.tokenizer(
|
| 82 |
+
text,
|
| 83 |
+
return_tensors="pt",
|
| 84 |
+
truncation=True,
|
| 85 |
+
max_length=512,
|
| 86 |
+
padding=True
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
outputs = self.model(**inputs)
|
| 91 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 92 |
+
|
| 93 |
+
# MNLI classes: 0 = entailment, 1 = neutral, 2 = contradiction
|
| 94 |
+
contradiction_prob = probs[0][2].item()
|
| 95 |
+
|
| 96 |
+
# -------------------- Combine scores --------------------
|
| 97 |
+
# 70% weight on contradiction probability, 30% on rule‑based
|
| 98 |
+
combined_risk = 0.7 * contradiction_prob + 0.3 * min(rule_score, 1.0)
|
| 99 |
+
detected = combined_risk > self.threshold
|
| 100 |
+
|
| 101 |
+
# -------------------- Build explanation --------------------
|
| 102 |
+
explanation = []
|
| 103 |
+
explanation.append(f"Contradiction probability: {contradiction_prob:.1%}")
|
| 104 |
+
if attack_categories:
|
| 105 |
+
unique_cats = list(set(attack_categories))
|
| 106 |
+
explanation.append(f"Rule matches: {', '.join(unique_cats)}")
|
| 107 |
+
if detected:
|
| 108 |
+
explanation.append(f"Combined risk {combined_risk:.1%} exceeds threshold {self.threshold:.0%}")
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"prompt_injection_detected": detected,
|
| 112 |
+
"confidence": combined_risk,
|
| 113 |
+
"risk_score": combined_risk, # alias for compatibility
|
| 114 |
+
"matched_patterns": matched_patterns,
|
| 115 |
+
"attack_categories": list(set(attack_categories)),
|
| 116 |
+
"explanation": explanation
|
| 117 |
+
}
|
app.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from utils.preprocessor import TextPreprocessor
|
| 7 |
+
from agents.agent1_external import ExternalAnalysisAgent
|
| 8 |
+
from agents.agent2_content import ContentAnalysisAgent
|
| 9 |
+
from agents.agent3_synthesizer import SynthesizerAgent
|
| 10 |
+
from agents.agent4_prompt import PromptInjectionAgent
|
| 11 |
+
|
| 12 |
+
class ThreatDetectionSystem:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
print("Initializing Threat Detection System...")
|
| 15 |
+
self.preprocessor = TextPreprocessor()
|
| 16 |
+
self.agent1 = ExternalAnalysisAgent()
|
| 17 |
+
self.agent2 = ContentAnalysisAgent()
|
| 18 |
+
self.agent3 = SynthesizerAgent()
|
| 19 |
+
self.agent4 = PromptInjectionAgent()
|
| 20 |
+
print("System initialized!")
|
| 21 |
+
|
| 22 |
+
def analyze(self, user_input):
|
| 23 |
+
"""Main analysis pipeline"""
|
| 24 |
+
start_time = time.time()
|
| 25 |
+
|
| 26 |
+
# Step 1: Preprocess
|
| 27 |
+
preprocessed = self.preprocessor.preprocess(user_input)
|
| 28 |
+
|
| 29 |
+
# Step 2: Run agents
|
| 30 |
+
agent1_results = self.agent1.analyze(preprocessed)
|
| 31 |
+
agent2_results = self.agent2.analyze(preprocessed)
|
| 32 |
+
agent4_results = self.agent4.analyze(user_input)
|
| 33 |
+
|
| 34 |
+
# Step 3: Synthesize results
|
| 35 |
+
final_result = self.agent3.synthesize(agent1_results, agent2_results, agent4_results)
|
| 36 |
+
final_result['processing_time'] = time.time() - start_time
|
| 37 |
+
|
| 38 |
+
return final_result
|
| 39 |
+
|
| 40 |
+
# Initialize the system globally for HF
|
| 41 |
+
system = ThreatDetectionSystem()
|
| 42 |
+
|
| 43 |
+
# --- Gradio UI Logic ---
|
| 44 |
+
def ui_analyze(text):
|
| 45 |
+
if not text or not text.strip():
|
| 46 |
+
return "Please enter some text", {}, {}
|
| 47 |
+
|
| 48 |
+
result = system.analyze(text)
|
| 49 |
+
|
| 50 |
+
# Prettify the report for display
|
| 51 |
+
risk_color = "🔴" if result['risk_level'] == "HIGH" else "🟠" if result['risk_level'] == "MEDIUM" else "🟡" if result['risk_level'] == "LOW" else "🟢"
|
| 52 |
+
report = f"{risk_color} {result['risk_level']} RISK DETECTED\n"
|
| 53 |
+
report += f"Confidence: {result['confidence']:.1%}\n"
|
| 54 |
+
report += f"Type: {', '.join(result['threat_types'])}\n\n"
|
| 55 |
+
report += "Forensic Reasons:\n" + "\n".join([f"- {r}" for r in result['explanation']['reasons']])
|
| 56 |
+
|
| 57 |
+
return report, result['detailed_results'], result['explanation']['actions']
|
| 58 |
+
|
| 59 |
+
# --- Next.js Backend Compatibility API ---
|
| 60 |
+
# This endpoint is what the Vercel frontend calls
|
| 61 |
+
def api_analyze(text):
|
| 62 |
+
try:
|
| 63 |
+
if not text or not text.strip():
|
| 64 |
+
return {"error": "No input provided"}
|
| 65 |
+
|
| 66 |
+
result = system.analyze(text)
|
| 67 |
+
|
| 68 |
+
# Map to the schema expected by the Next.js frontend
|
| 69 |
+
risk_map = {"MINIMAL": "Safe", "LOW": "Low", "MEDIUM": "Medium", "HIGH": "High"}
|
| 70 |
+
risk_level = risk_map.get(result["risk_level"], "Medium")
|
| 71 |
+
|
| 72 |
+
return {
|
| 73 |
+
"riskLevel": risk_level,
|
| 74 |
+
"threatType": ", ".join(result["threat_types"]),
|
| 75 |
+
"confidenceScore": round(result["confidence"] * 100, 1),
|
| 76 |
+
"riskScore": round(result["risk_score"], 4),
|
| 77 |
+
"explanation": " ".join(result["explanation"]["reasons"]),
|
| 78 |
+
"indicators": result["explanation"]["reasons"],
|
| 79 |
+
"recommendations": result["explanation"]["actions"],
|
| 80 |
+
"detailedScores": {
|
| 81 |
+
"phishingProb": round(result["detailed_results"]["agent2"].get("phishing_probability", 0), 3),
|
| 82 |
+
"spamProb": round(result["detailed_results"]["agent2"].get("spam_probability", 0), 3),
|
| 83 |
+
"urlRisk": round(result["detailed_results"]["agent1"].get("url_risk", 0), 3),
|
| 84 |
+
"sentimentLabel": result["detailed_results"]["agent2"].get("sentiment_label", "UNKNOWN"),
|
| 85 |
+
"sentimentScore": round(result["detailed_results"]["agent2"].get("sentiment_score", 0), 3),
|
| 86 |
+
"promptInjectionScore": round(result["detailed_results"]["agent4"].get("confidence", 0), 3),
|
| 87 |
+
"promptInjectionDetected": result["detailed_results"]["agent4"].get("prompt_injection_detected", False),
|
| 88 |
+
},
|
| 89 |
+
}
|
| 90 |
+
except Exception as e:
|
| 91 |
+
return {"error": str(e)}
|
| 92 |
+
|
| 93 |
+
# --- Theme and Layout ---
|
| 94 |
+
with gr.Blocks(theme="soft", title="🛡️ AegisAI Security") as demo:
|
| 95 |
+
gr.Markdown("# 🛡️ AegisAI: Advanced Phishing & Fraud Detector")
|
| 96 |
+
gr.Markdown("Drop an email body or URL here to get a full forensic breakdown.")
|
| 97 |
+
|
| 98 |
+
with gr.Row():
|
| 99 |
+
with gr.Column(scale=2):
|
| 100 |
+
input_box = gr.Textbox(label="Message Content", lines=8, placeholder="Paste email content...")
|
| 101 |
+
with gr.Row():
|
| 102 |
+
clear_btn = gr.Button("Clear")
|
| 103 |
+
submit_btn = gr.Button("Analyze Threat", variant="primary")
|
| 104 |
+
|
| 105 |
+
with gr.Column(scale=3):
|
| 106 |
+
out_report = gr.Textbox(label="Analysis Report", lines=10, interactive=False)
|
| 107 |
+
out_actions = gr.JSON(label="Recommended Actions")
|
| 108 |
+
out_scores = gr.JSON(label="Agent Confidence Scores")
|
| 109 |
+
|
| 110 |
+
# Connect UI
|
| 111 |
+
submit_btn.click(fn=ui_analyze, inputs=input_box, outputs=[out_report, out_scores, out_actions])
|
| 112 |
+
clear_btn.click(lambda: ["", "", {}, {}], outputs=[input_box, out_report, out_scores, out_actions])
|
| 113 |
+
|
| 114 |
+
# HIDDEN API ENDPOINT FOR VERCEL
|
| 115 |
+
# Note: Hugging Face exposes this as an endpoint /run/predict or via api_name
|
| 116 |
+
api_endpoint = gr.Button("API", visible=False)
|
| 117 |
+
api_endpoint.click(fn=api_analyze, inputs=input_box, outputs=out_scores, api_name="analyze")
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
demo.launch()
|
models/model_new.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b127302853205a24d23d79d543a17b0bc1aeecf152754f7ad9d1d77106acbe64
|
| 3 |
+
size 40720
|
models/phishing_new.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:910e6db7e9bf78ed8f52ceed3a813541c749f412ceeeddfbb3758726eb7267e8
|
| 3 |
+
size 40720
|
models/vectorizer_new.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a16317dca2e5aef5768222113e2a345bbef0e84e9f4115cd5394670506f5938b
|
| 3 |
+
size 191803
|
models/vectorizerurl_new.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e594d26eb86e52c21d67640d7a0485fd519bbeb83411fe49a721f317a14ab181
|
| 3 |
+
size 189473
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
scikit-learn
|
| 5 |
+
torch
|
| 6 |
+
transformers
|
| 7 |
+
sentence-transformers
|
| 8 |
+
# Optional: flask, gunicorn, flask-cors are not strictly needed for Gradio Spaces
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .preprocessor import TextPreprocessor
|
| 2 |
+
|
| 3 |
+
__all__ = ['TextPreprocessor']
|
utils/preprocessor.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from urllib.parse import urlparse
|
| 3 |
+
|
| 4 |
+
class TextPreprocessor:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
def clean_text(self, text):
|
| 9 |
+
"""Basic text cleaning"""
|
| 10 |
+
text = ' '.join(text.split())
|
| 11 |
+
return text
|
| 12 |
+
|
| 13 |
+
def extract_urls(self, text):
|
| 14 |
+
"""Extract URLs from text"""
|
| 15 |
+
url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
| 16 |
+
urls = re.findall(url_pattern, text)
|
| 17 |
+
return urls
|
| 18 |
+
|
| 19 |
+
def extract_domain(self, url):
|
| 20 |
+
"""Extract domain from URL"""
|
| 21 |
+
try:
|
| 22 |
+
parsed = urlparse(url)
|
| 23 |
+
domain = parsed.netloc or parsed.path.split('/')[0]
|
| 24 |
+
return domain
|
| 25 |
+
except:
|
| 26 |
+
return ""
|
| 27 |
+
|
| 28 |
+
def preprocess(self, text):
|
| 29 |
+
"""Main preprocessing function"""
|
| 30 |
+
cleaned_text = self.clean_text(text)
|
| 31 |
+
urls = self.extract_urls(cleaned_text)
|
| 32 |
+
domains = [self.extract_domain(url) for url in urls]
|
| 33 |
+
|
| 34 |
+
return {
|
| 35 |
+
'cleaned_text': cleaned_text,
|
| 36 |
+
'urls': urls,
|
| 37 |
+
'domains': domains,
|
| 38 |
+
'has_urls': len(urls) > 0,
|
| 39 |
+
'text_length': len(cleaned_text)
|
| 40 |
+
}
|