Silly98 commited on
Commit
dc71d7e
·
verified ·
1 Parent(s): 3bd8c76

Upload 2 files

Browse files
Files changed (2) hide show
  1. 03_infer_halfedge.py +198 -0
  2. brep_extractor_utils.py +271 -0
03_infer_halfedge.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: 03_infer_halfedge.py
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch_geometric.nn import HeteroConv, SAGEConv, GlobalAttention, JumpingKnowledge, BatchNorm
9
+ from torch_geometric.data import HeteroData
10
+
11
+ from brep_extractor_utils import load_coedge_arrays, make_heterodata
12
+
13
+ class HalfEdgeGNN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ coedge_in: int,
17
+ face_in: int,
18
+ edge_in: int,
19
+ global_in: int,
20
+ hidden=256,
21
+ layers=6,
22
+ dropout=0.2,
23
+ num_classes=3,
24
+ jk_mode="cat",
25
+ gating_dim=None,
26
+ ):
27
+ super().__init__()
28
+ self.convs = nn.ModuleList(); self.bns = nn.ModuleList()
29
+ self.encoders = nn.ModuleDict({
30
+ "coedge": nn.Sequential(nn.Linear(coedge_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
31
+ "face": nn.Sequential(nn.Linear(face_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
32
+ "edge": nn.Sequential(nn.Linear(edge_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
33
+ })
34
+ for _ in range(layers):
35
+ conv = HeteroConv({
36
+ ('coedge','next','coedge'): SAGEConv((hidden,hidden), hidden),
37
+ ('coedge','prev','coedge'): SAGEConv((hidden,hidden), hidden),
38
+ ('coedge','mate','coedge'): SAGEConv((hidden,hidden), hidden),
39
+ ('coedge','to_face','face'): SAGEConv((hidden, hidden), hidden),
40
+ ('face','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden),
41
+ ('coedge','to_edge','edge'): SAGEConv((hidden, hidden), hidden),
42
+ ('edge','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden),
43
+ ('face','to_edge','edge'): SAGEConv((hidden, hidden), hidden),
44
+ ('edge','to_face','face'): SAGEConv((hidden, hidden), hidden),
45
+ }, aggr='sum')
46
+ self.convs.append(conv)
47
+ self.bns.append(nn.ModuleDict({
48
+ "coedge": BatchNorm(hidden),
49
+ "face": BatchNorm(hidden),
50
+ "edge": BatchNorm(hidden),
51
+ }))
52
+ self.jk = JumpingKnowledge(mode=jk_mode)
53
+ self.jk_out = hidden * layers if jk_mode == "cat" else hidden
54
+ if gating_dim is None:
55
+ gating_dim = hidden
56
+ self.gating_dim = gating_dim
57
+ self.gate = nn.Sequential(
58
+ nn.Linear(self.jk_out, self.jk_out//2),
59
+ nn.ReLU(),
60
+ nn.Linear(self.jk_out//2, 1),
61
+ )
62
+ self.pool = GlobalAttention(self.gate)
63
+ self.proj = nn.Identity() if self.jk_out == gating_dim else nn.Linear(self.jk_out, gating_dim)
64
+ self.global_mlp = nn.Sequential(
65
+ nn.Linear(global_in, gating_dim),
66
+ nn.ReLU(),
67
+ nn.Dropout(0.3),
68
+ nn.Linear(gating_dim, 2 * gating_dim),
69
+ )
70
+ self.head = nn.Sequential(
71
+ nn.Linear(gating_dim, hidden),
72
+ nn.ReLU(),
73
+ nn.Dropout(dropout),
74
+ nn.Linear(hidden, num_classes),
75
+ )
76
+
77
+ def forward(self, data: HeteroData):
78
+ x = {
79
+ "coedge": self.encoders["coedge"](data["coedge"].x),
80
+ "face": self.encoders["face"](data["face"].x),
81
+ "edge": self.encoders["edge"](data["edge"].x),
82
+ }
83
+ outs = []
84
+ for conv, bn in zip(self.convs, self.bns):
85
+ x_new = conv(x, data.edge_index_dict)
86
+ x = {k: F.relu(bn[k](x_new[k]) + x[k]) for k in x}
87
+ outs.append(x["coedge"])
88
+ xj = self.jk(outs)
89
+ g = self.pool(xj, data['coedge'].batch)
90
+ g0 = self.proj(g)
91
+ global_x = data["global"].x
92
+ if global_x.dim() == 1:
93
+ global_x = global_x.view(1, -1)
94
+ if global_x.size(0) != g0.size(0):
95
+ raise RuntimeError(
96
+ f"Global feature batch mismatch: {global_x.size(0)} vs {g0.size(0)}"
97
+ )
98
+ gb = self.global_mlp(global_x)
99
+ gamma, beta = gb.chunk(2, dim=-1)
100
+ gamma = torch.sigmoid(gamma)
101
+ g_mod = g0 * gamma + beta
102
+ return self.head(g_mod)
103
+
104
+ def main():
105
+ ap = argparse.ArgumentParser()
106
+ ap.add_argument("--model", required=True)
107
+ ap.add_argument("--npz", required=True, help="Path to a processed BRep extractor npz file")
108
+ ap.add_argument("--tau", type=float, default=0.0, help="Reject threshold; below this outputs random")
109
+ ap.add_argument("--min_conf", type=float, default=0.85, help="Hard minimum confidence for known classes")
110
+ ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
111
+ args = ap.parse_args()
112
+
113
+ try:
114
+ ckpt = torch.load(args.model, map_location="cpu", weights_only=False)
115
+ except TypeError:
116
+ ckpt = torch.load(args.model, map_location="cpu")
117
+ if "global_in" not in ckpt or "gating_dim" not in ckpt:
118
+ raise RuntimeError(
119
+ "Checkpoint missing gating metadata. Please retrain with global gating enabled."
120
+ )
121
+ labels = ckpt["labels"]; inv_labels = {v:k for k,v in labels.items()}
122
+ random_id = labels.get("random")
123
+ if (args.tau > 0 or args.min_conf > 0) and random_id is None:
124
+ raise RuntimeError("Model labels do not include 'random'; retrain a 4-class model.")
125
+ stats = ckpt["stats"]
126
+ if not all(k in stats for k in ("coedge", "face", "edge")):
127
+ raise RuntimeError("Checkpoint missing heterograph stats; retrain required.")
128
+
129
+ coedge_in = ckpt.get("coedge_in", ckpt.get("node_in"))
130
+ face_in = ckpt.get("face_in")
131
+ edge_in = ckpt.get("edge_in")
132
+ if coedge_in is None or face_in is None or edge_in is None:
133
+ raise RuntimeError("Checkpoint missing heterograph input dims; retrain required.")
134
+
135
+ graph_data = load_coedge_arrays(Path(args.npz))
136
+ if int(graph_data["coedge_x"].shape[1]) != int(coedge_in):
137
+ raise RuntimeError(
138
+ f"Coedge feature dim mismatch: npz={int(graph_data['coedge_x'].shape[1])} "
139
+ f"ckpt={int(coedge_in)}"
140
+ )
141
+ if int(graph_data["face_x"].shape[1]) != int(face_in):
142
+ raise RuntimeError(
143
+ f"Face feature dim mismatch: npz={int(graph_data['face_x'].shape[1])} "
144
+ f"ckpt={int(face_in)}"
145
+ )
146
+ if int(graph_data["edge_x"].shape[1]) != int(edge_in):
147
+ raise RuntimeError(
148
+ f"Edge feature dim mismatch: npz={int(graph_data['edge_x'].shape[1])} "
149
+ f"ckpt={int(edge_in)}"
150
+ )
151
+ if int(graph_data["global_x"].shape[0]) != int(ckpt["global_in"]):
152
+ raise RuntimeError(
153
+ f"Global feature dim mismatch: npz={int(graph_data['global_x'].shape[0])} "
154
+ f"ckpt={int(ckpt['global_in'])}"
155
+ )
156
+ data = make_heterodata(
157
+ graph_data["coedge_x"],
158
+ graph_data["face_x"],
159
+ graph_data["edge_x"],
160
+ graph_data["next"],
161
+ graph_data["mate"],
162
+ graph_data["coedge_face"],
163
+ graph_data["coedge_edge"],
164
+ graph_data["global_x"],
165
+ label=None,
166
+ norm_stats=stats,
167
+ )
168
+ data['coedge'].batch = torch.zeros(data['coedge'].x.size(0), dtype=torch.long)
169
+ data["global"].batch = torch.zeros(1, dtype=torch.long)
170
+ data["face"].batch = torch.zeros(data["face"].x.size(0), dtype=torch.long)
171
+ data["edge"].batch = torch.zeros(data["edge"].x.size(0), dtype=torch.long)
172
+
173
+ global_in = ckpt["global_in"]
174
+ gating_dim = ckpt["gating_dim"]
175
+ model = HalfEdgeGNN(coedge_in=coedge_in, face_in=face_in, edge_in=edge_in, global_in=global_in,
176
+ hidden=ckpt["hp"]["hidden"],
177
+ layers=ckpt["hp"]["layers"], dropout=ckpt["hp"]["dropout"],
178
+ num_classes=len(labels), gating_dim=gating_dim).to(args.device)
179
+ model.load_state_dict(ckpt["state_dict"]); model.eval()
180
+
181
+ with torch.no_grad():
182
+ logits = model(data.to(args.device))
183
+ probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
184
+ pred = int(probs.argmax())
185
+ conf = float(probs[pred])
186
+ arg_label = inv_labels[pred]
187
+ effective_tau = max(args.tau, args.min_conf)
188
+ if conf < effective_tau and random_id is not None:
189
+ final_label = "random"
190
+ else:
191
+ final_label = arg_label
192
+ print(f"Argmax: {arg_label} (conf={conf:.4f})")
193
+ print(f"Predicted: {final_label} (tau={effective_tau:.2f})")
194
+ for i, p in enumerate(probs):
195
+ print(f"{inv_labels[i]:>6s}: {p:.4f}")
196
+
197
+ if __name__ == "__main__":
198
+ main()
brep_extractor_utils.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility helpers for loading BRep extractor-processed STEP data as PyG graphs.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from pathlib import Path
7
+ from typing import Dict, Iterable, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch_geometric.data import HeteroData
12
+
13
+ # Label mapping for the current project
14
+ LABELS: Dict[str, int] = {"pipe": 0, "elbow": 1, "tjoint": 2, "random": 3}
15
+ STEP_EXTS = ("*.step", "*.stp", "*.STEP", "*.STP")
16
+
17
+
18
+ def build_label_map(step_root: Path) -> Dict[str, int]:
19
+ """
20
+ Scan the STEP directory tree (containing /pipe, /elbow, /tjoint, ...)
21
+ and build a mapping from file stem to integer label.
22
+ """
23
+ mapping: Dict[str, int] = {}
24
+ for cls, label in LABELS.items():
25
+ cls_dir = step_root / cls
26
+ if not cls_dir.exists():
27
+ continue
28
+ for ext in STEP_EXTS:
29
+ for file in cls_dir.glob(ext):
30
+ mapping[file.stem] = label
31
+ if not mapping:
32
+ raise RuntimeError(f"No STEP files found under {step_root} for any of {tuple(LABELS)}")
33
+ return mapping
34
+
35
+
36
+ def _flatten(arr: np.ndarray) -> np.ndarray:
37
+ return np.asarray(arr, dtype=np.float32).reshape(arr.shape[0], -1)
38
+
39
+ def _face_grid_stats(face_grids: np.ndarray) -> np.ndarray:
40
+ """
41
+ Summarize face point grids into compact stats per face.
42
+ Returns [F, 10]: xyz_mean (3), xyz_std (3), nrm_mean (3), mask_frac (1).
43
+ """
44
+ face_grids = np.asarray(face_grids, dtype=np.float32)
45
+ f = face_grids.shape[0]
46
+ xyz = face_grids[:, 0:3, :, :].reshape(f, 3, -1)
47
+ nrm = face_grids[:, 3:6, :, :].reshape(f, 3, -1)
48
+ msk = face_grids[:, 6, :, :].reshape(f, -1)
49
+
50
+ mask = (msk > 0.5).astype(np.float32)
51
+ mask_frac = mask.mean(axis=1, keepdims=True)
52
+ w = mask / (mask.sum(axis=1, keepdims=True) + 1e-6)
53
+
54
+ xyz_mean = (xyz * w[:, None, :]).sum(axis=2)
55
+ xyz_var = (w[:, None, :] * (xyz - xyz_mean[:, :, None]) ** 2).sum(axis=2)
56
+ xyz_std = np.sqrt(np.maximum(xyz_var, 1e-12))
57
+ nrm_mean = (nrm * w[:, None, :]).sum(axis=2)
58
+ return np.concatenate([xyz_mean, xyz_std, nrm_mean, mask_frac], axis=1)
59
+
60
+ def compute_global_geom_features(data) -> np.ndarray:
61
+ """
62
+ Compute compact global geometry descriptors from face/coedge point samples.
63
+ Returns [5] float32: pca_ev_ratio_1/2/3, line_fit_rmse, plane_fit_rmse.
64
+ """
65
+ points = []
66
+ face_grids = np.asarray(data["face_point_grids"], dtype=np.float32)
67
+ if face_grids.size:
68
+ xyz = face_grids[:, 0:3, :, :].transpose(0, 2, 3, 1).reshape(-1, 3)
69
+ mask = face_grids[:, 6, :, :].reshape(-1) > 0.5
70
+ if mask.any():
71
+ points.append(xyz[mask])
72
+
73
+ coedge_grids = np.asarray(data["coedge_point_grids"], dtype=np.float32)
74
+ if coedge_grids.size:
75
+ co_xyz = coedge_grids[:, 0:3, :].transpose(0, 2, 1).reshape(-1, 3)
76
+ points.append(co_xyz)
77
+
78
+ if not points:
79
+ return np.zeros(5, dtype=np.float32)
80
+
81
+ pts = np.concatenate(points, axis=0)
82
+ if pts.shape[0] < 3:
83
+ return np.zeros(5, dtype=np.float32)
84
+ pts = pts[np.isfinite(pts).all(axis=1)]
85
+ if pts.shape[0] < 3:
86
+ return np.zeros(5, dtype=np.float32)
87
+
88
+ mean = pts.mean(axis=0, keepdims=True)
89
+ centered = pts - mean
90
+ scale = np.sqrt(np.mean(np.sum(centered ** 2, axis=1)))
91
+ centered = centered / (scale + 1e-6)
92
+ cov = (centered.T @ centered) / max(1, centered.shape[0])
93
+ if not np.isfinite(cov).all():
94
+ return np.zeros(5, dtype=np.float32)
95
+
96
+ ev = np.linalg.eigvalsh(cov)
97
+ ev = np.sort(ev)[::-1]
98
+ ev = np.maximum(ev, 0.0)
99
+ total = ev.sum()
100
+ if not np.isfinite(total) or total <= 0.0:
101
+ return np.zeros(5, dtype=np.float32)
102
+
103
+ ratios = ev / total
104
+ line_rmse = np.sqrt(max(ev[1] + ev[2], 0.0))
105
+ plane_rmse = np.sqrt(max(ev[2], 0.0))
106
+ feats = np.array(
107
+ [ratios[0], ratios[1], ratios[2], line_rmse, plane_rmse],
108
+ dtype=np.float32,
109
+ )
110
+ if not np.isfinite(feats).all():
111
+ return np.zeros(5, dtype=np.float32)
112
+ return feats
113
+
114
+ def load_coedge_arrays(npz_path: Path) -> Dict[str, np.ndarray]:
115
+ """
116
+ Load node features and adjacency indices from a BRep extractor npz.
117
+ Returns a dict with coedge/face/edge/global features and topology arrays.
118
+ """
119
+ with np.load(npz_path) as data:
120
+ coedge_feats = _flatten(data["coedge_features"])
121
+ scale = np.asarray(data["coedge_scale_factors"], dtype=np.float32)[:, None]
122
+ reverse = np.asarray(data["coedge_reverse_flags"], dtype=np.float32)[:, None]
123
+ point_grids = _flatten(data["coedge_point_grids"]) # [N, 12*U]
124
+ lcs = _flatten(data["coedge_lcs"]) # [N, 16]
125
+
126
+ face_idx = np.asarray(data["face"], dtype=np.int64)
127
+ edge_idx = np.asarray(data["edge"], dtype=np.int64)
128
+ face_feats = np.asarray(data["face_features"], dtype=np.float32) # [F, 7]
129
+ edge_feats = np.asarray(data["edge_features"], dtype=np.float32) # [E, 10]
130
+
131
+ face_grid_stats = _face_grid_stats(data["face_point_grids"])
132
+
133
+ coedge_x = np.concatenate(
134
+ [coedge_feats, scale, reverse, point_grids, lcs], axis=1
135
+ )
136
+ face_x = np.concatenate([face_feats, face_grid_stats], axis=1)
137
+ edge_x = edge_feats
138
+ next_index = np.asarray(data["next"], dtype=np.int64)
139
+ mate_index = np.asarray(data["mate"], dtype=np.int64)
140
+ global_features = compute_global_geom_features(data)
141
+
142
+ return {
143
+ "coedge_x": coedge_x,
144
+ "face_x": face_x,
145
+ "edge_x": edge_x,
146
+ "next": next_index,
147
+ "mate": mate_index,
148
+ "coedge_face": face_idx,
149
+ "coedge_edge": edge_idx,
150
+ "global_x": global_features,
151
+ }
152
+
153
+
154
+ def make_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
155
+ """
156
+ Build a 2 x E tensor of edge indices (with both directions, deduplicated).
157
+ """
158
+ pairs = np.stack([source, target], axis=1)
159
+ flipped = pairs[:, ::-1]
160
+ all_pairs = np.concatenate([pairs, flipped], axis=0)
161
+ all_pairs = np.unique(all_pairs, axis=0)
162
+ return torch.tensor(all_pairs.T, dtype=torch.long)
163
+
164
+ def make_directed_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
165
+ """
166
+ Build a 2 x E tensor of directed edge indices (no deduplication).
167
+ """
168
+ return torch.tensor(np.stack([source, target], axis=0), dtype=torch.long)
169
+
170
+ def make_bipartite_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
171
+ """
172
+ Build a 2 x E tensor of directed bipartite edge indices (deduplicated).
173
+ """
174
+ pairs = np.stack([source, target], axis=1)
175
+ pairs = np.unique(pairs, axis=0)
176
+ return torch.tensor(pairs.T, dtype=torch.long)
177
+
178
+ def make_heterodata(
179
+ coedge_x: np.ndarray,
180
+ face_x: np.ndarray,
181
+ edge_x: np.ndarray,
182
+ next_index: np.ndarray,
183
+ mate_index: np.ndarray,
184
+ coedge_face: np.ndarray,
185
+ coedge_edge: np.ndarray,
186
+ global_features: np.ndarray,
187
+ label: int | None,
188
+ norm_stats: Dict[str, Dict[str, np.ndarray | torch.Tensor]] | None = None,
189
+ ) -> HeteroData:
190
+ """
191
+ Create a PyG HeteroData graph for the coedge features/relations.
192
+ When mean/std are provided the features are normalised element-wise.
193
+ """
194
+ def _normalize(x_arr: np.ndarray, stats: Dict[str, np.ndarray | torch.Tensor] | None) -> torch.Tensor:
195
+ x_t = torch.tensor(x_arr, dtype=torch.float32)
196
+ if stats is None:
197
+ return x_t
198
+ mean = stats.get("mean")
199
+ std = stats.get("std")
200
+ if mean is None or std is None:
201
+ return x_t
202
+ mean_t = torch.as_tensor(mean, dtype=torch.float32)
203
+ std_t = torch.as_tensor(std, dtype=torch.float32)
204
+ return (x_t - mean_t) / std_t
205
+
206
+ coedge_stats = norm_stats.get("coedge") if norm_stats else None
207
+ face_stats = norm_stats.get("face") if norm_stats else None
208
+ edge_stats = norm_stats.get("edge") if norm_stats else None
209
+
210
+ x_coedge = _normalize(coedge_x, coedge_stats)
211
+ x_face = _normalize(face_x, face_stats)
212
+ x_edge = _normalize(edge_x, edge_stats)
213
+
214
+ idx = np.arange(coedge_x.shape[0], dtype=np.int64)
215
+ edge_next = make_directed_edge_index(idx, next_index)
216
+ edge_prev = make_directed_edge_index(next_index, idx)
217
+ edge_mate = make_edge_index(idx, mate_index)
218
+ edge_coedge_face = make_directed_edge_index(idx, coedge_face)
219
+ edge_face_coedge = make_directed_edge_index(coedge_face, idx)
220
+ edge_coedge_edge = make_directed_edge_index(idx, coedge_edge)
221
+ edge_edge_coedge = make_directed_edge_index(coedge_edge, idx)
222
+ edge_face_edge = make_bipartite_edge_index(coedge_face, coedge_edge)
223
+ edge_edge_face = make_bipartite_edge_index(coedge_edge, coedge_face)
224
+
225
+ data = HeteroData()
226
+ data["coedge"].x = x_coedge
227
+ data["face"].x = x_face
228
+ data["edge"].x = x_edge
229
+ data["global"].x = torch.tensor(global_features, dtype=torch.float32).view(1, -1)
230
+ data["coedge", "next", "coedge"].edge_index = edge_next
231
+ data["coedge", "prev", "coedge"].edge_index = edge_prev
232
+ data["coedge", "mate", "coedge"].edge_index = edge_mate
233
+ data["coedge", "to_face", "face"].edge_index = edge_coedge_face
234
+ data["face", "to_coedge", "coedge"].edge_index = edge_face_coedge
235
+ data["coedge", "to_edge", "edge"].edge_index = edge_coedge_edge
236
+ data["edge", "to_coedge", "coedge"].edge_index = edge_edge_coedge
237
+ data["face", "to_edge", "edge"].edge_index = edge_face_edge
238
+ data["edge", "to_face", "face"].edge_index = edge_edge_face
239
+ if label is not None:
240
+ data.y = torch.tensor([int(label)], dtype=torch.long)
241
+ return data
242
+
243
+
244
+ def compute_feature_stats(npz_paths: Iterable[Path]) -> Dict[str, np.ndarray]:
245
+ """
246
+ Compute mean and std (per feature dimension) across all node features in the dataset.
247
+ """
248
+ totals = {"coedge": 0, "face": 0, "edge": 0}
249
+ sum_vec: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}
250
+ sum_sq: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}
251
+
252
+ for path in npz_paths:
253
+ graph = load_coedge_arrays(path)
254
+ for key, x in (("coedge", graph["coedge_x"]), ("face", graph["face_x"]), ("edge", graph["edge_x"])):
255
+ if sum_vec[key] is None:
256
+ sum_vec[key] = np.zeros(x.shape[1], dtype=np.float64)
257
+ sum_sq[key] = np.zeros(x.shape[1], dtype=np.float64)
258
+ sum_vec[key] += x.sum(axis=0)
259
+ sum_sq[key] += (x * x).sum(axis=0)
260
+ totals[key] += x.shape[0]
261
+
262
+ out = {}
263
+ for key in ("coedge", "face", "edge"):
264
+ if sum_vec[key] is None or totals[key] == 0:
265
+ raise RuntimeError(f"Cannot compute feature stats: no {key} features observed.")
266
+ mean = sum_vec[key] / totals[key]
267
+ var = sum_sq[key] / totals[key] - mean * mean
268
+ var = np.maximum(var, 1e-12)
269
+ std = np.sqrt(var)
270
+ out[key] = {"mean": mean.astype(np.float32), "std": std.astype(np.float32)}
271
+ return out