from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, pipeline import torch import torch.nn.functional as F import os import pickle import re class ContentAnalysisAgent: def __init__(self): # Detection of Device self.device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") print(f"Using device: {self.device} for inference optimization.") self.model_name = "microsoft/deberta-v3-small" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_name, num_labels=2, ignore_mismatched_sizes=True ).to(self.device) # New: sentence-transformers/all-MiniLM-L6-v2 using AutoModel/AutoTokenizer self.minilm_name = "sentence-transformers/all-MiniLM-L6-v2" self.minilm_tokenizer = AutoTokenizer.from_pretrained(self.minilm_name) self.minilm_model = AutoModel.from_pretrained(self.minilm_name).to(self.device) # Optimization: Use Half-precision if on MPS if self.device.type == "mps": self.model = self.model.half() self.minilm_model = self.minilm_model.half() self.model.eval() self.minilm_model.eval() print("Loading Hugging Face pipelines...") try: self.mask_pipeline = pipeline("fill-mask", model="microsoft/deberta-v3-small") self.sentiment_pipeline = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english") self.has_pipelines = True print("Successfully loaded HF pipelines.") except Exception as e: print(f"Failed to load HF pipelines: {e}") self.has_pipelines = False print("Loading local text ML models...") model_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'models') try: with open(os.path.join(model_dir, 'model_new.pkl'), 'rb') as f: self.scikit_model = pickle.load(f) with open(os.path.join(model_dir, 'vectorizer_new.pkl'), 'rb') as f: self.scikit_vectorizer = pickle.load(f) self.has_text_ml = True print("Successfully loaded text ML models.") except Exception as e: print(f"Failed to load text ML models: {e}") self.has_text_ml = False self.phishing_keywords = [ 'verify', 'account', 'bank', 'login', 'password', 'credit card', 'ssn', 'social security', 'suspended', 'limited', 'unusual activity', 'confirm identity', 'update information', 'click here', 'urgent' ] self.urgency_phrases = [ 'immediately', 'within 24 hours', 'as soon as possible', 'urgent', 'action required', 'deadline', 'expire soon' ] self.prompt_injection_patterns = [ 'ignore previous instructions', 'ignore all previous', 'disregard previous', 'system prompt', 'you are now', 'act as', 'new role:', 'forget your instructions' ] def analyze_phishing(self, text): """Analyze text for phishing indicators""" text_lower = text.lower() keyword_matches = [] for keyword in self.phishing_keywords: if keyword in text_lower: keyword_matches.append(keyword) urgency_matches = [] for phrase in self.urgency_phrases: if phrase in text_lower: urgency_matches.append(phrase) keyword_score = min(len(keyword_matches) / 5, 1.0) urgency_score = min(len(urgency_matches) / 3, 1.0) has_personal_info_request = any([ 'password' in text_lower and 'send' in text_lower, 'credit card' in text_lower, 'ssn' in text_lower, 'social security' in text_lower ]) if has_personal_info_request: personal_info_score = 0.8 else: personal_info_score = 0.0 phishing_score = (keyword_score * 0.4 + urgency_score * 0.3 + personal_info_score * 0.3) return phishing_score, keyword_matches, urgency_matches def analyze_prompt_injection(self, text): """Check for prompt injection attempts""" text_lower = text.lower() for pattern in self.prompt_injection_patterns: if pattern in text_lower: return True, [f"Prompt injection pattern detected: '{pattern}'"] return False, [] def analyze_ai_generated(self, text): """Basic detection of AI-generated content patterns""" ai_indicators = [ 'as an ai', 'i am an ai', 'as a language model', 'i cannot', 'i apologize', 'i am unable to', 'unfortunately', 'i must inform you' ] text_lower = text.lower() matches = [ind for ind in ai_indicators if ind in text_lower] if len(matches) > 1: return 0.7, matches elif len(matches) > 0: return 0.4, matches else: return 0.0, [] def analyze_with_transformer(self, text): """Use transformer model for classification with optimized inference""" try: inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device) with torch.inference_mode(): # Faster than no_grad outputs = self.model(**inputs) probabilities = F.softmax(outputs.logits.float(), dim=-1) # Cast back to float for softmax phishing_prob = probabilities[0][1].item() return phishing_prob except Exception as e: print(f"Transformer error: {e}") return 0.5 def get_minilm_embeddings(self, text): """Get embeddings using all-MiniLM-L6-v2 with mean pooling (optimized)""" inputs = self.minilm_tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=512).to(self.device) with torch.inference_mode(): model_output = self.minilm_model(**inputs) # Mean Pooling attention_mask = inputs['attention_mask'] token_embeddings = model_output[0].float() # Cast to float16 to float32 for pooling stability input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) return embeddings def analyze_connection(self, text, urls): """Analyze the connection between email text and URLs""" if not urls: return 1.0, "No URLs to analyze" text_emb = self.get_minilm_embeddings(text) connection_scores = [] for url in urls: # Extract meaningful parts of the URL for semantic comparison url_parts = url.replace('http://', '').replace('https://', '').replace('www.', '') url_parts = re.sub(r'[/.\-_]', ' ', url_parts) url_emb = self.get_minilm_embeddings(url_parts) similarity = F.cosine_similarity(text_emb, url_emb).item() connection_scores.append(similarity) avg_connection = sum(connection_scores) / len(connection_scores) # A very low connection score (divergence) is an indicator of phishing if avg_connection < 0.2: return avg_connection, "High divergence: URL content does not match email context" elif avg_connection < 0.4: return avg_connection, "Moderate divergence: URL seems loosely related to email context" else: return avg_connection, "Stable: URL matches email context" def analyze(self, input_data): """Main analysis function with hybrid and connection logic""" text = input_data['cleaned_text'] urls = input_data['urls'] # Benign baseline check for short / common messages benign_greetings = ['hi', 'hii', 'hiii', 'hello', 'hey', 'how are you', 'how is this', 'test'] clean_msg = text.lower().strip().replace('?', '').replace('!', '') if clean_msg in benign_greetings and not urls: return { 'phishing_probability': 0.01, 'urgency_matches': [], 'keyword_matches': [], 'prompt_injection': False, 'ai_generated_probability': 0.05, 'spam_probability': 0.01, 'connection_score': 1.0, 'connection_message': "Safe: Benign conversational text", 'sentiment_label': "POSITIVE", 'sentiment_score': 0.99 } phishing_score, keyword_matches, urgency_matches = self.analyze_phishing(text) prompt_injection, injection_patterns = self.analyze_prompt_injection(text) ai_generated_score, ai_patterns = self.analyze_ai_generated(text) transformer_score = self.analyze_with_transformer(text) # Hybrid Text Analysis: Combine model.pkl score with transformer_score spam_probability = 0.0 spam_ml_prob = 0.0 if self.has_text_ml: try: features = self.scikit_vectorizer.transform([text]) spam_ml_prob = self.scikit_model.predict_proba(features)[0][1] # Fine-tune transformer score using the pickle model baseline transformer_score = (transformer_score * 0.7) + (spam_ml_prob * 0.3) spam_probability = spam_ml_prob except Exception as e: print(f"Text ML model unavailable (sklearn version mismatch), using fallback: {e}") self.has_text_ml = False # disable to avoid repeated errors # Connection Analysis connection_score, connection_msg = self.analyze_connection(text, urls) # Adjust combined phishing score based on connection divergence # If divergence is high (low connection), we increase the phishing probability connection_penalty = max(0, 0.5 - connection_score) if connection_score < 0.4 else 0 combined_phishing = min(max(phishing_score, transformer_score) + connection_penalty, 1.0) if spam_probability < 0.3: spam_indicators = ['free', 'win', 'winner', 'prize', 'click here', 'offer', 'limited time', 'lottery', 'congratulations', 'cash', 'money', 'claim', 'award'] spam_matches = [ind for ind in spam_indicators if ind in text.lower()] heuristic_spam = min(len(spam_matches) / 6, 1.0) # 1 match = 0.16 (Safe), 2 matches = 0.33 (Low), 3 matches = 0.5 (Medium) spam_probability = max(spam_probability, heuristic_spam) # Optional sentiment analysis using pipeline sentiment_score = 0.0 sentiment_label = "UNKNOWN" if self.has_pipelines: try: sent_result = self.sentiment_pipeline(text[:512])[0] sentiment_label = sent_result['label'] sentiment_score = sent_result['score'] if sentiment_label == 'NEGATIVE' else (1.0 - sent_result['score']) except Exception as e: print(f"Error predicting sentiment: {e}") results = { 'phishing_probability': combined_phishing, 'prompt_injection': prompt_injection, 'prompt_injection_patterns': injection_patterns, 'ai_generated_probability': ai_generated_score, 'spam_probability': spam_probability, 'spam_ml_score': spam_ml_prob, 'keyword_matches': keyword_matches, 'urgency_matches': urgency_matches, 'ai_patterns': ai_patterns, 'transformer_score': transformer_score, 'using_transformer': True, 'sentiment_score': sentiment_score, 'sentiment_label': sentiment_label, 'connection_score': connection_score, 'connection_message': connection_msg } return results