import gradio as gr from gradio_client import Client, handle_file import spaces from concurrent.futures import ThreadPoolExecutor import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["ATTN_BACKEND"] = "flash_attn_3" os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' from datetime import datetime import shutil import cv2 from typing import * import torch import numpy as np from PIL import Image import base64 import io import tempfile from trellis2.modules.sparse import SparseTensor from trellis2.pipelines import Trellis2ImageTo3DPipeline from trellis2.renderers import EnvMap from trellis2.utils import render_utils import o_voxel # Patch postprocess module with local fix for cumesh.fill_holes() bug import importlib.util _local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py') if os.path.exists(_local_postprocess): import sys _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess) _mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_mod) o_voxel.postprocess = _mod sys.modules['o_voxel.postprocess'] = _mod MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') MODES = [ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, ] STEPS = 8 DEFAULT_MODE = 3 DEFAULT_STEP = 3 css = """ /* Overwrite Gradio Default Style */ .stepper-wrapper { padding: 0; } .stepper-container { padding: 0; align-items: center; } .step-button { flex-direction: row; } .step-connector { transform: none; } .step-number { width: 16px; height: 16px; } .step-label { position: relative; bottom: 0; } .wrap.center.full { inset: 0; height: 100%; } .wrap.center.full.translucent { background: var(--block-background-fill); } .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; } /* Previewer */ .previewer-container { position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; width: 100%; height: 722px; margin: 0 auto; padding: 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; } .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; } .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; } .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; } .tips-icon:hover + .tips-text { display: block; opacity: 100%; } /* Row 1: Display Modes */ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; } .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid var(--neutral-600, #555); object-fit: cover; } .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); } .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); } /* Row 2: Display Image */ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; } .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; } .previewer-container .previewer-main-image.visible { display: block; } /* Row 3: Custom HTML Slider */ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; } .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; } .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: var(--neutral-700, #404040); border-radius: 5px; } .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; } .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); } /* Overwrite Previewer Block Style */ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; } .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; } """ head = """ """ empty_html = f"""
""" def image_to_base64(image): buffered = io.BytesIO() image = image.convert("RGB") image.save(buffered, format="jpeg", quality=85) img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) if os.path.exists(user_dir): shutil.rmtree(user_dir) def remove_background(input: Image.Image) -> Image.Image: try: with tempfile.NamedTemporaryFile(suffix='.png') as f: input = input.convert('RGB') input.save(f.name) output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0] output = Image.open(output) return output except Exception as e: raise gr.Error(f"Background removal failed: {e}. Please upload images with transparent backgrounds (RGBA), or try again later.") def preprocess_image(input: Image.Image) -> Image.Image: """ Preprocess the input image. """ # if has alpha channel, use it directly; otherwise, remove background has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] if not np.all(alpha == 255): has_alpha = True max_size = max(input.size) scale = min(1, 1024 / max_size) if scale < 1: input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) if has_alpha: output = input else: output = remove_background(input) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) if bbox.size == 0: # No visible pixels, center the image in a square size = max(output.size) square = Image.new('RGB', (size, size), (0, 0, 0)) output_rgb = output.convert('RGB') if output.mode == 'RGBA' else output square.paste(output_rgb, ((size - output.width) // 2, (size - output.height) // 2)) return square bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1) bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 output = output.crop(bbox) # type: ignore output_np = np.array(output).astype(np.float32) rgb = output_np[:, :, :3] alpha = output_np[:, :, 3:4] / 255.0 # Keep full RGB for visible pixels, zero out transparent background mask = (alpha > 0.05).astype(np.float32) rgb = rgb * mask output = Image.fromarray(rgb.astype(np.uint8)) return output def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: shape_slat, tex_slat, res = latents return { 'shape_slat_feats': shape_slat.feats.cpu().numpy(), 'tex_slat_feats': tex_slat.feats.cpu().numpy(), 'coords': shape_slat.coords.cpu().numpy(), 'res': res, } def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: shape_slat = SparseTensor( feats=torch.from_numpy(state['shape_slat_feats']).cuda(), coords=torch.from_numpy(state['coords']).cuda(), ) tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) return shape_slat, tex_slat, state['res'] def get_seed(randomize_seed, seed): """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed def prepare_multi_example() -> List[str]: """ Prepare multi-image examples. Returns list of image paths. Shows only the first view as representative thumbnail. """ multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) examples = [] for case in multi_case: first_img = f'assets/example_multi_image/{case}_1.png' if os.path.exists(first_img): examples.append(first_img) return examples def load_multi_example(image) -> List[Image.Image]: """Load all views for a multi-image case by matching the input image.""" if image is None: return [] # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to RGB for consistent comparison input_rgb = np.array(image.convert('RGB')) # Find matching case by comparing with first images example_dir = "assets/example_multi_image" case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')])) for case_name in case_names: first_img_path = f'{example_dir}/{case_name}_1.png' if os.path.exists(first_img_path): first_img = Image.open(first_img_path).convert('RGB') first_rgb = np.array(first_img) # Compare images (check if same shape and content) if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb): # Found match, load all views (without preprocessing - will be done on Generate) images = [] for i in range(1, 7): img_path = f'{example_dir}/{case_name}_{i}.png' if os.path.exists(img_path): img = Image.open(img_path).convert('RGBA') images.append(img) if images: return images # No match found, return the single image return [image.convert('RGBA') if image.mode != 'RGBA' else image] def split_image(image: Image.Image) -> List[Image.Image]: """ Split a concatenated image into multiple views. """ image = np.array(image) alpha = image[..., 3] alpha = np.any(alpha > 0, axis=0) start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() images = [] for s, e in zip(start_pos, end_pos): images.append(Image.fromarray(image[:, s:e+1])) return [preprocess_image(image) for image in images] @spaces.GPU(duration=120) def image_to_3d( multiimages, seed, resolution, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, multiimage_algo, tex_multiimage_algo, req: gr.Request, progress=gr.Progress(track_tqdm=True), ): if not multiimages: raise gr.Error("Please upload images or select an example first.") # Preprocess images (background removal for images without alpha) images = [image[0] for image in multiimages] processed_images = [preprocess_image(img) for img in images] # --- Sampling --- outputs, latents = pipeline.run_multi_image( processed_images, seed=seed, preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t, }, shape_slat_sampler_params={ "steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t, }, tex_slat_sampler_params={ "steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t, }, pipeline_type={ "512": "512", "1024": "1024_cascade", "1536": "1536_cascade", }[resolution], return_latent=True, mode=multiimage_algo, tex_mode=tex_multiimage_algo, ) mesh = outputs[0] mesh.simplify(16777216) # nvdiffrast limit images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap) state = pack_state(latents) torch.cuda.empty_cache() # --- HTML Construction --- def encode_preview_image(args): m_idx, s_idx, render_key = args img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx])) return (m_idx, s_idx, img_base64) encode_tasks = [ (m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS) ] with ThreadPoolExecutor(max_workers=8) as executor: encoded_results = list(executor.map(encode_preview_image, encode_tasks)) encoded_map = {(m, s): b64 for m, s, b64 in encoded_results} images_html = "" for m_idx, mode in enumerate(MODES): for s_idx in range(STEPS): unique_id = f"view-m{m_idx}-s{s_idx}" is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP) vis_class = "visible" if is_visible else "" img_base64 = encoded_map[(m_idx, s_idx)] images_html += f""" """ btns_html = "" for idx, mode in enumerate(MODES): active_class = "active" if idx == DEFAULT_MODE else "" btns_html += f""" """ full_html = f"""
💡Tips

Render Mode - Click on the circular buttons to switch between different render modes.

View Angle - Drag the slider to change the view angle.

{images_html}
{btns_html}
""" return state, full_html @spaces.GPU(duration=120) def extract_glb( state, decimation_target, texture_size, req: gr.Request, progress=gr.Progress(track_tqdm=True), ): """ Extract a GLB file from the 3D model. Args: state (dict): The state of the generated 3D model. decimation_target (int): The target face count for decimation. texture_size (int): The texture resolution. Returns: Tuple[str, str]: The path to the extracted GLB file (for Model3D and DownloadButton). """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shape_slat, tex_slat, res = unpack_state(state) mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] mesh.simplify(16777216) # nvdiffrast limit glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) now = datetime.now() timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" os.makedirs(user_dir, exist_ok=True) glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') glb.export(glb_path, extension_webp=True) torch.cuda.empty_cache() return glb_path, glb_path with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate"), css=css, head=head) as demo: gr.HTML("""
OpsiClear

Multi-View to 3D with TRELLIS.2

""") with gr.Row(): with gr.Column(scale=1, min_width=360): multiimage_prompt = gr.Gallery(label="Multi-View Images", format="png", type="pil", height=400, columns=3, interactive=True) remove_img_btn = gr.Button("Remove Selected Image", size="sm", variant="secondary") resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000) texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) with gr.Accordion(label="Advanced Settings", open=False): gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) gr.Markdown("Stage 2: Shape Generation") with gr.Row(): shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("Stage 3: Material Generation") with gr.Row(): tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic") tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion") with gr.Column(scale=10): preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True) with gr.Row(): generate_btn = gr.Button("Generate", variant="primary") extract_btn = gr.Button("Extract GLB") glb_output = gr.Model3D(label="Extracted GLB", height=600, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) download_btn = gr.DownloadButton(label="Download GLB") with gr.Accordion(label="Examples", open=True): example_image = gr.Image(visible=False) # Hidden component for examples examples_multi = gr.Examples( examples=prepare_multi_example(), inputs=[example_image], fn=load_multi_example, outputs=[multiimage_prompt], run_on_click=True, cache_examples=False, examples_per_page=50, ) output_buf = gr.State() selected_img_idx = gr.State(value=None) # Handlers demo.load(start_session) demo.unload(end_session) def on_gallery_select(evt: gr.SelectData): return evt.index def remove_selected_image(images, idx): if images is None or idx is None or not images: return images, None images = list(images) if idx < len(images): images.pop(idx) return images, None multiimage_prompt.select(on_gallery_select, outputs=[selected_img_idx]) remove_img_btn.click( remove_selected_image, inputs=[multiimage_prompt, selected_img_idx], outputs=[multiimage_prompt, selected_img_idx], ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( image_to_3d, inputs=[ multiimage_prompt, seed, resolution, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, multiimage_algo, tex_multiimage_algo ], outputs=[output_buf, preview_output], ) extract_btn.click( extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[glb_output, download_btn], ) # Launch the Gradio app if __name__ == "__main__": os.makedirs(TMP_DIR, exist_ok=True) # Construct ui components btn_img_base64_strs = {} for i in range(len(MODES)): icon = Image.open(MODES[i]['icon']) MODES[i]['icon_base64'] = image_to_base64(icon) rmbg_client = Client("briaai/BRIA-RMBG-2.0") pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B') pipeline.rembg_model = None pipeline.low_vram = False pipeline.cuda() envmap = { 'forest': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'sunset': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'courtyard': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), } demo.launch()