FastSplatStyler / app.py
incrl's picture
Update app.py
93da8e0 verified
import gradio as gr
import torch
import os
import tempfile
import shutil
from pathlib import Path
from time import time
# ── Core style-transfer logic (adapted from styletransfer_splat.py) ──────────
import pointCloudToMesh as ply2M
import utils
import graph_io as gio
from clusters import *
import splat_mesh_helpers as splt
import clusters as cl
from torch_geometric.data import Data
from scipy.interpolate import NearestNDInterpolator
from graph_networks.LinearStyleTransfer_vgg import encoder, decoder
from graph_networks.LinearStyleTransfer_matrix import TransformLayer
from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer
from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4
# ── Example assets (place your own files in ./examples/) ─────────────────────
EXAMPLE_SPLATS = [
["example-broche-rose-gold.splat", "style_ims/style2.jpg"],
["example-broche-rose-gold.splat", "style_ims/style6.jpg"],
]
# ── Style-transfer function called by Gradio ─────────────────────────────────
def run_style_transfer(
splat_file,
style_image,
threshold: float,
sampling_rate: float,
device_choice: str,
progress=gr.Progress(track_tqdm=True),
):
if splat_file is None:
raise gr.Error("Please upload a 3D Gaussian Splat file (.ply or .splat).")
if style_image is None:
raise gr.Error("Please upload a style image.")
device = device_choice if device_choice == "cpu" else f"cuda:{device_choice}"
# ── Parameters ────────────────────────────────────────────────────────────
n = 25
ratio = 0.25
depth = 3
style_shape = (512, 512)
logs = []
def log(msg):
logs.append(msg)
print(msg)
return "\n".join(logs)
# ── 1. Load splat ─────────────────────────────────────────────────────────
progress(0.05, desc="Loading splat…")
splat_path = splat_file.name if hasattr(splat_file, "name") else splat_file
log(f"Loading splat: {splat_path}")
pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = \
splt.splat_unpacker_with_threshold(n, splat_path, threshold)
# ── 2. Gaussian super-sampling ────────────────────────────────────────────
progress(0.15, desc="Super-sampling…")
t0 = time()
if sampling_rate > 1:
GaussianSamples = int(pos3D_Original.shape[0] * sampling_rate)
pos3D, colors = splt.splat_GaussianSuperSampler(
pos3D_Original.clone(), colors_Original.clone(),
opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(),
GaussianSamples,
)
else:
pos3D, colors = pos3D_Original, colors_Original
log(f"Nodes in graph: {pos3D.shape[0]} ({time()-t0:.1f}s)")
# ── 3. Graph construction ─────────────────────────────────────────────────
progress(0.30, desc="Building surface graph…")
t0 = time()
style_ref = utils.loadImage(style_image, shape=style_shape)
normalsNP = ply2M.Estimate_Normals(pos3D, threshold)
normals = torch.from_numpy(normalsNP)
up_vector = torch.tensor([[1, 1, 1]], dtype=torch.float)
up_vector = up_vector / torch.linalg.norm(up_vector, dim=1)
pos3D = pos3D.to(device)
colors = colors.to(device)
normals = normals.to(device)
up_vector = up_vector.to(device)
edge_index, directions = gh.surface2Edges(pos3D, normals, up_vector, k_neighbors=16)
edge_index, selections, interps = gh.edges2Selections(edge_index, directions, interpolated=True)
clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(
pos3D, normals, edge_index, selections, interps,
ratio=ratio, up_vector=up_vector, depth=depth, device=device,
)
log(f"Graph built ({time()-t0:.1f}s)")
# ── 4. Load networks ──────────────────────────────────────────────────────
progress(0.50, desc="Loading networks…")
t0 = time()
enc_ref = encoder4()
dec_ref = decoder4()
matrix_ref = MulLayer("r41")
enc_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/vgg_r41.pth", map_location=device))
dec_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/dec_r41.pth", map_location=device))
matrix_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/r41.pth", map_location=device))
enc = encoder(padding_mode="replicate")
dec = decoder(padding_mode="replicate")
matrix = TransformLayer()
with torch.no_grad():
enc.copy_weights(enc_ref)
dec.copy_weights(dec_ref)
matrix.copy_weights(matrix_ref)
content = Data(
x=colors, clusters=clusters,
edge_indexes=edge_indexes,
selections_list=selections_list,
interps_list=interps_list,
).to(device)
style, _ = gio.image2Graph(style_ref, depth=3, device=device)
enc = enc.to(device)
dec = dec.to(device)
matrix = matrix.to(device)
log(f"Networks loaded ({time()-t0:.1f}s)")
# ── 5. Style transfer ─────────────────────────────────────────────────────
progress(0.70, desc="Running style transfer…")
t0 = time()
with torch.no_grad():
cF = enc(content)
sF = enc(style)
feature, _ = matrix(
cF["r41"], sF["r41"],
content.edge_indexes[3], content.selections_list[3],
style.edge_indexes[3], style.selections_list[3],
content.interps_list[3] if hasattr(content, "interps_list") else None,
)
result = dec(feature, content).clamp(0, 1)
colors[:, 0:3] = result
log(f"Stylization done ({time()-t0:.1f}s)")
# ── 6. Interpolate back to original resolution ────────────────────────────
progress(0.88, desc="Interpolating back to original splat…")
t0 = time()
interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu())
results_OriginalNP = interp2(pos3D_Original)
results_Original = torch.from_numpy(results_OriginalNP).to(torch.float32)
colors_and_opacity_Original = torch.cat(
(results_Original, opacity_Original.unsqueeze(1)), dim=1
)
log(f"Interpolation done ({time()-t0:.1f}s)")
# ── 7. Save output ────────────────────────────────────────────────────────
progress(0.95, desc="Saving output splat…")
suffix = ".splat" if fileType == "splat" else ".ply"
out_dir = tempfile.mkdtemp()
out_path = os.path.join(out_dir, f"stylized{suffix}")
splt.splat_save(
pos3D_Original.numpy(),
scales_Original.numpy(),
rots_Original.numpy(),
colors_and_opacity_Original.numpy(),
out_path,
fileType,
)
#log(f"Saved to: {out_path}")
progress(1.0, desc="Done!")
return out_path, "\n".join(logs)
# ── Gradio UI ─────────────────────────────────────────────────────────────────
def build_ui():
available_devices = (
[str(i) for i in range(torch.cuda.device_count())] + ["cpu"]
if torch.cuda.is_available()
else ["cpu"]
)
with gr.Blocks(
title="3DGS Style Transfer",
theme=gr.themes.Soft(primary_hue="violet"),
css="""
#title { text-align: center; }
#subtitle { text-align: center; color: #666; margin-bottom: 1rem; }
.panel { border-radius: 12px; }
#run-btn { font-size: 1.1rem; }
""",
) as demo:
gr.Markdown("# 🎨 3D Gaussian Splat Style Transfer", elem_id="title")
gr.Markdown(
"Official implmentation of Optimization-Free Style Transfer for 3D Gaussian Splats. \n"
"Upload a 3DGS scene and a style image β€” the app will repaint the splat "
"with the artistic style of the image and give you a stylized splat to download. "
"After downloading, you can view your splat with an [online viewer](https://antimatter15.com/splat/).",
elem_id="subtitle",
)
with gr.Row():
# ── Left column: inputs ───────────────────────────────────────────
with gr.Column(scale=1, elem_classes="panel"):
gr.Markdown("### πŸ“‚ Inputs")
splat_input = gr.File(
label="3D Gaussian Splat (.ply or .splat)",
file_types=[".ply", ".splat"],
type="filepath",
)
style_input = gr.Image(
label="Style Image",
type="filepath",
height=240,
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
threshold_slider = gr.Slider(
minimum=90.0, maximum=100.0, value=99.8, step=0.1,
label="Opacity threshold (percentile)",
info="Points below this opacity percentile are removed.",
)
sampling_slider = gr.Slider(
minimum=0.5, maximum=3.0, value=1.5, step=0.1,
label="Gaussian super-sampling rate",
info="Values > 1 add extra samples; 1.0 = no super-sampling.",
)
device_radio = gr.Radio(
choices=available_devices,
value=available_devices[0],
label="Device",
)
run_btn = gr.Button("πŸš€ Run Style Transfer", variant="primary", elem_id="run-btn")
# ── Right column: outputs ─────────────────────────────────────────
with gr.Column(scale=1, elem_classes="panel"):
gr.Markdown("### πŸ“₯ Output")
output_file = gr.File(
label="Download Stylized Splat",
interactive=False,
)
log_box = gr.Textbox(
label="Progress log",
lines=12,
max_lines=20,
interactive=False,
placeholder="Logs will appear here once processing starts…",
)
# ── Examples ─────────────────────────────────────────────────────────
example_splat_paths = [row[0] for row in EXAMPLE_SPLATS]
example_style_paths = [row[1] for row in EXAMPLE_SPLATS]
valid_examples = [
row for row in EXAMPLE_SPLATS
if os.path.exists(row[0]) and os.path.exists(row[1])
]
if valid_examples:
gr.Markdown("### πŸ–ΌοΈ Examples")
gr.Examples(
examples=valid_examples,
inputs=[splat_input, style_input],
label="Click an example to load it",
)
# ── Event wiring ──────────────────────────────────────────────────────
run_btn.click(
fn=run_style_transfer,
inputs=[splat_input, style_input, threshold_slider, sampling_slider, device_radio],
outputs=[output_file, log_box],
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(share=False)