| | import torch |
| | import torch.nn as nn |
| | from typing import List, Optional, Tuple, Any |
| | from tqdm import tqdm |
| | import numpy as np |
| | import pandas as pd |
| | from torch_geometric.data import Batch |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | from dataloader import CodeGraphBuilder |
| |
|
| | from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion |
| |
|
| | class StructuralEncoderOnlyGraph(nn.Module): |
| | """ |
| | Ablation variant 1: Pure Structural Encoder. |
| | Removes GraphCodeBERT and uses only the graph path (R-GNN). |
| | """ |
| |
|
| | def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768): |
| | super().__init__() |
| | self.device = torch.device(device) |
| | |
| | self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers) |
| | self.graph_encoder.to(self.device) |
| |
|
| | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor: |
| | |
| | return self.graph_encoder(graph_batch) |
| | |
| | def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray: |
| | builder = CodeGraphBuilder() |
| | codes = df["code"].tolist() |
| | batches = range(0, len(codes), batch_size) |
| | all_embeddings: List[torch.Tensor] = [] |
| |
|
| | for start in tqdm(batches, desc=desc): |
| | batch_codes = codes[start:start + batch_size] |
| | |
| | data_list = [builder.build(c) for c in batch_codes] |
| | graph_batch = Batch.from_data_list(data_list) |
| | |
| | |
| | dummy_ids = torch.zeros((1,1), device=self.device) |
| | dummy_mask = torch.zeros((1,1), device=self.device) |
| |
|
| | with torch.no_grad(): |
| | out = self.forward(dummy_ids, dummy_mask, graph_batch) |
| | all_embeddings.append(out.cpu()) |
| |
|
| | embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32") |
| | if save_path is not None: |
| | np.save(save_path, embeddings) |
| | return embeddings |
| |
|
| | def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None: |
| | if not checkpoint_path: |
| | raise ValueError("checkpoint_path must be provided") |
| | state = torch.load(checkpoint_path, map_location=map_location) |
| | if isinstance(state, dict) and "state_dict" in state: |
| | state = state["state_dict"] |
| | self.load_state_dict(state, strict=strict) |
| |
|
| |
|
| | class StructuralEncoderConcat(StructuralEncoderV2): |
| | """ |
| | Ablation variant 2: Concatenation Fusion. |
| | Keeps both text and graph paths but fuses them via simple concatenation + projection |
| | instead of Gated Fusion. |
| | """ |
| |
|
| | def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2): |
| | super().__init__(device, graph_hidden_dim, graph_layers) |
| | |
| | text_dim = self.text_model.config.hidden_size |
| | graph_dim = self.text_model.config.hidden_size |
| | |
| | self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim) |
| | self.concat_proj.to(self.device) |
| | |
| | del self.fusion |
| |
|
| | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor: |
| | text_embeddings = self.encode_text(input_ids, attention_mask) |
| | graph_embeddings = self.graph_encoder(graph_batch) |
| | |
| | combined = torch.cat([text_embeddings, graph_embeddings], dim=-1) |
| | return self.concat_proj(combined) |
| |
|