Extend3D / utils /depth_utils.py
Seungwoo-Yoon
initial commit for HF space
a68e3ed
import torch
import numpy as np
import open3d as o3d
from PIL import Image, ImageOps
from sklearn.neighbors import NearestNeighbors
from typing import List, Tuple, Union
from abc import ABC, abstractmethod
# ---------------------------------------------------------------------------
# Point cloud utilities
# ---------------------------------------------------------------------------
def align_ground_to_z(
pc: torch.Tensor,
voxel_size: float = 0.05,
distance_threshold: float = 0.005,
ransac_n: int = 3,
num_iterations: int = 10000,
return_transform: bool = False,
) -> Union[
Tuple[torch.Tensor, o3d.geometry.PointCloud],
Tuple[torch.Tensor, o3d.geometry.PointCloud, np.ndarray],
]:
"""
Detect the dominant plane in a point cloud, align its normal with +Z,
and return the leveled cloud (and optionally the rotation matrix).
Args:
pc: Input point cloud of shape [N, 3], dtype float32.
voxel_size: Voxel size for downsampling before RANSAC; set 0 to skip.
distance_threshold: Max distance for a point to be counted as a RANSAC inlier.
ransac_n: Number of points sampled per RANSAC trial.
num_iterations: Maximum number of RANSAC iterations.
return_transform: If True, also return the 3×3 rotation matrix.
Returns:
aligned_pc: Rotated point cloud, same shape and device as `pc`.
aligned_pcd: Rotated Open3D point cloud.
R (optional): 3×3 rotation matrix mapping the detected plane normal to [0, 0, 1].
Only returned when `return_transform=True`.
Raises:
ValueError: If `pc` is not an Nx3 tensor.
RuntimeError: If RANSAC fails to find a valid dominant plane.
"""
if pc.ndim != 2 or pc.shape[1] != 3:
raise ValueError(f"Expected pc of shape [N, 3], got {tuple(pc.shape)}")
device = pc.device
xyz = pc.detach().cpu().numpy()
# Build (and optionally downsample) an Open3D point cloud for RANSAC
pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz))
if voxel_size > 0:
pcd = pcd.voxel_down_sample(voxel_size)
# Plane segmentation via RANSAC
plane_model, inliers = pcd.segment_plane(
distance_threshold=distance_threshold,
ransac_n=ransac_n,
num_iterations=num_iterations,
)
if len(inliers) < ransac_n:
raise RuntimeError("RANSAC failed to find a dominant plane.")
a, b, c, _ = plane_model
normal = np.array([a, b, c], dtype=np.float64)
normal /= np.linalg.norm(normal)
# Compute rotation from detected normal to +Z via axis-angle
target = np.array([0.0, 0.0, 1.0], dtype=np.float64)
dot = np.dot(normal, target)
if np.allclose(dot, 1.0, atol=1e-6):
# Already aligned
R = np.eye(3)
elif np.allclose(dot, -1.0, atol=1e-6):
# 180° flip about any axis orthogonal to the normal
ortho = np.array([0.0, 1.0, 0.0] if abs(normal[0]) > 0.9 else [1.0, 0.0, 0.0])
axis = np.cross(normal, ortho)
axis /= np.linalg.norm(axis)
R = o3d.geometry.get_rotation_matrix_from_axis_angle(axis * np.pi)
else:
axis = np.cross(normal, target)
axis /= np.linalg.norm(axis)
angle = np.arccos(dot)
R = o3d.geometry.get_rotation_matrix_from_axis_angle(axis * angle)
# Apply rotation to the full (non-downsampled) cloud
full_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz))
full_pcd.rotate(R, center=(0, 0, 0))
aligned_np = np.asarray(full_pcd.points, dtype=np.float32)
aligned_pc = torch.from_numpy(aligned_np).to(device)
if return_transform:
return aligned_pc, full_pcd, R
return aligned_pc, full_pcd
def remove_outliers_statistical(
pts: np.ndarray,
nb_neighbors: int = 20,
std_ratio: float = 2.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Remove outliers from a point cloud via Statistical Outlier Removal.
For each point, the mean distance to its `nb_neighbors` nearest neighbors is
computed. Points whose mean distance exceeds (μ + std_ratio × σ) are discarded.
Args:
pts: Input point cloud of shape [N, 3].
nb_neighbors: Number of nearest neighbors used per point.
std_ratio: Threshold multiplier for the outlier distance cutoff.
Returns:
clean_pts: Inlier points of shape [M, 3] where M ≤ N.
mask: Boolean mask of shape [N] that is True for inlier points.
"""
# k+1 neighbors so we can skip the zero-distance self-match
nbrs = NearestNeighbors(n_neighbors=nb_neighbors + 1, algorithm="auto").fit(pts)
dists, _ = nbrs.kneighbors(pts)
avg_dists = dists[:, 1:].mean(axis=1) # exclude self (distance = 0)
mu = avg_dists.mean()
sigma = avg_dists.std()
mask = avg_dists <= (mu + std_ratio * sigma)
return pts[mask], mask
# ---------------------------------------------------------------------------
# Image utilities
# ---------------------------------------------------------------------------
def crop_and_resize_foreground(img: Image.Image, padding: float = 0.05) -> Image.Image:
"""
Crop the foreground of an image and resize it back to the original dimensions.
Steps:
1. Find the tightest bounding box around all non-black pixels.
2. Crop to that bounding box.
3. Add a black border of `padding` × original dimensions on all sides.
4. Scale the padded crop uniformly to fit within the original canvas and
center-paste it onto a black background.
Args:
img: Input PIL image.
padding: Border width as a fraction of the original image dimensions.
Returns:
A new RGB PIL image of the same size as `img`.
"""
rgb = img.convert("RGB")
w, h = rgb.size
mask = rgb.convert("L").point(lambda x: 0 if x == 0 else 255, mode="1")
bbox = mask.getbbox()
if bbox is None:
return img.copy()
crop = rgb.crop(bbox)
pad_x = int(padding * w)
pad_y = int(padding * h)
padded = ImageOps.expand(crop, border=(pad_x, pad_y, pad_x, pad_y), fill=(0, 0, 0))
pw, ph = padded.size
scale = min(w / pw, h / ph)
new_w = max(1, int(pw * scale))
new_h = max(1, int(ph * scale))
resized_fg = padded.resize((new_w, new_h), Image.LANCZOS)
canvas = Image.new("RGB", (w, h), (0, 0, 0))
offset_x = (w - new_w) // 2
offset_y = (h - new_h) // 2
canvas.paste(resized_fg, (offset_x, offset_y))
return canvas
# ---------------------------------------------------------------------------
# Pointmap interface
# ---------------------------------------------------------------------------
class PointmapInfo(ABC):
"""Abstract base class for depth-based point cloud extraction."""
@abstractmethod
def __init__(self, image: Image.Image, device):
pass
@abstractmethod
def point_cloud(self) -> torch.Tensor:
pass
@abstractmethod
def camera_intrinsic(self) -> np.ndarray:
pass
@abstractmethod
def camera_extrinsic(self) -> np.ndarray:
pass
@abstractmethod
def divide_image(self, width: int, length: int, div: int) -> List[List[Image.Image]]:
pass