| import os |
| import json |
| import numpy as np |
| from PIL import Image |
| from typing import List |
|
|
| from tqdm import tqdm, trange |
|
|
| os.environ['SPCONV_ALGO'] = 'native' |
|
|
| import torch |
| from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
| from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure |
|
|
| from trellis.pipelines.base import Pipeline |
| from trellis.pipelines import TrellisImageTo3DPipeline |
| from trellis.models import SparseStructureFlowModel, SparseStructureEncoder, SparseStructureDecoder |
| from trellis.modules.sparse.basic import sparse_cat, sparse_unbind, SparseTensor |
| from trellis.utils import render_utils |
| from trellis.representations.mesh import MeshExtractResult |
| from trellis.representations.mesh.utils_cube import sparse_cube2verts |
|
|
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
|
|
| from utils import * |
|
|
|
|
| class Extend3D(Pipeline): |
|
|
| |
| |
| |
|
|
| def __init__(self, ckpt_path: str, device: str = 'cpu'): |
| super().__init__() |
|
|
| |
| self.pipeline = TrellisImageTo3DPipeline.from_pretrained(ckpt_path) |
| self.pipeline.to(device) |
| self.models = self.pipeline.models |
|
|
| |
| config_path = hf_hub_download(repo_id=ckpt_path, |
| filename='ckpts/ss_enc_conv3d_16l8_fp16.json') |
| model_path = hf_hub_download(repo_id=ckpt_path, |
| filename='ckpts/ss_enc_conv3d_16l8_fp16.safetensors') |
| with open(config_path, 'r') as f: |
| model_config = json.load(f) |
| state_dict = load_file(model_path) |
|
|
| encoder = SparseStructureEncoder(**model_config['args']) |
| encoder.load_state_dict(state_dict) |
| self.models['sparse_structure_encoder'] = encoder.to(device) |
|
|
| |
| self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True, net_type='squeeze').to(device) |
| self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) |
| self.lpips.requires_grad_(False) |
| self.ssim.requires_grad_(False) |
|
|
| |
| self.std = torch.tensor(self.pipeline.slat_normalization['std'])[None].to(device) |
| self.mean = torch.tensor(self.pipeline.slat_normalization['mean'])[None].to(device) |
| self.std.requires_grad_(False) |
| self.mean.requires_grad_(False) |
|
|
| |
| |
| |
|
|
| def to(self, device) -> "Extend3D": |
| self.pipeline.to(device) |
| self.models['sparse_structure_encoder'] = self.models['sparse_structure_encoder'].to(device) |
| self.lpips = self.lpips.to(device) |
| self.ssim = self.ssim.to(device) |
| self.std = self.std.to(device) |
| self.mean = self.mean.to(device) |
| return self |
|
|
| def cuda(self) -> "Extend3D": |
| return self.to(torch.device('cuda')) |
|
|
| def cpu(self) -> "Extend3D": |
| return self.to(torch.device('cpu')) |
|
|
| @staticmethod |
| def from_pretrained(ckpt_path: str, device: str = 'cpu') -> "Extend3D": |
| return Extend3D(ckpt_path, device=device) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def preprocess(image: Image.Image) -> Image.Image: |
| return image.resize((1024, 1024), Image.Resampling.LANCZOS) |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def get_cond( |
| self, |
| image: Image.Image, |
| pointmap_info: PointmapInfo = None, |
| width: int = 2, |
| length: int = 2, |
| div: int = 2, |
| ) -> List[List[dict]]: |
| """Compute per-patch image conditioning for the flow model.""" |
| if pointmap_info is None: |
| pointmap_info = PointmapInfo(image, device=self.device) |
|
|
| patches = pointmap_info.divide_image(width, length, div) |
| return [ |
| [self.pipeline.get_cond([self.preprocess(patch)]) for patch in row] |
| for row in patches |
| ] |
|
|
| |
| |
| |
|
|
| def sample_sparse_structure( |
| self, |
| image: Image.Image, |
| pointmap_info: PointmapInfo = None, |
| optim: bool = True, |
| width: int = 2, |
| length: int = 2, |
| div: int = 2, |
| iterations: int = 3, |
| steps: int = 25, |
| rescale_t: float = 3.0, |
| t_noise: float = 0.6, |
| t_start: float = 0.8, |
| cfg_strength: float = 7.5, |
| alpha: float = 5.0, |
| batch_size: int = 1, |
| progress_callback=None, |
| ) -> torch.Tensor: |
| """ |
| Sample occupied voxel coordinates via iterative flow-matching. |
| |
| Returns: |
| coords: int32 tensor of shape [N, 4] (batch, y, x, z). |
| """ |
| if pointmap_info is None: |
| pointmap_info = PointmapInfo(image, device=self.device) |
|
|
| flow_model: SparseStructureFlowModel = self.models['sparse_structure_flow_model'] |
| encoder: SparseStructureEncoder = self.models['sparse_structure_encoder'] |
| decoder: SparseStructureDecoder = self.models['sparse_structure_decoder'] |
| sampler = self.pipeline.sparse_structure_sampler |
| cfg_interval = self.pipeline.sparse_structure_sampler_params['cfg_interval'] |
|
|
| for p in decoder.parameters(): |
| p.requires_grad_(False) |
|
|
| sigma_min = sampler.sigma_min |
| reso = flow_model.resolution |
|
|
| |
| pc = torch.tensor(pointmap_info.point_cloud(), dtype=torch.float32) |
| pc[:, 2] *= max(width, length) |
|
|
| |
| voxel = pointcloud_to_voxel(pc, (4 * reso * length, 4 * reso * width, 4 * reso)) |
| voxel = voxel.permute(0, 1, 3, 2, 4).float().to(self.device) |
| encoded_voxel = encoder(voxel) |
| pc = pc.to(self.device) |
|
|
| _, t_pairs = schedule(steps, rescale_t, start=t_start) |
| views = get_views(width, length, reso, div) |
|
|
| |
| latent = torch.randn(1, flow_model.in_channels, reso * width, reso * length, reso, |
| device=self.device) |
| count = torch.zeros_like(latent) |
| value = torch.zeros_like(latent) |
|
|
| global_cond = self.get_cond(image, pointmap_info, 1, 1, 1)[0][0] |
| cond = self.get_cond(image, pointmap_info, width, length, div) |
|
|
| total_steps = iterations * len(t_pairs) |
| global_step = 0 |
|
|
| iter_range = trange(iterations, position=0) if progress_callback is None else range(iterations) |
| for it in iter_range: |
| |
| latent = diffuse(encoded_voxel, torch.tensor(t_noise, device=self.device), sigma_min) |
| latent = latent.detach() |
|
|
| step_iter = (tqdm(t_pairs, desc="Sparse Structure Sampling", position=1) |
| if progress_callback is None else t_pairs) |
| for t, t_prev in step_iter: |
| cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (1 - torch.tensor(t)))) |
| c = cosine_factor ** alpha |
|
|
| with torch.no_grad(): |
| |
| count.zero_() |
| value.zero_() |
|
|
| local_latents, patch_conds, patch_neg_conds, patch_views = [], [], [], [] |
| for view in views: |
| i, j, y0, y1, x0, x1 = view |
| patch_views.append(view) |
| local_latents.append(latent[:, :, y0:y1, x0:x1, :].contiguous()) |
| patch_cond = cond[i][j] |
| patch_conds.append(patch_cond['cond']) |
| patch_neg_conds.append(patch_cond['neg_cond']) |
|
|
| for start in range(0, len(local_latents), batch_size): |
| end = min(start + batch_size, len(local_latents)) |
|
|
| out = sampler.sample_once( |
| flow_model, |
| torch.cat(local_latents[start:end], dim=0), |
| t, t_prev, |
| cond=torch.cat(patch_conds[start:end], dim=0), |
| neg_cond=torch.cat(patch_neg_conds[start:end], dim=0), |
| cfg_strength=cfg_strength, |
| cfg_interval=cfg_interval, |
| ) |
|
|
| for view, pred_v in zip(patch_views[start:end], out.pred_v): |
| _, _, y0, y1, x0, x1 = view |
| count[:, :, y0:y1, x0:x1, :] += 1 |
| value[:, :, y0:y1, x0:x1, :] += pred_v |
|
|
| local_pred_v = torch.where(count > 0, value / count, latent) |
|
|
| |
| count.zero_() |
| value.zero_() |
|
|
| dilated_samples = dilated_sampling(reso, width, length) |
| dilated_latents = [] |
| dilated_conds = [] |
| dilated_neg_conds = [] |
|
|
| for sample in dilated_samples: |
| sample_latent = (latent[:, :, sample[:, 0], sample[:, 1], :] |
| .view(1, flow_model.in_channels, reso, reso, reso)) |
| dilated_latents.append(sample_latent) |
| dilated_conds.append(global_cond['cond']) |
| dilated_neg_conds.append(global_cond['neg_cond']) |
|
|
| for start in range(0, len(dilated_latents), batch_size): |
| end = min(start + batch_size, len(dilated_latents)) |
|
|
| out = sampler.sample_once( |
| flow_model, |
| torch.cat(dilated_latents[start:end], dim=0), |
| t, t_prev, |
| cond=torch.cat(dilated_conds[start:end], dim=0), |
| neg_cond=torch.cat(dilated_neg_conds[start:end], dim=0), |
| cfg_strength=cfg_strength, |
| cfg_interval=cfg_interval, |
| ) |
|
|
| for sample, pred_v in zip(dilated_samples[start:end], out.pred_v): |
| count[:, :, sample[:, 0], sample[:, 1], :] += 1 |
| value[:, :, sample[:, 0], sample[:, 1], :] += pred_v.view( |
| 1, flow_model.in_channels, reso * reso, reso |
| ) |
|
|
| global_pred_v = torch.where(count > 0, value / count, latent) |
|
|
| |
| v = local_pred_v * (1 - c) + global_pred_v * c |
| v = v.detach() |
|
|
| |
| v.requires_grad_() |
| v.retain_grad() |
| optimizer = torch.optim.Adam([v], lr=0.1) |
|
|
| if optim and t < 0.7: |
| for _ in range(20): |
| optimizer.zero_grad() |
| pred_latent = (1 - sigma_min) * latent - (sigma_min + (1 - sigma_min) * t) * v |
| decoded_latent = decoder(pred_latent) |
| loss = sparse_structure_loss(pc, decoded_latent.permute(0, 1, 3, 2, 4)) |
| loss.backward() |
| optimizer.step() |
|
|
| |
| latent = (latent - (t - t_prev) * v).detach() |
|
|
| if progress_callback is not None: |
| global_step += 1 |
| progress_callback( |
| global_step / total_steps, |
| f"Sparse Structure: iter {it + 1}/{iterations}, step {global_step}/{total_steps}", |
| ) |
|
|
| |
| voxel = (decoder(latent) > 0).float() |
| encoded_voxel = encoder(voxel) |
|
|
| coords = torch.argwhere(decoder(latent) > 0)[:, [0, 2, 3, 4]].int() |
| return coords |
|
|
| |
| |
| |
|
|
| def sample_slat( |
| self, |
| image: Image.Image, |
| coords: torch.Tensor, |
| pointmap_info: PointmapInfo = None, |
| optim: bool = True, |
| width: int = 2, |
| length: int = 2, |
| div: int = 2, |
| steps: int = 25, |
| rescale_t: float = 3.0, |
| cfg_strength: float = 3.0, |
| batch_size: int = 1, |
| progress_callback=None, |
| ) -> SparseTensor: |
| """ |
| Sample per-voxel latent features (SLAT) via flow-matching. |
| |
| Returns: |
| slat: SparseTensor with denormalized latent features. |
| """ |
| if pointmap_info is None: |
| pointmap_info = PointmapInfo(image, device=self.device) |
|
|
| |
| resized_image = image.resize((512, 512)) |
| tensor_image = (torch.from_numpy(np.array(resized_image)) |
| .permute(2, 0, 1).float() / 255.0).to(self.device) |
|
|
| intrinsic = torch.tensor(pointmap_info.camera_intrinsic(), dtype=torch.float32).to(self.device) |
| extrinsic = torch.tensor(pointmap_info.camera_extrinsic(), dtype=torch.float32).to(self.device) |
|
|
| flow_model = self.models['slat_flow_model'] |
| sampler = self.pipeline.slat_sampler |
| cfg_interval = self.pipeline.slat_sampler_params['cfg_interval'] |
| cond = self.get_cond(image, pointmap_info, width, length, div) |
|
|
| sigma_min = sampler.sigma_min |
| reso = flow_model.resolution |
|
|
| latent_feats = torch.randn(coords.shape[0], flow_model.in_channels, device=self.device) |
|
|
| |
| views = get_views(width, length, reso, div) |
| valid_views = [] |
| patch_indices = [] |
| for i, j, y0, y1, x0, x1 in views: |
| idx = torch.where( |
| (coords[:, 1] >= y0) & (coords[:, 1] < y1) & |
| (coords[:, 2] >= x0) & (coords[:, 2] < x1) |
| )[0] |
| if len(idx) > 0: |
| valid_views.append((i, j, y0, y1, x0, x1)) |
| patch_indices.append(idx) |
|
|
| count = torch.zeros(coords.shape[0], flow_model.in_channels, device=self.device) |
| value = torch.zeros(coords.shape[0], flow_model.in_channels, device=self.device) |
|
|
| _, t_pairs = schedule(steps, rescale_t) |
| total_steps = len(t_pairs) |
|
|
| step_iter = (tqdm(t_pairs, desc="Structured Latent Sampling") |
| if progress_callback is None else t_pairs) |
| for slat_step, (t, t_prev) in enumerate(step_iter, start=1): |
| with torch.no_grad(): |
| count.zero_() |
| value.zero_() |
|
|
| patch_latents = [] |
| patch_conds = [] |
| for view, patch_index in zip(valid_views, patch_indices): |
| i, j, y0, y1, x0, x1 = view |
| patch_conds.append(cond[i][j]) |
|
|
| patch_coords_local = coords[patch_index].clone() |
| patch_coords_local[:, 1] -= y0 |
| patch_coords_local[:, 2] -= x0 |
| patch_latents.append(SparseTensor( |
| feats=latent_feats[patch_index].contiguous(), |
| coords=patch_coords_local, |
| )) |
|
|
| for start in range(0, len(patch_latents), batch_size): |
| end = min(start + batch_size, len(patch_latents)) |
|
|
| conds_chunk = patch_conds[start:end] |
| batched_cond = { |
| k: torch.cat([d[k] for d in conds_chunk], dim=0) |
| for k in conds_chunk[0].keys() |
| } |
| outs = sampler.sample_once( |
| flow_model, |
| sparse_cat(patch_latents[start:end]), |
| t, t_prev, |
| cfg_strength=cfg_strength, |
| cfg_interval=cfg_interval, |
| **batched_cond, |
| ) |
|
|
| for out, pidx in zip(sparse_unbind(outs.pred_v, dim=0), patch_indices[start:end]): |
| count[pidx, :] += 1 |
| value[pidx, :] += out.feats |
|
|
| v_feats = torch.where(count > 0, value / count, latent_feats).detach() |
|
|
| |
| v_feats.requires_grad_() |
| optimizer = torch.optim.Adam([v_feats], lr=0.3) |
|
|
| if optim and t < 0.8: |
| for _ in range(20): |
| optimizer.zero_grad() |
|
|
| pred_feats = (1 - sigma_min) * latent_feats - (sigma_min + (1 - sigma_min) * t) * v_feats |
| pred_slat = SparseTensor(feats=pred_feats, coords=coords) * self.std + self.mean |
|
|
| rendered = render_utils.render_frames_torch( |
| self.decode_slat(pred_slat, width, length, formats=['gaussian'])['gaussian'][0], |
| [extrinsic], [intrinsic], |
| {'resolution': 512, 'bg_color': (0, 0, 0)}, |
| verbose=False, |
| )['color'][0].permute(2, 1, 0) |
|
|
| loss = (self.lpips(rendered.unsqueeze(0), tensor_image.unsqueeze(0)) |
| - self.ssim(rendered.unsqueeze(0), tensor_image.unsqueeze(0))) |
| loss.backward() |
| optimizer.step() |
|
|
| |
| latent_feats = (latent_feats - (t - t_prev) * v_feats).detach() |
|
|
| if progress_callback is not None: |
| progress_callback(slat_step / total_steps, |
| f"SLAT Sampling: step {slat_step}/{total_steps}") |
|
|
| slat = SparseTensor(feats=latent_feats, coords=coords) |
| return slat * self.std + self.mean |
|
|
| |
| |
| |
|
|
| def decode_slat( |
| self, |
| slat: SparseTensor, |
| width: int, |
| length: int, |
| formats: list[str] = ['gaussian', 'mesh'], |
| ) -> dict: |
| """Decode a structured latent into Gaussian splats and/or a triangle mesh.""" |
| ret = {} |
| feats = slat.feats |
| coords = slat.coords |
| reso = self.models['slat_flow_model'].resolution |
| scale = max(width, length) |
|
|
| |
| |
| |
| if 'mesh' in formats: |
| mesh_decoder = self.pipeline.models['slat_decoder_mesh'] |
| sf2m = mesh_decoder.mesh_extractor |
|
|
| |
| up_res = mesh_decoder.resolution * 4 |
| res_y, res_x, res_z = width * up_res, length * up_res, up_res |
|
|
| |
| C = sf2m.feats_channels |
| global_sum = torch.zeros(res_y, res_x, res_z, C, device=self.device) |
| global_count = torch.zeros(res_y, res_x, res_z, 1, device=self.device) |
|
|
| for _, _, y_start, y_end, x_start, x_end in get_views(width, length, reso, 4): |
| patch_index = torch.where( |
| (coords[:, 1] >= y_start) & (coords[:, 1] < y_end) & |
| (coords[:, 2] >= x_start) & (coords[:, 2] < x_end) |
| )[0] |
| if len(patch_index) == 0: |
| continue |
|
|
| patch_coords = coords[patch_index].clone() |
| patch_coords[:, 1] -= y_start |
| patch_coords[:, 2] -= x_start |
|
|
| patch_latent = SparseTensor( |
| feats=feats[patch_index].contiguous(), |
| coords=patch_coords, |
| ) |
| patch_hr = mesh_decoder.forward_features(patch_latent) |
|
|
| |
| hr_coords = patch_hr.coords[:, 1:].clone() |
| patch_size = float(4 * reso) |
| cos_w = (torch.cos(torch.pi * (hr_coords[:, 0].float() / patch_size - 0.5)) |
| * torch.cos(torch.pi * (hr_coords[:, 1].float() / patch_size - 0.5)) |
| ).unsqueeze(1) |
|
|
| |
| hr_coords[:, 0] = (hr_coords[:, 0] + 4 * y_start).clamp(0, res_y - 1) |
| hr_coords[:, 1] = (hr_coords[:, 1] + 4 * x_start).clamp(0, res_x - 1) |
| hr_coords[:, 2] = hr_coords[:, 2].clamp(0, res_z - 1) |
|
|
| gy, gx, gz = hr_coords[:, 0], hr_coords[:, 1], hr_coords[:, 2] |
| global_sum [gy, gx, gz] += patch_hr.feats * cos_w |
| global_count[gy, gx, gz] += cos_w |
|
|
| |
| occupied = global_count[..., 0] > 0 |
| global_sum[occupied] /= global_count[occupied] |
|
|
| if occupied.any(): |
| occ_coords = torch.argwhere(occupied) |
| occ_feats = global_sum[occ_coords[:, 0], occ_coords[:, 1], occ_coords[:, 2]] |
|
|
| |
| sdf = sf2m.get_layout(occ_feats, 'sdf') + sf2m.sdf_bias |
| deform = sf2m.get_layout(occ_feats, 'deform') |
| color = sf2m.get_layout(occ_feats, 'color') |
| weights = sf2m.get_layout(occ_feats, 'weights') |
|
|
| v_attrs_cat = (torch.cat([sdf, deform, color], dim=-1) |
| if sf2m.use_color else torch.cat([sdf, deform], dim=-1)) |
|
|
| |
| v_pos, v_attrs, _ = sparse_cube2verts(occ_coords, v_attrs_cat, training=False) |
|
|
| |
| res_vy, res_vx, res_vz = res_y + 1, res_x + 1, res_z + 1 |
| v_attrs_d = torch.zeros(res_vy * res_vx * res_vz, v_attrs.shape[-1], device=self.device) |
| v_attrs_d[:, 0] = 1.0 |
|
|
| vert_ids = v_pos[:, 0] * res_vx * res_vz + v_pos[:, 1] * res_vz + v_pos[:, 2] |
| v_attrs_d[vert_ids] = v_attrs |
|
|
| sdf_d = v_attrs_d[:, 0] |
| deform_d = v_attrs_d[:, 1:4] |
| colors_d = v_attrs_d[:, 4:] if sf2m.use_color else None |
|
|
| |
| weights_d = torch.zeros(res_y * res_x * res_z, weights.shape[-1], device=self.device) |
| cube_ids = occ_coords[:, 0] * res_x * res_z + occ_coords[:, 1] * res_z + occ_coords[:, 2] |
| weights_d[cube_ids] = weights |
|
|
| |
| ay, ax, az = (torch.arange(r, device=self.device, dtype=torch.float) |
| for r in (res_vy, res_vx, res_vz)) |
| gy, gx, gz = torch.meshgrid(ay, ax, az, indexing='ij') |
| reg_v = torch.stack([gy.flatten(), gx.flatten(), gz.flatten()], dim=1) |
|
|
| |
| |
| |
| norm_val = scale * up_res |
| norm_t = torch.tensor([norm_val, norm_val, norm_val], device=self.device, dtype=torch.float) |
| offset_t = torch.tensor([0.5, 0.5, 0.0], device=self.device, dtype=torch.float) |
| x_nx3 = reg_v / norm_t - offset_t + (1 - 1e-8) / (norm_t * 2) * torch.tanh(deform_d) |
|
|
| |
| cy, cx, cz = (torch.arange(r, device=self.device) for r in (res_y, res_x, res_z)) |
| gy, gx, gz = torch.meshgrid(cy, cx, cz, indexing='ij') |
| cc = torch.tensor( |
| [[0,0,0],[1,0,0],[0,1,0],[1,1,0],[0,0,1],[1,0,1],[0,1,1],[1,1,1]], |
| dtype=torch.long, device=self.device, |
| ) |
| reg_c = ((gy.flatten().unsqueeze(1) + cc[:, 0]) * res_vx * res_vz |
| + (gx.flatten().unsqueeze(1) + cc[:, 1]) * res_vz |
| + (gz.flatten().unsqueeze(1) + cc[:, 2])) |
|
|
| |
| vertices, faces, _, colors = sf2m.mesh_extractor( |
| voxelgrid_vertices=x_nx3, |
| scalar_field=sdf_d, |
| cube_idx=reg_c, |
| resolution=[res_y, res_x, res_z], |
| beta=weights_d[:, :12], |
| alpha=weights_d[:, 12:20], |
| gamma_f=weights_d[:, 20], |
| voxelgrid_colors=colors_d, |
| training=False, |
| ) |
| ret['mesh'] = [MeshExtractResult( |
| vertices=vertices, |
| faces=faces, |
| vertex_attrs=colors, |
| res=max(res_y, res_x, res_z), |
| )] |
| else: |
| ret['mesh'] = [] |
|
|
| |
| |
| |
| if 'gaussian' in formats: |
| gs_decoder = self.pipeline.models['slat_decoder_gs'] |
|
|
| |
| all_patch_lists: list | None = None |
| for i in range(width): |
| for j in range(length): |
| y0, y1 = i * reso, (i + 1) * reso |
| x0, x1 = j * reso, (j + 1) * reso |
|
|
| patch_index = torch.where( |
| (coords[:, 1] >= y0) & (coords[:, 1] < y1) & |
| (coords[:, 2] >= x0) & (coords[:, 2] < x1) |
| )[0] |
| if len(patch_index) == 0: |
| continue |
|
|
| patch_coords = coords[patch_index].clone() |
| patch_coords[:, 1] -= y0 |
| patch_coords[:, 2] -= x0 |
|
|
| patch_latent = SparseTensor( |
| feats=feats[patch_index].contiguous(), |
| coords=patch_coords, |
| ) |
| patch_gaussians = gs_decoder(patch_latent) |
|
|
| |
| offset = torch.tensor([[i + 0.5, j + 0.5, 0.5]], device=self.device) |
| for g in patch_gaussians: |
| g._xyz = g._xyz + offset |
|
|
| if all_patch_lists is None: |
| all_patch_lists = [[g] for g in patch_gaussians] |
| else: |
| for k, g in enumerate(patch_gaussians): |
| all_patch_lists[k].append(g) |
|
|
| |
| merged_gaussians = [] |
| for gs_list in all_patch_lists: |
| g0 = gs_list[0] |
| if len(gs_list) > 1: |
| g0._features_dc = torch.cat([g._features_dc for g in gs_list], dim=0) |
| g0._opacity = torch.cat([g._opacity for g in gs_list], dim=0) |
| g0._rotation = torch.cat([g._rotation for g in gs_list], dim=0) |
| g0._scaling = torch.cat([g._scaling for g in gs_list], dim=0) |
| g0._xyz = torch.cat([g._xyz for g in gs_list], dim=0) |
| merged_gaussians.append(g0) |
|
|
| |
| for g in merged_gaussians: |
| scale_norm = torch.sum(g.get_scaling ** 2, dim=1) ** 0.5 |
| keep = torch.where(scale_norm < 0.03)[0] |
| g._features_dc = g._features_dc[keep] |
| g._opacity = g._opacity[keep] |
| g._rotation = g._rotation[keep] |
| g._scaling = g._scaling[keep] |
| g._xyz = g._xyz[keep] |
|
|
| |
| eps = 1e-4 |
| center_offset = torch.tensor([[0.5, 0.5, 0.0]], device=self.device) |
| for g in merged_gaussians: |
| g.from_xyz(g.get_xyz / scale) |
| g._xyz -= center_offset |
| g.mininum_kernel_size /= scale |
| g.from_scaling(torch.max( |
| g.get_scaling / scale, |
| torch.tensor(g.mininum_kernel_size * (1 + eps), device=self.device), |
| )) |
|
|
| ret['gaussian'] = merged_gaussians |
|
|
| return ret |
|
|
| |
| |
| |
|
|
| def run( |
| self, |
| image: Image.Image, |
| |
| width: int = 2, |
| length: int = 2, |
| div: int = 2, |
| |
| ss_optim: bool = True, |
| ss_iterations: int = 3, |
| ss_steps: int = 25, |
| ss_rescale_t: float = 3.0, |
| ss_t_noise: float = 0.6, |
| ss_t_start: float = 0.8, |
| ss_cfg_strength: float = 7.5, |
| ss_alpha: float = 5.0, |
| ss_batch_size: int = 1, |
| |
| slat_optim: bool = True, |
| slat_steps: int = 25, |
| slat_rescale_t: float = 3.0, |
| slat_cfg_strength: float = 3.0, |
| slat_batch_size: int = 1, |
| |
| formats: list = ['gaussian', 'mesh'], |
| return_pointmap: bool = False, |
| progress_callback=None, |
| ) -> dict: |
| """Run the full Extend3D pipeline: SS sampling → SLAT sampling → decode.""" |
| pointmap_info = PointmapInfoMoGe(image, device=self.device) |
|
|
| coords = self.sample_sparse_structure( |
| image, pointmap_info, ss_optim, width, length, div, |
| iterations=ss_iterations, |
| steps=ss_steps, |
| rescale_t=ss_rescale_t, |
| t_noise=ss_t_noise, |
| t_start=ss_t_start, |
| cfg_strength=ss_cfg_strength, |
| alpha=ss_alpha, |
| batch_size=ss_batch_size, |
| progress_callback=progress_callback, |
| ).detach() |
|
|
| slat = self.sample_slat( |
| image, coords, pointmap_info, slat_optim, |
| width, length, div, |
| steps=slat_steps, |
| rescale_t=slat_rescale_t, |
| cfg_strength=slat_cfg_strength, |
| batch_size=slat_batch_size, |
| progress_callback=progress_callback, |
| ) |
|
|
| with torch.no_grad(): |
| decoded = self.decode_slat(slat, width, length, formats=formats) |
|
|
| if return_pointmap: |
| return decoded, pointmap_info |
| return decoded |
|
|