from extend3d import Extend3D from trellis.utils import render_utils, postprocessing_utils import imageio import os import argparse from PIL import Image def main(args): pipeline = Extend3D.from_pretrained('microsoft/TRELLIS-image-large') pipeline = pipeline.cuda() image = Image.open(args.image_path).convert('RGB') output = pipeline.run( image=image, width=args.width, length=args.length, div=args.div, ss_optim=not args.skip_ss_optim, ss_iterations=args.ss_iterations, ss_steps=args.ss_steps, ss_rescale_t=args.ss_rescale_t, ss_t_noise=args.ss_t_noise, ss_t_start=args.ss_t_start, ss_cfg_strength=args.ss_cfg_strength, ss_alpha=args.ss_alpha, ss_batch_size=args.ss_batch_size, slat_optim=not args.skip_slat_optim, slat_steps=args.slat_steps, slat_rescale_t=args.slat_rescale_t, slat_cfg_strength=args.slat_cfg_strength, slat_batch_size=args.slat_batch_size, formats=['gaussian', 'mesh']) os.makedirs(args.output_dir, exist_ok=True) output['gaussian'][0].save_ply(os.path.join(args.output_dir, 'sample.ply')) video = render_utils.render_video(output['gaussian'][0], r=1.6, resolution=1024)['color'] imageio.mimsave(os.path.join(args.output_dir, 'sample.mp4'), video, fps=30) glb = postprocessing_utils.to_glb( output['gaussian'][0], output['mesh'][0], simplify=0.9, texture_size=1024, ) glb.export(os.path.join(args.output_dir, 'sample.glb')) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--image-path', type=str, required=True, help='Path to the input image') parser.add_argument('--width', type=int, default=2) parser.add_argument('--length', type=int, default=2) parser.add_argument('--div', type=int, default=4) parser.add_argument('--skip-ss-optim', action='store_true') parser.add_argument('--ss_iterations', type=int, default=3) parser.add_argument('--ss_steps', type=int, default=25) parser.add_argument('--ss_rescale_t', type=float, default=5.0) parser.add_argument('--ss_t_noise', type=float, default=0.6) parser.add_argument('--ss_t_start', type=float, default=0.8) parser.add_argument('--ss_cfg_strength', type=float, default=7.5) parser.add_argument('--ss_alpha', type=float, default=5.0) parser.add_argument('--ss_batch_size', type=int, default=1) parser.add_argument('--skip-slat-optim', action='store_true') parser.add_argument('--slat_steps', type=int, default=25) parser.add_argument('--slat_rescale_t', type=float, default=3.0) parser.add_argument('--slat_cfg_strength', type=float, default=3.0) parser.add_argument('--slat_batch_size', type=int, default=1) parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save the output files') args = parser.parse_args() main(args)