File size: 2,526 Bytes
54c1208 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | from __future__ import annotations
from typing import Any
import torch
from scripts.image_process import prepare_image
def _call_pipeline(
pipe: Any, image, seed: int, num_inference_steps: int, guidance_scale: float
):
# TripoSG uses torch.Generator instead of seed parameter
generator = torch.Generator().manual_seed(int(seed))
if callable(pipe):
try:
return pipe(
image=image,
num_inference_steps=int(num_inference_steps),
guidance_scale=float(guidance_scale),
generator=generator,
)
except TypeError as e:
pass
if hasattr(pipe, "run"):
try:
return pipe.run(
image=image,
num_inference_steps=int(num_inference_steps),
guidance_scale=float(guidance_scale),
generator=generator,
)
except TypeError:
pass
for method_name in ("generate", "infer", "sample"):
method = getattr(pipe, method_name, None)
if method is None:
continue
try:
return method(
image=image,
num_inference_steps=int(num_inference_steps),
guidance_scale=float(guidance_scale),
generator=generator,
)
except TypeError:
continue
raise RuntimeError(
"Unsupported TripoSG pipeline interface in scripts/inference_triposg.py"
)
def run_triposg(
pipe: Any,
image_input: str,
rmbg_net: Any,
seed: int,
num_inference_steps: int,
guidance_scale: float,
faces: int = -1,
):
bg_color = torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32).numpy()
img = prepare_image(image_input, bg_color=bg_color, rmbg_net=rmbg_net)
result = _call_pipeline(
pipe, img, int(seed), int(num_inference_steps), float(guidance_scale)
)
# Extract mesh from TripoSGPipelineOutput
if hasattr(result, "mesh"):
mesh = result.mesh
elif hasattr(result, "meshes"):
mesh = result.meshes[0] if isinstance(result.meshes, (list, tuple)) else result.meshes
else:
# Assume result is already a mesh
mesh = result
if (
hasattr(mesh, "simplify_quadric_decimation")
and isinstance(faces, int)
and faces > 0
):
try:
mesh = mesh.simplify_quadric_decimation(face_count=faces)
except Exception:
pass
return mesh
|