Danielfonseca1212's picture
Create model.py
c9959a3 verified
"""
relgnn/model.py
Core RelGNN β€” AtenΓ§Γ£o sobre Rotas AtΓ΄micas (sem grafo estΓ‘tico).
Arquitetura:
1. TableEncoder: embeddings por tabela via MLP sobre features numΓ©ricas
2. RouteAggregator: attention ao longo de cada rota (sequΓͺncia de tabelas)
3. HierarchicalAgg: agrega mΓΊltiplas rotas com pesos aprendidos
4. FraudHead: classificador binΓ‘rio final
"""
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from data.routes import AtomicRoute
# ─── CONFIG ───────────────────────────────────────────────────────────────────
@dataclass
class RelGNNConfig:
hidden_dim: int = 64
num_epochs: int = 50
learning_rate: float = 1e-3
dropout: float = 0.2
num_heads: int = 4
seed: int = 42
# ─── TABLE ENCODER ────────────────────────────────────────────────────────────
class TableEncoder(nn.Module):
"""
Codifica as features de uma tabela em um embedding de tamanho `hidden_dim`.
Opera direto nas colunas numΓ©ricas β€” sem conversΓ£o para grafo.
"""
def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim * 2),
nn.LayerNorm(hidden_dim * 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
# ─── ROUTE ATTENTION ──────────────────────────────────────────────────────────
class RouteAttention(nn.Module):
"""
Mecanismo de atenΓ§Γ£o sobre uma Rota AtΓ΄mica.
Recebe sequΓͺncia de embeddings [h1, h2, ..., hK] (K = n_hops + 1)
e retorna um embedding agregado representando a rota.
Implementa atenΓ§Γ£o scaled-dot-product entre os hops.
"""
def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.2):
super().__init__()
self.attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm = nn.LayerNorm(hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim),
)
def forward(self, hop_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
hop_embeddings: [batch, n_hops, hidden_dim]
Returns:
route_emb: [batch, hidden_dim] β€” representaΓ§Γ£o da rota
alpha: [batch, n_hops] β€” pesos de atenΓ§Γ£o por hop
"""
# Self-attention entre os hops da rota
attn_out, alpha = self.attn(hop_embeddings, hop_embeddings, hop_embeddings)
# Residual + norm
attn_out = self.norm(attn_out + hop_embeddings)
# Agrega via mean-pooling ponderado (ΓΊltimo hop = entidade alvo)
# O primeiro token (tabela alvo) agrega informaΓ§Γ΅es dos vizinhos
route_emb = attn_out[:, 0, :] # [batch, hidden_dim]
route_emb = route_emb + self.mlp(route_emb)
alpha_weights = alpha.mean(dim=1)[:, 0, :] # [batch, n_hops]
return route_emb, alpha_weights
# ─── HIERARCHICAL ROUTE AGGREGATOR ───────────────────────────────────────────
class HierarchicalRouteAgg(nn.Module):
"""
Agrega embeddings de mΓΊltiplas rotas com pesos aprendidos.
Cada rota contribui de forma diferente para a prediΓ§Γ£o final.
"""
def __init__(self, hidden_dim: int, num_routes: int):
super().__init__()
self.route_weights = nn.Parameter(torch.ones(num_routes))
self.output_proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, route_embeddings: List[torch.Tensor]) -> torch.Tensor:
"""
Args:
route_embeddings: lista de [batch, hidden_dim], uma por rota
Returns:
agg: [batch, hidden_dim]
"""
stacked = torch.stack(route_embeddings, dim=1) # [batch, R, hidden]
weights = F.softmax(self.route_weights, dim=0) # [R]
weighted = (stacked * weights.unsqueeze(0).unsqueeze(-1)).sum(dim=1)
return self.output_proj(weighted)
# ─── FRAUD HEAD ───────────────────────────────────────────────────────────────
class FraudHead(nn.Module):
def __init__(self, hidden_dim: int, dropout: float = 0.2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x).squeeze(-1) # [batch]
# ─── RELGNN ───────────────────────────────────────────────────────────────────
class RelGNN(nn.Module):
"""
RelGNN completo.
Fluxo:
tabelas SQL
β†’ TableEncoder (por tabela)
β†’ RouteAttention (por rota atΓ΄mica)
β†’ HierarchicalRouteAgg
β†’ FraudHead
β†’ sigmoid(logit) = P(fraude)
"""
def __init__(self, config: RelGNNConfig):
super().__init__()
self.config = config
torch.manual_seed(config.seed)
def build(self, feature_dims: Dict[str, int], routes: List[AtomicRoute]):
"""Instancia os mΓ³dulos apΓ³s conhecer as dimensΓ΅es das features."""
H = self.config.hidden_dim
D = self.config.dropout
self.table_encoders = nn.ModuleDict({
table: TableEncoder(dim, H, D)
for table, dim in feature_dims.items()
})
self.route_attns = nn.ModuleList([
RouteAttention(H, self.config.num_heads, D)
for _ in routes
])
self.hierarchical = HierarchicalRouteAgg(H, len(routes))
self.fraud_head = FraudHead(H, D)
self.routes = routes
def forward(
self,
table_features: Dict[str, torch.Tensor],
) -> Tuple[torch.Tensor, Dict]:
"""
Args:
table_features: {table_name: [batch, feature_dim]}
Returns:
logits: [batch]
attention_info: dict com pesos de atenΓ§Γ£o por rota
"""
# 1. Encoder por tabela
table_embs = {
table: encoder(table_features[table])
for table, encoder in self.table_encoders.items()
if table in table_features
}
# 2. Attention por rota atΓ΄mica
route_embs = []
attention_info = {}
for i, (route, attn_module) in enumerate(zip(self.routes, self.route_attns)):
# Coleta embeddings das tabelas na rota
available = [t for t in route.path if t in table_embs]
if len(available) < 2:
# Usa embedding da tabela alvo repetido se rota incompleta
e = table_embs.get(route.path[0], list(table_embs.values())[0])
route_embs.append(e)
continue
hop_list = [table_embs[t] for t in available]
hop_tensor = torch.stack(hop_list, dim=1) # [batch, K, H]
route_emb, alpha = attn_module(hop_tensor)
route_embs.append(route_emb)
attention_info[f"route_{i}"] = alpha.detach().cpu().numpy()
# 3. Agrega rotas hierarquicamente
agg = self.hierarchical(route_embs)
# 4. Classificador de fraude
logits = self.fraud_head(agg)
return logits, attention_info
def fit(self, tables, routes, log_fn=print, progress_fn=None):
"""Wrapper de treinamento completo."""
from relgnn.trainer import Trainer
trainer = Trainer(self, self.config)
return trainer.fit(tables, routes, log_fn=log_fn, progress_fn=progress_fn)