sample / face_labeler.py
Silly98's picture
Upload face_labeler.py
99bcebb verified
import sys
import os
import argparse
import pathlib
from dataclasses import dataclass
from typing import List, Optional, Tuple
# Qt binding selection + pythonocc backend init
try:
from PyQt5 import QtCore, QtWidgets
_qt_backend = "qt-pyqt5"
except ImportError: # pragma: no cover
from PySide2 import QtCore, QtWidgets
_qt_backend = "qt-pyside2"
from OCC.Display.backend import load_backend
load_backend(_qt_backend)
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.IFSelect import IFSelect_RetDone
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.TopoDS import topods
from OCC.Core.Quantity import Quantity_Color, Quantity_TOC_RGB
from OCC.Display.qtDisplay import qtViewer3d
try:
from OCC.Core.Aspect import Aspect_TOL_SOLID
from OCC.Core.Prs3d import Prs3d_LineAspect
except Exception: # pragma: no cover - optional OCC build
Aspect_TOL_SOLID = None
Prs3d_LineAspect = None
try:
from OCC.Core.Graphic3d import Graphic3d_NOM_MATTE, Graphic3d_NOM_NEON
except Exception: # pragma: no cover - optional OCC build
Graphic3d_NOM_MATTE = None
Graphic3d_NOM_NEON = None
try:
from OCC.Core.Graphic3d import Graphic3d_TOSM_UNLIT
except Exception: # pragma: no cover - optional OCC build
Graphic3d_TOSM_UNLIT = None
CLASS_NAMES = [
"SOL",
"EOS",
"rect_slot",
"tri_slot",
"cir_slot",
"rect_psg",
"tri_psg",
"hexa_psg",
"hole",
"rect_step",
"tside_step",
"slant_step",
"rect_b_step",
"tri_step",
"cir_step",
"rect_b_slot",
"cir_b_slot",
"u_b_slot",
"rect_pkt",
"key_pkt",
"tri_pkt",
"hexa_pkt",
"o_ring",
"b_hole",
"chamfer",
"fillet",
]
CLASS_COLORS_HEX = [
"#1f77b4", # blue
"#ff7f0e", # orange
"#2ca02c", # green
"#d62728", # red
"#9467bd", # purple
"#8c564b", # brown
"#e377c2", # pink
"#7f7f7f", # gray
"#bcbd22", # olive
"#17becf", # cyan
"#393b79", # dark blue
"#637939", # dark green
"#8c6d31", # dark mustard
"#843c39", # dark red
"#7b4173", # dark purple
"#3182bd", # blue alt
"#e6550d", # orange alt
"#31a354", # green alt
"#756bb1", # purple alt
"#636363", # dark gray
"#6baed6", # light blue
"#fd8d3c", # light orange
"#74c476", # light green
"#9e9ac8", # light purple
"#a1d99b", # pale green
"#fdd0a2", # pale orange
]
UNLABELED_COLOR_HEX = "#d0d0d0"
HIGHLIGHT_COLOR_HEX = "#FFD400" # fixed for current selection
EDGE_COLOR_HEX = "#2b2b2b"
def hex_to_rgb01(color_hex: str) -> Tuple[float, float, float]:
color_hex = color_hex.lstrip("#")
r = int(color_hex[0:2], 16) / 255.0
g = int(color_hex[2:4], 16) / 255.0
b = int(color_hex[4:6], 16) / 255.0
return r, g, b
def rgb01_to_quantity(rgb: Tuple[float, float, float]) -> Quantity_Color:
return Quantity_Color(rgb[0], rgb[1], rgb[2], Quantity_TOC_RGB)
def text_color_for_bg(color_hex: str) -> str:
r, g, b = hex_to_rgb01(color_hex)
luminance = (0.299 * r + 0.587 * g + 0.114 * b)
return "#000000" if luminance > 0.6 else "#ffffff"
@dataclass
class FaceItem:
face: object
ais: object
class FaceLabeler(QtWidgets.QMainWindow):
def __init__(self, step_path: Optional[str] = None, output_dir: Optional[str] = None):
super().__init__()
self.setWindowTitle("BRepMFR Face Labeler")
self.resize(1600, 1000)
self.setMinimumSize(1200, 800)
self.class_names = CLASS_NAMES
self.class_colors_rgb = [hex_to_rgb01(c) for c in CLASS_COLORS_HEX]
self.class_colors = [rgb01_to_quantity(c) for c in self.class_colors_rgb]
self.unlabeled_color = rgb01_to_quantity(hex_to_rgb01(UNLABELED_COLOR_HEX))
self.face_items: List[FaceItem] = []
self.labels: List[Optional[int]] = []
self.current_index: Optional[int] = None
self.highlight_enabled = True
self.step_path: Optional[str] = None
self.output_dir: Optional[str] = output_dir
self.highlight_color = rgb01_to_quantity(hex_to_rgb01(HIGHLIGHT_COLOR_HEX))
self._build_ui()
self.update_step_label()
if step_path:
self.load_step(step_path)
def _build_ui(self) -> None:
central = QtWidgets.QWidget(self)
root_layout = QtWidgets.QHBoxLayout(central)
root_layout.setContentsMargins(8, 8, 8, 8)
root_layout.setSpacing(8)
self.viewer = qtViewer3d(central)
self.viewer.InitDriver()
self.display = self.viewer._display
try:
self.display.Context.SetAutomaticHilight(False)
except Exception:
pass
self._configure_viewer_visuals()
root_layout.addWidget(self.viewer, 1)
panel = QtWidgets.QWidget(central)
panel_layout = QtWidgets.QVBoxLayout(panel)
panel_layout.setContentsMargins(0, 0, 0, 0)
panel_layout.setSpacing(6)
root_layout.addWidget(panel, 0)
self.btn_import_step = QtWidgets.QPushButton("Import STEP")
self.btn_import_step.clicked.connect(self.on_import_step)
panel_layout.addWidget(self.btn_import_step)
self.btn_export_seg = QtWidgets.QPushButton("Export .seg")
self.btn_export_seg.clicked.connect(self.on_export_seg)
panel_layout.addWidget(self.btn_export_seg)
self.btn_review = QtWidgets.QPushButton("Review")
self.btn_review.clicked.connect(self.on_review)
panel_layout.addWidget(self.btn_review)
panel_layout.addSpacing(8)
nav_layout = QtWidgets.QHBoxLayout()
self.btn_prev = QtWidgets.QPushButton("<< Prev")
self.btn_prev.clicked.connect(self.on_prev)
self.btn_next = QtWidgets.QPushButton("Next >>")
self.btn_next.clicked.connect(self.on_next)
nav_layout.addWidget(self.btn_prev)
nav_layout.addWidget(self.btn_next)
panel_layout.addLayout(nav_layout)
self.step_label = QtWidgets.QLabel("STEP: (none)")
self.step_label.setWordWrap(True)
panel_layout.addWidget(self.step_label)
self.info_label = QtWidgets.QLabel("No STEP loaded")
self.info_label.setWordWrap(True)
panel_layout.addWidget(self.info_label)
panel_layout.addSpacing(8)
legend_label = QtWidgets.QLabel("Assign Label")
legend_label.setStyleSheet("font-weight: bold;")
panel_layout.addWidget(legend_label)
grid = QtWidgets.QGridLayout()
grid.setSpacing(6)
for idx, name in enumerate(self.class_names):
btn = QtWidgets.QPushButton(f"{idx}: {name}")
bg = CLASS_COLORS_HEX[idx]
fg = text_color_for_bg(bg)
btn.setStyleSheet(f"background-color: {bg}; color: {fg};")
btn.clicked.connect(lambda checked=False, i=idx: self.assign_label(i))
grid.addWidget(btn, idx, 0)
panel_layout.addLayout(grid)
panel_layout.addStretch(1)
self.setCentralWidget(central)
def update_step_label(self) -> None:
if self.step_path:
name = os.path.basename(self.step_path)
self.step_label.setText(f"STEP: {name}")
self.setWindowTitle(f"BRepMFR Face Labeler - {name}")
else:
self.step_label.setText("STEP: (none)")
self.setWindowTitle("BRepMFR Face Labeler")
def keyPressEvent(self, event) -> None: # pragma: no cover - UI only
if event.key() in (QtCore.Qt.Key_Right, QtCore.Qt.Key_D):
self.on_next()
return
if event.key() in (QtCore.Qt.Key_Left, QtCore.Qt.Key_A):
self.on_prev()
return
super().keyPressEvent(event)
def on_import_step(self) -> None:
path, _ = QtWidgets.QFileDialog.getOpenFileName(
self, "Open STEP", "", "STEP Files (*.stp *.step)"
)
if path:
self.load_step(path)
def on_export_seg(self) -> None:
if not self.labels:
QtWidgets.QMessageBox.warning(self, "Export", "No STEP loaded.")
return
if any(label is None for label in self.labels):
QtWidgets.QMessageBox.warning(
self,
"Export",
"Unlabeled faces remain. Label all faces before exporting.",
)
return
if self.output_dir:
if not self.step_path:
QtWidgets.QMessageBox.warning(self, "Export", "No STEP loaded.")
return
output_dir = pathlib.Path(self.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
filename = pathlib.Path(self.step_path).with_suffix(".seg").name
path = output_dir / filename
self.save_seg(str(path))
return
path, _ = QtWidgets.QFileDialog.getSaveFileName(
self, "Export .seg", "", "SEG Files (*.seg)"
)
if path:
self.save_seg(path)
def on_review(self) -> None:
if not self.labels:
QtWidgets.QMessageBox.information(self, "Review", "No STEP loaded.")
return
counts = [0 for _ in self.class_names]
unlabeled = []
for idx, label in enumerate(self.labels):
if label is None:
unlabeled.append(idx)
else:
counts[label] += 1
lines = [
f"Total faces: {len(self.labels)}",
f"Unlabeled: {len(unlabeled)}",
"",
]
for idx, name in enumerate(self.class_names):
lines.append(f"{idx} {name}: {counts[idx]}")
if unlabeled:
preview = ", ".join(str(i) for i in unlabeled[:20])
if len(unlabeled) > 20:
preview += ", ..."
lines.append("")
lines.append(f"Unlabeled indices: {preview}")
lines.append("")
lines.append("Jump to first unlabeled?")
res = QtWidgets.QMessageBox.question(
self,
"Review",
"\n".join(lines),
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
)
if res == QtWidgets.QMessageBox.Yes:
self.set_current_index(unlabeled[0])
else:
self.highlight_enabled = False
if self.current_index is not None:
self.set_face_color(
self.current_index, self.get_base_color(self.current_index)
)
QtWidgets.QMessageBox.information(self, "Review", "\n".join(lines))
def on_prev(self) -> None:
if self.current_index is None:
return
if self.current_index <= 0:
return
self.set_current_index(self.current_index - 1)
def on_next(self) -> None:
if self.current_index is None:
return
if self.current_index >= len(self.face_items) - 1:
return
self.set_current_index(self.current_index + 1)
def load_step(self, path: str) -> None:
reader = STEPControl_Reader()
status = reader.ReadFile(path)
if status != IFSelect_RetDone:
QtWidgets.QMessageBox.warning(
self, "Load STEP", f"Failed to read STEP file: {path}"
)
return
reader.TransferRoots()
shape = reader.OneShape()
self.step_path = path
self.update_step_label()
self.display.EraseAll()
self.face_items.clear()
self.labels.clear()
self.highlight_enabled = True
self.current_index = None
explorer = TopExp_Explorer(shape, TopAbs_FACE)
while explorer.More():
face = topods.Face(explorer.Current())
ais = self.display.DisplayShape(face, update=False, color=self.unlabeled_color)
if isinstance(ais, list):
ais = ais[0]
self._apply_face_material(ais)
try:
self.display.Context.SetDisplayMode(ais, 1, False)
except Exception:
pass
try:
self.display.Context.Redisplay(ais, False)
except Exception:
pass
self.face_items.append(FaceItem(face=face, ais=ais))
self.labels.append(None)
explorer.Next()
if not self.face_items:
QtWidgets.QMessageBox.warning(
self, "Load STEP", "No faces found in STEP file."
)
self.display.Repaint()
return
self.display.FitAll()
self.set_current_index(0)
self.display.Repaint()
def save_seg(self, path: str) -> None:
with open(path, "w", encoding="utf-8") as handle:
for label in self.labels:
handle.write(f"{label}\n")
QtWidgets.QMessageBox.information(self, "Export", f"Saved: {path}")
def assign_label(self, label_index: int) -> None:
if self.current_index is None:
return
self.labels[self.current_index] = label_index
self.apply_current_highlight(self.current_index)
self.update_info()
def update_info(self) -> None:
if self.current_index is None:
self.info_label.setText("No STEP loaded")
return
label = self.labels[self.current_index]
label_text = "Unlabeled" if label is None else f"{label}: {self.class_names[label]}"
self.info_label.setText(
f"Face {self.current_index + 1}/{len(self.face_items)}\n"
f"Label: {label_text}"
)
def get_base_color(self, index: int) -> Quantity_Color:
label = self.labels[index]
return self.unlabeled_color if label is None else self.class_colors[label]
def get_highlight_color(self, index: int) -> Quantity_Color:
return self.highlight_color
def set_current_index(self, index: int) -> None:
if not self.face_items:
return
index = max(0, min(index, len(self.face_items) - 1))
if self.current_index is not None:
self.set_face_color(self.current_index, self.get_base_color(self.current_index))
self.current_index = index
self.apply_current_highlight(self.current_index)
self.update_info()
def apply_current_highlight(self, index: int) -> None:
if self.highlight_enabled:
self.set_face_color(index, self.get_highlight_color(index))
else:
self.set_face_color(index, self.get_base_color(index))
def set_face_color(self, index: int, color: Quantity_Color) -> None:
ais = self.face_items[index].ais
if isinstance(ais, list):
for item in ais:
self._set_ais_color(item, color)
else:
self._set_ais_color(ais, color)
self.display.Repaint()
def _set_ais_color(self, ais, color: Quantity_Color) -> None:
try:
ais.SetColor(color)
except Exception:
self.display.Context.SetColor(ais, color, False)
self._apply_face_material(ais)
self.display.Context.Redisplay(ais, False)
def _apply_face_material(self, ais) -> None:
# Prefer emissive material to keep colors stable regardless of lighting.
applied = False
if Graphic3d_NOM_NEON is not None:
try:
ais.SetMaterial(Graphic3d_NOM_NEON)
applied = True
except Exception:
pass
if not applied and Graphic3d_NOM_MATTE is not None:
try:
ais.SetMaterial(Graphic3d_NOM_MATTE)
except Exception:
pass
self._apply_face_edges(ais)
def _configure_viewer_visuals(self) -> None:
# Try unlit shading to avoid view-dependent brightening.
if Graphic3d_TOSM_UNLIT is None:
return
try:
self.display.View.SetShadingModel(Graphic3d_TOSM_UNLIT)
except Exception:
pass
def _apply_face_edges(self, ais) -> None:
if Prs3d_LineAspect is None or Aspect_TOL_SOLID is None:
return
try:
drawer = ais.Attributes()
drawer.SetFaceBoundaryDraw(True)
line_aspect = Prs3d_LineAspect(
rgb01_to_quantity(hex_to_rgb01(EDGE_COLOR_HEX)),
Aspect_TOL_SOLID,
1.0,
)
drawer.SetFaceBoundaryAspect(line_aspect)
except Exception:
pass
def main() -> int:
parser = argparse.ArgumentParser(description="BRepMFR Face Labeler")
parser.add_argument("step_path", nargs="?", help="Optional STEP file to open")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output directory for .seg exports (auto-save with input filename)",
)
parser.add_argument(
"--output_folder",
type=str,
default=None,
help="Alias of --output_dir",
)
args = parser.parse_args()
if args.output_dir and args.output_folder and args.output_dir != args.output_folder:
raise SystemExit("--output_dir and --output_folder must match when both are provided")
app = QtWidgets.QApplication(sys.argv)
step_path = args.step_path if args.step_path and os.path.exists(args.step_path) else None
output_dir = args.output_folder or args.output_dir
window = FaceLabeler(step_path=step_path, output_dir=output_dir)
window.show()
return app.exec_()
if __name__ == "__main__":
raise SystemExit(main())