SentinelEdge / federated /simulate.py
shiven99's picture
Deploy SentinelEdge demo to HF Spaces
8ee5513
"""Simulate federated learning with N devices over M rounds using a real MLP."""
import numpy as np
import json
import os
import sys
import argparse
from typing import Optional
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from federated.local_trainer import LocalTrainer, LocalTrainingBuffer
from federated.dp_injector import DPInjector
from sentinel_edge.classifier.mlp_classifier import MLPClassifier
INPUT_DIM = 402
class FederatedSimulation:
"""Simulate federated learning with a real 3-layer MLP classifier.
Each round:
1. Each device copies global weights, fine-tunes with real numpy backprop
2. Computes gradient delta (local_weights - global_weights)
3. Applies DP noise (clip + Gaussian)
4. Hub aggregates via FedAvg weighted by n_samples
5. Global model updated, evaluated on hold-out test set
"""
def __init__(self, n_devices: int = 5, n_rounds: int = 10,
epsilon: float = 0.3, use_dp: bool = True):
self.n_devices = n_devices
self.n_rounds = n_rounds
self.use_dp = use_dp
self.devices: list = []
self.global_model: MLPClassifier = None
self.round_results: list = []
self.dp_injector = DPInjector(epsilon=epsilon)
# Hold-out test set
self.test_features: np.ndarray = None
self.test_labels: np.ndarray = None
# ------------------------------------------------------------------
# Device & data initialisation
# ------------------------------------------------------------------
def initialize(self):
"""Create N simulated devices with non-IID data distributions.
Device data profiles:
Device 0: Heavy on IRS scams (60% scam rate)
Device 1: Heavy on tech support scams (55% scam rate)
Device 2: Mixed scam types (50% scam rate)
Device 3: Mostly legitimate calls (15% scam rate)
Device 4: Heavy on bank fraud (55% scam rate)
Device 5+: Random profile
All data is globally normalized once (z-score) so that all
devices and the test set share the same feature scale.
"""
np.random.seed(42)
# Generate all device data first to compute global normalization
all_X = []
all_y = []
device_splits = []
for i in range(self.n_devices):
X, y = self._generate_device_data(i, n_samples=100)
device_splits.append((len(all_X), len(all_X) + len(X)))
all_X.append(X)
all_y.append(y)
# Balanced test set
rng = np.random.RandomState(999)
X_test, y_test = self._generate_test_set(rng, n_samples=300)
all_X.append(X_test)
# Global z-score normalization
all_data = np.vstack(all_X)
self._global_mean = all_data.mean(axis=0)
self._global_std = all_data.std(axis=0) + 1e-8
# Create devices with normalized data
self.devices = []
for i in range(self.n_devices):
device = LocalTrainer(device_id=f"device_{i}", input_dim=INPUT_DIM)
X = all_X[i]
y = all_y[i]
X_norm = (X - self._global_mean) / self._global_std
for j in range(X_norm.shape[0]):
device.ingest_call_data(X_norm[j], int(y[j]))
self.devices.append(device)
# Normalized test set
self.test_features = (X_test - self._global_mean) / self._global_std
self.test_labels = y_test
def _generate_device_data(self, device_idx: int,
n_samples: int = 100) -> tuple:
"""Generate synthetic 402-dim feature vectors for a device.
Scam vectors: positive bias in the first half of dimensions.
Legit vectors: negative bias in the first half.
Each device gets different class distributions (non-IID).
"""
rng = np.random.RandomState(42 + device_idx * 1000)
device_profiles = {
0: {"scam_rate": 0.60, "irs": 0.70, "tech": 0.10, "bank": 0.10, "generic": 0.10},
1: {"scam_rate": 0.55, "irs": 0.10, "tech": 0.65, "bank": 0.10, "generic": 0.15},
2: {"scam_rate": 0.50, "irs": 0.25, "tech": 0.25, "bank": 0.25, "generic": 0.25},
3: {"scam_rate": 0.15, "irs": 0.25, "tech": 0.25, "bank": 0.25, "generic": 0.25},
4: {"scam_rate": 0.55, "irs": 0.05, "tech": 0.10, "bank": 0.70, "generic": 0.15},
}
profile = device_profiles.get(device_idx, {
"scam_rate": rng.uniform(0.3, 0.6),
"irs": 0.25, "tech": 0.25, "bank": 0.25, "generic": 0.25,
})
X = np.zeros((n_samples, INPUT_DIM))
y = np.zeros(n_samples, dtype=int)
for i in range(n_samples):
is_scam = rng.random() < profile["scam_rate"]
y[i] = 1 if is_scam else 0
if is_scam:
scam_type = rng.choice(
["irs", "tech", "bank", "generic"],
p=[profile["irs"], profile["tech"],
profile["bank"], profile["generic"]],
)
X[i] = self._make_scam_vector(rng, scam_type)
else:
X[i] = self._make_legit_vector(rng)
return X, y
# ------------------------------------------------------------------
# Synthetic feature vector generators
# ------------------------------------------------------------------
def _make_scam_vector(self, rng: np.random.RandomState,
scam_type: str) -> np.ndarray:
"""Create a 402-dim feature vector for a scam call.
The discriminative signal is sparse: only a small subset of
features carry class information, embedded in high-dimensional
noise. This makes the classification problem realistically
difficult for federated learning with DP.
"""
n = INPUT_DIM
v = rng.normal(0.0, 0.5, size=n) # lower background noise
# Strong discriminative signal in the first 30 features
signal_end = 30
v[:signal_end] += rng.normal(2.0, 0.5, size=signal_end)
# Scam-type-specific sub-patterns
type_start = 30
type_block = 10
offsets = {"irs": 0, "tech": 1, "bank": 2, "generic": 3}
idx = offsets.get(scam_type, 3)
start = type_start + idx * type_block
v[start:start + type_block] += rng.normal(1.5, 0.4, size=type_block)
return v
def _make_legit_vector(self, rng: np.random.RandomState) -> np.ndarray:
"""Create a 402-dim feature vector for a legitimate call.
Negative bias in the same sparse feature block that scam
vectors use, so the MLP must learn to separate in that subspace.
"""
n = INPUT_DIM
v = rng.normal(0.0, 0.5, size=n) # lower background noise
# Opposite signal in the discriminative block
signal_end = 30
v[:signal_end] += rng.normal(-2.0, 0.5, size=signal_end)
return v
def _generate_test_set(self, rng: np.random.RandomState,
n_samples: int = 300) -> tuple:
"""Generate a balanced test set (50/50 scam/legit)."""
n_half = n_samples // 2
X = np.zeros((n_samples, INPUT_DIM))
y = np.zeros(n_samples, dtype=int)
scam_types = ["irs", "tech", "bank", "generic"]
for i in range(n_half):
stype = rng.choice(scam_types)
X[i] = self._make_scam_vector(rng, stype)
y[i] = 1
for i in range(n_half, n_samples):
X[i] = self._make_legit_vector(rng)
y[i] = 0
perm = rng.permutation(n_samples)
return X[perm], y[perm]
# ------------------------------------------------------------------
# Global model
# ------------------------------------------------------------------
def initialize_global_model(self):
"""Initialize global MLPClassifier with random weights."""
self.global_model = MLPClassifier(input_dim=INPUT_DIM)
# ------------------------------------------------------------------
# Federated round
# ------------------------------------------------------------------
def run_round(self, round_num: int) -> dict:
"""Execute one federated round.
1. Each device fine-tunes on its local data (real backprop)
2. Compute gradient delta
3. Add DP noise (if enabled)
4. Hub: FedAvg weighted by n_samples
5. Update global model
6. Evaluate on test set using real MLP forward pass
"""
global_weights = self.global_model.get_weights()
updates = []
device_sigmas = []
for device in self.devices:
n_local = device.buffer.size()
if n_local == 0:
continue
# Fine-tune locally with aggressive local training --
# high lr (0.5) and 20 epochs needed to produce a gradient
# delta large enough to survive DP noise and FedAvg averaging
delta = device.fine_tune(global_weights, lr=0.5, n_epochs=20)
if self.use_dp:
# DP noise injection
noised_delta, sigma, eps_round = self.dp_injector.add_noise(
delta, n_local
)
device_sigmas.append(sigma)
updates.append((noised_delta, n_local))
else:
updates.append((delta, n_local))
device.current_model_version = round_num + 1
if len(updates) == 0:
metrics = self._evaluate()
metrics.update({
"round": round_num,
"n_devices": 0,
"epsilon_spent": self.dp_injector.privacy_budget_spent(
round_num + 1
) if self.use_dp else 0.0,
"avg_sigma": 0.0,
})
return metrics
# FedAvg aggregation
aggregated_delta = self._fedavg_aggregate(updates)
# Apply aggregated update to global model (server lr = 1.0, no inflation)
new_weights = global_weights + aggregated_delta
self.global_model.set_weights(new_weights)
# Evaluate
metrics = self._evaluate()
metrics.update({
"round": round_num,
"n_devices": len(updates),
"epsilon_spent": self.dp_injector.privacy_budget_spent(
round_num + 1
) if self.use_dp else 0.0,
"avg_sigma": float(np.mean(device_sigmas)) if device_sigmas else 0.0,
})
# Inject fresh data each round to simulate ongoing call activity
self._add_round_data(round_num)
return metrics
def _add_round_data(self, round_num: int):
"""Add new training samples each round to simulate ongoing calls."""
extra = 30 + round_num * 10
profiles = {0: 0.60, 1: 0.55, 2: 0.50, 3: 0.15, 4: 0.55}
scam_types = ["irs", "tech", "bank", "generic"]
for i, device in enumerate(self.devices):
rng = np.random.RandomState(42 + i * 1000 + (round_num + 1) * 500)
scam_rate = profiles.get(i, 0.4)
for j in range(extra):
is_scam = rng.random() < scam_rate
if is_scam:
stype = rng.choice(scam_types)
vec = self._make_scam_vector(rng, stype)
label = 1
else:
vec = self._make_legit_vector(rng)
label = 0
# Apply global normalization
vec = (vec - self._global_mean) / self._global_std
device.ingest_call_data(vec, label)
# ------------------------------------------------------------------
# FedAvg aggregation
# ------------------------------------------------------------------
def _fedavg_aggregate(self, updates: list) -> np.ndarray:
"""FedAvg: weighted mean of gradient deltas.
G_global = sum(n_i * G_i) / sum(n_i)
"""
total_samples = sum(n for _, n in updates)
if total_samples == 0:
return np.zeros_like(updates[0][0])
weighted_sum = np.zeros_like(updates[0][0])
for delta, n_i in updates:
weighted_sum += n_i * delta
return weighted_sum / total_samples
# ------------------------------------------------------------------
# Evaluation
# ------------------------------------------------------------------
def _evaluate(self) -> dict:
"""Evaluate global MLP on test set.
Uses the real MLP forward pass (not a linear classifier).
Returns accuracy, precision, recall, F1.
"""
X = self.test_features # already globally normalized
y = self.test_labels
# Forward pass through the real MLP
probs = self.global_model.forward(X)
if isinstance(probs, float):
probs = np.array([probs])
preds = (probs >= 0.5).astype(int)
tp = int(np.sum((preds == 1) & (y == 1)))
tn = int(np.sum((preds == 0) & (y == 0)))
fp = int(np.sum((preds == 1) & (y == 0)))
fn = int(np.sum((preds == 0) & (y == 1)))
accuracy = (tp + tn) / max(tp + tn + fp + fn, 1)
precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
f1 = 2 * precision * recall / max(precision + recall, 1e-8)
return {
"accuracy": float(accuracy),
"precision": float(precision),
"recall": float(recall),
"f1": float(f1),
}
# ------------------------------------------------------------------
# Main run loop
# ------------------------------------------------------------------
def run(self) -> list:
"""Run full simulation: initialize + all rounds."""
self.initialize()
self.initialize_global_model()
dp_label = f"epsilon={self.dp_injector.epsilon}" if self.use_dp else "OFF"
print(f"\n{'='*60}")
print(f"SentinelEdge Federated Learning Simulation (MLP)")
print(f"Devices: {self.n_devices} | Rounds: {self.n_rounds}")
print(f"Differential Privacy: {dp_label}")
print(f"MLP: {INPUT_DIM} -> 128 -> 64 -> 1")
print(f"{'='*60}\n")
for r in range(self.n_rounds):
result = self.run_round(r)
self.round_results.append(result)
print(f"Round {r+1}/{self.n_rounds}:")
print(f" Accuracy: {result['accuracy']:.4f}")
print(f" Precision: {result['precision']:.4f}")
print(f" Recall: {result['recall']:.4f}")
print(f" F1 Score: {result['f1']:.4f}")
print(f" Devices: {result['n_devices']}")
if self.use_dp:
print(f" Epsilon: {result['epsilon_spent']:.4f}")
print(f" Avg sigma: {result['avg_sigma']:.6f}")
print()
return self.round_results
def run_dp_comparison(n_devices: int = 5, n_rounds: int = 10) -> dict:
"""Run the simulation twice: with DP and without DP.
Returns a dict with keys 'with_dp' and 'without_dp', each containing
the list of round results. Used by visualization.py for comparison plots.
"""
print("=" * 60)
print(" RUNNING COMPARISON: WITH DP vs WITHOUT DP")
print("=" * 60)
# Run WITH DP
np.random.seed(42)
sim_dp = FederatedSimulation(
n_devices=n_devices, n_rounds=n_rounds,
epsilon=0.3, use_dp=True,
)
results_dp = sim_dp.run()
# Run WITHOUT DP
np.random.seed(42)
sim_no_dp = FederatedSimulation(
n_devices=n_devices, n_rounds=n_rounds,
epsilon=0.3, use_dp=False,
)
results_no_dp = sim_no_dp.run()
return {"with_dp": results_dp, "without_dp": results_no_dp}
def main():
parser = argparse.ArgumentParser(
description="Run federated learning simulation with real MLP"
)
parser.add_argument("--devices", type=int, default=5,
help="Number of simulated devices")
parser.add_argument("--rounds", type=int, default=10,
help="Number of federated rounds")
parser.add_argument("--compare", action="store_true",
help="Run DP vs no-DP comparison")
args = parser.parse_args()
output_dir = os.path.dirname(os.path.abspath(__file__))
if args.compare:
comparison = run_dp_comparison(
n_devices=args.devices, n_rounds=args.rounds
)
output_path = os.path.join(output_dir, "simulation_results.json")
serializable = {
"with_dp": _make_serializable(comparison["with_dp"]),
"without_dp": _make_serializable(comparison["without_dp"]),
}
with open(output_path, "w") as f:
json.dump(serializable, f, indent=2)
print(f"\nSaved comparison results to {output_path}")
# Also generate plots
try:
from federated.visualization import (
plot_accuracy_over_rounds, plot_dp_comparison,
)
plot_accuracy_over_rounds(
comparison["with_dp"],
output_path=os.path.join(output_dir, "federated_results.png"),
)
plot_dp_comparison(
comparison,
output_path=os.path.join(output_dir, "dp_comparison.png"),
)
except ImportError:
print("(Skipping plots: matplotlib not available)")
else:
np.random.seed(42)
sim = FederatedSimulation(
n_devices=args.devices, n_rounds=args.rounds
)
results = sim.run()
output_path = os.path.join(output_dir, "simulation_results.json")
serializable = _make_serializable(results)
with open(output_path, "w") as f:
json.dump(serializable, f, indent=2)
print(f"\nSaved results to {output_path}")
# Generate plot
try:
from federated.visualization import plot_accuracy_over_rounds
plot_accuracy_over_rounds(
results,
output_path=os.path.join(output_dir, "federated_results.png"),
)
except ImportError:
print("(Skipping plot: matplotlib not available)")
def _make_serializable(results: list) -> list:
"""Convert numpy types to JSON-serializable Python types."""
out = []
for r in results:
out.append({
k: float(v) if isinstance(v, (np.floating, float)) else v
for k, v in r.items()
})
return out
if __name__ == "__main__":
main()