File size: 3,005 Bytes
a68e3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from extend3d import Extend3D
from trellis.utils import render_utils, postprocessing_utils

import imageio
import os
import argparse
from PIL import Image

def main(args):
    pipeline = Extend3D.from_pretrained('microsoft/TRELLIS-image-large')
    pipeline = pipeline.cuda()

    image = Image.open(args.image_path).convert('RGB')

    output = pipeline.run(
        image=image,
        width=args.width,
        length=args.length,
        div=args.div,

        ss_optim=not args.skip_ss_optim,
        ss_iterations=args.ss_iterations,
        ss_steps=args.ss_steps,
        ss_rescale_t=args.ss_rescale_t,
        ss_t_noise=args.ss_t_noise,
        ss_t_start=args.ss_t_start,
        ss_cfg_strength=args.ss_cfg_strength,
        ss_alpha=args.ss_alpha,
        ss_batch_size=args.ss_batch_size,

        slat_optim=not args.skip_slat_optim,
        slat_steps=args.slat_steps,
        slat_rescale_t=args.slat_rescale_t,
        slat_cfg_strength=args.slat_cfg_strength,
        slat_batch_size=args.slat_batch_size,

        formats=['gaussian', 'mesh'])

    os.makedirs(args.output_dir, exist_ok=True)
    
    output['gaussian'][0].save_ply(os.path.join(args.output_dir, 'sample.ply'))

    video = render_utils.render_video(output['gaussian'][0], r=1.6, resolution=1024)['color']
    imageio.mimsave(os.path.join(args.output_dir, 'sample.mp4'), video, fps=30)

    glb = postprocessing_utils.to_glb(
        output['gaussian'][0],
        output['mesh'][0],
        simplify=0.9,
        texture_size=1024,
    )
    glb.export(os.path.join(args.output_dir, 'sample.glb'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--image-path', type=str, required=True, help='Path to the input image')

    parser.add_argument('--width', type=int, default=2)
    parser.add_argument('--length', type=int, default=2)
    parser.add_argument('--div', type=int, default=4)

    parser.add_argument('--skip-ss-optim', action='store_true')
    parser.add_argument('--ss_iterations', type=int, default=3)
    parser.add_argument('--ss_steps', type=int, default=25)
    parser.add_argument('--ss_rescale_t', type=float, default=5.0)
    parser.add_argument('--ss_t_noise', type=float, default=0.6)
    parser.add_argument('--ss_t_start', type=float, default=0.8)
    parser.add_argument('--ss_cfg_strength', type=float, default=7.5)
    parser.add_argument('--ss_alpha', type=float, default=5.0)
    parser.add_argument('--ss_batch_size', type=int, default=1)

    parser.add_argument('--skip-slat-optim', action='store_true')
    parser.add_argument('--slat_steps', type=int, default=25)
    parser.add_argument('--slat_rescale_t', type=float, default=3.0)
    parser.add_argument('--slat_cfg_strength', type=float, default=3.0)
    parser.add_argument('--slat_batch_size', type=int, default=1)

    parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save the output files')

    args = parser.parse_args()

    main(args)