"""Simulate federated learning with N devices over M rounds using a real MLP.""" import numpy as np import json import os import sys import argparse from typing import Optional sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from federated.local_trainer import LocalTrainer, LocalTrainingBuffer from federated.dp_injector import DPInjector from sentinel_edge.classifier.mlp_classifier import MLPClassifier INPUT_DIM = 402 class FederatedSimulation: """Simulate federated learning with a real 3-layer MLP classifier. Each round: 1. Each device copies global weights, fine-tunes with real numpy backprop 2. Computes gradient delta (local_weights - global_weights) 3. Applies DP noise (clip + Gaussian) 4. Hub aggregates via FedAvg weighted by n_samples 5. Global model updated, evaluated on hold-out test set """ def __init__(self, n_devices: int = 5, n_rounds: int = 10, epsilon: float = 0.3, use_dp: bool = True): self.n_devices = n_devices self.n_rounds = n_rounds self.use_dp = use_dp self.devices: list = [] self.global_model: MLPClassifier = None self.round_results: list = [] self.dp_injector = DPInjector(epsilon=epsilon) # Hold-out test set self.test_features: np.ndarray = None self.test_labels: np.ndarray = None # ------------------------------------------------------------------ # Device & data initialisation # ------------------------------------------------------------------ def initialize(self): """Create N simulated devices with non-IID data distributions. Device data profiles: Device 0: Heavy on IRS scams (60% scam rate) Device 1: Heavy on tech support scams (55% scam rate) Device 2: Mixed scam types (50% scam rate) Device 3: Mostly legitimate calls (15% scam rate) Device 4: Heavy on bank fraud (55% scam rate) Device 5+: Random profile All data is globally normalized once (z-score) so that all devices and the test set share the same feature scale. """ np.random.seed(42) # Generate all device data first to compute global normalization all_X = [] all_y = [] device_splits = [] for i in range(self.n_devices): X, y = self._generate_device_data(i, n_samples=100) device_splits.append((len(all_X), len(all_X) + len(X))) all_X.append(X) all_y.append(y) # Balanced test set rng = np.random.RandomState(999) X_test, y_test = self._generate_test_set(rng, n_samples=300) all_X.append(X_test) # Global z-score normalization all_data = np.vstack(all_X) self._global_mean = all_data.mean(axis=0) self._global_std = all_data.std(axis=0) + 1e-8 # Create devices with normalized data self.devices = [] for i in range(self.n_devices): device = LocalTrainer(device_id=f"device_{i}", input_dim=INPUT_DIM) X = all_X[i] y = all_y[i] X_norm = (X - self._global_mean) / self._global_std for j in range(X_norm.shape[0]): device.ingest_call_data(X_norm[j], int(y[j])) self.devices.append(device) # Normalized test set self.test_features = (X_test - self._global_mean) / self._global_std self.test_labels = y_test def _generate_device_data(self, device_idx: int, n_samples: int = 100) -> tuple: """Generate synthetic 402-dim feature vectors for a device. Scam vectors: positive bias in the first half of dimensions. Legit vectors: negative bias in the first half. Each device gets different class distributions (non-IID). """ rng = np.random.RandomState(42 + device_idx * 1000) device_profiles = { 0: {"scam_rate": 0.60, "irs": 0.70, "tech": 0.10, "bank": 0.10, "generic": 0.10}, 1: {"scam_rate": 0.55, "irs": 0.10, "tech": 0.65, "bank": 0.10, "generic": 0.15}, 2: {"scam_rate": 0.50, "irs": 0.25, "tech": 0.25, "bank": 0.25, "generic": 0.25}, 3: {"scam_rate": 0.15, "irs": 0.25, "tech": 0.25, "bank": 0.25, "generic": 0.25}, 4: {"scam_rate": 0.55, "irs": 0.05, "tech": 0.10, "bank": 0.70, "generic": 0.15}, } profile = device_profiles.get(device_idx, { "scam_rate": rng.uniform(0.3, 0.6), "irs": 0.25, "tech": 0.25, "bank": 0.25, "generic": 0.25, }) X = np.zeros((n_samples, INPUT_DIM)) y = np.zeros(n_samples, dtype=int) for i in range(n_samples): is_scam = rng.random() < profile["scam_rate"] y[i] = 1 if is_scam else 0 if is_scam: scam_type = rng.choice( ["irs", "tech", "bank", "generic"], p=[profile["irs"], profile["tech"], profile["bank"], profile["generic"]], ) X[i] = self._make_scam_vector(rng, scam_type) else: X[i] = self._make_legit_vector(rng) return X, y # ------------------------------------------------------------------ # Synthetic feature vector generators # ------------------------------------------------------------------ def _make_scam_vector(self, rng: np.random.RandomState, scam_type: str) -> np.ndarray: """Create a 402-dim feature vector for a scam call. The discriminative signal is sparse: only a small subset of features carry class information, embedded in high-dimensional noise. This makes the classification problem realistically difficult for federated learning with DP. """ n = INPUT_DIM v = rng.normal(0.0, 0.5, size=n) # lower background noise # Strong discriminative signal in the first 30 features signal_end = 30 v[:signal_end] += rng.normal(2.0, 0.5, size=signal_end) # Scam-type-specific sub-patterns type_start = 30 type_block = 10 offsets = {"irs": 0, "tech": 1, "bank": 2, "generic": 3} idx = offsets.get(scam_type, 3) start = type_start + idx * type_block v[start:start + type_block] += rng.normal(1.5, 0.4, size=type_block) return v def _make_legit_vector(self, rng: np.random.RandomState) -> np.ndarray: """Create a 402-dim feature vector for a legitimate call. Negative bias in the same sparse feature block that scam vectors use, so the MLP must learn to separate in that subspace. """ n = INPUT_DIM v = rng.normal(0.0, 0.5, size=n) # lower background noise # Opposite signal in the discriminative block signal_end = 30 v[:signal_end] += rng.normal(-2.0, 0.5, size=signal_end) return v def _generate_test_set(self, rng: np.random.RandomState, n_samples: int = 300) -> tuple: """Generate a balanced test set (50/50 scam/legit).""" n_half = n_samples // 2 X = np.zeros((n_samples, INPUT_DIM)) y = np.zeros(n_samples, dtype=int) scam_types = ["irs", "tech", "bank", "generic"] for i in range(n_half): stype = rng.choice(scam_types) X[i] = self._make_scam_vector(rng, stype) y[i] = 1 for i in range(n_half, n_samples): X[i] = self._make_legit_vector(rng) y[i] = 0 perm = rng.permutation(n_samples) return X[perm], y[perm] # ------------------------------------------------------------------ # Global model # ------------------------------------------------------------------ def initialize_global_model(self): """Initialize global MLPClassifier with random weights.""" self.global_model = MLPClassifier(input_dim=INPUT_DIM) # ------------------------------------------------------------------ # Federated round # ------------------------------------------------------------------ def run_round(self, round_num: int) -> dict: """Execute one federated round. 1. Each device fine-tunes on its local data (real backprop) 2. Compute gradient delta 3. Add DP noise (if enabled) 4. Hub: FedAvg weighted by n_samples 5. Update global model 6. Evaluate on test set using real MLP forward pass """ global_weights = self.global_model.get_weights() updates = [] device_sigmas = [] for device in self.devices: n_local = device.buffer.size() if n_local == 0: continue # Fine-tune locally with aggressive local training -- # high lr (0.5) and 20 epochs needed to produce a gradient # delta large enough to survive DP noise and FedAvg averaging delta = device.fine_tune(global_weights, lr=0.5, n_epochs=20) if self.use_dp: # DP noise injection noised_delta, sigma, eps_round = self.dp_injector.add_noise( delta, n_local ) device_sigmas.append(sigma) updates.append((noised_delta, n_local)) else: updates.append((delta, n_local)) device.current_model_version = round_num + 1 if len(updates) == 0: metrics = self._evaluate() metrics.update({ "round": round_num, "n_devices": 0, "epsilon_spent": self.dp_injector.privacy_budget_spent( round_num + 1 ) if self.use_dp else 0.0, "avg_sigma": 0.0, }) return metrics # FedAvg aggregation aggregated_delta = self._fedavg_aggregate(updates) # Apply aggregated update to global model (server lr = 1.0, no inflation) new_weights = global_weights + aggregated_delta self.global_model.set_weights(new_weights) # Evaluate metrics = self._evaluate() metrics.update({ "round": round_num, "n_devices": len(updates), "epsilon_spent": self.dp_injector.privacy_budget_spent( round_num + 1 ) if self.use_dp else 0.0, "avg_sigma": float(np.mean(device_sigmas)) if device_sigmas else 0.0, }) # Inject fresh data each round to simulate ongoing call activity self._add_round_data(round_num) return metrics def _add_round_data(self, round_num: int): """Add new training samples each round to simulate ongoing calls.""" extra = 30 + round_num * 10 profiles = {0: 0.60, 1: 0.55, 2: 0.50, 3: 0.15, 4: 0.55} scam_types = ["irs", "tech", "bank", "generic"] for i, device in enumerate(self.devices): rng = np.random.RandomState(42 + i * 1000 + (round_num + 1) * 500) scam_rate = profiles.get(i, 0.4) for j in range(extra): is_scam = rng.random() < scam_rate if is_scam: stype = rng.choice(scam_types) vec = self._make_scam_vector(rng, stype) label = 1 else: vec = self._make_legit_vector(rng) label = 0 # Apply global normalization vec = (vec - self._global_mean) / self._global_std device.ingest_call_data(vec, label) # ------------------------------------------------------------------ # FedAvg aggregation # ------------------------------------------------------------------ def _fedavg_aggregate(self, updates: list) -> np.ndarray: """FedAvg: weighted mean of gradient deltas. G_global = sum(n_i * G_i) / sum(n_i) """ total_samples = sum(n for _, n in updates) if total_samples == 0: return np.zeros_like(updates[0][0]) weighted_sum = np.zeros_like(updates[0][0]) for delta, n_i in updates: weighted_sum += n_i * delta return weighted_sum / total_samples # ------------------------------------------------------------------ # Evaluation # ------------------------------------------------------------------ def _evaluate(self) -> dict: """Evaluate global MLP on test set. Uses the real MLP forward pass (not a linear classifier). Returns accuracy, precision, recall, F1. """ X = self.test_features # already globally normalized y = self.test_labels # Forward pass through the real MLP probs = self.global_model.forward(X) if isinstance(probs, float): probs = np.array([probs]) preds = (probs >= 0.5).astype(int) tp = int(np.sum((preds == 1) & (y == 1))) tn = int(np.sum((preds == 0) & (y == 0))) fp = int(np.sum((preds == 1) & (y == 0))) fn = int(np.sum((preds == 0) & (y == 1))) accuracy = (tp + tn) / max(tp + tn + fp + fn, 1) precision = tp / max(tp + fp, 1) recall = tp / max(tp + fn, 1) f1 = 2 * precision * recall / max(precision + recall, 1e-8) return { "accuracy": float(accuracy), "precision": float(precision), "recall": float(recall), "f1": float(f1), } # ------------------------------------------------------------------ # Main run loop # ------------------------------------------------------------------ def run(self) -> list: """Run full simulation: initialize + all rounds.""" self.initialize() self.initialize_global_model() dp_label = f"epsilon={self.dp_injector.epsilon}" if self.use_dp else "OFF" print(f"\n{'='*60}") print(f"SentinelEdge Federated Learning Simulation (MLP)") print(f"Devices: {self.n_devices} | Rounds: {self.n_rounds}") print(f"Differential Privacy: {dp_label}") print(f"MLP: {INPUT_DIM} -> 128 -> 64 -> 1") print(f"{'='*60}\n") for r in range(self.n_rounds): result = self.run_round(r) self.round_results.append(result) print(f"Round {r+1}/{self.n_rounds}:") print(f" Accuracy: {result['accuracy']:.4f}") print(f" Precision: {result['precision']:.4f}") print(f" Recall: {result['recall']:.4f}") print(f" F1 Score: {result['f1']:.4f}") print(f" Devices: {result['n_devices']}") if self.use_dp: print(f" Epsilon: {result['epsilon_spent']:.4f}") print(f" Avg sigma: {result['avg_sigma']:.6f}") print() return self.round_results def run_dp_comparison(n_devices: int = 5, n_rounds: int = 10) -> dict: """Run the simulation twice: with DP and without DP. Returns a dict with keys 'with_dp' and 'without_dp', each containing the list of round results. Used by visualization.py for comparison plots. """ print("=" * 60) print(" RUNNING COMPARISON: WITH DP vs WITHOUT DP") print("=" * 60) # Run WITH DP np.random.seed(42) sim_dp = FederatedSimulation( n_devices=n_devices, n_rounds=n_rounds, epsilon=0.3, use_dp=True, ) results_dp = sim_dp.run() # Run WITHOUT DP np.random.seed(42) sim_no_dp = FederatedSimulation( n_devices=n_devices, n_rounds=n_rounds, epsilon=0.3, use_dp=False, ) results_no_dp = sim_no_dp.run() return {"with_dp": results_dp, "without_dp": results_no_dp} def main(): parser = argparse.ArgumentParser( description="Run federated learning simulation with real MLP" ) parser.add_argument("--devices", type=int, default=5, help="Number of simulated devices") parser.add_argument("--rounds", type=int, default=10, help="Number of federated rounds") parser.add_argument("--compare", action="store_true", help="Run DP vs no-DP comparison") args = parser.parse_args() output_dir = os.path.dirname(os.path.abspath(__file__)) if args.compare: comparison = run_dp_comparison( n_devices=args.devices, n_rounds=args.rounds ) output_path = os.path.join(output_dir, "simulation_results.json") serializable = { "with_dp": _make_serializable(comparison["with_dp"]), "without_dp": _make_serializable(comparison["without_dp"]), } with open(output_path, "w") as f: json.dump(serializable, f, indent=2) print(f"\nSaved comparison results to {output_path}") # Also generate plots try: from federated.visualization import ( plot_accuracy_over_rounds, plot_dp_comparison, ) plot_accuracy_over_rounds( comparison["with_dp"], output_path=os.path.join(output_dir, "federated_results.png"), ) plot_dp_comparison( comparison, output_path=os.path.join(output_dir, "dp_comparison.png"), ) except ImportError: print("(Skipping plots: matplotlib not available)") else: np.random.seed(42) sim = FederatedSimulation( n_devices=args.devices, n_rounds=args.rounds ) results = sim.run() output_path = os.path.join(output_dir, "simulation_results.json") serializable = _make_serializable(results) with open(output_path, "w") as f: json.dump(serializable, f, indent=2) print(f"\nSaved results to {output_path}") # Generate plot try: from federated.visualization import plot_accuracy_over_rounds plot_accuracy_over_rounds( results, output_path=os.path.join(output_dir, "federated_results.png"), ) except ImportError: print("(Skipping plot: matplotlib not available)") def _make_serializable(results: list) -> list: """Convert numpy types to JSON-serializable Python types.""" out = [] for r in results: out.append({ k: float(v) if isinstance(v, (np.floating, float)) else v for k, v in r.items() }) return out if __name__ == "__main__": main()