import torch import numpy as np from PIL import Image from typing import List import utils3d from moge.model.v2 import MoGeModel from utils.depth_utils import PointmapInfo, align_ground_to_z, crop_and_resize_foreground # --------------------------------------------------------------------------- # Preprocessing # --------------------------------------------------------------------------- def moge_preprocess(image: Image.Image, device) -> torch.Tensor: """Convert a PIL image to a normalized float32 CHW tensor on `device`.""" rgb = np.array(image.convert("RGB")) return torch.tensor(rgb / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1) # --------------------------------------------------------------------------- # MoGe-based pointmap # --------------------------------------------------------------------------- class PointmapInfoMoGe(PointmapInfo): """ Concrete PointmapInfo implementation backed by the MoGe monocular depth estimator. The MoGe model is loaded once and cached as a class-level attribute, so subsequent instantiations reuse the same weights. """ # Shared across all instances to avoid redundant weight loading moge_model: MoGeModel | None = None def __init__(self, image: Image.Image, device: str = 'cuda'): self._input_image = moge_preprocess(image, device) # Run MoGe inference (no gradients needed) with torch.no_grad(): if PointmapInfoMoGe.moge_model is None: PointmapInfoMoGe.moge_model = ( MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device) ) predictions = PointmapInfoMoGe.moge_model.infer(self._input_image) # Mask out depth edges to suppress discontinuity artifacts depth_edge_mask = utils3d.numpy.depth_edge(predictions['depth'].cpu().numpy(), rtol=0.04) mask = predictions['mask'] & torch.from_numpy(~depth_edge_mask).to(device) # Align the ground plane with the XY plane (+Z up) points = predictions['points'] masked_points, _, R = align_ground_to_z(points[mask].reshape(-1, 3), return_transform=True) # Move arrays to CPU/numpy for coordinate normalization mask = mask.cpu().numpy() points = points.cpu().numpy() masked_points = masked_points.cpu().numpy() self.intrinsic = predictions['intrinsics'].cpu().numpy() # Normalize XY to [0, 1] and Z to a height relative to scene scale mins = masked_points[:, :2].min(axis=0) maxs = masked_points[:, :2].max(axis=0) scaling = (maxs - mins).max() height = masked_points[:, 2].max() / scaling # Flip Z, center XY, and apply uniform scale masked_points[:, 2] *= -1 masked_points[:, :2] = (masked_points[:, :2] - mins) / scaling + (1 - (maxs - mins) / scaling) / 2 masked_points[:, 2] -= masked_points[:, 2].min() masked_points[:, 2] *= 1.0 / scaling # Build the camera extrinsic [R | t] from the alignment transform R = R.T R[:, 2] *= -1 t = R @ np.array([*(mins / scaling - (1 - (maxs - mins) / scaling) / 2), -height]) t += R @ np.array([0.5, 0.5, 0.0]) # Permute axes from (y, x, z) to (x, y, z) convention P = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) self.intrinsic = P @ self.intrinsic @ P.T R = P @ R @ P.T t = P @ t self.extrinsic = np.vstack((np.hstack((R, t.reshape(-1, 1))), [0, 0, 0, 1])) # Store the full pointmap (with masked region filled in) for patch extraction self.pc = masked_points points[mask] = masked_points self._pointmap = points # ----------------------------------------------------------------------- # PointmapInfo interface # ----------------------------------------------------------------------- def point_cloud(self) -> np.ndarray: return self.pc def camera_intrinsic(self) -> np.ndarray: return self.intrinsic def camera_extrinsic(self) -> np.ndarray: return self.extrinsic def divide_image(self, width: int, length: int, div: int) -> List[List[Image.Image]]: """ Slice the image into overlapping patches based on the normalized pointmap. Args: width: Number of tiles along the Y axis. length: Number of tiles along the X axis. div: Overlap subdivision factor (higher = more overlap). Returns: 2D list of PIL images of shape [width*(div-1)+1][length*(div-1)+1]. """ # Convert the input tensor back to a uint8 HWC numpy array image_np = (self._input_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) patches = [] for i in range(width * div - div + 1): row = [] for j in range(length * div - div + 1): # Compute normalized [0, 1] bounds for this patch y_start = i / (width * div) x_start = j / (length * div) y_end = y_start + 1.0 / width x_end = x_start + 1.0 / length # Mask pixels whose pointmap coordinates fall within this patch pm = self._pointmap in_patch = ( (y_start <= pm[:, :, 1]) & (pm[:, :, 1] < y_end) & (x_start <= pm[:, :, 0]) & (pm[:, :, 0] < x_end) )[:, :, None] patch_np = np.where(in_patch, image_np, 0).astype(np.uint8) patch_img = crop_and_resize_foreground(Image.fromarray(patch_np)) row.append(patch_img) patches.append(row) return patches