sample / brep_extractor_utils.py
Silly98's picture
Upload 2 files
dc71d7e verified
"""
Utility helpers for loading BRep extractor-processed STEP data as PyG graphs.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, Iterable, Tuple
import numpy as np
import torch
from torch_geometric.data import HeteroData
# Label mapping for the current project
LABELS: Dict[str, int] = {"pipe": 0, "elbow": 1, "tjoint": 2, "random": 3}
STEP_EXTS = ("*.step", "*.stp", "*.STEP", "*.STP")
def build_label_map(step_root: Path) -> Dict[str, int]:
"""
Scan the STEP directory tree (containing /pipe, /elbow, /tjoint, ...)
and build a mapping from file stem to integer label.
"""
mapping: Dict[str, int] = {}
for cls, label in LABELS.items():
cls_dir = step_root / cls
if not cls_dir.exists():
continue
for ext in STEP_EXTS:
for file in cls_dir.glob(ext):
mapping[file.stem] = label
if not mapping:
raise RuntimeError(f"No STEP files found under {step_root} for any of {tuple(LABELS)}")
return mapping
def _flatten(arr: np.ndarray) -> np.ndarray:
return np.asarray(arr, dtype=np.float32).reshape(arr.shape[0], -1)
def _face_grid_stats(face_grids: np.ndarray) -> np.ndarray:
"""
Summarize face point grids into compact stats per face.
Returns [F, 10]: xyz_mean (3), xyz_std (3), nrm_mean (3), mask_frac (1).
"""
face_grids = np.asarray(face_grids, dtype=np.float32)
f = face_grids.shape[0]
xyz = face_grids[:, 0:3, :, :].reshape(f, 3, -1)
nrm = face_grids[:, 3:6, :, :].reshape(f, 3, -1)
msk = face_grids[:, 6, :, :].reshape(f, -1)
mask = (msk > 0.5).astype(np.float32)
mask_frac = mask.mean(axis=1, keepdims=True)
w = mask / (mask.sum(axis=1, keepdims=True) + 1e-6)
xyz_mean = (xyz * w[:, None, :]).sum(axis=2)
xyz_var = (w[:, None, :] * (xyz - xyz_mean[:, :, None]) ** 2).sum(axis=2)
xyz_std = np.sqrt(np.maximum(xyz_var, 1e-12))
nrm_mean = (nrm * w[:, None, :]).sum(axis=2)
return np.concatenate([xyz_mean, xyz_std, nrm_mean, mask_frac], axis=1)
def compute_global_geom_features(data) -> np.ndarray:
"""
Compute compact global geometry descriptors from face/coedge point samples.
Returns [5] float32: pca_ev_ratio_1/2/3, line_fit_rmse, plane_fit_rmse.
"""
points = []
face_grids = np.asarray(data["face_point_grids"], dtype=np.float32)
if face_grids.size:
xyz = face_grids[:, 0:3, :, :].transpose(0, 2, 3, 1).reshape(-1, 3)
mask = face_grids[:, 6, :, :].reshape(-1) > 0.5
if mask.any():
points.append(xyz[mask])
coedge_grids = np.asarray(data["coedge_point_grids"], dtype=np.float32)
if coedge_grids.size:
co_xyz = coedge_grids[:, 0:3, :].transpose(0, 2, 1).reshape(-1, 3)
points.append(co_xyz)
if not points:
return np.zeros(5, dtype=np.float32)
pts = np.concatenate(points, axis=0)
if pts.shape[0] < 3:
return np.zeros(5, dtype=np.float32)
pts = pts[np.isfinite(pts).all(axis=1)]
if pts.shape[0] < 3:
return np.zeros(5, dtype=np.float32)
mean = pts.mean(axis=0, keepdims=True)
centered = pts - mean
scale = np.sqrt(np.mean(np.sum(centered ** 2, axis=1)))
centered = centered / (scale + 1e-6)
cov = (centered.T @ centered) / max(1, centered.shape[0])
if not np.isfinite(cov).all():
return np.zeros(5, dtype=np.float32)
ev = np.linalg.eigvalsh(cov)
ev = np.sort(ev)[::-1]
ev = np.maximum(ev, 0.0)
total = ev.sum()
if not np.isfinite(total) or total <= 0.0:
return np.zeros(5, dtype=np.float32)
ratios = ev / total
line_rmse = np.sqrt(max(ev[1] + ev[2], 0.0))
plane_rmse = np.sqrt(max(ev[2], 0.0))
feats = np.array(
[ratios[0], ratios[1], ratios[2], line_rmse, plane_rmse],
dtype=np.float32,
)
if not np.isfinite(feats).all():
return np.zeros(5, dtype=np.float32)
return feats
def load_coedge_arrays(npz_path: Path) -> Dict[str, np.ndarray]:
"""
Load node features and adjacency indices from a BRep extractor npz.
Returns a dict with coedge/face/edge/global features and topology arrays.
"""
with np.load(npz_path) as data:
coedge_feats = _flatten(data["coedge_features"])
scale = np.asarray(data["coedge_scale_factors"], dtype=np.float32)[:, None]
reverse = np.asarray(data["coedge_reverse_flags"], dtype=np.float32)[:, None]
point_grids = _flatten(data["coedge_point_grids"]) # [N, 12*U]
lcs = _flatten(data["coedge_lcs"]) # [N, 16]
face_idx = np.asarray(data["face"], dtype=np.int64)
edge_idx = np.asarray(data["edge"], dtype=np.int64)
face_feats = np.asarray(data["face_features"], dtype=np.float32) # [F, 7]
edge_feats = np.asarray(data["edge_features"], dtype=np.float32) # [E, 10]
face_grid_stats = _face_grid_stats(data["face_point_grids"])
coedge_x = np.concatenate(
[coedge_feats, scale, reverse, point_grids, lcs], axis=1
)
face_x = np.concatenate([face_feats, face_grid_stats], axis=1)
edge_x = edge_feats
next_index = np.asarray(data["next"], dtype=np.int64)
mate_index = np.asarray(data["mate"], dtype=np.int64)
global_features = compute_global_geom_features(data)
return {
"coedge_x": coedge_x,
"face_x": face_x,
"edge_x": edge_x,
"next": next_index,
"mate": mate_index,
"coedge_face": face_idx,
"coedge_edge": edge_idx,
"global_x": global_features,
}
def make_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
"""
Build a 2 x E tensor of edge indices (with both directions, deduplicated).
"""
pairs = np.stack([source, target], axis=1)
flipped = pairs[:, ::-1]
all_pairs = np.concatenate([pairs, flipped], axis=0)
all_pairs = np.unique(all_pairs, axis=0)
return torch.tensor(all_pairs.T, dtype=torch.long)
def make_directed_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
"""
Build a 2 x E tensor of directed edge indices (no deduplication).
"""
return torch.tensor(np.stack([source, target], axis=0), dtype=torch.long)
def make_bipartite_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
"""
Build a 2 x E tensor of directed bipartite edge indices (deduplicated).
"""
pairs = np.stack([source, target], axis=1)
pairs = np.unique(pairs, axis=0)
return torch.tensor(pairs.T, dtype=torch.long)
def make_heterodata(
coedge_x: np.ndarray,
face_x: np.ndarray,
edge_x: np.ndarray,
next_index: np.ndarray,
mate_index: np.ndarray,
coedge_face: np.ndarray,
coedge_edge: np.ndarray,
global_features: np.ndarray,
label: int | None,
norm_stats: Dict[str, Dict[str, np.ndarray | torch.Tensor]] | None = None,
) -> HeteroData:
"""
Create a PyG HeteroData graph for the coedge features/relations.
When mean/std are provided the features are normalised element-wise.
"""
def _normalize(x_arr: np.ndarray, stats: Dict[str, np.ndarray | torch.Tensor] | None) -> torch.Tensor:
x_t = torch.tensor(x_arr, dtype=torch.float32)
if stats is None:
return x_t
mean = stats.get("mean")
std = stats.get("std")
if mean is None or std is None:
return x_t
mean_t = torch.as_tensor(mean, dtype=torch.float32)
std_t = torch.as_tensor(std, dtype=torch.float32)
return (x_t - mean_t) / std_t
coedge_stats = norm_stats.get("coedge") if norm_stats else None
face_stats = norm_stats.get("face") if norm_stats else None
edge_stats = norm_stats.get("edge") if norm_stats else None
x_coedge = _normalize(coedge_x, coedge_stats)
x_face = _normalize(face_x, face_stats)
x_edge = _normalize(edge_x, edge_stats)
idx = np.arange(coedge_x.shape[0], dtype=np.int64)
edge_next = make_directed_edge_index(idx, next_index)
edge_prev = make_directed_edge_index(next_index, idx)
edge_mate = make_edge_index(idx, mate_index)
edge_coedge_face = make_directed_edge_index(idx, coedge_face)
edge_face_coedge = make_directed_edge_index(coedge_face, idx)
edge_coedge_edge = make_directed_edge_index(idx, coedge_edge)
edge_edge_coedge = make_directed_edge_index(coedge_edge, idx)
edge_face_edge = make_bipartite_edge_index(coedge_face, coedge_edge)
edge_edge_face = make_bipartite_edge_index(coedge_edge, coedge_face)
data = HeteroData()
data["coedge"].x = x_coedge
data["face"].x = x_face
data["edge"].x = x_edge
data["global"].x = torch.tensor(global_features, dtype=torch.float32).view(1, -1)
data["coedge", "next", "coedge"].edge_index = edge_next
data["coedge", "prev", "coedge"].edge_index = edge_prev
data["coedge", "mate", "coedge"].edge_index = edge_mate
data["coedge", "to_face", "face"].edge_index = edge_coedge_face
data["face", "to_coedge", "coedge"].edge_index = edge_face_coedge
data["coedge", "to_edge", "edge"].edge_index = edge_coedge_edge
data["edge", "to_coedge", "coedge"].edge_index = edge_edge_coedge
data["face", "to_edge", "edge"].edge_index = edge_face_edge
data["edge", "to_face", "face"].edge_index = edge_edge_face
if label is not None:
data.y = torch.tensor([int(label)], dtype=torch.long)
return data
def compute_feature_stats(npz_paths: Iterable[Path]) -> Dict[str, np.ndarray]:
"""
Compute mean and std (per feature dimension) across all node features in the dataset.
"""
totals = {"coedge": 0, "face": 0, "edge": 0}
sum_vec: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}
sum_sq: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}
for path in npz_paths:
graph = load_coedge_arrays(path)
for key, x in (("coedge", graph["coedge_x"]), ("face", graph["face_x"]), ("edge", graph["edge_x"])):
if sum_vec[key] is None:
sum_vec[key] = np.zeros(x.shape[1], dtype=np.float64)
sum_sq[key] = np.zeros(x.shape[1], dtype=np.float64)
sum_vec[key] += x.sum(axis=0)
sum_sq[key] += (x * x).sum(axis=0)
totals[key] += x.shape[0]
out = {}
for key in ("coedge", "face", "edge"):
if sum_vec[key] is None or totals[key] == 0:
raise RuntimeError(f"Cannot compute feature stats: no {key} features observed.")
mean = sum_vec[key] / totals[key]
var = sum_sq[key] / totals[key] - mean * mean
var = np.maximum(var, 1e-12)
std = np.sqrt(var)
out[key] = {"mean": mean.astype(np.float32), "std": std.astype(np.float32)}
return out