Extend3D / example.py
Seungwoo-Yoon
initial commit for HF space
a68e3ed
from extend3d import Extend3D
from trellis.utils import render_utils, postprocessing_utils
import imageio
import os
import argparse
from PIL import Image
def main(args):
pipeline = Extend3D.from_pretrained('microsoft/TRELLIS-image-large')
pipeline = pipeline.cuda()
image = Image.open(args.image_path).convert('RGB')
output = pipeline.run(
image=image,
width=args.width,
length=args.length,
div=args.div,
ss_optim=not args.skip_ss_optim,
ss_iterations=args.ss_iterations,
ss_steps=args.ss_steps,
ss_rescale_t=args.ss_rescale_t,
ss_t_noise=args.ss_t_noise,
ss_t_start=args.ss_t_start,
ss_cfg_strength=args.ss_cfg_strength,
ss_alpha=args.ss_alpha,
ss_batch_size=args.ss_batch_size,
slat_optim=not args.skip_slat_optim,
slat_steps=args.slat_steps,
slat_rescale_t=args.slat_rescale_t,
slat_cfg_strength=args.slat_cfg_strength,
slat_batch_size=args.slat_batch_size,
formats=['gaussian', 'mesh'])
os.makedirs(args.output_dir, exist_ok=True)
output['gaussian'][0].save_ply(os.path.join(args.output_dir, 'sample.ply'))
video = render_utils.render_video(output['gaussian'][0], r=1.6, resolution=1024)['color']
imageio.mimsave(os.path.join(args.output_dir, 'sample.mp4'), video, fps=30)
glb = postprocessing_utils.to_glb(
output['gaussian'][0],
output['mesh'][0],
simplify=0.9,
texture_size=1024,
)
glb.export(os.path.join(args.output_dir, 'sample.glb'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image-path', type=str, required=True, help='Path to the input image')
parser.add_argument('--width', type=int, default=2)
parser.add_argument('--length', type=int, default=2)
parser.add_argument('--div', type=int, default=4)
parser.add_argument('--skip-ss-optim', action='store_true')
parser.add_argument('--ss_iterations', type=int, default=3)
parser.add_argument('--ss_steps', type=int, default=25)
parser.add_argument('--ss_rescale_t', type=float, default=5.0)
parser.add_argument('--ss_t_noise', type=float, default=0.6)
parser.add_argument('--ss_t_start', type=float, default=0.8)
parser.add_argument('--ss_cfg_strength', type=float, default=7.5)
parser.add_argument('--ss_alpha', type=float, default=5.0)
parser.add_argument('--ss_batch_size', type=int, default=1)
parser.add_argument('--skip-slat-optim', action='store_true')
parser.add_argument('--slat_steps', type=int, default=25)
parser.add_argument('--slat_rescale_t', type=float, default=3.0)
parser.add_argument('--slat_cfg_strength', type=float, default=3.0)
parser.add_argument('--slat_batch_size', type=int, default=1)
parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save the output files')
args = parser.parse_args()
main(args)