import torch import numpy as np import open3d as o3d from PIL import Image, ImageOps from sklearn.neighbors import NearestNeighbors from typing import List, Tuple, Union from abc import ABC, abstractmethod # --------------------------------------------------------------------------- # Point cloud utilities # --------------------------------------------------------------------------- def align_ground_to_z( pc: torch.Tensor, voxel_size: float = 0.05, distance_threshold: float = 0.005, ransac_n: int = 3, num_iterations: int = 10000, return_transform: bool = False, ) -> Union[ Tuple[torch.Tensor, o3d.geometry.PointCloud], Tuple[torch.Tensor, o3d.geometry.PointCloud, np.ndarray], ]: """ Detect the dominant plane in a point cloud, align its normal with +Z, and return the leveled cloud (and optionally the rotation matrix). Args: pc: Input point cloud of shape [N, 3], dtype float32. voxel_size: Voxel size for downsampling before RANSAC; set 0 to skip. distance_threshold: Max distance for a point to be counted as a RANSAC inlier. ransac_n: Number of points sampled per RANSAC trial. num_iterations: Maximum number of RANSAC iterations. return_transform: If True, also return the 3×3 rotation matrix. Returns: aligned_pc: Rotated point cloud, same shape and device as `pc`. aligned_pcd: Rotated Open3D point cloud. R (optional): 3×3 rotation matrix mapping the detected plane normal to [0, 0, 1]. Only returned when `return_transform=True`. Raises: ValueError: If `pc` is not an Nx3 tensor. RuntimeError: If RANSAC fails to find a valid dominant plane. """ if pc.ndim != 2 or pc.shape[1] != 3: raise ValueError(f"Expected pc of shape [N, 3], got {tuple(pc.shape)}") device = pc.device xyz = pc.detach().cpu().numpy() # Build (and optionally downsample) an Open3D point cloud for RANSAC pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz)) if voxel_size > 0: pcd = pcd.voxel_down_sample(voxel_size) # Plane segmentation via RANSAC plane_model, inliers = pcd.segment_plane( distance_threshold=distance_threshold, ransac_n=ransac_n, num_iterations=num_iterations, ) if len(inliers) < ransac_n: raise RuntimeError("RANSAC failed to find a dominant plane.") a, b, c, _ = plane_model normal = np.array([a, b, c], dtype=np.float64) normal /= np.linalg.norm(normal) # Compute rotation from detected normal to +Z via axis-angle target = np.array([0.0, 0.0, 1.0], dtype=np.float64) dot = np.dot(normal, target) if np.allclose(dot, 1.0, atol=1e-6): # Already aligned R = np.eye(3) elif np.allclose(dot, -1.0, atol=1e-6): # 180° flip about any axis orthogonal to the normal ortho = np.array([0.0, 1.0, 0.0] if abs(normal[0]) > 0.9 else [1.0, 0.0, 0.0]) axis = np.cross(normal, ortho) axis /= np.linalg.norm(axis) R = o3d.geometry.get_rotation_matrix_from_axis_angle(axis * np.pi) else: axis = np.cross(normal, target) axis /= np.linalg.norm(axis) angle = np.arccos(dot) R = o3d.geometry.get_rotation_matrix_from_axis_angle(axis * angle) # Apply rotation to the full (non-downsampled) cloud full_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz)) full_pcd.rotate(R, center=(0, 0, 0)) aligned_np = np.asarray(full_pcd.points, dtype=np.float32) aligned_pc = torch.from_numpy(aligned_np).to(device) if return_transform: return aligned_pc, full_pcd, R return aligned_pc, full_pcd def remove_outliers_statistical( pts: np.ndarray, nb_neighbors: int = 20, std_ratio: float = 2.0, ) -> Tuple[np.ndarray, np.ndarray]: """ Remove outliers from a point cloud via Statistical Outlier Removal. For each point, the mean distance to its `nb_neighbors` nearest neighbors is computed. Points whose mean distance exceeds (μ + std_ratio × σ) are discarded. Args: pts: Input point cloud of shape [N, 3]. nb_neighbors: Number of nearest neighbors used per point. std_ratio: Threshold multiplier for the outlier distance cutoff. Returns: clean_pts: Inlier points of shape [M, 3] where M ≤ N. mask: Boolean mask of shape [N] that is True for inlier points. """ # k+1 neighbors so we can skip the zero-distance self-match nbrs = NearestNeighbors(n_neighbors=nb_neighbors + 1, algorithm="auto").fit(pts) dists, _ = nbrs.kneighbors(pts) avg_dists = dists[:, 1:].mean(axis=1) # exclude self (distance = 0) mu = avg_dists.mean() sigma = avg_dists.std() mask = avg_dists <= (mu + std_ratio * sigma) return pts[mask], mask # --------------------------------------------------------------------------- # Image utilities # --------------------------------------------------------------------------- def crop_and_resize_foreground(img: Image.Image, padding: float = 0.05) -> Image.Image: """ Crop the foreground of an image and resize it back to the original dimensions. Steps: 1. Find the tightest bounding box around all non-black pixels. 2. Crop to that bounding box. 3. Add a black border of `padding` × original dimensions on all sides. 4. Scale the padded crop uniformly to fit within the original canvas and center-paste it onto a black background. Args: img: Input PIL image. padding: Border width as a fraction of the original image dimensions. Returns: A new RGB PIL image of the same size as `img`. """ rgb = img.convert("RGB") w, h = rgb.size mask = rgb.convert("L").point(lambda x: 0 if x == 0 else 255, mode="1") bbox = mask.getbbox() if bbox is None: return img.copy() crop = rgb.crop(bbox) pad_x = int(padding * w) pad_y = int(padding * h) padded = ImageOps.expand(crop, border=(pad_x, pad_y, pad_x, pad_y), fill=(0, 0, 0)) pw, ph = padded.size scale = min(w / pw, h / ph) new_w = max(1, int(pw * scale)) new_h = max(1, int(ph * scale)) resized_fg = padded.resize((new_w, new_h), Image.LANCZOS) canvas = Image.new("RGB", (w, h), (0, 0, 0)) offset_x = (w - new_w) // 2 offset_y = (h - new_h) // 2 canvas.paste(resized_fg, (offset_x, offset_y)) return canvas # --------------------------------------------------------------------------- # Pointmap interface # --------------------------------------------------------------------------- class PointmapInfo(ABC): """Abstract base class for depth-based point cloud extraction.""" @abstractmethod def __init__(self, image: Image.Image, device): pass @abstractmethod def point_cloud(self) -> torch.Tensor: pass @abstractmethod def camera_intrinsic(self) -> np.ndarray: pass @abstractmethod def camera_extrinsic(self) -> np.ndarray: pass @abstractmethod def divide_image(self, width: int, length: int, div: int) -> List[List[Image.Image]]: pass