Instructions to use quantispect/QuantiSpect-V1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Ising Decoding
How to use quantispect/QuantiSpect-V1 with Ising Decoding:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import hydra, sys, torch, os, json, numpy as np | |
| from omegaconf import DictConfig, OmegaConf | |
| from training.train import main as train_main | |
| from model.factory import ModelFactory | |
| from data.factory import DatapipeFactory | |
| from hydra.utils import to_absolute_path | |
| from workflows.config_validator import ( | |
| apply_public_defaults_and_model, | |
| validate_public_config, | |
| ) | |
| from training.distributed import DistributedManager | |
| from torch.utils.data import DataLoader | |
| def _ensure_inference_io_channels(cfg): | |
| # 1) Ensure out_channels matches the model’s heads (4: z_data, x_data, syn_x, syn_z) | |
| if not getattr(cfg.model, "out_channels", None) or cfg.model.out_channels == 0: | |
| cfg.model.out_channels = 4 | |
| # 2) Infer input_channels from a single inference sample if not set | |
| if not getattr(cfg.model, "input_channels", None) or cfg.model.input_channels == 0: | |
| ds = DatapipeFactory.create_datapipe_inference(cfg) | |
| tmp = DataLoader(ds, batch_size=1) | |
| sample = next(iter(tmp)) | |
| cfg.model.input_channels = int(sample["trainX"].shape[1]) | |
| # 3) Keep num_filters consistent with out_channels | |
| if hasattr(cfg.model, "num_filters"): | |
| filters = list(cfg.model.num_filters) | |
| if filters and filters[-1] != cfg.model.out_channels: | |
| print( | |
| f"[run] Adjusting model.num_filters[-1] {filters[-1]} -> {cfg.model.out_channels}" | |
| ) | |
| filters[-1] = cfg.model.out_channels | |
| cfg.model.num_filters = filters | |
| def run(cfg: DictConfig) -> None: | |
| # Early-access public release: validate public surface, then merge in hidden defaults. | |
| # NOTE: Validation is done BEFORE merging defaults so we can fail fast on injected fields. | |
| model_spec = validate_public_config(cfg) | |
| cfg = apply_public_defaults_and_model(cfg, model_spec) | |
| torch.backends.cuda.matmul.allow_tf32 = cfg.enable_matmul_tf32 | |
| torch.backends.cudnn.allow_tf32 = cfg.enable_cudnn_tf32 | |
| if cfg.code == "surface" or cfg.code == "surface_partition": | |
| run_surface(cfg) | |
| def run_surface(cfg: DictConfig): | |
| if cfg.workflow.task == "train": | |
| train_main(cfg) | |
| elif cfg.workflow.task == "threshold": | |
| raise ValueError( | |
| "workflow.task='threshold' has been renamed to workflow.task='inference'. " | |
| "Please update your config/env var to WORKFLOW=inference." | |
| ) | |
| elif cfg.workflow.task == "inference": | |
| from evaluation.inference import run_inference | |
| DistributedManager.initialize() | |
| dist = DistributedManager() | |
| model = _load_model(cfg, dist) | |
| run_inference(model, dist.device, dist, cfg) | |
| elif cfg.workflow.task == "data": | |
| DistributedManager.initialize() | |
| dist = DistributedManager() | |
| train_loader, _ = DatapipeFactory.create_dataloader(cfg, dist.world_size, dist.rank) | |
| for j, dl in enumerate(train_loader): | |
| print(f"Batch {j}: syndrome_shape: {dl['syndrome'].shape}") | |
| elif cfg.workflow.task == "decoder_ablation": | |
| from evaluation.failure_analysis import decoder_ablation_study | |
| DistributedManager.initialize() | |
| dist = DistributedManager() | |
| model = _load_model(cfg, dist) | |
| decoder_ablation_study(model, dist.device, dist, cfg) | |
| elif cfg.workflow.task in ("sampling", "visualize"): | |
| raise ValueError( | |
| f"workflow.task={cfg.workflow.task!r} is not supported in the early-access public release. " | |
| "Supported workflows: train, inference, decoder_ablation." | |
| ) | |
| def find_best_model(path, *, rank: int = 0): | |
| if rank == 0: | |
| print(f"Searching for best model in: {path}") | |
| if not os.path.isdir(path): | |
| raise FileNotFoundError(f"Model directory does not exist: {path}") | |
| max_value = -1 # Start with -1 to include epoch 0 | |
| best_file = None | |
| model_files = [] | |
| # Named .pt files without epoch numbers (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt) | |
| named_pt_files = [] | |
| for filename in os.listdir(path): | |
| if not filename.endswith(".pt"): | |
| continue | |
| if filename.startswith("PreDecoderModelMemory_"): | |
| try: | |
| value = float(filename.split(".")[2]) # Gets epoch number | |
| model_files.append((filename, value)) | |
| if value > max_value: | |
| max_value = value | |
| best_file = filename | |
| except (IndexError, ValueError) as e: | |
| print(f"Warning: could not parse epoch from filename {filename}: {e}") | |
| else: | |
| named_pt_files.append(filename) | |
| # Fall back to named .pt files when no epoch-numbered checkpoints are present | |
| if best_file is None and named_pt_files: | |
| named_pt_files.sort() | |
| best_file = named_pt_files[-1] | |
| model_files = [(f, None) for f in named_pt_files] | |
| if rank == 0: | |
| print(f"Found {len(model_files)} model file(s):") | |
| for filename, epoch in sorted(model_files, key=lambda x: (x[1] is None, x[1] or 0)): | |
| marker = "*" if filename == best_file else " " | |
| epoch_str = str(epoch) if epoch is not None else "n/a" | |
| print(f" [{marker}] {filename} (epoch {epoch_str})") | |
| if best_file is None: | |
| raise FileNotFoundError( | |
| f"No valid model checkpoint files found in {path}\n" | |
| f"Expected .pt files (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt or " | |
| f"PreDecoderModelMemory_*.pt).\n" | |
| f"Hint: download the pretrained weights and place them in this directory, " | |
| f"or set model_checkpoint_file in your config to an explicit path." | |
| ) | |
| best_model_path = os.path.join(path, best_file) | |
| if rank == 0: | |
| epoch_str = str(max_value) if max_value >= 0 else "n/a" | |
| print(f"Selected best model: {best_file} (epoch {epoch_str})") | |
| return best_model_path | |
| def _resolve_dir(path: str) -> str: | |
| """Return an absolute version of path, resolving relative paths from the repo root.""" | |
| if os.path.isabs(path): | |
| return path | |
| repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| return os.path.join(repo_root, path) | |
| def _load_state_dict_from_pt(model_path: str, device) -> dict: | |
| """Load a state dict from a .pt checkpoint, handling multiple saved formats. | |
| Supports: | |
| - bare state dict (keys are layer names) | |
| - {"model_state_dict": ...} | |
| - {"state_dict": ...} | |
| Also strips the DDP "module." prefix if present. | |
| """ | |
| raw = torch.load(model_path, map_location=device, weights_only=False) | |
| if isinstance(raw, dict): | |
| if "model_state_dict" in raw: | |
| state_dict = raw["model_state_dict"] | |
| elif "state_dict" in raw: | |
| state_dict = raw["state_dict"] | |
| else: | |
| state_dict = raw | |
| else: | |
| raise ValueError(f"Unexpected checkpoint format: expected a dict, got {type(raw).__name__}") | |
| return { | |
| (k[len("module."):] if k.startswith("module.") else k): v for k, v in state_dict.items() | |
| } | |
| def _load_model(cfg, dist): | |
| if dist.rank == 0: | |
| print(f"Loading model for task: {cfg.workflow.task}") | |
| _ensure_inference_io_channels(cfg) | |
| # SafeTensors path: load fp16/fp32 model from SafeTensors file | |
| safetensors_path = os.environ.get("PREDECODER_SAFETENSORS_CHECKPOINT", "").strip() | |
| if safetensors_path: | |
| from export.safetensors_utils import load_safetensors | |
| if dist.rank == 0: | |
| print(f"Loading model from SafeTensors: {safetensors_path}") | |
| # Auto-detect model_id from SafeTensors metadata (don't override with config) | |
| model, metadata = load_safetensors( | |
| safetensors_path, | |
| model_id=None, | |
| device=str(dist.device), | |
| ) | |
| if dist.rank == 0: | |
| loaded_model_id = metadata.get("model_id", "unknown") | |
| dtype = metadata.get("quant_format", "fp32") | |
| receptive_field = metadata.get("receptive_field", "unknown") | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| print(f" model_id: {loaded_model_id} (from SafeTensors metadata)") | |
| print(f" receptive_field: {receptive_field}") | |
| print(f" dtype: {dtype}") | |
| print(f" parameters: {param_count:,}") | |
| # Warn if config model_id doesn't match file metadata | |
| config_model_id = getattr(cfg, "model_id", None) | |
| if config_model_id is not None and str(config_model_id) != str(loaded_model_id): | |
| print( | |
| f" Warning: config model_id={config_model_id} differs from " | |
| f"file model_id={loaded_model_id}; using {loaded_model_id}" | |
| ) | |
| if metadata.get("quant_format") == "fp16": | |
| cfg.enable_fp16 = True | |
| return model | |
| # Direct file path override (for named pretrained models without epoch numbers) | |
| model_checkpoint_file = getattr(cfg, 'model_checkpoint_file', None) | |
| if model_checkpoint_file: | |
| model_checkpoint_file = _resolve_dir(str(model_checkpoint_file)) | |
| if not os.path.exists(model_checkpoint_file): | |
| raise FileNotFoundError(f"Checkpoint not found: {model_checkpoint_file}") | |
| if dist.rank == 0: | |
| print(f"Loading model from: {model_checkpoint_file}") | |
| model = ModelFactory.create_model(cfg).to(dist.device) | |
| if cfg.enable_fp16: | |
| model = model.half() | |
| state_dict = _load_state_dict_from_pt(model_checkpoint_file, dist.device) | |
| model.load_state_dict(state_dict) | |
| if dist.rank == 0: | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| print(f"Model loaded ({param_count:,} parameters)") | |
| return model | |
| model = ModelFactory.create_model(cfg).to(dist.device) | |
| if cfg.enable_fp16: | |
| model = model.half() | |
| if dist.rank == 0: | |
| print("Model converted to float16 for fp16 inference") | |
| # Determine model directory | |
| # Priority: 1) model_checkpoint_dir (for inference configs) | |
| # 2) cfg.output/models (for training configs) | |
| model_checkpoint_dir = getattr(cfg, 'model_checkpoint_dir', None) | |
| use_checkpoint = getattr(cfg.test, 'use_model_checkpoint', -1) | |
| if use_checkpoint == -1: | |
| model_dir = _resolve_dir( | |
| os.path.join(model_checkpoint_dir, "best_model") | |
| if model_checkpoint_dir else f"{cfg.output}/models/best_model" | |
| ) | |
| if dist.rank == 0: | |
| print(f"Loading best model from: {model_dir}") | |
| # Fallback: older runs may not have a best_model/ folder | |
| if not os.path.isdir(model_dir): | |
| fallback_dir = _resolve_dir( | |
| model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models" | |
| ) | |
| if dist.rank == 0: | |
| print(f"best_model/ not found; falling back to: {fallback_dir}") | |
| model_dir = fallback_dir | |
| model_path = find_best_model(model_dir, rank=dist.rank) | |
| else: | |
| checkpoint_dir = _resolve_dir( | |
| model_checkpoint_dir if model_checkpoint_dir else f"{cfg.output}/models" | |
| ) | |
| if dist.rank == 0: | |
| print(f"Loading checkpoint {use_checkpoint} from: {checkpoint_dir}") | |
| # Prefer any PreDecoderModelMemory_* file ending with .0.{use_checkpoint}.pt | |
| target_suffix = f".0.{use_checkpoint}.pt" | |
| checkpoint_filename = None | |
| try: | |
| for f in os.listdir(checkpoint_dir): | |
| if f.startswith("PreDecoderModelMemory_") and f.endswith(target_suffix): | |
| checkpoint_filename = f | |
| break | |
| except OSError: | |
| pass | |
| if checkpoint_filename is None: | |
| checkpoint_filename = f"PreDecoderModelMemory_v1.0.{use_checkpoint}.pt" | |
| model_path = os.path.join(checkpoint_dir, checkpoint_filename) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {model_path}") | |
| if dist.rank == 0: | |
| print(f"Loading model parameters from: {model_path}") | |
| state_dict = _load_state_dict_from_pt(model_path, dist.device) | |
| model.load_state_dict(state_dict) | |
| if dist.rank == 0: | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| print(f"Model loaded ({param_count:,} parameters)") | |
| return model | |
| if __name__ == "__main__": | |
| run() | |