| """ |
| relgnn/model.py |
| Core RelGNN β AtenΓ§Γ£o sobre Rotas AtΓ΄micas (sem grafo estΓ‘tico). |
| |
| Arquitetura: |
| 1. TableEncoder: embeddings por tabela via MLP sobre features numΓ©ricas |
| 2. RouteAggregator: attention ao longo de cada rota (sequΓͺncia de tabelas) |
| 3. HierarchicalAgg: agrega mΓΊltiplas rotas com pesos aprendidos |
| 4. FraudHead: classificador binΓ‘rio final |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import List, Dict, Tuple, Optional |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from data.routes import AtomicRoute |
|
|
|
|
| |
|
|
| @dataclass |
| class RelGNNConfig: |
| hidden_dim: int = 64 |
| num_epochs: int = 50 |
| learning_rate: float = 1e-3 |
| dropout: float = 0.2 |
| num_heads: int = 4 |
| seed: int = 42 |
|
|
|
|
| |
|
|
| class TableEncoder(nn.Module): |
| """ |
| Codifica as features de uma tabela em um embedding de tamanho `hidden_dim`. |
| Opera direto nas colunas numΓ©ricas β sem conversΓ£o para grafo. |
| """ |
| def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.2): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim * 2), |
| nn.LayerNorm(hidden_dim * 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.ReLU(), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| |
|
|
| class RouteAttention(nn.Module): |
| """ |
| Mecanismo de atenΓ§Γ£o sobre uma Rota AtΓ΄mica. |
| Recebe sequΓͺncia de embeddings [h1, h2, ..., hK] (K = n_hops + 1) |
| e retorna um embedding agregado representando a rota. |
| |
| Implementa atenΓ§Γ£o scaled-dot-product entre os hops. |
| """ |
| def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.2): |
| super().__init__() |
| self.attn = nn.MultiheadAttention( |
| embed_dim=hidden_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True, |
| ) |
| self.norm = nn.LayerNorm(hidden_dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim * 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| ) |
|
|
| def forward(self, hop_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| hop_embeddings: [batch, n_hops, hidden_dim] |
| Returns: |
| route_emb: [batch, hidden_dim] β representaΓ§Γ£o da rota |
| alpha: [batch, n_hops] β pesos de atenΓ§Γ£o por hop |
| """ |
| |
| attn_out, alpha = self.attn(hop_embeddings, hop_embeddings, hop_embeddings) |
|
|
| |
| attn_out = self.norm(attn_out + hop_embeddings) |
|
|
| |
| |
| route_emb = attn_out[:, 0, :] |
| route_emb = route_emb + self.mlp(route_emb) |
|
|
| alpha_weights = alpha.mean(dim=1)[:, 0, :] |
| return route_emb, alpha_weights |
|
|
|
|
| |
|
|
| class HierarchicalRouteAgg(nn.Module): |
| """ |
| Agrega embeddings de mΓΊltiplas rotas com pesos aprendidos. |
| Cada rota contribui de forma diferente para a prediΓ§Γ£o final. |
| """ |
| def __init__(self, hidden_dim: int, num_routes: int): |
| super().__init__() |
| self.route_weights = nn.Parameter(torch.ones(num_routes)) |
| self.output_proj = nn.Linear(hidden_dim, hidden_dim) |
|
|
| def forward(self, route_embeddings: List[torch.Tensor]) -> torch.Tensor: |
| """ |
| Args: |
| route_embeddings: lista de [batch, hidden_dim], uma por rota |
| Returns: |
| agg: [batch, hidden_dim] |
| """ |
| stacked = torch.stack(route_embeddings, dim=1) |
| weights = F.softmax(self.route_weights, dim=0) |
| weighted = (stacked * weights.unsqueeze(0).unsqueeze(-1)).sum(dim=1) |
| return self.output_proj(weighted) |
|
|
|
|
| |
|
|
| class FraudHead(nn.Module): |
| def __init__(self, hidden_dim: int, dropout: float = 0.2): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, 1), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x).squeeze(-1) |
|
|
|
|
| |
|
|
| class RelGNN(nn.Module): |
| """ |
| RelGNN completo. |
| |
| Fluxo: |
| tabelas SQL |
| β TableEncoder (por tabela) |
| β RouteAttention (por rota atΓ΄mica) |
| β HierarchicalRouteAgg |
| β FraudHead |
| β sigmoid(logit) = P(fraude) |
| """ |
|
|
| def __init__(self, config: RelGNNConfig): |
| super().__init__() |
| self.config = config |
| torch.manual_seed(config.seed) |
|
|
| def build(self, feature_dims: Dict[str, int], routes: List[AtomicRoute]): |
| """Instancia os mΓ³dulos apΓ³s conhecer as dimensΓ΅es das features.""" |
| H = self.config.hidden_dim |
| D = self.config.dropout |
|
|
| self.table_encoders = nn.ModuleDict({ |
| table: TableEncoder(dim, H, D) |
| for table, dim in feature_dims.items() |
| }) |
|
|
| self.route_attns = nn.ModuleList([ |
| RouteAttention(H, self.config.num_heads, D) |
| for _ in routes |
| ]) |
|
|
| self.hierarchical = HierarchicalRouteAgg(H, len(routes)) |
| self.fraud_head = FraudHead(H, D) |
| self.routes = routes |
|
|
| def forward( |
| self, |
| table_features: Dict[str, torch.Tensor], |
| ) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Args: |
| table_features: {table_name: [batch, feature_dim]} |
| Returns: |
| logits: [batch] |
| attention_info: dict com pesos de atenΓ§Γ£o por rota |
| """ |
| |
| table_embs = { |
| table: encoder(table_features[table]) |
| for table, encoder in self.table_encoders.items() |
| if table in table_features |
| } |
|
|
| |
| route_embs = [] |
| attention_info = {} |
|
|
| for i, (route, attn_module) in enumerate(zip(self.routes, self.route_attns)): |
| |
| available = [t for t in route.path if t in table_embs] |
| if len(available) < 2: |
| |
| e = table_embs.get(route.path[0], list(table_embs.values())[0]) |
| route_embs.append(e) |
| continue |
|
|
| hop_list = [table_embs[t] for t in available] |
| hop_tensor = torch.stack(hop_list, dim=1) |
|
|
| route_emb, alpha = attn_module(hop_tensor) |
| route_embs.append(route_emb) |
| attention_info[f"route_{i}"] = alpha.detach().cpu().numpy() |
|
|
| |
| agg = self.hierarchical(route_embs) |
|
|
| |
| logits = self.fraud_head(agg) |
|
|
| return logits, attention_info |
|
|
| def fit(self, tables, routes, log_fn=print, progress_fn=None): |
| """Wrapper de treinamento completo.""" |
| from relgnn.trainer import Trainer |
| trainer = Trainer(self, self.config) |
| return trainer.fit(tables, routes, log_fn=log_fn, progress_fn=progress_fn) |