""" relgnn/model.py Core RelGNN — Atenção sobre Rotas Atômicas (sem grafo estático). Arquitetura: 1. TableEncoder: embeddings por tabela via MLP sobre features numéricas 2. RouteAggregator: attention ao longo de cada rota (sequência de tabelas) 3. HierarchicalAgg: agrega múltiplas rotas com pesos aprendidos 4. FraudHead: classificador binário final """ from dataclasses import dataclass from typing import List, Dict, Tuple, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from data.routes import AtomicRoute # ─── CONFIG ─────────────────────────────────────────────────────────────────── @dataclass class RelGNNConfig: hidden_dim: int = 64 num_epochs: int = 50 learning_rate: float = 1e-3 dropout: float = 0.2 num_heads: int = 4 seed: int = 42 # ─── TABLE ENCODER ──────────────────────────────────────────────────────────── class TableEncoder(nn.Module): """ Codifica as features de uma tabela em um embedding de tamanho `hidden_dim`. Opera direto nas colunas numéricas — sem conversão para grafo. """ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.2): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim * 2), nn.LayerNorm(hidden_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) # ─── ROUTE ATTENTION ────────────────────────────────────────────────────────── class RouteAttention(nn.Module): """ Mecanismo de atenção sobre uma Rota Atômica. Recebe sequência de embeddings [h1, h2, ..., hK] (K = n_hops + 1) e retorna um embedding agregado representando a rota. Implementa atenção scaled-dot-product entre os hops. """ def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.2): super().__init__() self.attn = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) self.norm = nn.LayerNorm(hidden_dim) self.mlp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 2, hidden_dim), ) def forward(self, hop_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: hop_embeddings: [batch, n_hops, hidden_dim] Returns: route_emb: [batch, hidden_dim] — representação da rota alpha: [batch, n_hops] — pesos de atenção por hop """ # Self-attention entre os hops da rota attn_out, alpha = self.attn(hop_embeddings, hop_embeddings, hop_embeddings) # Residual + norm attn_out = self.norm(attn_out + hop_embeddings) # Agrega via mean-pooling ponderado (último hop = entidade alvo) # O primeiro token (tabela alvo) agrega informações dos vizinhos route_emb = attn_out[:, 0, :] # [batch, hidden_dim] route_emb = route_emb + self.mlp(route_emb) alpha_weights = alpha.mean(dim=1)[:, 0, :] # [batch, n_hops] return route_emb, alpha_weights # ─── HIERARCHICAL ROUTE AGGREGATOR ─────────────────────────────────────────── class HierarchicalRouteAgg(nn.Module): """ Agrega embeddings de múltiplas rotas com pesos aprendidos. Cada rota contribui de forma diferente para a predição final. """ def __init__(self, hidden_dim: int, num_routes: int): super().__init__() self.route_weights = nn.Parameter(torch.ones(num_routes)) self.output_proj = nn.Linear(hidden_dim, hidden_dim) def forward(self, route_embeddings: List[torch.Tensor]) -> torch.Tensor: """ Args: route_embeddings: lista de [batch, hidden_dim], uma por rota Returns: agg: [batch, hidden_dim] """ stacked = torch.stack(route_embeddings, dim=1) # [batch, R, hidden] weights = F.softmax(self.route_weights, dim=0) # [R] weighted = (stacked * weights.unsqueeze(0).unsqueeze(-1)).sum(dim=1) return self.output_proj(weighted) # ─── FRAUD HEAD ─────────────────────────────────────────────────────────────── class FraudHead(nn.Module): def __init__(self, hidden_dim: int, dropout: float = 0.2): super().__init__() self.net = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x).squeeze(-1) # [batch] # ─── RELGNN ─────────────────────────────────────────────────────────────────── class RelGNN(nn.Module): """ RelGNN completo. Fluxo: tabelas SQL → TableEncoder (por tabela) → RouteAttention (por rota atômica) → HierarchicalRouteAgg → FraudHead → sigmoid(logit) = P(fraude) """ def __init__(self, config: RelGNNConfig): super().__init__() self.config = config torch.manual_seed(config.seed) def build(self, feature_dims: Dict[str, int], routes: List[AtomicRoute]): """Instancia os módulos após conhecer as dimensões das features.""" H = self.config.hidden_dim D = self.config.dropout self.table_encoders = nn.ModuleDict({ table: TableEncoder(dim, H, D) for table, dim in feature_dims.items() }) self.route_attns = nn.ModuleList([ RouteAttention(H, self.config.num_heads, D) for _ in routes ]) self.hierarchical = HierarchicalRouteAgg(H, len(routes)) self.fraud_head = FraudHead(H, D) self.routes = routes def forward( self, table_features: Dict[str, torch.Tensor], ) -> Tuple[torch.Tensor, Dict]: """ Args: table_features: {table_name: [batch, feature_dim]} Returns: logits: [batch] attention_info: dict com pesos de atenção por rota """ # 1. Encoder por tabela table_embs = { table: encoder(table_features[table]) for table, encoder in self.table_encoders.items() if table in table_features } # 2. Attention por rota atômica route_embs = [] attention_info = {} for i, (route, attn_module) in enumerate(zip(self.routes, self.route_attns)): # Coleta embeddings das tabelas na rota available = [t for t in route.path if t in table_embs] if len(available) < 2: # Usa embedding da tabela alvo repetido se rota incompleta e = table_embs.get(route.path[0], list(table_embs.values())[0]) route_embs.append(e) continue hop_list = [table_embs[t] for t in available] hop_tensor = torch.stack(hop_list, dim=1) # [batch, K, H] route_emb, alpha = attn_module(hop_tensor) route_embs.append(route_emb) attention_info[f"route_{i}"] = alpha.detach().cpu().numpy() # 3. Agrega rotas hierarquicamente agg = self.hierarchical(route_embs) # 4. Classificador de fraude logits = self.fraud_head(agg) return logits, attention_info def fit(self, tables, routes, log_fn=print, progress_fn=None): """Wrapper de treinamento completo.""" from relgnn.trainer import Trainer trainer = Trainer(self, self.config) return trainer.fit(tables, routes, log_fn=log_fn, progress_fn=progress_fn)