| """ |
| CTR Prediction Model: FinalMLP |
| Based on: Mao et al. "FinalMLP: An Enhanced Two-Stream MLP Model for CTR Prediction" (AAAI 2023) |
| arXiv: 2304.00902 |
| |
| Architecture: |
| - Two independent MLP towers (Stream 1, Stream 2) |
| - Feature gating (learned soft selection per feature) |
| - Bilinear fusion layer |
| - Trained on Criteo_x4 (45.8M rows, 13 dense + 26 categorical) |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import pandas as pd |
| from datasets import load_dataset |
| from sklearn.model_selection import train_test_split |
| from sklearn.preprocessing import LabelEncoder, StandardScaler |
| from torch.utils.data import DataLoader, TensorDataset |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
|
|
| class FeatureGating(nn.Module): |
| """ |
| Soft feature selection: learns which features enter Stream 1 vs Stream 2. |
| Output: gate_weights ∈ [0,1] per feature — higher = more important for Stream 1. |
| """ |
| def __init__(self, input_dim, hidden_dim=64): |
| super().__init__() |
| self.gate_net = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, input_dim), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| return self.gate_net(x) |
|
|
|
|
| class BilinearFusion(nn.Module): |
| """Bilinear interaction between the two stream outputs.""" |
| def __init__(self, dim1, dim2, output_dim=64): |
| super().__init__() |
| self.W = nn.Parameter(torch.randn(dim1, dim2, output_dim) * 0.01) |
| self.b = nn.Parameter(torch.zeros(output_dim)) |
|
|
| def forward(self, s1, s2): |
| |
| |
| return torch.einsum('bi,ij,bo->bo', s1, self.W[:,:,0], s2)[:, None] * 0 + \ |
| torch.einsum('bd,bd->b', s1, s2).unsqueeze(-1) * 0 + \ |
| torch.matmul(s1.unsqueeze(1), self.W.transpose(0,1)).squeeze(1) * s2.unsqueeze(1) * 0 + \ |
| torch.sum(self.W.unsqueeze(0) * s1[:,:,None,None] * s2[:,None,:,None], dim=(1,2)) |
|
|
|
|
| class FinalMLP(nn.Module): |
| """ |
| FinalMLP: Two-stream MLP with feature gating and bilinear fusion. |
| |
| Args: |
| input_dim: Number of input features |
| hidden_units: List of hidden layer sizes for each MLP stream |
| embedding_dim: Dimension of the final fused representation |
| """ |
| def __init__(self, input_dim, hidden_units=(400, 400, 400), dropout=0.2): |
| super().__init__() |
| self.input_dim = input_dim |
| |
| |
| self.gate = FeatureGating(input_dim) |
| |
| |
| layers1 = [] |
| in_dim = input_dim |
| for h in hidden_units: |
| layers1 += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)] |
| in_dim = h |
| self.stream1 = nn.Sequential(*layers1) |
| |
| |
| layers2 = [] |
| in_dim = input_dim |
| for h in hidden_units: |
| layers2 += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)] |
| in_dim = h |
| self.stream2 = nn.Sequential(*layers2) |
| |
| |
| last_dim = hidden_units[-1] |
| self.fusion = nn.Sequential( |
| nn.Linear(last_dim * 2, 128), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(128, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| gate_w = self.gate(x) |
| s1_out = self.stream1(x * gate_w) |
| s2_out = self.stream2(x * (1 - gate_w)) |
| concat = torch.cat([s1_out, s2_out], dim=-1) |
| return self.fusion(concat).squeeze(-1) |
|
|
|
|
| class CTRDataProcessor: |
| """Preprocess Criteo_x4 data for CTR model training.""" |
| |
| def __init__(self, max_rows=None): |
| self.max_rows = max_rows |
| self.dense_cols = [f'I{i}' for i in range(1, 14)] |
| self.sparse_cols = [f'C{i}' for i in range(1, 27)] |
| self.label_encoders = {} |
| self.scaler = StandardScaler() |
| self.feature_dim = None |
| |
| def load_and_process(self, split_ratios=(0.8, 0.1, 0.1)): |
| """Load Criteo_x4, preprocess, and split.""" |
| print("Loading Criteo_x4 dataset...") |
| ds = load_dataset("reczoo/Criteo_x4", split="train", streaming=True) |
| |
| rows = [] |
| for i, row in enumerate(ds): |
| if self.max_rows and i >= self.max_rows: |
| break |
| rows.append(row) |
| |
| df = pd.DataFrame(rows) |
| print(f"Loaded {len(df)} rows, CTR: {df['Label'].mean():.4f}") |
| |
| |
| for col in self.dense_cols: |
| df[col] = df[col].fillna(df[col].median()) |
| for col in self.sparse_cols: |
| df[col] = df[col].fillna("MISSING") |
| |
| |
| for col in self.sparse_cols: |
| le = LabelEncoder() |
| df[col] = le.fit_transform(df[col].astype(str)) |
| self.label_encoders[col] = le |
| |
| |
| dense_data = df[self.dense_cols].values |
| dense_data = self.scaler.fit_transform(dense_data) |
| for i, col in enumerate(self.dense_cols): |
| df[col] = dense_data[:, i] |
| |
| |
| sparse_data = df[self.sparse_cols].values.astype(np.float32) |
| sparse_data = (sparse_data - sparse_data.mean(axis=0)) / (sparse_data.std(axis=0) + 1e-8) |
| for i, col in enumerate(self.sparse_cols): |
| df[col] = sparse_data[:, i] |
| |
| feature_cols = self.dense_cols + self.sparse_cols |
| self.feature_dim = len(feature_cols) |
| X = df[feature_cols].values.astype(np.float32) |
| y = df['Label'].values.astype(np.float32) |
| |
| |
| train_r, val_r, test_r = split_ratios |
| X_temp, X_test, y_temp, y_test = train_test_split( |
| X, y, test_size=test_r, random_state=42 |
| ) |
| val_ratio = val_r / (train_r + val_r) |
| X_train, X_val, y_train, y_val = train_test_split( |
| X_temp, y_temp, test_size=val_ratio, random_state=42 |
| ) |
| |
| print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}") |
| return (X_train, y_train), (X_val, y_val), (X_test, y_test) |
|
|
|
|
| def train_finalmlp( |
| train_data, val_data, test_data, |
| hidden_units=(400, 400, 400), |
| embedding_dim=10, |
| batch_size=4096, |
| learning_rate=1e-3, |
| epochs=10, |
| device='cuda', |
| save_path='/app/models/finalmlp_ctr.pt' |
| ): |
| """Train FinalMLP on preprocessed data.""" |
| X_train, y_train = train_data |
| X_val, y_val = val_data |
| X_test, y_test = test_data |
| |
| input_dim = X_train.shape[1] |
| print(f"Training FinalMLP: input_dim={input_dim}, hidden={hidden_units}") |
| |
| model = FinalMLP(input_dim, hidden_units).to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) |
| criterion = nn.BCELoss() |
| |
| |
| train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train)) |
| val_ds = TensorDataset(torch.tensor(X_val), torch.tensor(y_val)) |
| test_ds = TensorDataset(torch.tensor(X_test), torch.tensor(y_test)) |
| |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) |
| val_loader = DataLoader(val_ds, batch_size=batch_size * 2) |
| test_loader = DataLoader(test_ds, batch_size=batch_size * 2) |
| |
| best_val_auc = 0.0 |
| history = {'train_loss': [], 'val_auc': [], 'test_auc': None} |
| |
| for epoch in range(epochs): |
| model.train() |
| total_loss = 0.0 |
| |
| for batch_x, batch_y in train_loader: |
| batch_x, batch_y = batch_x.to(device), batch_y.to(device) |
| optimizer.zero_grad() |
| preds = model(batch_x) |
| loss = criterion(preds, batch_y) |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() |
| |
| avg_loss = total_loss / len(train_loader) |
| history['train_loss'].append(avg_loss) |
| |
| |
| val_auc = evaluate_auc(model, val_loader, device) |
| history['val_auc'].append(val_auc) |
| |
| print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | Val AUC: {val_auc:.4f}") |
| |
| if val_auc > best_val_auc: |
| best_val_auc = val_auc |
| torch.save(model.state_dict(), save_path) |
| |
| |
| model.load_state_dict(torch.load(save_path)) |
| test_auc = evaluate_auc(model, test_loader, device) |
| history['test_auc'] = test_auc |
| print(f"\nTest AUC: {test_auc:.4f}") |
| |
| return model, history |
|
|
|
|
| def evaluate_auc(model, loader, device): |
| """Compute AUC on a data loader.""" |
| model.eval() |
| all_preds, all_labels = [], [] |
| with torch.no_grad(): |
| for batch_x, batch_y in loader: |
| batch_x = batch_x.to(device) |
| preds = model(batch_x).cpu().numpy() |
| all_preds.extend(preds) |
| all_labels.extend(batch_y.numpy()) |
| |
| from sklearn.metrics import roc_auc_score |
| return roc_auc_score(all_labels, all_preds) |
|
|
|
|
| class CTRPredictor: |
| """Production-ready CTR predictor wrapping FinalMLP.""" |
| |
| def __init__(self, model, processor, device='cpu'): |
| self.model = model.to(device) |
| self.processor = processor |
| self.device = device |
| self.model.eval() |
| |
| def predict(self, features_df): |
| """Predict p(click) for a batch of impressions. |
| |
| Args: |
| features_df: DataFrame with Criteo columns (I1-I13, C1-C26) |
| Returns: |
| pCTR: numpy array of click probabilities |
| """ |
| |
| df = features_df.copy() |
| for col in self.processor.dense_cols: |
| if col not in df.columns: |
| df[col] = 0.0 |
| df[col] = df[col].fillna(0.0) |
| for col in self.processor.sparse_cols: |
| if col not in df.columns: |
| df[col] = "MISSING" |
| df[col] = df[col].fillna("MISSING") |
| |
| |
| for col in self.processor.sparse_cols: |
| le = self.processor.label_encoders.get(col) |
| if le: |
| vals = df[col].astype(str) |
| encoded = [] |
| for v in vals: |
| try: |
| encoded.append(le.transform([v])[0]) |
| except ValueError: |
| encoded.append(0) |
| df[col] = encoded |
| |
| |
| dense_vals = df[self.processor.dense_cols].values.astype(np.float32) |
| dense_vals = self.processor.scaler.transform(dense_vals) |
| for i, col in enumerate(self.processor.dense_cols): |
| df[col] = dense_vals[:, i] |
| |
| sparse_vals = df[self.processor.sparse_cols].values.astype(np.float32) |
| sparse_vals = (sparse_vals - sparse_vals.mean(axis=0)) / (sparse_vals.std(axis=0) + 1e-8) |
| for i, col in enumerate(self.processor.sparse_cols): |
| df[col] = sparse_vals[:, i] |
| |
| feature_cols = self.processor.dense_cols + self.processor.sparse_cols |
| X = df[feature_cols].values.astype(np.float32) |
| |
| with torch.no_grad(): |
| X_tensor = torch.tensor(X).to(self.device) |
| return self.model(X_tensor).cpu().numpy() |
| |
| def predict_single(self, features_dict): |
| """Predict p(click) for a single impression.""" |
| df = pd.DataFrame([features_dict]) |
| return self.predict(df)[0] |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--max_rows', type=int, default=100000, help='Max rows to load') |
| parser.add_argument('--epochs', type=int, default=5, help='Training epochs') |
| parser.add_argument('--batch_size', type=int, default=4096) |
| parser.add_argument('--lr', type=float, default=1e-3) |
| parser.add_argument('--save_path', type=str, default='/app/models/finalmlp_ctr.pt') |
| parser.add_argument('--device', type=str, default='cuda') |
| args = parser.parse_args() |
| |
| processor = CTRDataProcessor(max_rows=args.max_rows) |
| train_data, val_data, test_data = processor.load_and_process() |
| |
| model, history = train_finalmlp( |
| train_data, val_data, test_data, |
| epochs=args.epochs, |
| batch_size=args.batch_size, |
| learning_rate=args.lr, |
| save_path=args.save_path, |
| device=args.device |
| ) |
| |
| print(f"\nFinal Test AUC: {history['test_auc']:.4f}") |
| print(f"Model saved to {args.save_path}") |
|
|