| """Convert SHARP PyTorch model to Core ML .mlmodel format. |
| |
| This script converts the SHARP (Sharp Monocular View Synthesis) model |
| from PyTorch (.pt) to Core ML (.mlmodel) format for deployment on Apple devices. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import coremltools as ct |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
|
|
| |
| from sharp.models import PredictorParams, create_predictor |
| from sharp.models.predictor import RGBGaussianPredictor |
| from sharp.utils import io |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
| DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" |
|
|
| |
| |
| |
|
|
| |
| OUTPUT_NAMES = [ |
| "mean_vectors_3d_positions", |
| "singular_values_scales", |
| "quaternions_rotations", |
| "colors_rgb_linear", |
| "opacities_alpha_channel", |
| ] |
|
|
| |
| OUTPUT_DESCRIPTIONS = { |
| "mean_vectors_3d_positions": ( |
| "3D positions of Gaussian splats in normalized device coordinates (NDC). " |
| "Shape: (1, N, 3), where N is the number of Gaussians." |
| ), |
| "singular_values_scales": ( |
| "Scale factors for each Gaussian along its principal axes. " |
| "Represents size and anisotropy. Shape: (1, N, 3)." |
| ), |
| "quaternions_rotations": ( |
| "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
| "Used to orient the ellipsoid. Shape: (1, N, 4)." |
| ), |
| "colors_rgb_linear": ( |
| "RGB color values in linear RGB space (not gamma-corrected). " |
| "Shape: (1, N, 3), with range [0, 1]." |
| ), |
| "opacities_alpha_channel": ( |
| "Opacity value per Gaussian (alpha channel), used for blending. " |
| "Shape: (1, N), where values are in [0, 1]." |
| ), |
| } |
|
|
|
|
| @dataclass |
| class ToleranceConfig: |
| """Tolerance configuration for validation.""" |
| |
| |
| random_tolerances: dict[str, float] = None |
| |
| |
| image_tolerances: dict[str, float] = None |
| |
| |
| angular_tolerances_random: dict[str, float] = None |
| angular_tolerances_image: dict[str, float] = None |
| |
| def __post_init__(self): |
| if self.random_tolerances is None: |
| self.random_tolerances = { |
| "mean_vectors_3d_positions": 0.001, |
| "singular_values_scales": 0.0001, |
| "quaternions_rotations": 2.0, |
| "colors_rgb_linear": 0.002, |
| "opacities_alpha_channel": 0.005, |
| } |
| |
| if self.image_tolerances is None: |
| self.image_tolerances = { |
| "mean_vectors_3d_positions": 3.5, |
| "singular_values_scales": 0.035, |
| "quaternions_rotations": 5.0, |
| "colors_rgb_linear": 0.01, |
| "opacities_alpha_channel": 0.05, |
| } |
| |
| if self.angular_tolerances_random is None: |
| self.angular_tolerances_random = { |
| "mean": 0.01, |
| "p99": 0.1, |
| "p99_9": 1.0, |
| "max": 5.0, |
| } |
| |
| if self.angular_tolerances_image is None: |
| self.angular_tolerances_image = { |
| "mean": 0.2, |
| "p99": 2.0, |
| "p99_9": 5.0, |
| "max": 25.0, |
| } |
|
|
|
|
| class SharpModelTraceable(nn.Module): |
| """Fully traceable version of SHARP for Core ML conversion. |
| |
| This version removes all dynamic control flow and makes the model |
| fully traceable with torch.jit.trace. |
| """ |
|
|
| def __init__(self, predictor: RGBGaussianPredictor): |
| """Initialize the traceable wrapper. |
| |
| Args: |
| predictor: The SHARP RGBGaussianPredictor model. |
| """ |
| super().__init__() |
| |
| self.init_model = predictor.init_model |
| self.feature_model = predictor.feature_model |
| self.monodepth_model = predictor.monodepth_model |
| self.prediction_head = predictor.prediction_head |
| self.gaussian_composer = predictor.gaussian_composer |
| self.depth_alignment = predictor.depth_alignment |
| |
| |
| self.last_global_scale = None |
| self.last_monodepth_min = None |
|
|
| def forward( |
| self, |
| image: torch.Tensor, |
| disparity_factor: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Run inference with traceable forward pass. |
| |
| Args: |
| image: Input image tensor of shape (1, 3, H, W) in range [0, 1]. |
| disparity_factor: Disparity factor tensor of shape (1,). |
| |
| Returns: |
| Tuple of 5 tensors representing 3D Gaussians. |
| """ |
| |
| monodepth_output = self.monodepth_model(image) |
| monodepth_disparity = monodepth_output.disparity |
|
|
| |
| |
| disparity_factor_expanded = disparity_factor[:, None, None, None] |
| |
| |
| disparity_clamped = monodepth_disparity.clamp(min=1e-4, max=1e4) |
| monodepth = disparity_factor_expanded / disparity_clamped |
|
|
| |
| monodepth, _ = self.depth_alignment(monodepth, None, monodepth_output.decoder_features) |
|
|
| |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
| self.last_monodepth_min = monodepth.flatten().min().item() |
|
|
| |
| init_output = self.init_model(image, monodepth) |
| |
| |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
| if init_output.global_scale is not None: |
| self.last_global_scale = init_output.global_scale.item() |
|
|
| |
| image_features = self.feature_model( |
| init_output.feature_input, |
| encodings=monodepth_output.output_features |
| ) |
|
|
| |
| delta_values = self.prediction_head(image_features) |
|
|
| |
| gaussians = self.gaussian_composer( |
| delta=delta_values, |
| base_values=init_output.gaussian_base_values, |
| global_scale=init_output.global_scale, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| quaternions = gaussians.quaternions |
|
|
| |
| |
| quat_norm_sq = torch.sum(quaternions * quaternions, dim=-1, keepdim=True) |
| quat_norm = torch.sqrt(torch.clamp(quat_norm_sq, min=1e-12)) |
| quaternions_normalized = quaternions / quat_norm |
|
|
| |
| |
| abs_quat = torch.abs(quaternions_normalized) |
| max_idx = torch.argmax(abs_quat, dim=-1, keepdim=True) |
|
|
| |
| one_hot = torch.zeros_like(quaternions_normalized) |
| one_hot.scatter_(-1, max_idx, 1.0) |
|
|
| |
| max_component_sign = torch.sum(quaternions_normalized * one_hot, dim=-1, keepdim=True) |
|
|
| |
| |
| quaternions = torch.where(max_component_sign < 0, -quaternions_normalized, quaternions_normalized).float() |
|
|
| return ( |
| gaussians.mean_vectors, |
| gaussians.singular_values, |
| quaternions, |
| gaussians.colors, |
| gaussians.opacities, |
| ) |
|
|
|
|
| def load_sharp_model(checkpoint_path: Path | None = None) -> RGBGaussianPredictor: |
| """Load SHARP model from checkpoint. |
| |
| Args: |
| checkpoint_path: Path to the .pt checkpoint file. |
| If None, downloads the default model. |
| |
| Returns: |
| The loaded RGBGaussianPredictor model in eval mode. |
| """ |
| if checkpoint_path is None: |
| LOGGER.info("Downloading default model from %s", DEFAULT_MODEL_URL) |
| state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) |
| else: |
| LOGGER.info("Loading checkpoint from %s", checkpoint_path) |
| state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu") |
|
|
| |
| predictor = create_predictor(PredictorParams()) |
| predictor.load_state_dict(state_dict) |
| predictor.eval() |
|
|
| return predictor |
|
|
|
|
| def convert_to_coreml( |
| predictor: RGBGaussianPredictor, |
| output_path: Path, |
| input_shape: tuple[int, int] = (1536, 1536), |
| compute_precision: ct.precision = ct.precision.FLOAT16, |
| compute_units: ct.ComputeUnit = ct.ComputeUnit.ALL, |
| minimum_deployment_target: ct.target | None = None, |
| ) -> ct.models.MLModel: |
| """Convert SHARP model to Core ML format. |
| |
| Args: |
| predictor: The SHARP RGBGaussianPredictor model. |
| output_path: Path to save the .mlmodel file. |
| input_shape: Input image shape (height, width). Default is (1536, 1536). |
| compute_precision: Precision for compute (FLOAT16 or FLOAT32). |
| compute_units: Target compute units (ALL, CPU_AND_GPU, CPU_ONLY, etc.). |
| minimum_deployment_target: Minimum iOS/macOS deployment target. |
| |
| Returns: |
| The converted Core ML model. |
| """ |
| LOGGER.info("Preparing model for Core ML conversion...") |
|
|
| |
| predictor.depth_alignment.scale_map_estimator = None |
|
|
| |
| model_wrapper = SharpModelTraceable(predictor) |
| model_wrapper.eval() |
|
|
| |
| LOGGER.info("Pre-warming model for better tracing...") |
| with torch.no_grad(): |
| for _ in range(3): |
| warm_image = torch.randn(1, 3, input_shape[0], input_shape[1]) |
| warm_disparity = torch.tensor([1.0]) |
| _ = model_wrapper(warm_image, warm_disparity) |
|
|
| |
| height, width = input_shape |
| torch.manual_seed(42) |
| example_image = torch.randn(1, 3, height, width) |
| example_disparity_factor = torch.tensor([1.0]) |
|
|
| LOGGER.info("Attempting torch.jit.script for better tracing...") |
| try: |
| with torch.no_grad(): |
| scripted_model = torch.jit.script(model_wrapper) |
| LOGGER.info("torch.jit.script succeeded, using scripted model") |
| traced_model = scripted_model |
| except Exception as e: |
| LOGGER.warning(f"torch.jit.script failed: {e}") |
| LOGGER.info("Falling back to torch.jit.trace...") |
| with torch.no_grad(): |
| traced_model = torch.jit.trace( |
| model_wrapper, |
| (example_image, example_disparity_factor), |
| strict=False, |
| check_trace=False, |
| ) |
|
|
| LOGGER.info("Converting traced model to Core ML...") |
|
|
| |
| inputs = [ |
| ct.TensorType( |
| name="image", |
| shape=(1, 3, height, width), |
| dtype=np.float32, |
| ), |
| ct.TensorType( |
| name="disparity_factor", |
| shape=(1,), |
| dtype=np.float32, |
| ), |
| ] |
|
|
| |
| output_names = [ |
| "mean_vectors_3d_positions", |
| "singular_values_scales", |
| "quaternions_rotations", |
| "colors_rgb_linear", |
| "opacities_alpha_channel", |
| ] |
|
|
| |
| outputs = [ |
| ct.TensorType(name=output_names[0], dtype=np.float32), |
| ct.TensorType(name=output_names[1], dtype=np.float32), |
| ct.TensorType(name=output_names[2], dtype=np.float32), |
| ct.TensorType(name=output_names[3], dtype=np.float32), |
| ct.TensorType(name=output_names[4], dtype=np.float32), |
| ] |
|
|
| |
| conversion_kwargs: dict[str, Any] = { |
| "inputs": inputs, |
| "outputs": outputs, |
| "convert_to": "mlprogram", |
| "compute_precision": compute_precision, |
| "compute_units": compute_units, |
| } |
|
|
| if minimum_deployment_target is not None: |
| conversion_kwargs["minimum_deployment_target"] = minimum_deployment_target |
|
|
| |
| mlmodel = ct.convert( |
| traced_model, |
| **conversion_kwargs, |
| ) |
|
|
| |
| mlmodel.author = "Apple Inc." |
| mlmodel.license = "See LICENSE_MODEL in ml-sharp repository" |
| mlmodel.short_description = ( |
| "SHARP: Sharp Monocular View Synthesis - Predicts 3D Gaussian splats from a single image" |
| ) |
| mlmodel.version = "1.0.0" |
|
|
| |
| spec = mlmodel.get_spec() |
|
|
| |
| input_descriptions = { |
| "image": "RGB image normalized to [0, 1], shape (1, 3, H, W)", |
| "disparity_factor": "Focal length / image width ratio, shape (1,)", |
| } |
|
|
| |
| output_descriptions = { |
| "mean_vectors_3d_positions": ( |
| "3D positions of Gaussian splats in normalized device coordinates (NDC). " |
| "Shape: (1, N, 3), where N is the number of Gaussians." |
| ), |
| "singular_values_scales": ( |
| "Scale factors for each Gaussian along its principal axes. " |
| "Represents size and anisotropy. Shape: (1, N, 3)." |
| ), |
| "quaternions_rotations": ( |
| "Rotation of each Gaussian as a unit quaternion [w, x, y, z]. " |
| "Used to orient the ellipsoid. Shape: (1, N, 4)." |
| ), |
| "colors_rgb_linear": ( |
| "RGB color values in linear RGB space (not gamma-corrected). " |
| "Shape: (1, N, 3), with range [0, 1]." |
| ), |
| "opacities_alpha_channel": ( |
| "Opacity value per Gaussian (alpha channel), used for blending. " |
| "Shape: (1, N), where values are in [0, 1]." |
| ), |
| } |
|
|
| |
| for i, name in enumerate(output_names): |
| if i < len(spec.description.output): |
| output = spec.description.output[i] |
| output.name = name |
| output.shortDescription = output_descriptions[name] |
|
|
| |
| LOGGER.info("Output names after update: %s", [o.name for o in spec.description.output]) |
|
|
| |
| LOGGER.info("Saving Core ML model to %s", output_path) |
| mlmodel.save(str(output_path)) |
|
|
| return mlmodel |
|
|
|
|
| class QuaternionValidator: |
| """Validator for quaternion comparisons with configurable tolerances and outlier analysis.""" |
|
|
| DEFAULT_ANGULAR_TOLERANCES = { |
| "mean": 0.01, |
| "p99": 0.5, |
| "p99_9": 2.0, |
| "max": 15.0, |
| } |
|
|
| def __init__( |
| self, |
| angular_tolerances: dict[str, float] | None = None, |
| enable_outlier_analysis: bool = True, |
| outlier_thresholds: list[float] | None = None, |
| ): |
| """Initialize validator with tolerances. |
| |
| Args: |
| angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees. |
| enable_outlier_analysis: Whether to perform detailed outlier analysis. |
| outlier_thresholds: List of angle thresholds for outlier counting. |
| """ |
| self.angular_tolerances = angular_tolerances or self.DEFAULT_ANGULAR_TOLERANCES.copy() |
| self.enable_outlier_analysis = enable_outlier_analysis |
| self.outlier_thresholds = outlier_thresholds or [5.0, 10.0, 15.0] |
|
|
| @staticmethod |
| def canonicalize_quaternion(q: np.ndarray) -> np.ndarray: |
| """Canonicalize quaternion to ensure consistent representation. |
| |
| Ensures the quaternion with the largest absolute component is positive. |
| This handles the sign ambiguity where q and -q represent the same rotation. |
| |
| Args: |
| q: Quaternion array of shape (..., 4) |
| |
| Returns: |
| Canonicalized quaternion array. |
| """ |
| abs_q = np.abs(q) |
| max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) |
| selector = np.zeros_like(q) |
| np.put_along_axis(selector, max_component_idx, 1.0, axis=-1) |
| max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) |
| return np.where(max_component_sign < 0, -q, q) |
|
|
| @staticmethod |
| def compute_angular_differences( |
| quats1: np.ndarray, quats2: np.ndarray |
| ) -> tuple[np.ndarray, dict[str, float]]: |
| """Compute angular differences between two sets of quaternions. |
| |
| Args: |
| quats1: First set of quaternions shape (N, 4) |
| quats2: Second set of quaternions shape (N, 4) |
| |
| Returns: |
| Tuple of (angular_differences in degrees, statistics dict) |
| """ |
| |
| norm1 = np.linalg.norm(quats1, axis=-1, keepdims=True) |
| norm2 = np.linalg.norm(quats2, axis=-1, keepdims=True) |
| quats1_norm = quats1 / np.clip(norm1, 1e-12, None) |
| quats2_norm = quats2 / np.clip(norm2, 1e-12, None) |
|
|
| |
| quats1_canon = QuaternionValidator.canonicalize_quaternion(quats1_norm) |
| quats2_canon = QuaternionValidator.canonicalize_quaternion(quats2_norm) |
|
|
| |
| dot_products = np.sum(quats1_canon * quats2_canon, axis=-1) |
| dot_products_flipped = np.sum(quats1_canon * (-quats2_canon), axis=-1) |
|
|
| |
| dot_products = np.maximum(np.abs(dot_products), np.abs(dot_products_flipped)) |
| dot_products = np.clip(dot_products, 0.0, 1.0) |
|
|
| |
| angular_diff_rad = 2.0 * np.arccos(dot_products) |
| angular_diff_deg = np.degrees(angular_diff_rad) |
|
|
| |
| stats = { |
| "mean": float(np.mean(angular_diff_deg)), |
| "std": float(np.std(angular_diff_deg)), |
| "min": float(np.min(angular_diff_deg)), |
| "max": float(np.max(angular_diff_deg)), |
| "p50": float(np.percentile(angular_diff_deg, 50)), |
| "p90": float(np.percentile(angular_diff_deg, 90)), |
| "p99": float(np.percentile(angular_diff_deg, 99)), |
| "p99_9": float(np.percentile(angular_diff_deg, 99.9)), |
| } |
|
|
| return angular_diff_deg, stats |
|
|
| def analyze_outliers( |
| self, angular_diff_deg: np.ndarray |
| ) -> dict[str, dict[str, int | float]]: |
| """Analyze outliers in angular differences. |
| |
| Args: |
| angular_diff_deg: Array of angular differences in degrees. |
| |
| Returns: |
| Dict with outlier statistics for each threshold. |
| """ |
| if not self.enable_outlier_analysis: |
| return {} |
|
|
| outlier_stats = {} |
| total = len(angular_diff_deg) |
|
|
| for threshold in self.outlier_thresholds: |
| count = int(np.sum(angular_diff_deg > threshold)) |
| outlier_stats[f">{threshold}°"] = { |
| "count": count, |
| "percentage": (count / total) * 100.0 if total > 0 else 0.0, |
| } |
|
|
| return outlier_stats |
|
|
| def validate( |
| self, |
| pt_quaternions: np.ndarray, |
| coreml_quaternions: np.ndarray, |
| image_name: str = "Unknown", |
| ) -> dict: |
| """Validate Core ML quaternions against PyTorch quaternions. |
| |
| Args: |
| pt_quaternions: PyTorch quaternion outputs. |
| coreml_quaternions: Core ML quaternion outputs. |
| image_name: Name of the image being validated. |
| |
| Returns: |
| Dict with validation results including status, stats, and outliers. |
| """ |
| angular_diff_deg, stats = self.compute_angular_differences( |
| pt_quaternions, coreml_quaternions |
| ) |
| outlier_stats = self.analyze_outliers(angular_diff_deg) |
|
|
| |
| passed = True |
| failure_reasons = [] |
|
|
| for key, tolerance in self.angular_tolerances.items(): |
| if key in stats and stats[key] > tolerance: |
| passed = False |
| failure_reasons.append( |
| f"{key} angular {stats[key]:.4f}° > tolerance {tolerance:.4f}°" |
| ) |
|
|
| return { |
| "image": image_name, |
| "passed": passed, |
| "failure_reasons": failure_reasons, |
| "stats": stats, |
| "outliers": outlier_stats, |
| "num_gaussians": len(angular_diff_deg), |
| } |
|
|
|
|
| def find_coreml_output_key(name: str, coreml_outputs: dict) -> str: |
| """Find matching Core ML output key for a given output name. |
| |
| Args: |
| name: The expected output name |
| coreml_outputs: Dictionary of Core ML outputs |
| |
| Returns: |
| The matching key from coreml_outputs |
| """ |
| if name in coreml_outputs: |
| return name |
| |
| |
| for key in coreml_outputs: |
| base_name = name.split('_')[0] |
| if base_name in key.lower(): |
| return key |
| |
| |
| output_index = OUTPUT_NAMES.index(name) if name in OUTPUT_NAMES else 0 |
| return list(coreml_outputs.keys())[output_index] |
|
|
|
|
| def run_inference_pair( |
| pytorch_model: RGBGaussianPredictor, |
| mlmodel: ct.models.MLModel, |
| image_tensor: torch.Tensor, |
| disparity_factor: float = 1.0, |
| log_internals: bool = False, |
| ) -> tuple[list[np.ndarray], dict[str, np.ndarray]]: |
| """Run inference on both PyTorch and Core ML models. |
| |
| Args: |
| pytorch_model: The PyTorch model |
| mlmodel: The Core ML model |
| image_tensor: Input image tensor |
| disparity_factor: Disparity factor value |
| log_internals: Whether to log internal values for debugging |
| |
| Returns: |
| Tuple of (pytorch_outputs, coreml_outputs) |
| """ |
| |
| traceable_wrapper = SharpModelTraceable(pytorch_model) |
| traceable_wrapper.eval() |
| |
| |
| image_tensor = image_tensor.float() |
| |
| test_disparity_pt = torch.tensor([disparity_factor], dtype=torch.float32) |
| with torch.no_grad(): |
| pt_outputs = traceable_wrapper(image_tensor, test_disparity_pt) |
| |
| |
| if log_internals: |
| if hasattr(traceable_wrapper, 'last_global_scale') and traceable_wrapper.last_global_scale is not None: |
| LOGGER.info(f"PyTorch global_scale: {traceable_wrapper.last_global_scale:.6f}") |
| if hasattr(traceable_wrapper, 'last_monodepth_min') and traceable_wrapper.last_monodepth_min is not None: |
| LOGGER.info(f"PyTorch monodepth_min: {traceable_wrapper.last_monodepth_min:.6f}") |
| |
| |
| pt_outputs_np = [o.numpy() for o in pt_outputs] |
| |
| |
| test_image_np = image_tensor.numpy() |
| test_disparity_np = np.array([disparity_factor], dtype=np.float32) |
| coreml_inputs = { |
| "image": test_image_np, |
| "disparity_factor": test_disparity_np, |
| } |
| coreml_outputs = mlmodel.predict(coreml_inputs) |
| |
| return pt_outputs_np, coreml_outputs |
|
|
|
|
| def compare_outputs( |
| pt_outputs: list[np.ndarray], |
| coreml_outputs: dict[str, np.ndarray], |
| tolerances: dict[str, float], |
| quat_validator: QuaternionValidator, |
| image_name: str = "Unknown", |
| ) -> list[dict]: |
| """Compare PyTorch and Core ML outputs. |
| |
| Args: |
| pt_outputs: List of PyTorch outputs |
| coreml_outputs: Dictionary of Core ML outputs |
| tolerances: Tolerance values per output type |
| quat_validator: QuaternionValidator instance |
| image_name: Name of the image being validated |
| |
| Returns: |
| List of validation result dictionaries |
| """ |
| validation_results = [] |
| |
| for i, name in enumerate(OUTPUT_NAMES): |
| pt_output = pt_outputs[i] |
| coreml_key = find_coreml_output_key(name, coreml_outputs) |
| coreml_output = coreml_outputs[coreml_key] |
| |
| result = {"output": name, "passed": True, "failure_reason": ""} |
| |
| if name == "quaternions_rotations": |
| |
| quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_name) |
| |
| result.update({ |
| "max_diff": f"{quat_result['stats']['max']:.6f}", |
| "mean_diff": f"{quat_result['stats']['mean']:.6f}", |
| "p99_diff": f"{quat_result['stats']['p99']:.6f}", |
| "passed": quat_result["passed"], |
| "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", |
| }) |
| else: |
| |
| diff = np.abs(pt_output - coreml_output) |
| output_tolerance = tolerances.get(name, 0.01) |
| max_diff = np.max(diff) |
| |
| result.update({ |
| "max_diff": f"{max_diff:.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| }) |
| |
| if max_diff > output_tolerance: |
| result["passed"] = False |
| result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}" |
| |
| validation_results.append(result) |
| |
| return validation_results |
|
|
|
|
| def format_validation_table( |
| validation_results: list[dict], |
| image_name: str, |
| include_image_column: bool = False, |
| ) -> str: |
| """Format validation results as a markdown table. |
| |
| Args: |
| validation_results: List of validation result dicts with keys: |
| output, max_diff, mean_diff, p99_diff, passed, etc. |
| image_name: Name of the image being validated. |
| include_image_column: Whether to include the image name as a column. |
| |
| Returns: |
| Formatted markdown table as a string. |
| """ |
| lines = [] |
| |
| if include_image_column: |
| lines.append("| Image | Output | Max Diff | Mean Diff | P99 Diff | Status |") |
| lines.append("|-------|--------|----------|-----------|----------|--------|") |
| |
| for result in validation_results: |
| output_name = result["output"].replace("_", " ").title() |
| status = "✅ PASS" if result["passed"] else "❌ FAIL" |
| lines.append( |
| f"| {image_name} | {output_name} | {result['max_diff']} | " |
| f"{result['mean_diff']} | {result['p99_diff']} | {status} |" |
| ) |
| else: |
| lines.append("| Output | Max Diff | Mean Diff | P99 Diff | Status |") |
| lines.append("|--------|----------|-----------|----------|--------|") |
| |
| for result in validation_results: |
| output_name = result["output"].replace("_", " ").title() |
| status = "✅ PASS" if result["passed"] else "❌ FAIL" |
| lines.append( |
| f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | " |
| f"{result['p99_diff']} | {status} |" |
| ) |
| |
| return "\n".join(lines) |
|
|
|
|
| def validate_coreml_model( |
| mlmodel: ct.models.MLModel, |
| pytorch_model: RGBGaussianPredictor, |
| input_shape: tuple[int, int] = (1536, 1536), |
| tolerance: float = 0.01, |
| angular_tolerances: dict[str, float] | None = None, |
| ) -> bool: |
| """Validate Core ML model outputs against PyTorch model. |
| |
| Args: |
| mlmodel: The Core ML model to validate. |
| pytorch_model: The original PyTorch model. |
| input_shape: Input image shape (height, width). |
| tolerance: Maximum allowed difference between outputs. |
| angular_tolerances: Dict with keys 'mean', 'p99', 'p99_9', 'max' for angular diffs in degrees. |
| |
| Returns: |
| True if validation passes, False otherwise. |
| """ |
| LOGGER.info("Validating Core ML model against PyTorch...") |
|
|
| height, width = input_shape |
|
|
| |
| np.random.seed(42) |
| torch.manual_seed(42) |
|
|
| |
| test_image_np = np.random.rand(1, 3, height, width).astype(np.float32) |
| test_disparity = np.array([1.0], dtype=np.float32) |
|
|
| |
| test_image_pt = torch.from_numpy(test_image_np) |
| test_disparity_pt = torch.from_numpy(test_disparity) |
|
|
| traceable_wrapper = SharpModelTraceable(pytorch_model) |
| traceable_wrapper.eval() |
|
|
| with torch.no_grad(): |
| pt_outputs = traceable_wrapper(test_image_pt, test_disparity_pt) |
|
|
| |
| coreml_inputs = { |
| "image": test_image_np, |
| "disparity_factor": test_disparity, |
| } |
| coreml_outputs = mlmodel.predict(coreml_inputs) |
|
|
| LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") |
| LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") |
|
|
| |
| output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
|
|
| |
| tolerances = { |
| "mean_vectors_3d_positions": 0.001, |
| "singular_values_scales": 0.0001, |
| "quaternions_rotations": 2.0, |
| "colors_rgb_linear": 0.002, |
| "opacities_alpha_channel": 0.005, |
| } |
|
|
| |
| if angular_tolerances is None: |
| angular_tolerances = { |
| "mean": 0.01, |
| "p99": 0.1, |
| "p99_9": 1.0, |
| "max": 5.0, |
| } |
|
|
| |
| quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances) |
|
|
| all_passed = True |
|
|
| |
| LOGGER.info("=== Depth/Position Statistics ===") |
| pt_positions = pt_outputs[0].numpy() |
| coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] |
| coreml_positions = coreml_outputs[coreml_key] |
|
|
| LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") |
| LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") |
|
|
| z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
| LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") |
| LOGGER.info("=================================") |
|
|
| |
| validation_results = [] |
|
|
| for i, name in enumerate(output_names): |
| pt_output = pt_outputs[i].numpy() |
|
|
| |
| coreml_key = None |
| if name in coreml_outputs: |
| coreml_key = name |
| else: |
| |
| for key in coreml_outputs: |
| base_name = name.split('_')[0] |
| if base_name in key.lower(): |
| coreml_key = key |
| break |
| if coreml_key is None: |
| coreml_key = list(coreml_outputs.keys())[i] |
|
|
| coreml_output = coreml_outputs[coreml_key] |
| result = {"output": name, "passed": True, "failure_reason": ""} |
|
|
| |
| if name == "quaternions_rotations": |
| |
| quat_result = quat_validator.validate(pt_output, coreml_output, image_name="Random") |
| |
| result.update({ |
| "max_diff": f"{quat_result['stats']['max']:.6f}", |
| "mean_diff": f"{quat_result['stats']['mean']:.6f}", |
| "p99_diff": f"{quat_result['stats']['p99']:.6f}", |
| "p99_9_diff": f"{quat_result['stats']['p99_9']:.6f}", |
| "max_angular": f"{quat_result['stats']['max']:.4f}", |
| "mean_angular": f"{quat_result['stats']['mean']:.4f}", |
| "p99_angular": f"{quat_result['stats']['p99']:.4f}", |
| "passed": quat_result["passed"], |
| "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", |
| "quat_stats": quat_result["stats"], |
| "outliers": quat_result["outliers"], |
| }) |
| if not quat_result["passed"]: |
| all_passed = False |
| else: |
| diff = np.abs(pt_output - coreml_output) |
| output_tolerance = tolerances.get(name, tolerance) |
| result.update({ |
| "max_diff": f"{np.max(diff):.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| "tolerance": f"{output_tolerance:.6f}" |
| }) |
| if np.max(diff) > output_tolerance: |
| result["passed"] = False |
| result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" |
| all_passed = False |
|
|
| validation_results.append(result) |
|
|
| |
| LOGGER.info("\n### Validation Results\n") |
| LOGGER.info("| Output | Max Diff | Mean Diff | P99 Diff | P99.9 Diff | Angular Diff (°) | Status |") |
| LOGGER.info("|--------|----------|-----------|----------|------------|------------------|--------|") |
|
|
| for result in validation_results: |
| output_name = result["output"].replace("_", " ").title() |
| if "max_angular" in result: |
| angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" |
| p99_9 = result.get("p99_9_diff", "-") |
| status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
| LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {p99_9} | {angular_info} | {status} |") |
| else: |
| status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
| LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | - | - | {status} |") |
| LOGGER.info("") |
|
|
| |
| for result in validation_results: |
| if "outliers" in result and result["outliers"]: |
| LOGGER.info("### Quaternion Outlier Analysis\n") |
| LOGGER.info(f"| Threshold | Count | Percentage |") |
| LOGGER.info("|-----------|-------|------------|") |
| for threshold, data in result["outliers"].items(): |
| LOGGER.info(f"| {threshold} | {data['count']} | {data['percentage']:.4f}% |") |
| LOGGER.info("") |
|
|
| return all_passed |
|
|
|
|
| def load_and_preprocess_image( |
| image_path: Path, |
| target_size: tuple[int, int] = (1536, 1536), |
| ) -> tuple[torch.Tensor, float, tuple[int, int]]: |
| """Load and preprocess an input image for SHARP inference. |
| |
| Args: |
| image_path: Path to the input image file. |
| target_size: Target (height, width) for resizing. |
| |
| Returns: |
| Tuple of (preprocessed image tensor, focal_length_px, original_size) |
| - Preprocessed image tensor of shape (1, 3, H, W) in range [0, 1] |
| - Focal length in pixels (from EXIF or default) |
| - Original image size (width, height) |
| """ |
| LOGGER.info(f"Loading image from {image_path}") |
| |
| |
| image_np, original_size, f_px = io.load_rgb(image_path) |
| LOGGER.info(f"Original image size: {original_size}, focal length: {f_px:.2f}px") |
| |
| |
| |
| image_tensor = torch.from_numpy(image_np).float() / 255.0 |
| image_tensor = image_tensor.permute(2, 0, 1) |
| original_height, original_width = image_np.shape[:2] |
| |
| |
| if (original_width, original_height) != (target_size[1], target_size[0]): |
| LOGGER.info(f"Resizing to {target_size[1]}x{target_size[0]}") |
| import torch.nn.functional as F |
| image_tensor = F.interpolate( |
| image_tensor.unsqueeze(0), |
| size=(target_size[0], target_size[1]), |
| mode="bilinear", |
| align_corners=True, |
| ).squeeze(0) |
| |
| |
| image_tensor = image_tensor.unsqueeze(0) |
| |
| LOGGER.info(f"Preprocessed image shape: {image_tensor.shape}, range: [{image_tensor.min():.4f}, {image_tensor.max():.4f}]") |
| |
| return image_tensor, f_px, (original_width, original_height) |
|
|
|
|
| def validate_with_image( |
| mlmodel: ct.models.MLModel, |
| pytorch_model: RGBGaussianPredictor, |
| image_path: Path, |
| input_shape: tuple[int, int] = (1536, 1536), |
| ) -> bool: |
| """Validate Core ML model outputs against PyTorch model using a real input image. |
| |
| Args: |
| mlmodel: The Core ML model to validate. |
| pytorch_model: The original PyTorch model. |
| image_path: Path to the input image file. |
| input_shape: Expected input image shape (height, width). |
| |
| Returns: |
| True if validation passes, False otherwise. |
| """ |
| LOGGER.info("=" * 60) |
| LOGGER.info("Validating Core ML model against PyTorch with real image") |
| LOGGER.info("=" * 60) |
| |
| |
| test_image = load_and_preprocess_image(image_path, input_shape) |
| test_disparity = np.array([1.0], dtype=np.float32) |
| |
| |
| traceable_wrapper = SharpModelTraceable(pytorch_model) |
| traceable_wrapper.eval() |
| |
| with torch.no_grad(): |
| pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity)) |
| |
| LOGGER.info(f"PyTorch outputs shapes: {[o.shape for o in pt_outputs]}") |
| |
| |
| test_image_np = test_image.numpy() |
| coreml_inputs = { |
| "image": test_image_np, |
| "disparity_factor": test_disparity, |
| } |
| coreml_outputs = mlmodel.predict(coreml_inputs) |
| |
| LOGGER.info(f"Core ML outputs keys: {list(coreml_outputs.keys())}") |
| |
| |
| output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
| |
| |
| |
| tolerances = { |
| "mean_vectors_3d_positions": 1.2, |
| "singular_values_scales": 0.01, |
| "quaternions_rotations": 5.0, |
| "colors_rgb_linear": 0.01, |
| "opacities_alpha_channel": 0.05, |
| } |
| |
| |
| angular_tolerances = { |
| "mean": 0.1, |
| "p99": 1.0, |
| "max": 15.0, |
| } |
| |
| all_passed = True |
| |
| |
| LOGGER.info(f"\n=== Input Image Statistics ===") |
| LOGGER.info(f"Image path: {image_path}") |
| LOGGER.info(f"Image shape: {test_image.shape}") |
| LOGGER.info(f"Image range: [{test_image.min():.4f}, {test_image.max():.4f}]") |
| LOGGER.info(f"Image mean: {test_image.mean(dim=[1,2,3]).tolist()}") |
| LOGGER.info("=" * 30) |
| |
| |
| pt_positions = pt_outputs[0].numpy() |
| coreml_key = [k for k in coreml_outputs.keys() if "mean_vectors" in k][0] |
| coreml_positions = coreml_outputs[coreml_key] |
| |
| LOGGER.info("\n=== Depth/Position Statistics ===") |
| LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}, std: {pt_positions[..., 2].std():.4f}") |
| LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}, std: {coreml_positions[..., 2].std():.4f}") |
| |
| z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
| LOGGER.info(f"Z-coordinate difference - max: {z_diff.max():.6f}, mean: {z_diff.mean():.6f}, std: {z_diff.std():.6f}") |
| LOGGER.info("=================================\n") |
| |
| |
| validation_results = [] |
| |
| for i, name in enumerate(output_names): |
| pt_output = pt_outputs[i].numpy() |
| |
| |
| coreml_key = None |
| if name in coreml_outputs: |
| coreml_key = name |
| else: |
| |
| for key in coreml_outputs: |
| base_name = name.split('_')[0] |
| if base_name in key.lower(): |
| coreml_key = key |
| break |
| if coreml_key is None: |
| coreml_key = list(coreml_outputs.keys())[i] |
| |
| coreml_output = coreml_outputs[coreml_key] |
| result = {"output": name, "passed": True, "failure_reason": ""} |
| |
| |
| if name == "quaternions_rotations": |
| pt_quat_norm = np.linalg.norm(pt_output, axis=-1, keepdims=True) |
| pt_output_normalized = pt_output / np.clip(pt_quat_norm, 1e-12, None) |
| |
| coreml_quat_norm = np.linalg.norm(coreml_output, axis=-1, keepdims=True) |
| coreml_output_normalized = coreml_output / np.clip(coreml_quat_norm, 1e-12, None) |
| |
| def canonicalize_quaternion(q): |
| abs_q = np.abs(q) |
| max_component_idx = np.argmax(abs_q, axis=-1, keepdims=True) |
| selector = np.zeros_like(q) |
| np.put_along_axis(selector, max_component_idx, 1, axis=-1) |
| max_component_sign = np.sum(q * selector, axis=-1, keepdims=True) |
| return np.where(max_component_sign < 0, -q, q) |
| |
| pt_output_canonical = canonicalize_quaternion(pt_output_normalized) |
| coreml_output_canonical = canonicalize_quaternion(coreml_output_normalized) |
| |
| diff = np.abs(pt_output_canonical - coreml_output_canonical) |
| dot_products = np.sum(pt_output_canonical * coreml_output_canonical, axis=-1) |
| dot_products_flipped = np.sum(pt_output_canonical * (-coreml_output_canonical), axis=-1) |
| |
| |
| dot_products = np.where( |
| np.abs(dot_products) > np.abs(dot_products_flipped), |
| np.abs(dot_products), |
| np.abs(dot_products_flipped) |
| ) |
| dot_products = np.clip(dot_products, 0.0, 1.0) |
| angular_diff_rad = 2 * np.arccos(dot_products) |
| angular_diff_deg = np.degrees(angular_diff_rad) |
| max_angular = np.max(angular_diff_deg) |
| mean_angular = np.mean(angular_diff_deg) |
| p99_angular = np.percentile(angular_diff_deg, 99) |
| |
| quat_passed = True |
| failure_reasons = [] |
| |
| if mean_angular > angular_tolerances["mean"]: |
| quat_passed = False |
| failure_reasons.append(f"mean angular {mean_angular:.4f}° > {angular_tolerances['mean']:.4f}°") |
| if p99_angular > angular_tolerances["p99"]: |
| quat_passed = False |
| failure_reasons.append(f"p99 angular {p99_angular:.4f}° > {angular_tolerances['p99']:.4f}°") |
| if max_angular > angular_tolerances["max"]: |
| quat_passed = False |
| failure_reasons.append(f"max angular {max_angular:.4f}° > {angular_tolerances['max']:.4f}°") |
| |
| result.update({ |
| "max_diff": f"{np.max(diff):.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| "max_angular": f"{max_angular:.4f}", |
| "mean_angular": f"{mean_angular:.4f}", |
| "p99_angular": f"{p99_angular:.4f}", |
| "passed": quat_passed, |
| "failure_reason": "; ".join(failure_reasons) if failure_reasons else "" |
| }) |
| if not quat_passed: |
| all_passed = False |
| else: |
| diff = np.abs(pt_output - coreml_output) |
| output_tolerance = tolerances.get(name, 0.01) |
| result.update({ |
| "max_diff": f"{np.max(diff):.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| "tolerance": f"{output_tolerance:.6f}" |
| }) |
| if np.max(diff) > output_tolerance: |
| result["passed"] = False |
| result["failure_reason"] = f"max diff {np.max(diff):.6f} > tolerance {output_tolerance:.6f}" |
| all_passed = False |
| |
| validation_results.append(result) |
| |
| |
| LOGGER.info("\n### Image Validation Results\n") |
| LOGGER.info(f"| Output | Max Diff | Mean Diff | P99 Diff | Angular Diff (°) | Status |") |
| LOGGER.info(f"|--------|----------|-----------|----------|------------------|--------|") |
| |
| for result in validation_results: |
| output_name = result["output"].replace("_", " ").title() |
| if "max_angular" in result: |
| angular_info = f"{result['max_angular']} / {result['mean_angular']} / {result['p99_angular']}" |
| else: |
| angular_info = "-" |
| status = "✅ PASS" if result["passed"] else f"❌ FAIL" |
| LOGGER.info(f"| {output_name} | {result['max_diff']} | {result['mean_diff']} | {result['p99_diff']} | {angular_info} | {status} |") |
| LOGGER.info("") |
| |
| return all_passed |
|
|
|
|
| def validate_with_image_set( |
| mlmodel: ct.models.MLModel, |
| pytorch_model: RGBGaussianPredictor, |
| image_paths: list[Path], |
| input_shape: tuple[int, int] = (1536, 1536), |
| ) -> bool: |
| """Validate Core ML model against PyTorch using multiple input images. |
| |
| Args: |
| mlmodel: The Core ML model to validate. |
| pytorch_model: The original PyTorch model. |
| image_paths: List of paths to input images for validation. |
| input_shape: Expected input image shape (height, width). |
| |
| Returns: |
| True if all validations pass, False otherwise. |
| """ |
| LOGGER.info("=" * 60) |
| LOGGER.info(f"Validating Core ML model with {len(image_paths)} images") |
| LOGGER.info("=" * 60) |
|
|
| |
| |
| angular_tolerances = { |
| "mean": 0.2, |
| "p99": 2.0, |
| "p99_9": 5.0, |
| "max": 25.0, |
| } |
|
|
| |
| quat_validator = QuaternionValidator(angular_tolerances=angular_tolerances) |
|
|
| all_passed = True |
| all_validation_results = [] |
|
|
| for image_path in image_paths: |
| if not image_path.exists(): |
| LOGGER.error(f"Input image not found: {image_path}") |
| all_passed = False |
| continue |
|
|
| LOGGER.info(f"\n--- Validating with {image_path.name} ---") |
|
|
| |
| image_results = validate_with_single_image_detailed( |
| mlmodel, pytorch_model, image_path, input_shape, quat_validator |
| ) |
| |
| |
| for result in image_results: |
| result["image"] = image_path.name |
| all_validation_results.append(result) |
| |
| |
| if not all(r["passed"] for r in image_results): |
| all_passed = False |
|
|
| |
| LOGGER.info("\n" + "=" * 60) |
| LOGGER.info("### Multi-Image Validation Summary") |
| LOGGER.info("=" * 60 + "\n") |
| |
| |
| if all_validation_results: |
| table = format_validation_table(all_validation_results, "", include_image_column=True) |
| LOGGER.info(table) |
| LOGGER.info("") |
|
|
| return all_passed |
|
|
|
|
| def validate_with_single_image_detailed( |
| mlmodel: ct.models.MLModel, |
| pytorch_model: RGBGaussianPredictor, |
| image_path: Path, |
| input_shape: tuple[int, int], |
| quat_validator: QuaternionValidator | None = None, |
| ) -> list[dict]: |
| """Validate with a single image and return detailed results. |
| |
| Args: |
| mlmodel: The Core ML model to validate. |
| pytorch_model: The original PyTorch model. |
| image_path: Path to the input image file. |
| input_shape: Expected input image shape. |
| quat_validator: Optional QuaternionValidator instance. |
| |
| Returns: |
| List of validation result dictionaries. |
| """ |
| |
| test_image, f_px, (orig_width, orig_height) = load_and_preprocess_image(image_path, input_shape) |
| |
| |
| disparity_factor = f_px / orig_width |
| LOGGER.info(f"Using disparity_factor = {disparity_factor:.6f} (f_px={f_px:.2f} / width={orig_width})") |
| |
| |
| pt_outputs, coreml_outputs = run_inference_pair( |
| pytorch_model, mlmodel, test_image, |
| disparity_factor=disparity_factor, |
| log_internals=True |
| ) |
| |
| |
| pt_positions = pt_outputs[0] |
| coreml_key = find_coreml_output_key("mean_vectors_3d_positions", coreml_outputs) |
| coreml_positions = coreml_outputs[coreml_key] |
| |
| |
| LOGGER.info(f"=== Depth/Position Statistics ({image_path.name}) ===") |
| LOGGER.info(f"PyTorch positions - Z range: [{pt_positions[..., 2].min():.4f}, {pt_positions[..., 2].max():.4f}], mean: {pt_positions[..., 2].mean():.4f}") |
| LOGGER.info(f"CoreML positions - Z range: [{coreml_positions[..., 2].min():.4f}, {coreml_positions[..., 2].max():.4f}], mean: {coreml_positions[..., 2].mean():.4f}") |
| |
| |
| pos_diff = np.abs(pt_positions - coreml_positions) |
| LOGGER.info(f"Position difference (X,Y,Z) - max: [{pos_diff[..., 0].max():.6f}, {pos_diff[..., 1].max():.6f}, {pos_diff[..., 2].max():.6f}]") |
| LOGGER.info(f"Position difference (X,Y,Z) - mean: [{pos_diff[..., 0].mean():.6f}, {pos_diff[..., 1].mean():.6f}, {pos_diff[..., 2].mean():.6f}]") |
| |
| |
| z_diff = np.abs(pt_positions[..., 2] - coreml_positions[..., 2]) |
| z_ratio = z_diff / np.clip(pt_positions[..., 2], 1e-6, None) |
| LOGGER.info(f"Z relative error - mean: {z_ratio.mean()*100:.4f}%, max: {z_ratio.max()*100:.4f}%") |
| |
| |
| pt_scales = pt_outputs[1] |
| coreml_scales_key = find_coreml_output_key("singular_values_scales", coreml_outputs) |
| coreml_scales = coreml_outputs[coreml_scales_key] |
| scales_diff = np.abs(pt_scales - coreml_scales) |
| scales_ratio = scales_diff / np.clip(pt_scales, 1e-6, None) |
| LOGGER.info(f"Scales relative error - mean: {scales_ratio.mean()*100:.4f}%, max: {scales_ratio.max()*100:.4f}%") |
| |
| |
| tolerance_config = ToleranceConfig() |
| tolerances = tolerance_config.image_tolerances |
| |
| |
| if quat_validator is None: |
| quat_validator = QuaternionValidator( |
| angular_tolerances=tolerance_config.angular_tolerances_image |
| ) |
| |
| |
| validation_results = compare_outputs( |
| pt_outputs, |
| coreml_outputs, |
| tolerances, |
| quat_validator, |
| image_name=image_path.name |
| ) |
| |
| return validation_results |
|
|
|
|
| def validate_with_single_image( |
| mlmodel: ct.models.MLModel, |
| pytorch_model: RGBGaussianPredictor, |
| image_path: Path, |
| input_shape: tuple[int, int], |
| quat_validator: QuaternionValidator | None = None, |
| ) -> bool: |
| """Validate with a single image using the new QuaternionValidator. |
| |
| Args: |
| mlmodel: The Core ML model to validate. |
| pytorch_model: The original PyTorch model. |
| image_path: Path to the input image file. |
| input_shape: Expected input image shape. |
| quat_validator: Optional QuaternionValidator instance. |
| |
| Returns: |
| True if validation passes, False otherwise. |
| """ |
| |
| test_image = load_and_preprocess_image(image_path, input_shape) |
| test_disparity = np.array([1.0], dtype=np.float32) |
|
|
| |
| traceable_wrapper = SharpModelTraceable(pytorch_model) |
| traceable_wrapper.eval() |
|
|
| with torch.no_grad(): |
| pt_outputs = traceable_wrapper(test_image, torch.from_numpy(test_disparity)) |
|
|
| |
| test_image_np = test_image.numpy() |
| coreml_inputs = { |
| "image": test_image_np, |
| "disparity_factor": test_disparity, |
| } |
| coreml_outputs = mlmodel.predict(coreml_inputs) |
|
|
| |
| output_names = ["mean_vectors_3d_positions", "singular_values_scales", "quaternions_rotations", "colors_rgb_linear", "opacities_alpha_channel"] |
|
|
| |
| tolerances = { |
| "mean_vectors_3d_positions": 1.2, |
| "singular_values_scales": 0.01, |
| "colors_rgb_linear": 0.01, |
| "opacities_alpha_channel": 0.05, |
| "quaternions_rotations": 5.0, |
| } |
|
|
| |
| if quat_validator is None: |
| quat_validator = QuaternionValidator() |
|
|
| |
| LOGGER.info(f"Image: {image_path.name}, shape: {test_image.shape}, range: [{test_image.min():.4f}, {test_image.max():.4f}]") |
|
|
| |
| all_passed = True |
| validation_results = [] |
|
|
| for i, name in enumerate(output_names): |
| pt_output = pt_outputs[i].numpy() |
|
|
| |
| coreml_key = None |
| if name in coreml_outputs: |
| coreml_key = name |
| else: |
| for key in coreml_outputs: |
| base_name = name.split('_')[0] |
| if base_name in key.lower(): |
| coreml_key = key |
| break |
| if coreml_key is None: |
| coreml_key = list(coreml_outputs.keys())[i] |
|
|
| coreml_output = coreml_outputs[coreml_key] |
| result = {"output": name, "passed": True, "failure_reason": ""} |
|
|
| if name == "quaternions_rotations": |
| |
| quat_result = quat_validator.validate(pt_output, coreml_output, image_name=image_path.name) |
|
|
| result.update({ |
| "max_diff": f"{quat_result['stats']['max']:.6f}", |
| "mean_diff": f"{quat_result['stats']['mean']:.6f}", |
| "p99_diff": f"{quat_result['stats']['p99']:.6f}", |
| "passed": quat_result["passed"], |
| "failure_reason": "; ".join(quat_result["failure_reasons"]) if quat_result["failure_reasons"] else "", |
| }) |
|
|
| if not quat_result["passed"]: |
| all_passed = False |
| else: |
| diff = np.abs(pt_output - coreml_output) |
| output_tolerance = tolerances.get(name, 0.01) |
| max_diff = np.max(diff) |
|
|
| result.update({ |
| "max_diff": f"{max_diff:.6f}", |
| "mean_diff": f"{np.mean(diff):.6f}", |
| "p99_diff": f"{np.percentile(diff, 99):.6f}", |
| }) |
|
|
| if max_diff > output_tolerance: |
| result["passed"] = False |
| result["failure_reason"] = f"max diff {max_diff:.6f} > tolerance {output_tolerance:.6f}" |
| all_passed = False |
|
|
| validation_results.append(result) |
|
|
| |
| LOGGER.info(f"\n### Validation Results: {image_path.name}\n") |
| table = format_validation_table(validation_results, image_path.name, include_image_column=False) |
| LOGGER.info(table) |
| LOGGER.info("") |
|
|
| return all_passed |
|
|
|
|
| def main(): |
| """Main conversion script.""" |
| parser = argparse.ArgumentParser( |
| description="Convert SHARP PyTorch model to Core ML format" |
| ) |
| parser.add_argument( |
| "-c", "--checkpoint", |
| type=Path, |
| default=None, |
| help="Path to PyTorch checkpoint. Downloads default if not provided.", |
| ) |
| parser.add_argument( |
| "-o", "--output", |
| type=Path, |
| default=Path("sharp.mlpackage"), |
| help="Output path for Core ML model (default: sharp.mlpackage)", |
| ) |
| parser.add_argument( |
| "--height", |
| type=int, |
| default=1536, |
| help="Input image height (default: 1536)", |
| ) |
| parser.add_argument( |
| "--width", |
| type=int, |
| default=1536, |
| help="Input image width (default: 1536)", |
| ) |
| parser.add_argument( |
| "--precision", |
| choices=["float16", "float32"], |
| default="float32", |
| help="Compute precision (default: float32)", |
| ) |
| parser.add_argument( |
| "--validate", |
| action="store_true", |
| help="Validate Core ML model against PyTorch", |
| ) |
| parser.add_argument( |
| "-v", "--verbose", |
| action="store_true", |
| help="Enable verbose logging", |
| ) |
| parser.add_argument( |
| "--input-image", |
| type=Path, |
| default=None, |
| action="append", |
| help="Path to input image for validation (can be specified multiple times, requires --validate)", |
| ) |
| parser.add_argument( |
| "--tolerance-mean", |
| type=float, |
| default=None, |
| help="Custom mean angular tolerance in degrees (default: 0.01 for random, 0.1 for images)", |
| ) |
| parser.add_argument( |
| "--tolerance-p99", |
| type=float, |
| default=None, |
| help="Custom P99 angular tolerance in degrees (default: 0.5 for random, 1.0 for images)", |
| ) |
| parser.add_argument( |
| "--tolerance-max", |
| type=float, |
| default=None, |
| help="Custom max angular tolerance in degrees (default: 15.0)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| logging.basicConfig( |
| level=logging.DEBUG if args.verbose else logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
|
|
| |
| LOGGER.info("Loading SHARP model...") |
| predictor = load_sharp_model(args.checkpoint) |
|
|
| |
| input_shape = (args.height, args.width) |
| precision = ct.precision.FLOAT16 if args.precision == "float16" else ct.precision.FLOAT32 |
|
|
| |
| LOGGER.info("Converting using direct tracing...") |
| mlmodel = convert_to_coreml( |
| predictor, |
| args.output, |
| input_shape=input_shape, |
| compute_precision=precision, |
| ) |
|
|
| LOGGER.info(f"Core ML model saved to {args.output}") |
|
|
| |
| if args.validate: |
| if args.input_image: |
| |
| validation_passed = validate_with_image_set(mlmodel, predictor, args.input_image, input_shape) |
| else: |
| |
| |
| angular_tolerances = None |
| if args.tolerance_mean or args.tolerance_p99 or args.tolerance_max: |
| angular_tolerances = { |
| "mean": args.tolerance_mean if args.tolerance_mean else 0.01, |
| "p99": args.tolerance_p99 if args.tolerance_p99 else 0.5, |
| "p99_9": 2.0, |
| "max": args.tolerance_max if args.tolerance_max else 15.0, |
| } |
| validation_passed = validate_coreml_model(mlmodel, predictor, input_shape, angular_tolerances=angular_tolerances) |
|
|
| if validation_passed: |
| LOGGER.info("✓ Validation passed!") |
| else: |
| LOGGER.error("✗ Validation failed!") |
| return 1 |
|
|
| LOGGER.info("Conversion complete!") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
| exit(main()) |
|
|