import os import json import numpy as np from PIL import Image from typing import List from tqdm import tqdm, trange os.environ['SPCONV_ALGO'] = 'native' import torch from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure from trellis.pipelines.base import Pipeline from trellis.pipelines import TrellisImageTo3DPipeline from trellis.models import SparseStructureFlowModel, SparseStructureEncoder, SparseStructureDecoder from trellis.modules.sparse.basic import sparse_cat, sparse_unbind, SparseTensor from trellis.utils import render_utils from trellis.representations.mesh import MeshExtractResult from trellis.representations.mesh.utils_cube import sparse_cube2verts from huggingface_hub import hf_hub_download from safetensors.torch import load_file from utils import * class Extend3D(Pipeline): # ----------------------------------------------------------------------- # Construction # ----------------------------------------------------------------------- def __init__(self, ckpt_path: str, device: str = 'cpu'): super().__init__() # Load the base Trellis pipeline self.pipeline = TrellisImageTo3DPipeline.from_pretrained(ckpt_path) self.pipeline.to(device) self.models = self.pipeline.models # Replace the sparse-structure encoder with a higher-capacity checkpoint config_path = hf_hub_download(repo_id=ckpt_path, filename='ckpts/ss_enc_conv3d_16l8_fp16.json') model_path = hf_hub_download(repo_id=ckpt_path, filename='ckpts/ss_enc_conv3d_16l8_fp16.safetensors') with open(config_path, 'r') as f: model_config = json.load(f) state_dict = load_file(model_path) encoder = SparseStructureEncoder(**model_config['args']) encoder.load_state_dict(state_dict) self.models['sparse_structure_encoder'] = encoder.to(device) # Perceptual metrics used for SLAT optimization loss (frozen, no gradients needed) self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True, net_type='squeeze').to(device) self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) self.lpips.requires_grad_(False) self.ssim.requires_grad_(False) # SLAT normalization constants (frozen; gradients must not flow through them) self.std = torch.tensor(self.pipeline.slat_normalization['std'])[None].to(device) self.mean = torch.tensor(self.pipeline.slat_normalization['mean'])[None].to(device) self.std.requires_grad_(False) self.mean.requires_grad_(False) # ----------------------------------------------------------------------- # Device management # ----------------------------------------------------------------------- def to(self, device) -> "Extend3D": self.pipeline.to(device) self.models['sparse_structure_encoder'] = self.models['sparse_structure_encoder'].to(device) self.lpips = self.lpips.to(device) self.ssim = self.ssim.to(device) self.std = self.std.to(device) self.mean = self.mean.to(device) return self def cuda(self) -> "Extend3D": return self.to(torch.device('cuda')) def cpu(self) -> "Extend3D": return self.to(torch.device('cpu')) @staticmethod def from_pretrained(ckpt_path: str, device: str = 'cpu') -> "Extend3D": return Extend3D(ckpt_path, device=device) # ----------------------------------------------------------------------- # Preprocessing # ----------------------------------------------------------------------- @staticmethod def preprocess(image: Image.Image) -> Image.Image: return image.resize((1024, 1024), Image.Resampling.LANCZOS) # ----------------------------------------------------------------------- # Conditioning # ----------------------------------------------------------------------- @torch.no_grad() def get_cond( self, image: Image.Image, pointmap_info: PointmapInfo = None, width: int = 2, length: int = 2, div: int = 2, ) -> List[List[dict]]: """Compute per-patch image conditioning for the flow model.""" if pointmap_info is None: pointmap_info = PointmapInfo(image, device=self.device) patches = pointmap_info.divide_image(width, length, div) return [ [self.pipeline.get_cond([self.preprocess(patch)]) for patch in row] for row in patches ] # ----------------------------------------------------------------------- # Stage 1: Sparse structure sampling # ----------------------------------------------------------------------- def sample_sparse_structure( self, image: Image.Image, pointmap_info: PointmapInfo = None, optim: bool = True, width: int = 2, length: int = 2, div: int = 2, iterations: int = 3, steps: int = 25, rescale_t: float = 3.0, t_noise: float = 0.6, t_start: float = 0.8, cfg_strength: float = 7.5, alpha: float = 5.0, batch_size: int = 1, progress_callback=None, ) -> torch.Tensor: """ Sample occupied voxel coordinates via iterative flow-matching. Returns: coords: int32 tensor of shape [N, 4] (batch, y, x, z). """ if pointmap_info is None: pointmap_info = PointmapInfo(image, device=self.device) flow_model: SparseStructureFlowModel = self.models['sparse_structure_flow_model'] encoder: SparseStructureEncoder = self.models['sparse_structure_encoder'] decoder: SparseStructureDecoder = self.models['sparse_structure_decoder'] sampler = self.pipeline.sparse_structure_sampler cfg_interval = self.pipeline.sparse_structure_sampler_params['cfg_interval'] for p in decoder.parameters(): p.requires_grad_(False) sigma_min = sampler.sigma_min reso = flow_model.resolution # Build point cloud from the pointmap info pc = torch.tensor(pointmap_info.point_cloud(), dtype=torch.float32) pc[:, 2] *= max(width, length) # Encode initial voxel from the point cloud voxel = pointcloud_to_voxel(pc, (4 * reso * length, 4 * reso * width, 4 * reso)) voxel = voxel.permute(0, 1, 3, 2, 4).float().to(self.device) encoded_voxel = encoder(voxel) pc = pc.to(self.device) _, t_pairs = schedule(steps, rescale_t, start=t_start) views = get_views(width, length, reso, div) # Latent tensor and accumulation buffers latent = torch.randn(1, flow_model.in_channels, reso * width, reso * length, reso, device=self.device) count = torch.zeros_like(latent) value = torch.zeros_like(latent) global_cond = self.get_cond(image, pointmap_info, 1, 1, 1)[0][0] cond = self.get_cond(image, pointmap_info, width, length, div) total_steps = iterations * len(t_pairs) global_step = 0 iter_range = trange(iterations, position=0) if progress_callback is None else range(iterations) for it in iter_range: # Noise the latent to t_noise at the start of each iteration latent = diffuse(encoded_voxel, torch.tensor(t_noise, device=self.device), sigma_min) latent = latent.detach() step_iter = (tqdm(t_pairs, desc="Sparse Structure Sampling", position=1) if progress_callback is None else t_pairs) for t, t_prev in step_iter: cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (1 - torch.tensor(t)))) c = cosine_factor ** alpha with torch.no_grad(): # --- 1. Overlapping patch-wise flow --- count.zero_() value.zero_() local_latents, patch_conds, patch_neg_conds, patch_views = [], [], [], [] for view in views: i, j, y0, y1, x0, x1 = view patch_views.append(view) local_latents.append(latent[:, :, y0:y1, x0:x1, :].contiguous()) patch_cond = cond[i][j] patch_conds.append(patch_cond['cond']) patch_neg_conds.append(patch_cond['neg_cond']) for start in range(0, len(local_latents), batch_size): end = min(start + batch_size, len(local_latents)) out = sampler.sample_once( flow_model, torch.cat(local_latents[start:end], dim=0), t, t_prev, cond=torch.cat(patch_conds[start:end], dim=0), neg_cond=torch.cat(patch_neg_conds[start:end], dim=0), cfg_strength=cfg_strength, cfg_interval=cfg_interval, ) for view, pred_v in zip(patch_views[start:end], out.pred_v): _, _, y0, y1, x0, x1 = view count[:, :, y0:y1, x0:x1, :] += 1 value[:, :, y0:y1, x0:x1, :] += pred_v local_pred_v = torch.where(count > 0, value / count, latent) # --- 2. Dilated sampling (global structure) --- count.zero_() value.zero_() dilated_samples = dilated_sampling(reso, width, length) dilated_latents = [] dilated_conds = [] dilated_neg_conds = [] for sample in dilated_samples: sample_latent = (latent[:, :, sample[:, 0], sample[:, 1], :] .view(1, flow_model.in_channels, reso, reso, reso)) dilated_latents.append(sample_latent) dilated_conds.append(global_cond['cond']) dilated_neg_conds.append(global_cond['neg_cond']) for start in range(0, len(dilated_latents), batch_size): end = min(start + batch_size, len(dilated_latents)) out = sampler.sample_once( flow_model, torch.cat(dilated_latents[start:end], dim=0), t, t_prev, cond=torch.cat(dilated_conds[start:end], dim=0), neg_cond=torch.cat(dilated_neg_conds[start:end], dim=0), cfg_strength=cfg_strength, cfg_interval=cfg_interval, ) for sample, pred_v in zip(dilated_samples[start:end], out.pred_v): count[:, :, sample[:, 0], sample[:, 1], :] += 1 value[:, :, sample[:, 0], sample[:, 1], :] += pred_v.view( 1, flow_model.in_channels, reso * reso, reso ) global_pred_v = torch.where(count > 0, value / count, latent) # Blend local and global velocity predictions v = local_pred_v * (1 - c) + global_pred_v * c v = v.detach() # Enable grad so that Adam can optimize v as a leaf variable v.requires_grad_() v.retain_grad() optimizer = torch.optim.Adam([v], lr=0.1) if optim and t < 0.7: for _ in range(20): optimizer.zero_grad() pred_latent = (1 - sigma_min) * latent - (sigma_min + (1 - sigma_min) * t) * v decoded_latent = decoder(pred_latent) loss = sparse_structure_loss(pc, decoded_latent.permute(0, 1, 3, 2, 4)) loss.backward() optimizer.step() # Euler step latent = (latent - (t - t_prev) * v).detach() if progress_callback is not None: global_step += 1 progress_callback( global_step / total_steps, f"Sparse Structure: iter {it + 1}/{iterations}, step {global_step}/{total_steps}", ) # Re-encode the decoded voxel for the next iteration voxel = (decoder(latent) > 0).float() encoded_voxel = encoder(voxel) coords = torch.argwhere(decoder(latent) > 0)[:, [0, 2, 3, 4]].int() return coords # ----------------------------------------------------------------------- # Stage 2: Structured latent (SLAT) sampling # ----------------------------------------------------------------------- def sample_slat( self, image: Image.Image, coords: torch.Tensor, pointmap_info: PointmapInfo = None, optim: bool = True, width: int = 2, length: int = 2, div: int = 2, steps: int = 25, rescale_t: float = 3.0, cfg_strength: float = 3.0, batch_size: int = 1, progress_callback=None, ) -> SparseTensor: """ Sample per-voxel latent features (SLAT) via flow-matching. Returns: slat: SparseTensor with denormalized latent features. """ if pointmap_info is None: pointmap_info = PointmapInfo(image, device=self.device) # Prepare reference image tensor for perceptual optimization loss resized_image = image.resize((512, 512)) tensor_image = (torch.from_numpy(np.array(resized_image)) .permute(2, 0, 1).float() / 255.0).to(self.device) intrinsic = torch.tensor(pointmap_info.camera_intrinsic(), dtype=torch.float32).to(self.device) extrinsic = torch.tensor(pointmap_info.camera_extrinsic(), dtype=torch.float32).to(self.device) flow_model = self.models['slat_flow_model'] sampler = self.pipeline.slat_sampler cfg_interval = self.pipeline.slat_sampler_params['cfg_interval'] cond = self.get_cond(image, pointmap_info, width, length, div) sigma_min = sampler.sigma_min reso = flow_model.resolution latent_feats = torch.randn(coords.shape[0], flow_model.in_channels, device=self.device) # Pre-compute where each voxel coordinate falls in the overlapping patch grid views = get_views(width, length, reso, div) valid_views = [] patch_indices = [] for i, j, y0, y1, x0, x1 in views: idx = torch.where( (coords[:, 1] >= y0) & (coords[:, 1] < y1) & (coords[:, 2] >= x0) & (coords[:, 2] < x1) )[0] if len(idx) > 0: valid_views.append((i, j, y0, y1, x0, x1)) patch_indices.append(idx) count = torch.zeros(coords.shape[0], flow_model.in_channels, device=self.device) value = torch.zeros(coords.shape[0], flow_model.in_channels, device=self.device) _, t_pairs = schedule(steps, rescale_t) total_steps = len(t_pairs) step_iter = (tqdm(t_pairs, desc="Structured Latent Sampling") if progress_callback is None else t_pairs) for slat_step, (t, t_prev) in enumerate(step_iter, start=1): with torch.no_grad(): count.zero_() value.zero_() patch_latents = [] patch_conds = [] for view, patch_index in zip(valid_views, patch_indices): i, j, y0, y1, x0, x1 = view patch_conds.append(cond[i][j]) patch_coords_local = coords[patch_index].clone() patch_coords_local[:, 1] -= y0 patch_coords_local[:, 2] -= x0 patch_latents.append(SparseTensor( feats=latent_feats[patch_index].contiguous(), coords=patch_coords_local, )) for start in range(0, len(patch_latents), batch_size): end = min(start + batch_size, len(patch_latents)) conds_chunk = patch_conds[start:end] batched_cond = { k: torch.cat([d[k] for d in conds_chunk], dim=0) for k in conds_chunk[0].keys() } outs = sampler.sample_once( flow_model, sparse_cat(patch_latents[start:end]), t, t_prev, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **batched_cond, ) for out, pidx in zip(sparse_unbind(outs.pred_v, dim=0), patch_indices[start:end]): count[pidx, :] += 1 value[pidx, :] += out.feats v_feats = torch.where(count > 0, value / count, latent_feats).detach() # Enable grad for leaf-variable optimization v_feats.requires_grad_() optimizer = torch.optim.Adam([v_feats], lr=0.3) if optim and t < 0.8: for _ in range(20): optimizer.zero_grad() pred_feats = (1 - sigma_min) * latent_feats - (sigma_min + (1 - sigma_min) * t) * v_feats pred_slat = SparseTensor(feats=pred_feats, coords=coords) * self.std + self.mean rendered = render_utils.render_frames_torch( self.decode_slat(pred_slat, width, length, formats=['gaussian'])['gaussian'][0], [extrinsic], [intrinsic], {'resolution': 512, 'bg_color': (0, 0, 0)}, verbose=False, )['color'][0].permute(2, 1, 0) loss = (self.lpips(rendered.unsqueeze(0), tensor_image.unsqueeze(0)) - self.ssim(rendered.unsqueeze(0), tensor_image.unsqueeze(0))) loss.backward() optimizer.step() # Euler step; detach to free the computation graph latent_feats = (latent_feats - (t - t_prev) * v_feats).detach() if progress_callback is not None: progress_callback(slat_step / total_steps, f"SLAT Sampling: step {slat_step}/{total_steps}") slat = SparseTensor(feats=latent_feats, coords=coords) return slat * self.std + self.mean # ----------------------------------------------------------------------- # Stage 3: Decode SLAT → Gaussians and/or mesh # ----------------------------------------------------------------------- def decode_slat( self, slat: SparseTensor, width: int, length: int, formats: list[str] = ['gaussian', 'mesh'], ) -> dict: """Decode a structured latent into Gaussian splats and/or a triangle mesh.""" ret = {} feats = slat.feats coords = slat.coords reso = self.models['slat_flow_model'].resolution scale = max(width, length) # ------------------------------------------------------------------- # Mesh decoding # ------------------------------------------------------------------- if 'mesh' in formats: mesh_decoder = self.pipeline.models['slat_decoder_mesh'] sf2m = mesh_decoder.mesh_extractor # SparseFeatures2Mesh # Global high-res grid dimensions (4× upsampling from SLAT resolution) up_res = mesh_decoder.resolution * 4 res_y, res_x, res_z = width * up_res, length * up_res, up_res # Accumulate high-res sparse features across overlapping patches with cosine blending C = sf2m.feats_channels global_sum = torch.zeros(res_y, res_x, res_z, C, device=self.device) global_count = torch.zeros(res_y, res_x, res_z, 1, device=self.device) for _, _, y_start, y_end, x_start, x_end in get_views(width, length, reso, 4): patch_index = torch.where( (coords[:, 1] >= y_start) & (coords[:, 1] < y_end) & (coords[:, 2] >= x_start) & (coords[:, 2] < x_end) )[0] if len(patch_index) == 0: continue patch_coords = coords[patch_index].clone() patch_coords[:, 1] -= y_start patch_coords[:, 2] -= x_start patch_latent = SparseTensor( feats=feats[patch_index].contiguous(), coords=patch_coords, ) patch_hr = mesh_decoder.forward_features(patch_latent) # Cosine spatial weight: 1 at patch center, 0 at edges hr_coords = patch_hr.coords[:, 1:].clone() # [N, 3] patch_size = float(4 * reso) cos_w = (torch.cos(torch.pi * (hr_coords[:, 0].float() / patch_size - 0.5)) * torch.cos(torch.pi * (hr_coords[:, 1].float() / patch_size - 0.5)) ).unsqueeze(1) # [N, 1] # Shift to global coordinates hr_coords[:, 0] = (hr_coords[:, 0] + 4 * y_start).clamp(0, res_y - 1) hr_coords[:, 1] = (hr_coords[:, 1] + 4 * x_start).clamp(0, res_x - 1) hr_coords[:, 2] = hr_coords[:, 2].clamp(0, res_z - 1) gy, gx, gz = hr_coords[:, 0], hr_coords[:, 1], hr_coords[:, 2] global_sum [gy, gx, gz] += patch_hr.feats * cos_w global_count[gy, gx, gz] += cos_w # Average overlapping regions occupied = global_count[..., 0] > 0 global_sum[occupied] /= global_count[occupied] if occupied.any(): occ_coords = torch.argwhere(occupied) occ_feats = global_sum[occ_coords[:, 0], occ_coords[:, 1], occ_coords[:, 2]] # Extract per-cube SDF, deformation, color, and FlexiCubes weights sdf = sf2m.get_layout(occ_feats, 'sdf') + sf2m.sdf_bias # [N, 8, 1] deform = sf2m.get_layout(occ_feats, 'deform') # [N, 8, 3] color = sf2m.get_layout(occ_feats, 'color') # [N, 8, 6] or None weights = sf2m.get_layout(occ_feats, 'weights') # [N, 21] v_attrs_cat = (torch.cat([sdf, deform, color], dim=-1) if sf2m.use_color else torch.cat([sdf, deform], dim=-1)) # Merge cube corners into unique vertices v_pos, v_attrs, _ = sparse_cube2verts(occ_coords, v_attrs_cat, training=False) # Build flat dense vertex attribute array for the global grid res_vy, res_vx, res_vz = res_y + 1, res_x + 1, res_z + 1 v_attrs_d = torch.zeros(res_vy * res_vx * res_vz, v_attrs.shape[-1], device=self.device) v_attrs_d[:, 0] = 1.0 # SDF default: outside surface vert_ids = v_pos[:, 0] * res_vx * res_vz + v_pos[:, 1] * res_vz + v_pos[:, 2] v_attrs_d[vert_ids] = v_attrs sdf_d = v_attrs_d[:, 0] deform_d = v_attrs_d[:, 1:4] colors_d = v_attrs_d[:, 4:] if sf2m.use_color else None # Build flat dense cube weight array weights_d = torch.zeros(res_y * res_x * res_z, weights.shape[-1], device=self.device) cube_ids = occ_coords[:, 0] * res_x * res_z + occ_coords[:, 1] * res_z + occ_coords[:, 2] weights_d[cube_ids] = weights # Regular vertex position grid [V, 3], normalized to world space ay, ax, az = (torch.arange(r, device=self.device, dtype=torch.float) for r in (res_vy, res_vx, res_vz)) gy, gx, gz = torch.meshgrid(ay, ax, az, indexing='ij') reg_v = torch.stack([gy.flatten(), gx.flatten(), gz.flatten()], dim=1) # Normalize to Gaussian world coordinate convention: # y, x : [-0.5, 0.5] (centered) # z : [0, 1/scale] (not centered) norm_val = scale * up_res norm_t = torch.tensor([norm_val, norm_val, norm_val], device=self.device, dtype=torch.float) offset_t = torch.tensor([0.5, 0.5, 0.0], device=self.device, dtype=torch.float) x_nx3 = reg_v / norm_t - offset_t + (1 - 1e-8) / (norm_t * 2) * torch.tanh(deform_d) # Global cube → 8 corner vertex index table [C_total, 8] cy, cx, cz = (torch.arange(r, device=self.device) for r in (res_y, res_x, res_z)) gy, gx, gz = torch.meshgrid(cy, cx, cz, indexing='ij') cc = torch.tensor( [[0,0,0],[1,0,0],[0,1,0],[1,1,0],[0,0,1],[1,0,1],[0,1,1],[1,1,1]], dtype=torch.long, device=self.device, ) reg_c = ((gy.flatten().unsqueeze(1) + cc[:, 0]) * res_vx * res_vz + (gx.flatten().unsqueeze(1) + cc[:, 1]) * res_vz + (gz.flatten().unsqueeze(1) + cc[:, 2])) # [C, 8] # Single FlexiCubes call on the full global SDF vertices, faces, _, colors = sf2m.mesh_extractor( voxelgrid_vertices=x_nx3, scalar_field=sdf_d, cube_idx=reg_c, resolution=[res_y, res_x, res_z], beta=weights_d[:, :12], alpha=weights_d[:, 12:20], gamma_f=weights_d[:, 20], voxelgrid_colors=colors_d, training=False, ) ret['mesh'] = [MeshExtractResult( vertices=vertices, faces=faces, vertex_attrs=colors, res=max(res_y, res_x, res_z), )] else: ret['mesh'] = [] # ------------------------------------------------------------------- # Gaussian decoding # ------------------------------------------------------------------- if 'gaussian' in formats: gs_decoder = self.pipeline.models['slat_decoder_gs'] # Decode each patch and collect Gaussian lists per batch element all_patch_lists: list | None = None for i in range(width): for j in range(length): y0, y1 = i * reso, (i + 1) * reso x0, x1 = j * reso, (j + 1) * reso patch_index = torch.where( (coords[:, 1] >= y0) & (coords[:, 1] < y1) & (coords[:, 2] >= x0) & (coords[:, 2] < x1) )[0] if len(patch_index) == 0: continue patch_coords = coords[patch_index].clone() patch_coords[:, 1] -= y0 patch_coords[:, 2] -= x0 patch_latent = SparseTensor( feats=feats[patch_index].contiguous(), coords=patch_coords, ) patch_gaussians = gs_decoder(patch_latent) # Translate Gaussians to their world-space tile position offset = torch.tensor([[i + 0.5, j + 0.5, 0.5]], device=self.device) for g in patch_gaussians: g._xyz = g._xyz + offset if all_patch_lists is None: all_patch_lists = [[g] for g in patch_gaussians] else: for k, g in enumerate(patch_gaussians): all_patch_lists[k].append(g) # Concatenate all patches into a single Gaussian set per batch element merged_gaussians = [] for gs_list in all_patch_lists: g0 = gs_list[0] if len(gs_list) > 1: g0._features_dc = torch.cat([g._features_dc for g in gs_list], dim=0) g0._opacity = torch.cat([g._opacity for g in gs_list], dim=0) g0._rotation = torch.cat([g._rotation for g in gs_list], dim=0) g0._scaling = torch.cat([g._scaling for g in gs_list], dim=0) g0._xyz = torch.cat([g._xyz for g in gs_list], dim=0) merged_gaussians.append(g0) # Filter Gaussians with overly large kernels (outliers) for g in merged_gaussians: scale_norm = torch.sum(g.get_scaling ** 2, dim=1) ** 0.5 keep = torch.where(scale_norm < 0.03)[0] g._features_dc = g._features_dc[keep] g._opacity = g._opacity[keep] g._rotation = g._rotation[keep] g._scaling = g._scaling[keep] g._xyz = g._xyz[keep] # Normalize to world-space coordinate convention eps = 1e-4 center_offset = torch.tensor([[0.5, 0.5, 0.0]], device=self.device) for g in merged_gaussians: g.from_xyz(g.get_xyz / scale) g._xyz -= center_offset g.mininum_kernel_size /= scale g.from_scaling(torch.max( g.get_scaling / scale, torch.tensor(g.mininum_kernel_size * (1 + eps), device=self.device), )) ret['gaussian'] = merged_gaussians return ret # ----------------------------------------------------------------------- # Full pipeline # ----------------------------------------------------------------------- def run( self, image: Image.Image, width: int = 2, length: int = 2, div: int = 2, ss_optim: bool = True, ss_iterations: int = 3, ss_steps: int = 25, ss_rescale_t: float = 3.0, ss_t_noise: float = 0.6, ss_t_start: float = 0.8, ss_cfg_strength: float = 7.5, ss_alpha: float = 5.0, ss_batch_size: int = 1, slat_optim: bool = True, slat_steps: int = 25, slat_rescale_t: float = 3.0, slat_cfg_strength: float = 3.0, slat_batch_size: int = 1, formats: list = ['gaussian', 'mesh'], return_pointmap: bool = False, progress_callback=None, ) -> dict: """Run the full Extend3D pipeline: SS sampling → SLAT sampling → decode.""" pointmap_info = PointmapInfoMoGe(image, device=self.device) coords = self.sample_sparse_structure( image, pointmap_info, ss_optim, width, length, div, iterations=ss_iterations, steps=ss_steps, rescale_t=ss_rescale_t, t_noise=ss_t_noise, t_start=ss_t_start, cfg_strength=ss_cfg_strength, alpha=ss_alpha, batch_size=ss_batch_size, progress_callback=progress_callback, ).detach() slat = self.sample_slat( image, coords, pointmap_info, slat_optim, width, length, div, steps=slat_steps, rescale_t=slat_rescale_t, cfg_strength=slat_cfg_strength, batch_size=slat_batch_size, progress_callback=progress_callback, ) with torch.no_grad(): decoded = self.decode_slat(slat, width, length, formats=formats) if return_pointmap: return decoded, pointmap_info return decoded