| | import hashlib |
| | from collections import defaultdict |
| | from typing import Dict, List, Tuple, TYPE_CHECKING, Optional |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch_geometric.data import HeteroData, Batch |
| | from torch_geometric.nn import HeteroConv, GATConv, global_mean_pool |
| | from transformers import AutoModel, AutoTokenizer |
| | from tqdm import tqdm |
| | import numpy as np |
| |
|
| | if TYPE_CHECKING: |
| | import pandas as pd |
| |
|
| | from dataloader import CodeGraphBuilder |
| |
|
| | class RelationalGraphEncoder(nn.Module): |
| | """R-GNN encoder over the AST+CFG heterogeneous graph.""" |
| |
|
| | EDGE_TYPES = ( |
| | ("ast", "ast_parent_child", "ast"), |
| | ("ast", "ast_child_parent", "ast"), |
| | ("ast", "ast_next_sibling", "ast"), |
| | ("ast", "ast_prev_sibling", "ast"), |
| | ("token", "token_to_ast", "ast"), |
| | ("ast", "ast_to_token", "token"), |
| | ("stmt", "cfg", "stmt"), |
| | ("stmt", "cfg_rev", "stmt"), |
| | ("stmt", "stmt_to_ast", "ast"), |
| | ("ast", "ast_to_stmt", "stmt"), |
| | ) |
| |
|
| | def __init__(self, hidden_dim: int = 256, out_dim: int = 768, num_layers: int = 2) -> None: |
| | super().__init__() |
| | self.hidden_dim = hidden_dim |
| | self.out_dim = out_dim |
| |
|
| | self.ast_encoder = nn.Embedding(2048, hidden_dim) |
| | self.token_encoder = nn.Embedding(8192, hidden_dim) |
| | self.stmt_encoder = nn.Embedding(512, hidden_dim) |
| |
|
| | self.convs = nn.ModuleList() |
| | for _ in range(num_layers): |
| | hetero_modules = { |
| | edge_type: GATConv((-1, -1), hidden_dim, add_self_loops=False) |
| | for edge_type in self.EDGE_TYPES |
| | } |
| | hetero_conv = HeteroConv(hetero_modules, aggr="sum") |
| | self.convs.append(hetero_conv) |
| |
|
| | self.output_proj = nn.Linear(hidden_dim, out_dim) |
| |
|
| | def _encode_nodes(self, data: HeteroData) -> Dict[str, torch.Tensor]: |
| | device = self.ast_encoder.weight.device |
| | |
| | def get_embed(node_type, encoder): |
| | if node_type not in data.node_types: |
| | return torch.zeros((0, self.hidden_dim), device=device) |
| | |
| | x = data[node_type].get('x') |
| | if x is None: |
| | return torch.zeros((0, self.hidden_dim), device=device) |
| | |
| | x = x.to(device) |
| | return encoder(x) |
| |
|
| | x_dict = { |
| | "ast": get_embed("ast", self.ast_encoder), |
| | "token": get_embed("token", self.token_encoder), |
| | "stmt": get_embed("stmt", self.stmt_encoder), |
| | } |
| | return x_dict |
| |
|
| | def forward(self, data: HeteroData) -> torch.Tensor: |
| | device = next(self.parameters()).device |
| | data = data.to(device) |
| | |
| | x_dict = self._encode_nodes(data) |
| |
|
| | edge_index_dict = {} |
| | for edge_type in self.EDGE_TYPES: |
| | if edge_type in data.edge_index_dict: |
| | edge_index_dict[edge_type] = data.edge_index_dict[edge_type] |
| |
|
| | for conv in self.convs: |
| | x_dict = conv(x_dict, edge_index_dict) |
| | x_dict = {key: F.relu(x) for key, x in x_dict.items()} |
| |
|
| | batch_size = data.num_graphs if hasattr(data, 'num_graphs') else 1 |
| | |
| | pooled_embeddings = [] |
| | for key, x in x_dict.items(): |
| | if x.size(0) == 0: |
| | continue |
| | |
| | if hasattr(data[key], 'batch') and data[key].batch is not None: |
| | pool = global_mean_pool(x, data[key].batch, size=batch_size) |
| | else: |
| | pool = x.mean(dim=0, keepdim=True) |
| | if pool.size(0) != batch_size: |
| | pass |
| | pooled_embeddings.append(pool) |
| | |
| | if not pooled_embeddings: |
| | return torch.zeros((batch_size, self.out_dim), device=device) |
| |
|
| | graph_repr = torch.stack(pooled_embeddings).mean(dim=0) |
| | return self.output_proj(graph_repr) |
| |
|
| |
|
| | class GatedFusion(nn.Module): |
| | def __init__(self, text_dim: int, graph_dim: int) -> None: |
| | super().__init__() |
| | self.graph_proj = nn.Linear(graph_dim, text_dim) |
| | self.gate = nn.Linear(text_dim * 2, text_dim) |
| |
|
| | def forward(self, h_text: torch.Tensor, h_graph: torch.Tensor) -> torch.Tensor: |
| | h_graph_proj = self.graph_proj(h_graph) |
| | joint = torch.cat([h_text, h_graph_proj], dim=-1) |
| | gate = torch.sigmoid(self.gate(joint)) |
| | return gate * h_text + (1.0 - gate) * h_graph_proj |
| |
|
| |
|
| | class StructuralEncoderV2(nn.Module): |
| | """Structural encoder that fuses GraphCodeBERT text features with AST+CFG graph context.""" |
| |
|
| | def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2): |
| | super().__init__() |
| | self.device = torch.device(device) |
| | |
| | self.text_model = AutoModel.from_pretrained("microsoft/graphcodebert-base") |
| | self.text_model.to(self.device) |
| |
|
| | self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=self.text_model.config.hidden_size, num_layers=graph_layers) |
| | self.graph_encoder.to(self.device) |
| |
|
| | self.fusion = GatedFusion(self.text_model.config.hidden_size, self.text_model.config.hidden_size) |
| | self.fusion.to(self.device) |
| |
|
| | def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| | input_ids = input_ids.to(self.device) |
| | attention_mask = attention_mask.to(self.device) |
| | outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) |
| | return outputs.last_hidden_state[:, 0, :] |
| |
|
| | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch | HeteroData) -> torch.Tensor: |
| | text_embeddings = self.encode_text(input_ids, attention_mask) |
| | graph_embeddings = self.graph_encoder(graph_batch) |
| | return self.fusion(text_embeddings, graph_embeddings) |
| |
|
| | def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural V2 embeddings") -> np.ndarray: |
| | |
| | builder = CodeGraphBuilder() |
| | tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") |
| | |
| | codes = df["code"].tolist() |
| | batches = range(0, len(codes), batch_size) |
| | all_embeddings: List[torch.Tensor] = [] |
| |
|
| | for start in tqdm(batches, desc=desc): |
| | batch_codes = codes[start:start + batch_size] |
| | |
| | data_list = [builder.build(c) for c in batch_codes] |
| | graph_batch = Batch.from_data_list(data_list) |
| | |
| | tok = tokenizer(batch_codes, padding=True, truncation=True, max_length=512, return_tensors="pt") |
| | |
| | with torch.no_grad(): |
| | fused = self.forward(tok["input_ids"], tok["attention_mask"], graph_batch) |
| | all_embeddings.append(fused.cpu()) |
| |
|
| | embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32") |
| | if save_path is not None: |
| | np.save(save_path, embeddings) |
| | return embeddings |
| |
|
| | def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None: |
| | if not checkpoint_path: |
| | raise ValueError("checkpoint_path must be provided") |
| | state = torch.load(checkpoint_path, map_location=map_location) |
| | if isinstance(state, dict) and "state_dict" in state: |
| | state = state["state_dict"] |
| | self.load_state_dict(state, strict=strict) |
| |
|