| """ |
| 3D Voxel Shape Classifier — Complete Geometric Primitive Vocabulary |
| 5×5×5 binary voxel grid → rigid cascade → curvature analysis → classify |
| |
| 38 shape classes covering: |
| - Rigid 0D-3D: points, lines, joints, triangles, quads, polyhedra, prisms |
| - Curved 1D: arcs, helices |
| - Curved 2D: circles, ellipses, discs |
| - Curved 3D solid: sphere, hemisphere, cylinder, cone, capsule, torus |
| - Curved 3D hollow: shell, tube |
| - Curved 3D open: bowl (concave), saddle (hyperbolic) |
| |
| Curvature types: none, convex, concave, cylindrical, conical, toroidal, hyperbolic, helical |
| """ |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
| import math |
| from itertools import combinations |
|
|
|
|
| |
|
|
| class SwiGLU(nn.Module): |
| """ |
| SwiGLU activation: out = (x @ W1) * SiLU(x @ W2) |
| |
| SiLU(x) = x * sigmoid(x), aka Swish — the "Swi" in SwiGLU. |
| Unlike plain sigmoid gating, SiLU preserves gradient magnitude |
| through the gate branch while maintaining sharp gating behavior. |
| |
| Used at geometric decision points where crisp on/off transitions |
| matter more than smooth interpolation. |
| """ |
|
|
| def __init__(self, in_dim, out_dim): |
| super().__init__() |
| self.w1 = nn.Linear(in_dim, out_dim) |
| self.w2 = nn.Linear(in_dim, out_dim) |
|
|
| def forward(self, x): |
| return self.w1(x) * F.silu(self.w2(x)) |
|
|
|
|
| |
|
|
| SHAPE_CATALOG = { |
| |
| "point": {"dim": 0, "curved": False, "curvature": "none"}, |
|
|
| |
| "line_x": {"dim": 1, "curved": False, "curvature": "none"}, |
| "line_y": {"dim": 1, "curved": False, "curvature": "none"}, |
| "line_z": {"dim": 1, "curved": False, "curvature": "none"}, |
| "line_diag": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
| |
| "cross": {"dim": 1, "curved": False, "curvature": "none"}, |
| "l_shape": {"dim": 1, "curved": False, "curvature": "none"}, |
| "collinear": {"dim": 1, "curved": False, "curvature": "none"}, |
|
|
| |
| "triangle_xy": {"dim": 2, "curved": False, "curvature": "none"}, |
| "triangle_xz": {"dim": 2, "curved": False, "curvature": "none"}, |
| "triangle_3d": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
| |
| "square_xy": {"dim": 2, "curved": False, "curvature": "none"}, |
| "square_xz": {"dim": 2, "curved": False, "curvature": "none"}, |
| "rectangle": {"dim": 2, "curved": False, "curvature": "none"}, |
| "coplanar": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
| |
| "plane": {"dim": 2, "curved": False, "curvature": "none"}, |
|
|
| |
| "tetrahedron": {"dim": 3, "curved": False, "curvature": "none"}, |
| "pyramid": {"dim": 3, "curved": False, "curvature": "none"}, |
| "pentachoron": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
| |
| "cube": {"dim": 3, "curved": False, "curvature": "none"}, |
| "cuboid": {"dim": 3, "curved": False, "curvature": "none"}, |
| "triangular_prism": {"dim": 3, "curved": False, "curvature": "none"}, |
| "octahedron": {"dim": 3, "curved": False, "curvature": "none"}, |
|
|
| |
| "arc": {"dim": 1, "curved": True, "curvature": "convex"}, |
| "helix": {"dim": 1, "curved": True, "curvature": "helical"}, |
|
|
| |
| "circle": {"dim": 2, "curved": True, "curvature": "convex"}, |
| "ellipse": {"dim": 2, "curved": True, "curvature": "convex"}, |
|
|
| |
| "disc": {"dim": 2, "curved": True, "curvature": "convex"}, |
|
|
| |
| "sphere": {"dim": 3, "curved": True, "curvature": "convex"}, |
| "hemisphere": {"dim": 3, "curved": True, "curvature": "convex"}, |
| "cylinder": {"dim": 3, "curved": True, "curvature": "cylindrical"}, |
| "cone": {"dim": 3, "curved": True, "curvature": "conical"}, |
| "capsule": {"dim": 3, "curved": True, "curvature": "convex"}, |
| "torus": {"dim": 3, "curved": True, "curvature": "toroidal"}, |
|
|
| |
| "shell": {"dim": 3, "curved": True, "curvature": "convex"}, |
| "tube": {"dim": 3, "curved": True, "curvature": "cylindrical"}, |
|
|
| |
| "bowl": {"dim": 3, "curved": True, "curvature": "concave"}, |
| "saddle": {"dim": 3, "curved": True, "curvature": "hyperbolic"}, |
| } |
|
|
| NUM_CLASSES = len(SHAPE_CATALOG) |
| CLASS_NAMES = list(SHAPE_CATALOG.keys()) |
| CLASS_TO_IDX = {name: i for i, name in enumerate(CLASS_NAMES)} |
|
|
| CURVATURE_TYPES = ["none", "convex", "concave", "cylindrical", "conical", |
| "toroidal", "hyperbolic", "helical"] |
| CURV_TO_IDX = {c: i for i, c in enumerate(CURVATURE_TYPES)} |
| NUM_CURVATURES = len(CURVATURE_TYPES) |
|
|
| GS = 5 |
|
|
|
|
| |
|
|
| def cayley_menger_det(points: np.ndarray) -> float: |
| n = len(points) |
| D = np.zeros((n, n)) |
| for i in range(n): |
| for j in range(n): |
| D[i, j] = np.sum((points[i] - points[j]) ** 2) |
| CM = np.zeros((n + 1, n + 1)) |
| CM[0, 1:] = 1 |
| CM[1:, 0] = 1 |
| CM[1:, 1:] = D |
| return np.linalg.det(CM) |
|
|
|
|
| def simplex_volume(points: np.ndarray) -> float: |
| k = len(points) |
| if k < 2: return 0.0 |
| cm = cayley_menger_det(points) |
| sign = (-1) ** k |
| denom = (2 ** (k - 1)) * (math.factorial(k - 1) ** 2) |
| v_sq = sign * cm / denom |
| return np.sqrt(max(0, v_sq)) |
|
|
|
|
| def effective_volume(points: np.ndarray) -> float: |
| k = len(points) |
| if k < 2: return 0.0 |
| if k == 2: return np.linalg.norm(points[0] - points[1]) |
| if k >= 3: |
| max_a = 0 |
| for idx in combinations(range(min(k, 8)), 3): |
| max_a = max(max_a, simplex_volume(points[list(idx)])) |
| if k < 4: return max_a |
| if k >= 4: |
| max_v = 0 |
| for idx in combinations(range(min(k, 8)), 4): |
| max_v = max(max_v, simplex_volume(points[list(idx)])) |
| return max_v |
| return 0.0 |
|
|
|
|
| |
|
|
| class ShapeGenerator: |
| def __init__(self, seed=42): |
| self.rng = np.random.RandomState(seed) |
|
|
| def generate(self, n_samples: int) -> list: |
| samples = [] |
| per_class = n_samples // NUM_CLASSES |
| for name in CLASS_NAMES: |
| count = 0 |
| attempts = 0 |
| while count < per_class and attempts < per_class * 5: |
| s = self._make(name) |
| attempts += 1 |
| if s is not None: |
| samples.append(s) |
| count += 1 |
| while len(samples) < n_samples: |
| name = self.rng.choice(CLASS_NAMES) |
| s = self._make(name) |
| if s is not None: |
| samples.append(s) |
| self.rng.shuffle(samples) |
| return samples[:n_samples] |
|
|
| def _make(self, name: str) -> Optional[dict]: |
| info = SHAPE_CATALOG[name] |
| if info["curved"]: |
| voxels = self._curved(name) |
| else: |
| voxels = self._rigid(name) |
| if voxels is None: return None |
| voxels = np.clip(voxels, 0, GS - 1).astype(int) |
| voxels = np.unique(voxels, axis=0) |
| if len(voxels) < 1: return None |
| return self._build(name, info, voxels) |
|
|
| |
|
|
| def _rigid(self, name): |
| rng = self.rng |
|
|
| if name == "point": |
| return rng.randint(0, GS, size=(1, 3)) |
|
|
| elif name == "line_x": |
| y, z = rng.randint(0, GS, size=2) |
| x1, x2 = sorted(rng.choice(GS, 2, replace=False)) |
| return np.array([[x1, y, z], [x2, y, z]]) |
|
|
| elif name == "line_y": |
| x, z = rng.randint(0, GS, size=2) |
| y1, y2 = sorted(rng.choice(GS, 2, replace=False)) |
| return np.array([[x, y1, z], [x, y2, z]]) |
|
|
| elif name == "line_z": |
| x, y = rng.randint(0, GS, size=2) |
| z1, z2 = sorted(rng.choice(GS, 2, replace=False)) |
| return np.array([[x, y, z1], [x, y, z2]]) |
|
|
| elif name == "line_diag": |
| p1 = rng.randint(0, 3, size=3) |
| step = rng.randint(1, 3) |
| direction = rng.choice([-1, 1], size=3) |
| if np.sum(direction != 0) < 2: |
| direction[rng.randint(3)] = rng.choice([-1, 1]) |
| p2 = np.clip(p1 + step * direction, 0, GS - 1) |
| if np.array_equal(p1, p2): |
| p2 = np.clip(p1 + np.array([1, 1, 0]), 0, GS - 1) |
| return np.array([p1, p2]) |
|
|
| elif name == "cross": |
| |
| cx, cy, cz = rng.randint(1, GS - 1, size=3) |
| length = rng.randint(1, 3) |
| axis1, axis2 = rng.choice(3, 2, replace=False) |
| pts = [[cx, cy, cz]] |
| for sign in [-1, 1]: |
| p = [cx, cy, cz] |
| p[axis1] = np.clip(p[axis1] + sign * length, 0, GS - 1) |
| pts.append(list(p)) |
| for sign in [-1, 1]: |
| p = [cx, cy, cz] |
| p[axis2] = np.clip(p[axis2] + sign * length, 0, GS - 1) |
| pts.append(list(p)) |
| return np.array(pts) |
|
|
| elif name == "l_shape": |
| |
| corner = rng.randint(1, GS - 1, size=3) |
| axis1, axis2 = rng.choice(3, 2, replace=False) |
| len1 = rng.randint(1, 3) |
| len2 = rng.randint(1, 3) |
| dir1 = rng.choice([-1, 1]) |
| dir2 = rng.choice([-1, 1]) |
| pts = [list(corner)] |
| for i in range(1, len1 + 1): |
| p = list(corner) |
| p[axis1] = np.clip(p[axis1] + dir1 * i, 0, GS - 1) |
| pts.append(p) |
| for i in range(1, len2 + 1): |
| p = list(corner) |
| p[axis2] = np.clip(p[axis2] + dir2 * i, 0, GS - 1) |
| pts.append(p) |
| return np.array(pts) |
|
|
| elif name == "collinear": |
| axis = rng.randint(3) |
| fixed = rng.randint(0, GS, size=2) |
| vals = sorted(rng.choice(GS, 3, replace=False)) |
| pts = np.zeros((3, 3), dtype=int) |
| for i, v in enumerate(vals): |
| pts[i, axis] = v |
| pts[i, (axis + 1) % 3] = fixed[0] |
| pts[i, (axis + 2) % 3] = fixed[1] |
| return pts |
|
|
| elif name == "triangle_xy": |
| z = rng.randint(0, GS) |
| pts = self._rand_pts_2d(3, min_dist=1) |
| if pts is None: return None |
| return np.column_stack([pts, np.full(3, z)]) |
|
|
| elif name == "triangle_xz": |
| y = rng.randint(0, GS) |
| pts = self._rand_pts_2d(3, min_dist=1) |
| if pts is None: return None |
| return np.column_stack([pts[:, 0], np.full(3, y), pts[:, 1]]) |
|
|
| elif name == "triangle_3d": |
| return self._rand_pts_3d(3, min_dist=1) |
|
|
| elif name == "square_xy": |
| z = rng.randint(0, GS) |
| x1, y1 = rng.randint(0, 3, size=2) |
| s = rng.randint(1, 3) |
| pts = np.array([[x1, y1, z], [x1 + s, y1, z], |
| [x1, y1 + s, z], [x1 + s, y1 + s, z]]) |
| return np.clip(pts, 0, GS - 1) |
|
|
| elif name == "square_xz": |
| y = rng.randint(0, GS) |
| x1, z1 = rng.randint(0, 3, size=2) |
| s = rng.randint(1, 3) |
| pts = np.array([[x1, y, z1], [x1 + s, y, z1], |
| [x1, y, z1 + s], [x1 + s, y, z1 + s]]) |
| return np.clip(pts, 0, GS - 1) |
|
|
| elif name == "rectangle": |
| axis = rng.randint(3) |
| val = rng.randint(0, GS) |
| a1, a2 = rng.randint(0, 3), rng.randint(0, 3) |
| w, h = rng.randint(1, 4), rng.randint(1, 3) |
| if w == h: w = min(GS - 1, w + 1) |
| c = np.array([[a1, a2], [a1 + w, a2], [a1, a2 + h], [a1 + w, a2 + h]]) |
| c = np.clip(c, 0, GS - 1) |
| if axis == 0: return np.column_stack([np.full(4, val), c]) |
| elif axis == 1: return np.column_stack([c[:, 0], np.full(4, val), c[:, 1]]) |
| else: return np.column_stack([c, np.full(4, val)]) |
|
|
| elif name == "coplanar": |
| pts = self._rand_pts_3d(4, min_dist=1) |
| if pts is None: return None |
| pts[:, rng.randint(3)] = pts[0, rng.randint(3)] |
| return pts |
|
|
| elif name == "plane": |
| |
| axis = rng.randint(3) |
| val = rng.randint(0, GS) |
| a_start = rng.randint(0, 2) |
| b_start = rng.randint(0, 2) |
| a_size = rng.randint(2, GS - a_start + 1) |
| b_size = rng.randint(2, GS - b_start + 1) |
| pts = [] |
| for a in range(a_start, min(GS, a_start + a_size)): |
| for b in range(b_start, min(GS, b_start + b_size)): |
| p = [0, 0, 0] |
| p[axis] = val |
| p[(axis + 1) % 3] = a |
| p[(axis + 2) % 3] = b |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "tetrahedron": |
| pts = self._rand_pts_3d(4, min_dist=1) |
| if pts is None: return None |
| centered = pts - pts.mean(axis=0) |
| _, s, _ = np.linalg.svd(centered.astype(float)) |
| if s[-1] < 0.5: |
| pts[rng.randint(4), rng.randint(3)] = (pts[0, 0] + 2) % GS |
| return pts |
|
|
| elif name == "pyramid": |
| z_base = rng.randint(0, 3) |
| x1, y1 = rng.randint(0, 3), rng.randint(0, 3) |
| s = rng.randint(1, 3) |
| base = np.array([[x1, y1, z_base], [x1 + s, y1, z_base], |
| [x1, y1 + s, z_base], [x1 + s, y1 + s, z_base]]) |
| apex = np.array([[x1 + s // 2, y1 + s // 2, z_base + rng.randint(1, 3)]]) |
| return np.clip(np.vstack([base, apex]), 0, GS - 1) |
|
|
| elif name == "pentachoron": |
| return self._rand_pts_3d(5, min_dist=1) |
|
|
| elif name == "cube": |
| x1, y1, z1 = rng.randint(0, 3, size=3) |
| s = rng.randint(1, 3) |
| pts = [] |
| for dx in [0, s]: |
| for dy in [0, s]: |
| for dz in [0, s]: |
| pts.append([x1 + dx, y1 + dy, z1 + dz]) |
| return np.clip(np.array(pts), 0, GS - 1) |
|
|
| elif name == "cuboid": |
| x1, y1, z1 = rng.randint(0, 2, size=3) |
| sx, sy, sz = rng.randint(1, 4, size=3) |
| |
| if sx == sy == sz: |
| sx = min(GS - 1, sx + 1) |
| pts = [] |
| for dx in [0, sx]: |
| for dy in [0, sy]: |
| for dz in [0, sz]: |
| pts.append([x1 + dx, y1 + dy, z1 + dz]) |
| return np.clip(np.array(pts), 0, GS - 1) |
|
|
| elif name == "triangular_prism": |
| |
| axis = rng.randint(3) |
| ext_start = rng.randint(0, 3) |
| ext_len = rng.randint(1, 3) |
| tri = self._rand_pts_2d(3, min_dist=1) |
| if tri is None: return None |
| pts = [] |
| for e in range(ext_start, min(GS, ext_start + ext_len + 1)): |
| for t in tri: |
| p = [0, 0, 0] |
| p[axis] = e |
| p[(axis + 1) % 3] = t[0] |
| p[(axis + 2) % 3] = t[1] |
| pts.append(p) |
| return np.clip(np.array(pts), 0, GS - 1) if len(pts) >= 6 else None |
|
|
| elif name == "octahedron": |
| |
| cx, cy, cz = rng.randint(1, GS - 1, size=3) |
| s = rng.randint(1, 3) |
| pts = [[cx, cy, cz + s], [cx, cy, cz - s], |
| [cx + s, cy, cz], [cx - s, cy, cz], |
| [cx, cy + s, cz], [cx, cy - s, cz]] |
| return np.clip(np.array(pts), 0, GS - 1) |
|
|
| return None |
|
|
| |
|
|
| def _curved(self, name): |
| rng = self.rng |
| cx, cy, cz = rng.uniform(1.0, 3.0, size=3) |
|
|
| if name == "arc": |
| r = rng.uniform(1.2, 2.2) |
| plane = rng.choice(["xy", "xz", "yz"]) |
| start = rng.uniform(0, 2 * np.pi) |
| span = rng.uniform(np.pi * 0.4, np.pi * 1.2) |
| n = rng.randint(6, 12) |
| angles = np.linspace(start, start + span, n) |
| pts = [] |
| for a in angles: |
| if plane == "xy": |
| pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) |
| elif plane == "xz": |
| pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) |
| else: |
| pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) |
| pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| return pts if len(pts) >= 3 else None |
|
|
| elif name == "helix": |
| |
| r = rng.uniform(0.8, 1.8) |
| axis = rng.randint(3) |
| pitch = rng.uniform(0.3, 0.8) |
| n = rng.randint(15, 30) |
| t = np.linspace(0, 2 * np.pi * rng.uniform(1.0, 2.5), n) |
| pts = [] |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| start_h = rng.uniform(0, 1.0) |
| for ti in t: |
| p = [0.0, 0.0, 0.0] |
| p[axes[0]] = center[axes[0]] + r * np.cos(ti) |
| p[axes[1]] = center[axes[1]] + r * np.sin(ti) |
| p[axis] = start_h + pitch * ti |
| pts.append(p) |
| pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| return pts if len(pts) >= 5 else None |
|
|
| elif name == "circle": |
| r = rng.uniform(1.0, 2.0) |
| plane = rng.choice(["xy", "xz", "yz"]) |
| n = rng.randint(12, 20) |
| angles = np.linspace(0, 2 * np.pi, n, endpoint=False) |
| pts = [] |
| for a in angles: |
| if plane == "xy": |
| pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) |
| elif plane == "xz": |
| pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) |
| else: |
| pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) |
| pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| return pts if len(pts) >= 5 else None |
|
|
| elif name == "ellipse": |
| rx, ry = rng.uniform(0.8, 2.0), rng.uniform(0.8, 2.0) |
| if abs(rx - ry) < 0.3: rx *= 1.4 |
| plane = rng.choice(["xy", "xz", "yz"]) |
| n = rng.randint(12, 20) |
| angles = np.linspace(0, 2 * np.pi, n, endpoint=False) |
| pts = [] |
| for a in angles: |
| if plane == "xy": |
| pts.append([cx + rx * np.cos(a), cy + ry * np.sin(a), cz]) |
| elif plane == "xz": |
| pts.append([cx + rx * np.cos(a), cy, cz + ry * np.sin(a)]) |
| else: |
| pts.append([cx, cy + rx * np.cos(a), cz + ry * np.sin(a)]) |
| pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) |
| return pts if len(pts) >= 5 else None |
|
|
| elif name == "disc": |
| |
| r = rng.uniform(1.0, 2.2) |
| axis = rng.randint(3) |
| val = round(rng.uniform(0.5, 3.5)) |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| p = [0, 0, 0] |
| p[axis] = val |
| p[axes[0]] = x |
| p[axes[1]] = y |
| dist = np.sqrt((x - center[axes[0]])**2 + (y - center[axes[1]])**2) |
| if dist <= r: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "sphere": |
| r = rng.uniform(1.0, 2.2) |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: |
| pts.append([x, y, z]) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "hemisphere": |
| r = rng.uniform(1.0, 2.2) |
| cut_axis = rng.randint(3) |
| center = [cx, cy, cz] |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: |
| if p[cut_axis] >= center[cut_axis]: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 3 else None |
|
|
| elif name == "cylinder": |
| r = rng.uniform(0.8, 1.8) |
| axis = rng.randint(3) |
| length = rng.randint(2, 5) |
| start = rng.randint(0, GS - length + 1) |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| if p[axis] < start or p[axis] >= start + length: continue |
| dist_sq = sum((p[a] - center[a])**2 for a in axes) |
| if dist_sq <= r**2: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "cone": |
| r_base = rng.uniform(1.0, 2.0) |
| axis = rng.randint(3) |
| height = rng.randint(2, 5) |
| base_pos = rng.randint(0, GS - height + 1) |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| along = p[axis] - base_pos |
| if along < 0 or along >= height: continue |
| t = along / (height - 1 + 1e-6) |
| r_at = r_base * (1.0 - t) |
| dist_sq = sum((p[a] - center[a])**2 for a in axes) |
| if dist_sq <= r_at**2: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "capsule": |
| |
| r = rng.uniform(0.8, 1.5) |
| axis = rng.randint(3) |
| body_len = rng.randint(1, 3) |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| body_start = round(center[axis] - body_len / 2) |
| body_end = body_start + body_len |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| radial_sq = sum((p[a] - center[a])**2 for a in axes) |
| along = p[axis] |
| |
| if body_start <= along <= body_end and radial_sq <= r**2: |
| pts.append(p) |
| |
| elif along < body_start: |
| cap_center = list(center) |
| cap_center[axis] = body_start |
| dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) |
| if dist_sq <= r**2: |
| pts.append(p) |
| |
| elif along > body_end: |
| cap_center = list(center) |
| cap_center[axis] = body_end |
| dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) |
| if dist_sq <= r**2: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 5 else None |
|
|
| elif name == "torus": |
| R = rng.uniform(1.2, 2.0) |
| r = rng.uniform(0.5, 0.9) |
| axis = rng.randint(3) |
| center = [cx, cy, cz] |
| ring_axes = [i for i in range(3) if i != axis] |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| dist_in_plane = np.sqrt( |
| sum((p[a] - center[a])**2 for a in ring_axes)) |
| dist_from_ring = np.sqrt( |
| (dist_in_plane - R)**2 + (p[axis] - center[axis])**2) |
| if dist_from_ring <= r: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "shell": |
| |
| r_out = rng.uniform(1.5, 2.3) |
| r_in = r_out - rng.uniform(0.4, 0.8) |
| if r_in < 0.3: r_in = 0.3 |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| d_sq = (x - cx)**2 + (y - cy)**2 + (z - cz)**2 |
| if r_in**2 <= d_sq <= r_out**2: |
| pts.append([x, y, z]) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "tube": |
| |
| r_out = rng.uniform(1.0, 2.0) |
| r_in = r_out - rng.uniform(0.3, 0.7) |
| if r_in < 0.2: r_in = 0.2 |
| axis = rng.randint(3) |
| length = rng.randint(2, 5) |
| start = rng.randint(0, GS - length + 1) |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| if p[axis] < start or p[axis] >= start + length: continue |
| dist_sq = sum((p[a] - center[a])**2 for a in axes) |
| if r_in**2 <= dist_sq <= r_out**2: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "bowl": |
| |
| r = rng.uniform(1.2, 2.2) |
| axis = rng.randint(3) |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| thickness = 0.6 |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| dist_planar = np.sqrt( |
| sum((p[a] - center[a])**2 for a in axes)) |
| if dist_planar > r: continue |
| |
| k = 1.0 / (r + 1e-6) |
| expected_h = center[axis] + k * dist_planar**2 |
| actual_h = p[axis] |
| if abs(actual_h - expected_h) <= thickness: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| elif name == "saddle": |
| |
| axis = rng.randint(3) |
| center = [cx, cy, cz] |
| axes = [i for i in range(3) if i != axis] |
| k = rng.uniform(0.3, 0.8) |
| thickness = 0.7 |
| pts = [] |
| for x in range(GS): |
| for y in range(GS): |
| for z in range(GS): |
| p = [x, y, z] |
| da = p[axes[0]] - center[axes[0]] |
| db = p[axes[1]] - center[axes[1]] |
| expected_h = center[axis] + k * (da**2 - db**2) |
| if abs(p[axis] - expected_h) <= thickness: |
| |
| dist_sq = da**2 + db**2 |
| if dist_sq <= 4.0: |
| pts.append(p) |
| return np.array(pts) if len(pts) >= 4 else None |
|
|
| return None |
|
|
| |
|
|
| def _rand_pts_2d(self, n, min_dist=0): |
| for _ in range(50): |
| pts = set() |
| while len(pts) < n: |
| pts.add((self.rng.randint(0, GS), self.rng.randint(0, GS))) |
| pts = np.array(list(pts)[:n]) |
| if min_dist <= 0 or self._check_dist(pts, min_dist): |
| return pts |
| return None |
|
|
| def _rand_pts_3d(self, n, min_dist=0): |
| for _ in range(100): |
| pts = set() |
| while len(pts) < n: |
| pts.add(tuple(self.rng.randint(0, GS, size=3))) |
| pts = np.array(list(pts)[:n]) |
| if min_dist <= 0 or self._check_dist(pts, min_dist): |
| return pts |
| return None |
|
|
| def _check_dist(self, pts, min_dist): |
| for i in range(len(pts)): |
| for j in range(i + 1, len(pts)): |
| if np.sum(np.abs(pts[i] - pts[j])) < min_dist: |
| return False |
| return True |
|
|
| def _build(self, name, info, voxels): |
| n = len(voxels) |
| sub = voxels[:6].astype(float) if n > 6 else voxels.astype(float) |
| cm_det = cayley_menger_det(sub) |
| volume = effective_volume(sub) |
|
|
| dim_conf = np.zeros(4, dtype=np.float32) |
| dim_conf[0] = 1.0 |
| if n >= 2: dim_conf[1] = 1.0 |
| if info["dim"] >= 2: dim_conf[2] = 1.0 |
| if info["dim"] >= 3: dim_conf[3] = 1.0 |
|
|
| grid = np.zeros((GS, GS, GS), dtype=np.float32) |
| for v in voxels: |
| grid[v[0], v[1], v[2]] = 1.0 |
|
|
| return { |
| "grid": grid, "label": CLASS_TO_IDX[name], "class_name": name, |
| "n_points": n, "n_occupied": int(grid.sum()), |
| "cm_det": float(cm_det), "volume": float(volume), |
| "peak_dim": info["dim"], "dim_confidence": dim_conf, |
| "is_curved": info["curved"], "curvature": CURV_TO_IDX[info["curvature"]], |
| } |
|
|
|
|
| |
|
|
| def _generate_chunk(args): |
| """Worker function for parallel shape generation.""" |
| class_assignments, seed, start_idx = args |
| gen = ShapeGenerator(seed=seed) |
| samples = [] |
| for ci in class_assignments: |
| name = CLASS_NAMES[ci] |
| for attempt in range(10): |
| s = gen._make(name) |
| if s is not None: |
| samples.append(s) |
| break |
| else: |
| s = gen._make("cube") |
| if s is not None: |
| samples.append(s) |
| return samples |
|
|
|
|
| def generate_parallel(n_samples, seed=42, n_workers=8): |
| """Pre-generate all samples using multiprocessing.""" |
| import multiprocessing as mp |
| per_class = n_samples // NUM_CLASSES |
| class_assignments = [] |
| for ci in range(NUM_CLASSES): |
| class_assignments.extend([ci] * per_class) |
| rng = np.random.RandomState(seed) |
| while len(class_assignments) < n_samples: |
| class_assignments.append(rng.randint(0, NUM_CLASSES)) |
| rng.shuffle(class_assignments) |
| class_assignments = class_assignments[:n_samples] |
|
|
| |
| chunk_size = (n_samples + n_workers - 1) // n_workers |
| chunks = [] |
| for i in range(n_workers): |
| start = i * chunk_size |
| end = min(start + chunk_size, n_samples) |
| if start >= n_samples: |
| break |
| chunks.append((class_assignments[start:end], seed + i * 1000000, start)) |
|
|
| print(f"Generating {n_samples} shapes across {len(chunks)} workers...") |
| import time; t0 = time.time() |
| with mp.Pool(n_workers) as pool: |
| results = pool.map(_generate_chunk, chunks) |
| samples = [] |
| for r in results: |
| samples.extend(r) |
| rng.shuffle(samples) |
| dt = time.time() - t0 |
| print(f"Generated {len(samples)} samples in {dt:.1f}s ({len(samples)/dt:.0f} samples/s)") |
| return samples |
|
|
|
|
| class ShapeDataset(torch.utils.data.Dataset): |
| def __init__(self, samples): |
| self.grids = torch.tensor(np.stack([s["grid"] for s in samples]), dtype=torch.float32) |
| self.labels = torch.tensor([s["label"] for s in samples], dtype=torch.long) |
| self.dim_conf = torch.tensor(np.stack([s["dim_confidence"] for s in samples]), dtype=torch.float32) |
| self.peak_dim = torch.tensor([s["peak_dim"] for s in samples], dtype=torch.long) |
| self.volume = torch.tensor([s["volume"] for s in samples], dtype=torch.float32) |
| self.cm_det = torch.tensor([s["cm_det"] for s in samples], dtype=torch.float32) |
| self.is_curved = torch.tensor([s["is_curved"] for s in samples], dtype=torch.float32) |
| self.curvature = torch.tensor([s["curvature"] for s in samples], dtype=torch.long) |
|
|
| def __len__(self): |
| return len(self.labels) |
|
|
| def __getitem__(self, idx): |
| return (self.grids[idx], self.labels[idx], self.dim_conf[idx], |
| self.peak_dim[idx], self.volume[idx], self.cm_det[idx], |
| self.is_curved[idx], self.curvature[idx]) |
|
|
|
|
|
|
| print(f'Loaded {NUM_CLASSES} shape classes, GS={GS}') |