| """Contains linear algebra related utility functions. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn.functional as F |
| from scipy.spatial.transform import Rotation |
|
|
|
|
| def rotation_matrices_from_quaternions(quaternions: torch.Tensor) -> torch.Tensor: |
| """Convert batch of quaternions into rotations matrices. |
| |
| Args: |
| quaternions: The quaternions convert to matrices. |
| |
| Returns: |
| The rotations matrices corresponding to the (normalized) quaternions. |
| """ |
| device = quaternions.device |
| shape = quaternions.shape[:-1] |
|
|
| quaternions = quaternions / torch.linalg.norm(quaternions, dim=-1, keepdim=True) |
| real_part = quaternions[..., 0] |
| vector_part = quaternions[..., 1:] |
|
|
| vector_cross = get_cross_product_matrix(vector_part) |
| real_part = real_part[..., None, None] |
|
|
| matrix_outer = vector_part[..., :, None] * vector_part[..., None, :] |
| matrix_diag = real_part.square() * eyes(3, shape=shape, device=device) |
| matrix_cross_1 = 2 * real_part * vector_cross |
| matrix_cross_2 = vector_cross @ vector_cross |
|
|
| return matrix_outer + matrix_diag + matrix_cross_1 + matrix_cross_2 |
|
|
|
|
| def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor: |
| """Convert batch of rotation matrices to quaternions. |
| |
| Args: |
| matrices: The matrices to convert to quaternions. |
| |
| Returns: |
| The quaternions corresponding to the rotation matrices. |
| |
| Note: this operation is not differentiable and will be performed on the CPU. |
| """ |
| if not matrices.shape[-2:] == (3, 3): |
| raise ValueError(f"matrices have invalid shape {matrices.shape}") |
| matrices_np = matrices.detach().cpu().numpy() |
| quaternions_np = Rotation.from_matrix(matrices_np.reshape(-1, 3, 3)).as_quat() |
| |
| quaternions_np = quaternions_np[:, [3, 0, 1, 2]] |
| quaternions_np = quaternions_np.reshape(matrices_np.shape[:-2] + (4,)) |
| return torch.as_tensor(quaternions_np, device=matrices.device, dtype=matrices.dtype) |
|
|
|
|
| def get_cross_product_matrix(vectors: torch.Tensor) -> torch.Tensor: |
| """Generate cross product matrix for vector exterior product.""" |
| if not vectors.shape[-1] == 3: |
| raise ValueError("Only 3-dimensional vectors are supported") |
| device = vectors.device |
| shape = vectors.shape[:-1] |
| unit_basis = eyes(3, shape=shape, device=device) |
| |
| |
| return torch.cross(vectors[..., :, None], unit_basis, dim=-2) |
|
|
|
|
| def eyes( |
| dim: int, shape: tuple[int, ...], device: torch.device | str | None = None |
| ) -> torch.Tensor: |
| """Create batch of identity matrices.""" |
| return torch.eye(dim, device=device).broadcast_to(shape + (dim, dim)).clone() |
|
|
|
|
| def quaternion_product(q1, q2): |
| """Compute dot product between two quaternions.""" |
| real_1 = q1[..., :1] |
| real_2 = q2[..., :1] |
| vector_1 = q1[..., 1:] |
| vector_2 = q2[..., 1:] |
|
|
| real_out = real_1 * real_2 - (vector_1 * vector_2).sum(dim=-1, keepdim=True) |
| vector_out = real_1 * vector_2 + real_2 * vector_1 + torch.cross(vector_1, vector_2) |
| return torch.concatenate([real_out, vector_out], dim=-1) |
|
|
|
|
| def quaternion_conj(q): |
| """Get conjugate of a quaternion.""" |
| real = q[..., :1] |
| vector = q[..., 1:] |
| return torch.concatenate([real, -vector], dim=-1) |
|
|
|
|
| def project(u: torch.Tensor, basis: torch.Tensor) -> torch.Tensor: |
| """Project tensor u to unit basis a.""" |
| unit_u = F.normalize(u, dim=-1) |
| inner_prod = (unit_u * basis).sum(dim=-1, keepdim=True) |
| return inner_prod * u |
|
|