File size: 10,803 Bytes
dc71d7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""
Utility helpers for loading BRep extractor-processed STEP data as PyG graphs.
"""
from __future__ import annotations

from pathlib import Path
from typing import Dict, Iterable, Tuple

import numpy as np
import torch
from torch_geometric.data import HeteroData

# Label mapping for the current project
LABELS: Dict[str, int] = {"pipe": 0, "elbow": 1, "tjoint": 2, "random": 3}
STEP_EXTS = ("*.step", "*.stp", "*.STEP", "*.STP")


def build_label_map(step_root: Path) -> Dict[str, int]:
    """
    Scan the STEP directory tree (containing /pipe, /elbow, /tjoint, ...)
    and build a mapping from file stem to integer label.
    """
    mapping: Dict[str, int] = {}
    for cls, label in LABELS.items():
        cls_dir = step_root / cls
        if not cls_dir.exists():
            continue
        for ext in STEP_EXTS:
            for file in cls_dir.glob(ext):
                mapping[file.stem] = label
    if not mapping:
        raise RuntimeError(f"No STEP files found under {step_root} for any of {tuple(LABELS)}")
    return mapping


def _flatten(arr: np.ndarray) -> np.ndarray:
    return np.asarray(arr, dtype=np.float32).reshape(arr.shape[0], -1)

def _face_grid_stats(face_grids: np.ndarray) -> np.ndarray:
    """
    Summarize face point grids into compact stats per face.
    Returns [F, 10]: xyz_mean (3), xyz_std (3), nrm_mean (3), mask_frac (1).
    """
    face_grids = np.asarray(face_grids, dtype=np.float32)
    f = face_grids.shape[0]
    xyz = face_grids[:, 0:3, :, :].reshape(f, 3, -1)
    nrm = face_grids[:, 3:6, :, :].reshape(f, 3, -1)
    msk = face_grids[:, 6, :, :].reshape(f, -1)

    mask = (msk > 0.5).astype(np.float32)
    mask_frac = mask.mean(axis=1, keepdims=True)
    w = mask / (mask.sum(axis=1, keepdims=True) + 1e-6)

    xyz_mean = (xyz * w[:, None, :]).sum(axis=2)
    xyz_var = (w[:, None, :] * (xyz - xyz_mean[:, :, None]) ** 2).sum(axis=2)
    xyz_std = np.sqrt(np.maximum(xyz_var, 1e-12))
    nrm_mean = (nrm * w[:, None, :]).sum(axis=2)
    return np.concatenate([xyz_mean, xyz_std, nrm_mean, mask_frac], axis=1)

def compute_global_geom_features(data) -> np.ndarray:
    """
    Compute compact global geometry descriptors from face/coedge point samples.
    Returns [5] float32: pca_ev_ratio_1/2/3, line_fit_rmse, plane_fit_rmse.
    """
    points = []
    face_grids = np.asarray(data["face_point_grids"], dtype=np.float32)
    if face_grids.size:
        xyz = face_grids[:, 0:3, :, :].transpose(0, 2, 3, 1).reshape(-1, 3)
        mask = face_grids[:, 6, :, :].reshape(-1) > 0.5
        if mask.any():
            points.append(xyz[mask])

    coedge_grids = np.asarray(data["coedge_point_grids"], dtype=np.float32)
    if coedge_grids.size:
        co_xyz = coedge_grids[:, 0:3, :].transpose(0, 2, 1).reshape(-1, 3)
        points.append(co_xyz)

    if not points:
        return np.zeros(5, dtype=np.float32)

    pts = np.concatenate(points, axis=0)
    if pts.shape[0] < 3:
        return np.zeros(5, dtype=np.float32)
    pts = pts[np.isfinite(pts).all(axis=1)]
    if pts.shape[0] < 3:
        return np.zeros(5, dtype=np.float32)

    mean = pts.mean(axis=0, keepdims=True)
    centered = pts - mean
    scale = np.sqrt(np.mean(np.sum(centered ** 2, axis=1)))
    centered = centered / (scale + 1e-6)
    cov = (centered.T @ centered) / max(1, centered.shape[0])
    if not np.isfinite(cov).all():
        return np.zeros(5, dtype=np.float32)

    ev = np.linalg.eigvalsh(cov)
    ev = np.sort(ev)[::-1]
    ev = np.maximum(ev, 0.0)
    total = ev.sum()
    if not np.isfinite(total) or total <= 0.0:
        return np.zeros(5, dtype=np.float32)

    ratios = ev / total
    line_rmse = np.sqrt(max(ev[1] + ev[2], 0.0))
    plane_rmse = np.sqrt(max(ev[2], 0.0))
    feats = np.array(
        [ratios[0], ratios[1], ratios[2], line_rmse, plane_rmse],
        dtype=np.float32,
    )
    if not np.isfinite(feats).all():
        return np.zeros(5, dtype=np.float32)
    return feats

def load_coedge_arrays(npz_path: Path) -> Dict[str, np.ndarray]:
    """
    Load node features and adjacency indices from a BRep extractor npz.
    Returns a dict with coedge/face/edge/global features and topology arrays.
    """
    with np.load(npz_path) as data:
        coedge_feats = _flatten(data["coedge_features"])
        scale = np.asarray(data["coedge_scale_factors"], dtype=np.float32)[:, None]
        reverse = np.asarray(data["coedge_reverse_flags"], dtype=np.float32)[:, None]
        point_grids = _flatten(data["coedge_point_grids"])  # [N, 12*U]
        lcs = _flatten(data["coedge_lcs"])                  # [N, 16]

        face_idx = np.asarray(data["face"], dtype=np.int64)
        edge_idx = np.asarray(data["edge"], dtype=np.int64)
        face_feats = np.asarray(data["face_features"], dtype=np.float32)  # [F, 7]
        edge_feats = np.asarray(data["edge_features"], dtype=np.float32)  # [E, 10]

        face_grid_stats = _face_grid_stats(data["face_point_grids"])

        coedge_x = np.concatenate(
            [coedge_feats, scale, reverse, point_grids, lcs], axis=1
        )
        face_x = np.concatenate([face_feats, face_grid_stats], axis=1)
        edge_x = edge_feats
        next_index = np.asarray(data["next"], dtype=np.int64)
        mate_index = np.asarray(data["mate"], dtype=np.int64)
        global_features = compute_global_geom_features(data)

    return {
        "coedge_x": coedge_x,
        "face_x": face_x,
        "edge_x": edge_x,
        "next": next_index,
        "mate": mate_index,
        "coedge_face": face_idx,
        "coedge_edge": edge_idx,
        "global_x": global_features,
    }


def make_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
    """
    Build a 2 x E tensor of edge indices (with both directions, deduplicated).
    """
    pairs = np.stack([source, target], axis=1)
    flipped = pairs[:, ::-1]
    all_pairs = np.concatenate([pairs, flipped], axis=0)
    all_pairs = np.unique(all_pairs, axis=0)
    return torch.tensor(all_pairs.T, dtype=torch.long)

def make_directed_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
    """
    Build a 2 x E tensor of directed edge indices (no deduplication).
    """
    return torch.tensor(np.stack([source, target], axis=0), dtype=torch.long)

def make_bipartite_edge_index(source: np.ndarray, target: np.ndarray) -> torch.Tensor:
    """
    Build a 2 x E tensor of directed bipartite edge indices (deduplicated).
    """
    pairs = np.stack([source, target], axis=1)
    pairs = np.unique(pairs, axis=0)
    return torch.tensor(pairs.T, dtype=torch.long)

def make_heterodata(
    coedge_x: np.ndarray,
    face_x: np.ndarray,
    edge_x: np.ndarray,
    next_index: np.ndarray,
    mate_index: np.ndarray,
    coedge_face: np.ndarray,
    coedge_edge: np.ndarray,
    global_features: np.ndarray,
    label: int | None,
    norm_stats: Dict[str, Dict[str, np.ndarray | torch.Tensor]] | None = None,
) -> HeteroData:
    """
    Create a PyG HeteroData graph for the coedge features/relations.
    When mean/std are provided the features are normalised element-wise.
    """
    def _normalize(x_arr: np.ndarray, stats: Dict[str, np.ndarray | torch.Tensor] | None) -> torch.Tensor:
        x_t = torch.tensor(x_arr, dtype=torch.float32)
        if stats is None:
            return x_t
        mean = stats.get("mean")
        std = stats.get("std")
        if mean is None or std is None:
            return x_t
        mean_t = torch.as_tensor(mean, dtype=torch.float32)
        std_t = torch.as_tensor(std, dtype=torch.float32)
        return (x_t - mean_t) / std_t

    coedge_stats = norm_stats.get("coedge") if norm_stats else None
    face_stats = norm_stats.get("face") if norm_stats else None
    edge_stats = norm_stats.get("edge") if norm_stats else None

    x_coedge = _normalize(coedge_x, coedge_stats)
    x_face = _normalize(face_x, face_stats)
    x_edge = _normalize(edge_x, edge_stats)

    idx = np.arange(coedge_x.shape[0], dtype=np.int64)
    edge_next = make_directed_edge_index(idx, next_index)
    edge_prev = make_directed_edge_index(next_index, idx)
    edge_mate = make_edge_index(idx, mate_index)
    edge_coedge_face = make_directed_edge_index(idx, coedge_face)
    edge_face_coedge = make_directed_edge_index(coedge_face, idx)
    edge_coedge_edge = make_directed_edge_index(idx, coedge_edge)
    edge_edge_coedge = make_directed_edge_index(coedge_edge, idx)
    edge_face_edge = make_bipartite_edge_index(coedge_face, coedge_edge)
    edge_edge_face = make_bipartite_edge_index(coedge_edge, coedge_face)

    data = HeteroData()
    data["coedge"].x = x_coedge
    data["face"].x = x_face
    data["edge"].x = x_edge
    data["global"].x = torch.tensor(global_features, dtype=torch.float32).view(1, -1)
    data["coedge", "next", "coedge"].edge_index = edge_next
    data["coedge", "prev", "coedge"].edge_index = edge_prev
    data["coedge", "mate", "coedge"].edge_index = edge_mate
    data["coedge", "to_face", "face"].edge_index = edge_coedge_face
    data["face", "to_coedge", "coedge"].edge_index = edge_face_coedge
    data["coedge", "to_edge", "edge"].edge_index = edge_coedge_edge
    data["edge", "to_coedge", "coedge"].edge_index = edge_edge_coedge
    data["face", "to_edge", "edge"].edge_index = edge_face_edge
    data["edge", "to_face", "face"].edge_index = edge_edge_face
    if label is not None:
        data.y = torch.tensor([int(label)], dtype=torch.long)
    return data


def compute_feature_stats(npz_paths: Iterable[Path]) -> Dict[str, np.ndarray]:
    """
    Compute mean and std (per feature dimension) across all node features in the dataset.
    """
    totals = {"coedge": 0, "face": 0, "edge": 0}
    sum_vec: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}
    sum_sq: Dict[str, np.ndarray | None] = {"coedge": None, "face": None, "edge": None}

    for path in npz_paths:
        graph = load_coedge_arrays(path)
        for key, x in (("coedge", graph["coedge_x"]), ("face", graph["face_x"]), ("edge", graph["edge_x"])):
            if sum_vec[key] is None:
                sum_vec[key] = np.zeros(x.shape[1], dtype=np.float64)
                sum_sq[key] = np.zeros(x.shape[1], dtype=np.float64)
            sum_vec[key] += x.sum(axis=0)
            sum_sq[key] += (x * x).sum(axis=0)
            totals[key] += x.shape[0]

    out = {}
    for key in ("coedge", "face", "edge"):
        if sum_vec[key] is None or totals[key] == 0:
            raise RuntimeError(f"Cannot compute feature stats: no {key} features observed.")
        mean = sum_vec[key] / totals[key]
        var = sum_sq[key] / totals[key] - mean * mean
        var = np.maximum(var, 1e-12)
        std = np.sqrt(var)
        out[key] = {"mean": mean.astype(np.float32), "std": std.astype(np.float32)}
    return out