| """Contains `sharp predict` CLI implementation. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from pathlib import Path |
|
|
| import click |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch.utils.data |
|
|
| from sharp.models import ( |
| PredictorParams, |
| RGBGaussianPredictor, |
| create_predictor, |
| ) |
| from sharp.utils import io |
| from sharp.utils import logging as logging_utils |
| from sharp.utils.gaussians import ( |
| Gaussians3D, |
| SceneMetaData, |
| save_ply, |
| unproject_gaussians, |
| ) |
|
|
| from .render import render_gaussians |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
| DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt" |
|
|
|
|
| @click.command() |
| @click.option( |
| "-i", |
| "--input-path", |
| type=click.Path(path_type=Path, exists=True), |
| help="Path to an image or containing a list of images.", |
| required=True, |
| ) |
| @click.option( |
| "-o", |
| "--output-path", |
| type=click.Path(path_type=Path, file_okay=False), |
| help="Path to save the predicted Gaussians and renderings.", |
| required=True, |
| ) |
| @click.option( |
| "-c", |
| "--checkpoint-path", |
| type=click.Path(path_type=Path, dir_okay=False), |
| default=None, |
| help="Path to the .pt checkpoint. If not provided, downloads the default model automatically.", |
| required=False, |
| ) |
| @click.option( |
| "--render/--no-render", |
| "with_rendering", |
| is_flag=True, |
| default=False, |
| help="Whether to render trajectory for checkpoint.", |
| ) |
| @click.option( |
| "--device", |
| type=str, |
| default="default", |
| help="Device to run on. ['cpu', 'mps', 'cuda']", |
| ) |
| @click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.") |
| def predict_cli( |
| input_path: Path, |
| output_path: Path, |
| checkpoint_path: Path, |
| with_rendering: bool, |
| device: str, |
| verbose: bool, |
| ): |
| """Predict Gaussians from input images.""" |
| logging_utils.configure(logging.DEBUG if verbose else logging.INFO) |
|
|
| extensions = io.get_supported_image_extensions() |
|
|
| image_paths = [] |
| if input_path.is_file(): |
| if input_path.suffix in extensions: |
| image_paths = [input_path] |
| else: |
| for ext in extensions: |
| image_paths.extend(list(input_path.glob(f"**/*{ext}"))) |
|
|
| if len(image_paths) == 0: |
| LOGGER.info("No valid images found. Input was %s.", input_path) |
| return |
|
|
| LOGGER.info("Processing %d valid image files.", len(image_paths)) |
|
|
| if device == "default": |
| if torch.cuda.is_available(): |
| device = "cuda" |
| elif torch.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
| LOGGER.info("Using device %s", device) |
|
|
| if with_rendering and device != "cuda": |
| LOGGER.warning("Can only run rendering with gsplat on CUDA. Rendering is disabled.") |
| with_rendering = False |
|
|
| |
| if checkpoint_path is None: |
| LOGGER.info("No checkpoint provided. Downloading default model from %s", DEFAULT_MODEL_URL) |
| state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True) |
| else: |
| LOGGER.info("Loading checkpoint from %s", checkpoint_path) |
| state_dict = torch.load(checkpoint_path, weights_only=True) |
|
|
| gaussian_predictor = create_predictor(PredictorParams()) |
| gaussian_predictor.load_state_dict(state_dict) |
| gaussian_predictor.eval() |
| gaussian_predictor.to(device) |
|
|
| output_path.mkdir(exist_ok=True, parents=True) |
|
|
| for image_path in image_paths: |
| LOGGER.info("Processing %s", image_path) |
| image, _, f_px = io.load_rgb(image_path) |
| height, width = image.shape[:2] |
| intrinsics = torch.tensor( |
| [ |
| [f_px, 0, (width - 1) / 2.0, 0], |
| [0, f_px, (height - 1) / 2.0, 0], |
| [0, 0, 1, 0], |
| [0, 0, 0, 1], |
| ], |
| device=device, |
| dtype=torch.float32, |
| ) |
| gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device)) |
|
|
| LOGGER.info("Saving 3DGS to %s", output_path) |
| save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply") |
|
|
| if with_rendering: |
| output_video_path = (output_path / image_path.stem).with_suffix(".mp4") |
| LOGGER.info("Rendering trajectory to %s", output_video_path) |
|
|
| metadata = SceneMetaData(intrinsics[0, 0].item(), (width, height), "linearRGB") |
| render_gaussians(gaussians, metadata, output_video_path) |
|
|
|
|
| @torch.no_grad() |
| def predict_image( |
| predictor: RGBGaussianPredictor, |
| image: np.ndarray, |
| f_px: float, |
| device: torch.device, |
| ) -> Gaussians3D: |
| """Predict Gaussians from an image.""" |
| internal_shape = (1536, 1536) |
|
|
| LOGGER.info("Running preprocessing.") |
| image_pt = torch.from_numpy(image.copy()).float().to(device).permute(2, 0, 1) / 255.0 |
| _, height, width = image_pt.shape |
| disparity_factor = torch.tensor([f_px / width]).float().to(device) |
|
|
| image_resized_pt = F.interpolate( |
| image_pt[None], |
| size=(internal_shape[1], internal_shape[0]), |
| mode="bilinear", |
| align_corners=True, |
| ) |
|
|
| |
| LOGGER.info("Running inference.") |
| gaussians_ndc = predictor(image_resized_pt, disparity_factor) |
|
|
| LOGGER.info("Running postprocessing.") |
| intrinsics = ( |
| torch.tensor( |
| [ |
| [f_px, 0, width / 2, 0], |
| [0, f_px, height / 2, 0], |
| [0, 0, 1, 0], |
| [0, 0, 0, 1], |
| ] |
| ) |
| .float() |
| .to(device) |
| ) |
| intrinsics_resized = intrinsics.clone() |
| intrinsics_resized[0] *= internal_shape[0] / width |
| intrinsics_resized[1] *= internal_shape[1] / height |
|
|
| |
| gaussians = unproject_gaussians( |
| gaussians_ndc, torch.eye(4).to(device), intrinsics_resized, internal_shape |
| ) |
|
|
| return gaussians |
|
|