Upload 2 files
Browse files- 03_infer_halfedge.py +198 -0
- brep_extractor_utils.py +271 -0
03_infer_halfedge.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file: 03_infer_halfedge.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch_geometric.nn import HeteroConv, SAGEConv, GlobalAttention, JumpingKnowledge, BatchNorm
|
| 9 |
+
from torch_geometric.data import HeteroData
|
| 10 |
+
|
| 11 |
+
from brep_extractor_utils import load_coedge_arrays, make_heterodata
|
| 12 |
+
|
| 13 |
+
class HalfEdgeGNN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
coedge_in: int,
|
| 17 |
+
face_in: int,
|
| 18 |
+
edge_in: int,
|
| 19 |
+
global_in: int,
|
| 20 |
+
hidden=256,
|
| 21 |
+
layers=6,
|
| 22 |
+
dropout=0.2,
|
| 23 |
+
num_classes=3,
|
| 24 |
+
jk_mode="cat",
|
| 25 |
+
gating_dim=None,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.convs = nn.ModuleList(); self.bns = nn.ModuleList()
|
| 29 |
+
self.encoders = nn.ModuleDict({
|
| 30 |
+
"coedge": nn.Sequential(nn.Linear(coedge_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
|
| 31 |
+
"face": nn.Sequential(nn.Linear(face_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
|
| 32 |
+
"edge": nn.Sequential(nn.Linear(edge_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
|
| 33 |
+
})
|
| 34 |
+
for _ in range(layers):
|
| 35 |
+
conv = HeteroConv({
|
| 36 |
+
('coedge','next','coedge'): SAGEConv((hidden,hidden), hidden),
|
| 37 |
+
('coedge','prev','coedge'): SAGEConv((hidden,hidden), hidden),
|
| 38 |
+
('coedge','mate','coedge'): SAGEConv((hidden,hidden), hidden),
|
| 39 |
+
('coedge','to_face','face'): SAGEConv((hidden, hidden), hidden),
|
| 40 |
+
('face','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden),
|
| 41 |
+
('coedge','to_edge','edge'): SAGEConv((hidden, hidden), hidden),
|
| 42 |
+
('edge','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden),
|
| 43 |
+
('face','to_edge','edge'): SAGEConv((hidden, hidden), hidden),
|
| 44 |
+
('edge','to_face','face'): SAGEConv((hidden, hidden), hidden),
|
| 45 |
+
}, aggr='sum')
|
| 46 |
+
self.convs.append(conv)
|
| 47 |
+
self.bns.append(nn.ModuleDict({
|
| 48 |
+
"coedge": BatchNorm(hidden),
|
| 49 |
+
"face": BatchNorm(hidden),
|
| 50 |
+
"edge": BatchNorm(hidden),
|
| 51 |
+
}))
|
| 52 |
+
self.jk = JumpingKnowledge(mode=jk_mode)
|
| 53 |
+
self.jk_out = hidden * layers if jk_mode == "cat" else hidden
|
| 54 |
+
if gating_dim is None:
|
| 55 |
+
gating_dim = hidden
|
| 56 |
+
self.gating_dim = gating_dim
|
| 57 |
+
self.gate = nn.Sequential(
|
| 58 |
+
nn.Linear(self.jk_out, self.jk_out//2),
|
| 59 |
+
nn.ReLU(),
|
| 60 |
+
nn.Linear(self.jk_out//2, 1),
|
| 61 |
+
)
|
| 62 |
+
self.pool = GlobalAttention(self.gate)
|
| 63 |
+
self.proj = nn.Identity() if self.jk_out == gating_dim else nn.Linear(self.jk_out, gating_dim)
|
| 64 |
+
self.global_mlp = nn.Sequential(
|
| 65 |
+
nn.Linear(global_in, gating_dim),
|
| 66 |
+
nn.ReLU(),
|
| 67 |
+
nn.Dropout(0.3),
|
| 68 |
+
nn.Linear(gating_dim, 2 * gating_dim),
|
| 69 |
+
)
|
| 70 |
+
self.head = nn.Sequential(
|
| 71 |
+
nn.Linear(gating_dim, hidden),
|
| 72 |
+
nn.ReLU(),
|
| 73 |
+
nn.Dropout(dropout),
|
| 74 |
+
nn.Linear(hidden, num_classes),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def forward(self, data: HeteroData):
|
| 78 |
+
x = {
|
| 79 |
+
"coedge": self.encoders["coedge"](data["coedge"].x),
|
| 80 |
+
"face": self.encoders["face"](data["face"].x),
|
| 81 |
+
"edge": self.encoders["edge"](data["edge"].x),
|
| 82 |
+
}
|
| 83 |
+
outs = []
|
| 84 |
+
for conv, bn in zip(self.convs, self.bns):
|
| 85 |
+
x_new = conv(x, data.edge_index_dict)
|
| 86 |
+
x = {k: F.relu(bn[k](x_new[k]) + x[k]) for k in x}
|
| 87 |
+
outs.append(x["coedge"])
|
| 88 |
+
xj = self.jk(outs)
|
| 89 |
+
g = self.pool(xj, data['coedge'].batch)
|
| 90 |
+
g0 = self.proj(g)
|
| 91 |
+
global_x = data["global"].x
|
| 92 |
+
if global_x.dim() == 1:
|
| 93 |
+
global_x = global_x.view(1, -1)
|
| 94 |
+
if global_x.size(0) != g0.size(0):
|
| 95 |
+
raise RuntimeError(
|
| 96 |
+
f"Global feature batch mismatch: {global_x.size(0)} vs {g0.size(0)}"
|
| 97 |
+
)
|
| 98 |
+
gb = self.global_mlp(global_x)
|
| 99 |
+
gamma, beta = gb.chunk(2, dim=-1)
|
| 100 |
+
gamma = torch.sigmoid(gamma)
|
| 101 |
+
g_mod = g0 * gamma + beta
|
| 102 |
+
return self.head(g_mod)
|
| 103 |
+
|
| 104 |
+
def main():
|
| 105 |
+
ap = argparse.ArgumentParser()
|
| 106 |
+
ap.add_argument("--model", required=True)
|
| 107 |
+
ap.add_argument("--npz", required=True, help="Path to a processed BRep extractor npz file")
|
| 108 |
+
ap.add_argument("--tau", type=float, default=0.0, help="Reject threshold; below this outputs random")
|
| 109 |
+
ap.add_argument("--min_conf", type=float, default=0.85, help="Hard minimum confidence for known classes")
|
| 110 |
+
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 111 |
+
args = ap.parse_args()
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
ckpt = torch.load(args.model, map_location="cpu", weights_only=False)
|
| 115 |
+
except TypeError:
|
| 116 |
+
ckpt = torch.load(args.model, map_location="cpu")
|
| 117 |
+
if "global_in" not in ckpt or "gating_dim" not in ckpt:
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
"Checkpoint missing gating metadata. Please retrain with global gating enabled."
|
| 120 |
+
)
|
| 121 |
+
labels = ckpt["labels"]; inv_labels = {v:k for k,v in labels.items()}
|
| 122 |
+
random_id = labels.get("random")
|
| 123 |
+
if (args.tau > 0 or args.min_conf > 0) and random_id is None:
|
| 124 |
+
raise RuntimeError("Model labels do not include 'random'; retrain a 4-class model.")
|
| 125 |
+
stats = ckpt["stats"]
|
| 126 |
+
if not all(k in stats for k in ("coedge", "face", "edge")):
|
| 127 |
+
raise RuntimeError("Checkpoint missing heterograph stats; retrain required.")
|
| 128 |
+
|
| 129 |
+
coedge_in = ckpt.get("coedge_in", ckpt.get("node_in"))
|
| 130 |
+
face_in = ckpt.get("face_in")
|
| 131 |
+
edge_in = ckpt.get("edge_in")
|
| 132 |
+
if coedge_in is None or face_in is None or edge_in is None:
|
| 133 |
+
raise RuntimeError("Checkpoint missing heterograph input dims; retrain required.")
|
| 134 |
+
|
| 135 |
+
graph_data = load_coedge_arrays(Path(args.npz))
|
| 136 |
+
if int(graph_data["coedge_x"].shape[1]) != int(coedge_in):
|
| 137 |
+
raise RuntimeError(
|
| 138 |
+
f"Coedge feature dim mismatch: npz={int(graph_data['coedge_x'].shape[1])} "
|
| 139 |
+
f"ckpt={int(coedge_in)}"
|
| 140 |
+
)
|
| 141 |
+
if int(graph_data["face_x"].shape[1]) != int(face_in):
|
| 142 |
+
raise RuntimeError(
|
| 143 |
+
f"Face feature dim mismatch: npz={int(graph_data['face_x'].shape[1])} "
|
| 144 |
+
f"ckpt={int(face_in)}"
|
| 145 |
+
)
|
| 146 |
+
if int(graph_data["edge_x"].shape[1]) != int(edge_in):
|
| 147 |
+
raise RuntimeError(
|
| 148 |
+
f"Edge feature dim mismatch: npz={int(graph_data['edge_x'].shape[1])} "
|
| 149 |
+
f"ckpt={int(edge_in)}"
|
| 150 |
+
)
|
| 151 |
+
if int(graph_data["global_x"].shape[0]) != int(ckpt["global_in"]):
|
| 152 |
+
raise RuntimeError(
|
| 153 |
+
f"Global feature dim mismatch: npz={int(graph_data['global_x'].shape[0])} "
|
| 154 |
+
f"ckpt={int(ckpt['global_in'])}"
|
| 155 |
+
)
|
| 156 |
+
data = make_heterodata(
|
| 157 |
+
graph_data["coedge_x"],
|
| 158 |
+
graph_data["face_x"],
|
| 159 |
+
graph_data["edge_x"],
|
| 160 |
+
graph_data["next"],
|
| 161 |
+
graph_data["mate"],
|
| 162 |
+
graph_data["coedge_face"],
|
| 163 |
+
graph_data["coedge_edge"],
|
| 164 |
+
graph_data["global_x"],
|
| 165 |
+
label=None,
|
| 166 |
+
norm_stats=stats,
|
| 167 |
+
)
|
| 168 |
+
data['coedge'].batch = torch.zeros(data['coedge'].x.size(0), dtype=torch.long)
|
| 169 |
+
data["global"].batch = torch.zeros(1, dtype=torch.long)
|
| 170 |
+
data["face"].batch = torch.zeros(data["face"].x.size(0), dtype=torch.long)
|
| 171 |
+
data["edge"].batch = torch.zeros(data["edge"].x.size(0), dtype=torch.long)
|
| 172 |
+
|
| 173 |
+
global_in = ckpt["global_in"]
|
| 174 |
+
gating_dim = ckpt["gating_dim"]
|
| 175 |
+
model = HalfEdgeGNN(coedge_in=coedge_in, face_in=face_in, edge_in=edge_in, global_in=global_in,
|
| 176 |
+
hidden=ckpt["hp"]["hidden"],
|
| 177 |
+
layers=ckpt["hp"]["layers"], dropout=ckpt["hp"]["dropout"],
|
| 178 |
+
num_classes=len(labels), gating_dim=gating_dim).to(args.device)
|
| 179 |
+
model.load_state_dict(ckpt["state_dict"]); model.eval()
|
| 180 |
+
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
logits = model(data.to(args.device))
|
| 183 |
+
probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
|
| 184 |
+
pred = int(probs.argmax())
|
| 185 |
+
conf = float(probs[pred])
|
| 186 |
+
arg_label = inv_labels[pred]
|
| 187 |
+
effective_tau = max(args.tau, args.min_conf)
|
| 188 |
+
if conf < effective_tau and random_id is not None:
|
| 189 |
+
final_label = "random"
|
| 190 |
+
else:
|
| 191 |
+
final_label = arg_label
|
| 192 |
+
print(f"Argmax: {arg_label} (conf={conf:.4f})")
|
| 193 |
+
print(f"Predicted: {final_label} (tau={effective_tau:.2f})")
|
| 194 |
+
for i, p in enumerate(probs):
|
| 195 |
+
print(f"{inv_labels[i]:>6s}: {p:.4f}")
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
brep_extractor_utils.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility helpers for loading BRep extractor-processed STEP data as PyG graphs.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Iterable, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch_geometric.data import HeteroData
|
| 12 |
+
|
| 13 |
+
# Label mapping for the current project
|
| 14 |
+
LABELS: Dict[str, int] = {"pipe": 0, "elbow": 1, "tjoint": 2, "random": 3}
|
| 15 |
+
STEP_EXTS = ("*.step", "*.stp", "*.STEP", "*.STP")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_label_map(step_root: Path) -> Dict[str, int]:
|
| 19 |
+
"""
|
| 20 |
+
Scan the STEP directory tree (containing /pipe, /elbow, /tjoint, ...)
|
| 21 |
+
and build a mapping from file stem to integer label.
|
| 22 |
+
"""
|
| 23 |
+
mapping: Dict[str, int] = {}
|
| 24 |
+
for cls, label in LABELS.items():
|
| 25 |
+
cls_dir = step_root / cls
|
| 26 |
+
if not cls_dir.exists():
|
| 27 |
+
continue
|
| 28 |
+
for ext in STEP_EXTS:
|
| 29 |
+
for file in cls_dir.glob(ext):
|
| 30 |
+
mapping[file.stem] = label
|
| 31 |
+
if not mapping:
|
| 32 |
+
raise RuntimeError(f"No STEP files found under {step_root} for any of {tuple(LABELS)}")
|
| 33 |
+
return mapping
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _flatten(arr: np.ndarray) -> np.ndarray:
|
| 37 |
+
return np.asarray(arr, dtype=np.float32).reshape(arr.shape[0], -1)
|
| 38 |
+
|
| 39 |
+
def _face_grid_stats(face_grids: np.ndarray) -> np.ndarray:
|
| 40 |
+
"""
|
| 41 |
+
Summarize face point grids into compact stats per face.
|
| 42 |
+
Returns [F, 10]: xyz_mean (3), xyz_std (3), nrm_mean (3), mask_frac (1).
|
| 43 |
+
"""
|
| 44 |
+
face_grids = np.asarray(face_grids, dtype=np.float32)
|
| 45 |
+
f = face_grids.shape[0]
|
| 46 |
+
xyz = face_grids[:, 0:3, :, :].reshape(f, 3, -1)
|
| 47 |
+
nrm = face_grids[:, 3:6, :, :].reshape(f, 3, -1)
|
| 48 |
+
msk = face_grids[:, 6, :, :].reshape(f, -1)
|
| 49 |
+
|
| 50 |
+
mask = (msk > 0.5).astype(np.float32)
|
| 51 |
+
mask_frac = mask.mean(axis=1, keepdims=True)
|
| 52 |
+
w = mask / (mask.sum(axis=1, keepdims=True) + 1e-6)
|
| 53 |
+
|
| 54 |
+
xyz_mean = (xyz * w[:, None, :]).sum(axis=2)
|
| 55 |
+
xyz_var = (w[:, None, :] * (xyz - xyz_mean[:, :, None]) ** 2).sum(axis=2)
|
| 56 |
+
xyz_std = np.sqrt(np.maximum(xyz_var, 1e-12))
|
| 57 |
+
nrm_mean = (nrm * w[:, None, :]).sum(axis=2)
|
| 58 |
+
return np.concatenate([xyz_mean, xyz_std, nrm_mean, mask_frac], axis=1)
|
| 59 |
+
|
| 60 |
+
def compute_global_geom_features(data) -> np.ndarray:
|
| 61 |
+
"""
|
| 62 |
+
Compute compact global geometry descriptors from face/coedge point samples.
|
| 63 |
+
Returns [5] float32: pca_ev_ratio_1/2/3, line_fit_rmse, plane_fit_rmse.
|
| 64 |
+
"""
|
| 65 |
+
points = []
|
| 66 |
+
face_grids = np.asarray(data["face_point_grids"], dtype=np.float32)
|
| 67 |
+
if face_grids.size:
|
| 68 |
+
xyz = face_grids[:, 0:3, :, :].transpose(0, 2, 3, 1).reshape(-1, 3)
|
| 69 |
+
mask = face_grids[:, 6, :, :].reshape(-1) > 0.5
|
| 70 |
+
if mask.any():
|
| 71 |
+
points.append(xyz[mask])
|
| 72 |
+
|
| 73 |
+
coedge_grids = np.asarray(data["coedge_point_grids"], dtype=np.float32)
|
| 74 |
+
if coedge_grids.size:
|
| 75 |
+
co_xyz = coedge_grids[:, 0:3, :].transpose(0, 2, 1).reshape(-1, 3)
|
| 76 |
+
points.append(co_xyz)
|
| 77 |
+
|
| 78 |
+
if not points:
|
| 79 |
+
return np.zeros(5, dtype=np.float32)
|
| 80 |
+
|
| 81 |
+
pts = np.concatenate(points, axis=0)
|
| 82 |
+
if pts.shape[0] < 3:
|
| 83 |
+
return np.zeros(5, dtype=np.float32)
|
| 84 |
+
pts = pts[np.isfinite(pts).all(axis=1)]
|
| 85 |
+
if pts.shape[0] < 3:
|
| 86 |
+
return np.zeros(5, dtype=np.float32)
|
| 87 |
+
|
| 88 |
+
mean = pts.mean(axis=0, keepdims=True)
|
| 89 |
+
centered = pts - mean
|
| 90 |
+
scale = np.sqrt(np.mean(np.sum(centered ** 2, axis=1)))
|
| 91 |
+
centered = centered / (scale + 1e-6)
|
| 92 |
+
cov = (centered.T @ centered) / max(1, centered.shape[0])
|
| 93 |
+
if not np.isfinite(cov).all():
|
| 94 |
+
return np.zeros(5, dtype=np.float32)
|
| 95 |
+
|
| 96 |
+
ev = np.linalg.eigvalsh(cov)
|
| 97 |
+
ev = np.sort(ev)[::-1]
|
| 98 |
+
ev = np.maximum(ev, 0.0)
|
| 99 |
+
total = ev.sum()
|
| 100 |
+
if not np.isfinite(total) or total <= 0.0:
|
| 101 |
+
return np.zeros(5, dtype=np.float32)
|
| 102 |
+
|
| 103 |
+
ratios = ev / total
|
| 104 |
+
line_rmse = np.sqrt(max(ev[1] + ev[2], 0.0))
|
| 105 |
+
plane_rmse = np.sqrt(max(ev[2], 0.0))
|
| 106 |
+
feats = np.array(
|
| 107 |
+
[ratios[0], ratios[1], ratios[2], line_rmse, plane_rmse],
|
| 108 |
+
dtype=np.float32,
|
| 109 |
+
)
|
| 110 |
+
if not np.isfinite(feats).all():
|
| 111 |
+
return np.zeros(5, dtype=np.float32)
|
| 112 |
+
return feats
|
| 113 |
+
|
| 114 |
+
def load_coedge_arrays(npz_path: Path) -> Dict[str, np.ndarray]:
|
| 115 |
+
"""
|
| 116 |
+
Load node features and adjacency indices from a BRep extractor npz.
|
| 117 |
+
Returns a dict with coedge/face/edge/global features and topology arrays.
|
| 118 |
+
"""
|
| 119 |
+
with np.load(npz_path) as data:
|
| 120 |
+
coedge_feats = _flatten(data["coedge_features"])
|
| 121 |
+
scale = np.asarray(data["coedge_scale_factors"], dtype=np.float32)[:, None]
|
| 122 |
+
reverse = np.asarray(data["coedge_reverse_flags"], dtype=np.float32)[:, None]
|
| 123 |
+
point_grids = _flatten(data["coedge_point_grids"]) # [N, 12*U]
|
| 124 |
+
lcs = _flatten(data["coedge_lcs"]) # [N, 16]
|
| 125 |
+
|
| 126 |
+
face_idx = np.asarray(data["face"], dtype=np.int64)
|
| 127 |
+
edge_idx = np.asarray(data["edge"], dtype=np.int64)
|
| 128 |
+
face_feats = np.asarray(data["face_features"], dtype=np.float32) # [F, 7]
|
| 129 |
+
edge_feats = np.asarray(data["edge_features"], dtype=np.float32) # [E, 10]
|
| 130 |
+
|
| 131 |
+
face_grid_stats = _face_grid_stats(data["face_point_grids"])
|
| 132 |
+
|
| 133 |
+
coedge_x = np.concatenate(
|
| 134 |
+
[coedge_feats, scale, reverse, point_grids, lcs], axis=1
|
| 135 |
+
)
|
| 136 |
+
face_x = np.concatenate([face_feats, face_grid_stats], axis=1)
|
| 137 |
+
edge_x = edge_feats
|
| 138 |
+
next_index = np.asarray(data["next"], dtype=np.int64)
|
| 139 |
+
mate_index = np.asarray(data["mate"], dtype=np.int64)
|
| 140 |
+
global_features = compute_global_geom_features(data)
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"coedge_x": coedge_x,
|
| 144 |
+
"face_x": face_x,
|
| 145 |
+
"edge_x": edge_x,
|
| 146 |
+
"next": next_index,
|
| 147 |
+
"mate": mate_index,
|
| 148 |
+
"coedge_face": face_idx,
|
| 149 |
+
"coedge_edge": edge_idx,
|
| 150 |
+
"global_x": global_features,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def make_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
|
| 155 |
+
"""
|
| 156 |
+
Build a 2 x E tensor of edge indices (with both directions, deduplicated).
|
| 157 |
+
"""
|
| 158 |
+
pairs = np.stack([source, target], axis=1)
|
| 159 |
+
flipped = pairs[:, ::-1]
|
| 160 |
+
all_pairs = np.concatenate([pairs, flipped], axis=0)
|
| 161 |
+
all_pairs = np.unique(all_pairs, axis=0)
|
| 162 |
+
return torch.tensor(all_pairs.T, dtype=torch.long)
|
| 163 |
+
|
| 164 |
+
def make_directed_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
|
| 165 |
+
"""
|
| 166 |
+
Build a 2 x E tensor of directed edge indices (no deduplication).
|
| 167 |
+
"""
|
| 168 |
+
return torch.tensor(np.stack([source, target], axis=0), dtype=torch.long)
|
| 169 |
+
|
| 170 |
+
def make_bipartite_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
|
| 171 |
+
"""
|
| 172 |
+
Build a 2 x E tensor of directed bipartite edge indices (deduplicated).
|
| 173 |
+
"""
|
| 174 |
+
pairs = np.stack([source, target], axis=1)
|
| 175 |
+
pairs = np.unique(pairs, axis=0)
|
| 176 |
+
return torch.tensor(pairs.T, dtype=torch.long)
|
| 177 |
+
|
| 178 |
+
def make_heterodata(
|
| 179 |
+
coedge_x: np.ndarray,
|
| 180 |
+
face_x: np.ndarray,
|
| 181 |
+
edge_x: np.ndarray,
|
| 182 |
+
next_index: np.ndarray,
|
| 183 |
+
mate_index: np.ndarray,
|
| 184 |
+
coedge_face: np.ndarray,
|
| 185 |
+
coedge_edge: np.ndarray,
|
| 186 |
+
global_features: np.ndarray,
|
| 187 |
+
label: int | None,
|
| 188 |
+
norm_stats: Dict[str, Dict[str, np.ndarray | torch.Tensor]] | None = None,
|
| 189 |
+
) -> HeteroData:
|
| 190 |
+
"""
|
| 191 |
+
Create a PyG HeteroData graph for the coedge features/relations.
|
| 192 |
+
When mean/std are provided the features are normalised element-wise.
|
| 193 |
+
"""
|
| 194 |
+
def _normalize(x_arr: np.ndarray, stats: Dict[str, np.ndarray | torch.Tensor] | None) -> torch.Tensor:
|
| 195 |
+
x_t = torch.tensor(x_arr, dtype=torch.float32)
|
| 196 |
+
if stats is None:
|
| 197 |
+
return x_t
|
| 198 |
+
mean = stats.get("mean")
|
| 199 |
+
std = stats.get("std")
|
| 200 |
+
if mean is None or std is None:
|
| 201 |
+
return x_t
|
| 202 |
+
mean_t = torch.as_tensor(mean, dtype=torch.float32)
|
| 203 |
+
std_t = torch.as_tensor(std, dtype=torch.float32)
|
| 204 |
+
return (x_t - mean_t) / std_t
|
| 205 |
+
|
| 206 |
+
coedge_stats = norm_stats.get("coedge") if norm_stats else None
|
| 207 |
+
face_stats = norm_stats.get("face") if norm_stats else None
|
| 208 |
+
edge_stats = norm_stats.get("edge") if norm_stats else None
|
| 209 |
+
|
| 210 |
+
x_coedge = _normalize(coedge_x, coedge_stats)
|
| 211 |
+
x_face = _normalize(face_x, face_stats)
|
| 212 |
+
x_edge = _normalize(edge_x, edge_stats)
|
| 213 |
+
|
| 214 |
+
idx = np.arange(coedge_x.shape[0], dtype=np.int64)
|
| 215 |
+
edge_next = make_directed_edge_index(idx, next_index)
|
| 216 |
+
edge_prev = make_directed_edge_index(next_index, idx)
|
| 217 |
+
edge_mate = make_edge_index(idx, mate_index)
|
| 218 |
+
edge_coedge_face = make_directed_edge_index(idx, coedge_face)
|
| 219 |
+
edge_face_coedge = make_directed_edge_index(coedge_face, idx)
|
| 220 |
+
edge_coedge_edge = make_directed_edge_index(idx, coedge_edge)
|
| 221 |
+
edge_edge_coedge = make_directed_edge_index(coedge_edge, idx)
|
| 222 |
+
edge_face_edge = make_bipartite_edge_index(coedge_face, coedge_edge)
|
| 223 |
+
edge_edge_face = make_bipartite_edge_index(coedge_edge, coedge_face)
|
| 224 |
+
|
| 225 |
+
data = HeteroData()
|
| 226 |
+
data["coedge"].x = x_coedge
|
| 227 |
+
data["face"].x = x_face
|
| 228 |
+
data["edge"].x = x_edge
|
| 229 |
+
data["global"].x = torch.tensor(global_features, dtype=torch.float32).view(1, -1)
|
| 230 |
+
data["coedge", "next", "coedge"].edge_index = edge_next
|
| 231 |
+
data["coedge", "prev", "coedge"].edge_index = edge_prev
|
| 232 |
+
data["coedge", "mate", "coedge"].edge_index = edge_mate
|
| 233 |
+
data["coedge", "to_face", "face"].edge_index = edge_coedge_face
|
| 234 |
+
data["face", "to_coedge", "coedge"].edge_index = edge_face_coedge
|
| 235 |
+
data["coedge", "to_edge", "edge"].edge_index = edge_coedge_edge
|
| 236 |
+
data["edge", "to_coedge", "coedge"].edge_index = edge_edge_coedge
|
| 237 |
+
data["face", "to_edge", "edge"].edge_index = edge_face_edge
|
| 238 |
+
data["edge", "to_face", "face"].edge_index = edge_edge_face
|
| 239 |
+
if label is not None:
|
| 240 |
+
data.y = torch.tensor([int(label)], dtype=torch.long)
|
| 241 |
+
return data
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def compute_feature_stats(npz_paths: Iterable[Path]) -> Dict[str, np.ndarray]:
|
| 245 |
+
"""
|
| 246 |
+
Compute mean and std (per feature dimension) across all node features in the dataset.
|
| 247 |
+
"""
|
| 248 |
+
totals = {"coedge": 0, "face": 0, "edge": 0}
|
| 249 |
+
sum_vec: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}
|
| 250 |
+
sum_sq: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}
|
| 251 |
+
|
| 252 |
+
for path in npz_paths:
|
| 253 |
+
graph = load_coedge_arrays(path)
|
| 254 |
+
for key, x in (("coedge", graph["coedge_x"]), ("face", graph["face_x"]), ("edge", graph["edge_x"])):
|
| 255 |
+
if sum_vec[key] is None:
|
| 256 |
+
sum_vec[key] = np.zeros(x.shape[1], dtype=np.float64)
|
| 257 |
+
sum_sq[key] = np.zeros(x.shape[1], dtype=np.float64)
|
| 258 |
+
sum_vec[key] += x.sum(axis=0)
|
| 259 |
+
sum_sq[key] += (x * x).sum(axis=0)
|
| 260 |
+
totals[key] += x.shape[0]
|
| 261 |
+
|
| 262 |
+
out = {}
|
| 263 |
+
for key in ("coedge", "face", "edge"):
|
| 264 |
+
if sum_vec[key] is None or totals[key] == 0:
|
| 265 |
+
raise RuntimeError(f"Cannot compute feature stats: no {key} features observed.")
|
| 266 |
+
mean = sum_vec[key] / totals[key]
|
| 267 |
+
var = sum_sq[key] / totals[key] - mean * mean
|
| 268 |
+
var = np.maximum(var, 1e-12)
|
| 269 |
+
std = np.sqrt(var)
|
| 270 |
+
out[key] = {"mean": mean.astype(np.float32), "std": std.astype(np.float32)}
|
| 271 |
+
return out
|