""" Utility helpers for loading BRep extractor-processed STEP data as PyG graphs. """ from __future__ import annotations from pathlib import Path from typing import Dict, Iterable, Tuple import numpy as np import torch from torch_geometric.data import HeteroData # Label mapping for the current project LABELS: Dict[str, int] = {"pipe": 0, "elbow": 1, "tjoint": 2, "random": 3} STEP_EXTS = ("*.step", "*.stp", "*.STEP", "*.STP") def build_label_map(step_root: Path) -> Dict[str, int]: """ Scan the STEP directory tree (containing /pipe, /elbow, /tjoint, ...) and build a mapping from file stem to integer label. """ mapping: Dict[str, int] = {} for cls, label in LABELS.items(): cls_dir = step_root / cls if not cls_dir.exists(): continue for ext in STEP_EXTS: for file in cls_dir.glob(ext): mapping[file.stem] = label if not mapping: raise RuntimeError(f"No STEP files found under {step_root} for any of {tuple(LABELS)}") return mapping def _flatten(arr: np.ndarray) -> np.ndarray: return np.asarray(arr, dtype=np.float32).reshape(arr.shape[0], -1) def _face_grid_stats(face_grids: np.ndarray) -> np.ndarray: """ Summarize face point grids into compact stats per face. Returns [F, 10]: xyz_mean (3), xyz_std (3), nrm_mean (3), mask_frac (1). """ face_grids = np.asarray(face_grids, dtype=np.float32) f = face_grids.shape[0] xyz = face_grids[:, 0:3, :, :].reshape(f, 3, -1) nrm = face_grids[:, 3:6, :, :].reshape(f, 3, -1) msk = face_grids[:, 6, :, :].reshape(f, -1) mask = (msk > 0.5).astype(np.float32) mask_frac = mask.mean(axis=1, keepdims=True) w = mask / (mask.sum(axis=1, keepdims=True) + 1e-6) xyz_mean = (xyz * w[:, None, :]).sum(axis=2) xyz_var = (w[:, None, :] * (xyz - xyz_mean[:, :, None]) ** 2).sum(axis=2) xyz_std = np.sqrt(np.maximum(xyz_var, 1e-12)) nrm_mean = (nrm * w[:, None, :]).sum(axis=2) return np.concatenate([xyz_mean, xyz_std, nrm_mean, mask_frac], axis=1) def compute_global_geom_features(data) -> np.ndarray: """ Compute compact global geometry descriptors from face/coedge point samples. Returns [5] float32: pca_ev_ratio_1/2/3, line_fit_rmse, plane_fit_rmse. """ points = [] face_grids = np.asarray(data["face_point_grids"], dtype=np.float32) if face_grids.size: xyz = face_grids[:, 0:3, :, :].transpose(0, 2, 3, 1).reshape(-1, 3) mask = face_grids[:, 6, :, :].reshape(-1) > 0.5 if mask.any(): points.append(xyz[mask]) coedge_grids = np.asarray(data["coedge_point_grids"], dtype=np.float32) if coedge_grids.size: co_xyz = coedge_grids[:, 0:3, :].transpose(0, 2, 1).reshape(-1, 3) points.append(co_xyz) if not points: return np.zeros(5, dtype=np.float32) pts = np.concatenate(points, axis=0) if pts.shape[0] < 3: return np.zeros(5, dtype=np.float32) pts = pts[np.isfinite(pts).all(axis=1)] if pts.shape[0] < 3: return np.zeros(5, dtype=np.float32) mean = pts.mean(axis=0, keepdims=True) centered = pts - mean scale = np.sqrt(np.mean(np.sum(centered ** 2, axis=1))) centered = centered / (scale + 1e-6) cov = (centered.T @ centered) / max(1, centered.shape[0]) if not np.isfinite(cov).all(): return np.zeros(5, dtype=np.float32) ev = np.linalg.eigvalsh(cov) ev = np.sort(ev)[::-1] ev = np.maximum(ev, 0.0) total = ev.sum() if not np.isfinite(total) or total <= 0.0: return np.zeros(5, dtype=np.float32) ratios = ev / total line_rmse = np.sqrt(max(ev[1] + ev[2], 0.0)) plane_rmse = np.sqrt(max(ev[2], 0.0)) feats = np.array( [ratios[0], ratios[1], ratios[2], line_rmse, plane_rmse], dtype=np.float32, ) if not np.isfinite(feats).all(): return np.zeros(5, dtype=np.float32) return feats def load_coedge_arrays(npz_path: Path) -> Dict[str, np.ndarray]: """ Load node features and adjacency indices from a BRep extractor npz. Returns a dict with coedge/face/edge/global features and topology arrays. """ with np.load(npz_path) as data: coedge_feats = _flatten(data["coedge_features"]) scale = np.asarray(data["coedge_scale_factors"], dtype=np.float32)[:, None] reverse = np.asarray(data["coedge_reverse_flags"], dtype=np.float32)[:, None] point_grids = _flatten(data["coedge_point_grids"]) # [N, 12*U] lcs = _flatten(data["coedge_lcs"]) # [N, 16] face_idx = np.asarray(data["face"], dtype=np.int64) edge_idx = np.asarray(data["edge"], dtype=np.int64) face_feats = np.asarray(data["face_features"], dtype=np.float32) # [F, 7] edge_feats = np.asarray(data["edge_features"], dtype=np.float32) # [E, 10] face_grid_stats = _face_grid_stats(data["face_point_grids"]) coedge_x = np.concatenate( [coedge_feats, scale, reverse, point_grids, lcs], axis=1 ) face_x = np.concatenate([face_feats, face_grid_stats], axis=1) edge_x = edge_feats next_index = np.asarray(data["next"], dtype=np.int64) mate_index = np.asarray(data["mate"], dtype=np.int64) global_features = compute_global_geom_features(data) return { "coedge_x": coedge_x, "face_x": face_x, "edge_x": edge_x, "next": next_index, "mate": mate_index, "coedge_face": face_idx, "coedge_edge": edge_idx, "global_x": global_features, } def make_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor: """ Build a 2 x E tensor of edge indices (with both directions, deduplicated). """ pairs = np.stack([source, target], axis=1) flipped = pairs[:, ::-1] all_pairs = np.concatenate([pairs, flipped], axis=0) all_pairs = np.unique(all_pairs, axis=0) return torch.tensor(all_pairs.T, dtype=torch.long) def make_directed_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor: """ Build a 2 x E tensor of directed edge indices (no deduplication). """ return torch.tensor(np.stack([source, target], axis=0), dtype=torch.long) def make_bipartite_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor: """ Build a 2 x E tensor of directed bipartite edge indices (deduplicated). """ pairs = np.stack([source, target], axis=1) pairs = np.unique(pairs, axis=0) return torch.tensor(pairs.T, dtype=torch.long) def make_heterodata( coedge_x: np.ndarray, face_x: np.ndarray, edge_x: np.ndarray, next_index: np.ndarray, mate_index: np.ndarray, coedge_face: np.ndarray, coedge_edge: np.ndarray, global_features: np.ndarray, label: int | None, norm_stats: Dict[str, Dict[str, np.ndarray | torch.Tensor]] | None = None, ) -> HeteroData: """ Create a PyG HeteroData graph for the coedge features/relations. When mean/std are provided the features are normalised element-wise. """ def _normalize(x_arr: np.ndarray, stats: Dict[str, np.ndarray | torch.Tensor] | None) -> torch.Tensor: x_t = torch.tensor(x_arr, dtype=torch.float32) if stats is None: return x_t mean = stats.get("mean") std = stats.get("std") if mean is None or std is None: return x_t mean_t = torch.as_tensor(mean, dtype=torch.float32) std_t = torch.as_tensor(std, dtype=torch.float32) return (x_t - mean_t) / std_t coedge_stats = norm_stats.get("coedge") if norm_stats else None face_stats = norm_stats.get("face") if norm_stats else None edge_stats = norm_stats.get("edge") if norm_stats else None x_coedge = _normalize(coedge_x, coedge_stats) x_face = _normalize(face_x, face_stats) x_edge = _normalize(edge_x, edge_stats) idx = np.arange(coedge_x.shape[0], dtype=np.int64) edge_next = make_directed_edge_index(idx, next_index) edge_prev = make_directed_edge_index(next_index, idx) edge_mate = make_edge_index(idx, mate_index) edge_coedge_face = make_directed_edge_index(idx, coedge_face) edge_face_coedge = make_directed_edge_index(coedge_face, idx) edge_coedge_edge = make_directed_edge_index(idx, coedge_edge) edge_edge_coedge = make_directed_edge_index(coedge_edge, idx) edge_face_edge = make_bipartite_edge_index(coedge_face, coedge_edge) edge_edge_face = make_bipartite_edge_index(coedge_edge, coedge_face) data = HeteroData() data["coedge"].x = x_coedge data["face"].x = x_face data["edge"].x = x_edge data["global"].x = torch.tensor(global_features, dtype=torch.float32).view(1, -1) data["coedge", "next", "coedge"].edge_index = edge_next data["coedge", "prev", "coedge"].edge_index = edge_prev data["coedge", "mate", "coedge"].edge_index = edge_mate data["coedge", "to_face", "face"].edge_index = edge_coedge_face data["face", "to_coedge", "coedge"].edge_index = edge_face_coedge data["coedge", "to_edge", "edge"].edge_index = edge_coedge_edge data["edge", "to_coedge", "coedge"].edge_index = edge_edge_coedge data["face", "to_edge", "edge"].edge_index = edge_face_edge data["edge", "to_face", "face"].edge_index = edge_edge_face if label is not None: data.y = torch.tensor([int(label)], dtype=torch.long) return data def compute_feature_stats(npz_paths: Iterable[Path]) -> Dict[str, np.ndarray]: """ Compute mean and std (per feature dimension) across all node features in the dataset. """ totals = {"coedge": 0, "face": 0, "edge": 0} sum_vec: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None} sum_sq: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None} for path in npz_paths: graph = load_coedge_arrays(path) for key, x in (("coedge", graph["coedge_x"]), ("face", graph["face_x"]), ("edge", graph["edge_x"])): if sum_vec[key] is None: sum_vec[key] = np.zeros(x.shape[1], dtype=np.float64) sum_sq[key] = np.zeros(x.shape[1], dtype=np.float64) sum_vec[key] += x.sum(axis=0) sum_sq[key] += (x * x).sum(axis=0) totals[key] += x.shape[0] out = {} for key in ("coedge", "face", "edge"): if sum_vec[key] is None or totals[key] == 0: raise RuntimeError(f"Cannot compute feature stats: no {key} features observed.") mean = sum_vec[key] / totals[key] var = sum_sq[key] / totals[key] - mean * mean var = np.maximum(var, 1e-12) std = np.sqrt(var) out[key] = {"mean": mean.astype(np.float32), "std": std.astype(np.float32)} return out