| | from abc import ABC |
| | from abc import abstractmethod |
| | import math |
| | import torch |
| | from torch_scatter import scatter_mean, scatter_add |
| |
|
| | import src.data.so3_utils as so3 |
| |
|
| |
|
| | class ICFM(ABC): |
| | """ |
| | Abstract base class for all Independent-coupling CFM classes. |
| | Defines a common interface. |
| | Notation: |
| | - zt is the intermediate representation at time step t \in [0, 1] |
| | - zs is the noised representation at time step s < t |
| | |
| | # TODO: add interpolation schedule (not necessrily linear) |
| | """ |
| | def __init__(self, sigma): |
| | self.sigma = sigma |
| |
|
| | @abstractmethod |
| | def sample_zt(self, z0, z1, t, *args, **kwargs): |
| | """ TODO. """ |
| | pass |
| |
|
| | @abstractmethod |
| | def sample_zt_given_zs(self, *args, **kwargs): |
| | """ Perform update, typically using an explicit Euler step. """ |
| | pass |
| |
|
| | @abstractmethod |
| | def sample_z0(self, *args, **kwargs): |
| | """ Prior. """ |
| | pass |
| |
|
| | @abstractmethod |
| | def compute_loss(self, pred, z0, z1, *args, **kwargs): |
| | """ Compute loss per sample. """ |
| | pass |
| |
|
| |
|
| | class CoordICFM(ICFM): |
| | def __init__(self, sigma): |
| | self.dim = 3 |
| | self.scale = 2.7 |
| | super().__init__(sigma) |
| |
|
| | def sample_zt(self, z0, z1, t, batch_mask): |
| | zt = t[batch_mask] * z1 + (1 - t)[batch_mask] * z0 |
| | |
| | return zt |
| |
|
| | def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): |
| | """ Perform an explicit Euler step. """ |
| | step_size = t - s |
| | zt = zs + step_size[batch_mask] * self.scale * pred |
| | return zt |
| |
|
| | def sample_z0(self, com, batch_mask): |
| | """ Prior. """ |
| | z0 = torch.randn((len(batch_mask), self.dim), device=batch_mask.device) |
| |
|
| | |
| | z0 = z0 + com[batch_mask] |
| |
|
| | return z0 |
| |
|
| | def reduce_loss(self, loss, batch_mask, reduce): |
| | assert reduce in {'mean', 'sum', 'none'} |
| |
|
| | if reduce == 'mean': |
| | loss = scatter_mean(loss / self.dim, batch_mask, dim=0) |
| | elif reduce == 'sum': |
| | loss = scatter_add(loss, batch_mask, dim=0) |
| |
|
| | return loss |
| |
|
| | def compute_loss(self, pred, z0, z1, t, batch_mask, reduce='mean'): |
| | """ Compute loss per sample. """ |
| |
|
| | loss = torch.sum((pred - (z1 - z0) / self.scale) ** 2, dim=-1) |
| |
|
| | return self.reduce_loss(loss, batch_mask, reduce) |
| |
|
| | def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): |
| | """ Make a best guess on the final state z1 given the current state and |
| | the network prediction. """ |
| | |
| | z1 = zt + (1 - t)[batch_mask] * pred |
| | return z1 |
| |
|
| |
|
| | class TorusICFM(ICFM): |
| | """ |
| | Following: |
| | Chen, Ricky TQ, and Yaron Lipman. |
| | "Riemannian flow matching on general geometries." |
| | arXiv preprint arXiv:2302.03660 (2023). |
| | """ |
| | def __init__(self, sigma, dim, scheduler_args=None): |
| | super().__init__(sigma) |
| | self.dim = dim |
| |
|
| | |
| | scheduler_args = scheduler_args or {} |
| | scheduler_args["type"] = scheduler_args.get("type", "linear") |
| | scheduler_args["learn_scaled"] = scheduler_args.get("learn_scaled", False) |
| |
|
| | |
| | if scheduler_args["type"] == "linear": |
| | |
| | self.flow_scaling = lambda t: t |
| |
|
| | |
| | self.velocity_scaling = lambda t: torch.ones_like(t) |
| |
|
| | |
| | elif scheduler_args["type"] == "exponential": |
| |
|
| | self.c = scheduler_args["c"] |
| | assert self.c > 0 |
| |
|
| | |
| | self.flow_scaling = lambda t: 1 - torch.exp(-self.c * t) |
| |
|
| | |
| | self.velocity_scaling = lambda t: self.c * torch.exp(-self.c * t) |
| |
|
| | |
| | elif scheduler_args["type"] == "polynomial": |
| | self.k = scheduler_args["k"] |
| | assert self.k > 0 |
| |
|
| | |
| | self.flow_scaling = lambda t: 1 - (1 - t)**self.k |
| |
|
| | |
| | self.velocity_scaling = lambda t: self.k * (1 - t)**(self.k - 1) |
| |
|
| | else: |
| | raise NotImplementedError(f"Scheduler {scheduler_args['type']} not implemented.") |
| |
|
| | kappa_interval = self.flow_scaling(torch.tensor([0.0, 1.0])) |
| | if kappa_interval[0] != 0.0 or kappa_interval[1] != 1.0: |
| | print(f"Scheduler should satisfy kappa(0)=1 and kappa(1)=0. Found " |
| | f"interval {kappa_interval.tolist()} instead.") |
| |
|
| | |
| | |
| | self.learn_scaled = scheduler_args["learn_scaled"] |
| |
|
| | @staticmethod |
| | def wrap(angle): |
| | """ Maps angles to range [-\pi, \pi). """ |
| | return ((angle + math.pi) % (2 * math.pi)) - math.pi |
| |
|
| | def exponential_map(self, x, u): |
| | """ |
| | :param x: point on the manifold |
| | :param u: point on the tangent space |
| | """ |
| | return self.wrap(x + u) |
| |
|
| | @staticmethod |
| | def logarithm_map(x, y): |
| | """ |
| | :param x, y: points on the manifold |
| | """ |
| | return torch.atan2(torch.sin(y - x), torch.cos(y - x)) |
| |
|
| | def sample_zt(self, z0, z1, t, batch_mask): |
| | """ expressed in terms of exponential and logarithm maps """ |
| |
|
| | |
| | |
| | zt_tangent = self.flow_scaling(t)[batch_mask] * self.logarithm_map(z0, z1) |
| |
|
| | |
| | return self.exponential_map(z0, zt_tangent) |
| |
|
| | def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): |
| | """ Make a best guess on the final state z1 given the current state and |
| | the network prediction. """ |
| |
|
| | |
| | if self.learn_scaled: |
| | pred = pred / torch.clamp(self.velocity_scaling(t), min=1e-3)[batch_mask] |
| |
|
| | z1_tangent = (1 - t)[batch_mask] * pred |
| |
|
| | |
| | return self.exponential_map(zt, z1_tangent) |
| |
|
| | def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): |
| | """ Perform update, typically using an explicit Euler step. """ |
| |
|
| | step_size = t - s |
| | zt_tangent = step_size[batch_mask] * pred |
| |
|
| | if not self.learn_scaled: |
| | zt_tangent = self.velocity_scaling(t)[batch_mask] * zt_tangent |
| |
|
| | |
| | return self.exponential_map(zs, zt_tangent) |
| |
|
| | def sample_z0(self, batch_mask): |
| | """ Prior. """ |
| |
|
| | |
| | z0 = torch.rand((len(batch_mask), self.dim), device=batch_mask.device) |
| |
|
| | return 2 * math.pi * z0 - math.pi |
| |
|
| | def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean'): |
| | """ Compute loss per sample. """ |
| | assert reduce in {'mean', 'sum', 'none'} |
| | mask = ~torch.isnan(z1) |
| | z1 = torch.nan_to_num(z1, nan=0.0) |
| |
|
| | zt_dot = self.logarithm_map(z0, z1) |
| | if self.learn_scaled: |
| | |
| | zt_dot = self.velocity_scaling(t)[batch_mask] * zt_dot |
| | loss = mask * (pred - zt_dot) ** 2 |
| | loss = torch.sum(loss, dim=-1) |
| |
|
| | if reduce == 'mean': |
| | denom = mask.sum(dim=-1) + 1e-6 |
| | loss = scatter_mean(loss / denom, batch_mask, dim=0) |
| | elif reduce == 'sum': |
| | loss = scatter_add(loss, batch_mask, dim=0) |
| | return loss |
| |
|
| |
|
| | class SO3ICFM(ICFM): |
| | """ |
| | All rotations are assumed to be in axis-angle format. |
| | Mostly following descriptions from the FoldFlow paper: |
| | https://openreview.net/forum?id=kJFIH23hXb |
| | |
| | See also: |
| | https://geomstats.github.io/_modules/geomstats/geometry/special_orthogonal.html#SpecialOrthogonal |
| | https://geomstats.github.io/_modules/geomstats/geometry/lie_group.html#LieGroup |
| | """ |
| | def __init__(self, sigma): |
| | super().__init__(sigma) |
| |
|
| | def exponential_map(self, base, tangent): |
| | """ |
| | Args: |
| | base: base point (rotation vector) on the manifold |
| | tangent: point in tangent space at identity |
| | Returns: |
| | rotation vector on the manifold |
| | """ |
| | |
| | return so3.compose_rotations(base, so3.exp(tangent)) |
| |
|
| | def logarithm_map(self, base, r): |
| | """ |
| | Args: |
| | base: base point (rotation vector) on the manifold |
| | r: rotation vector on the manifold |
| | Return: |
| | point in tangent space at identity |
| | """ |
| | |
| | return so3.log(so3.compose_rotations(-base, r)) |
| |
|
| | def sample_zt(self, z0, z1, t, batch_mask): |
| | """ |
| | Expressed in terms of exponential and logarithm maps. |
| | Corresponds to SLERP interpolation: R(t) = R1 exp( t * log(R1^T R2) ) |
| | (see https://lucaballan.altervista.org/pdfs/IK.pdf, slide 16) |
| | """ |
| |
|
| | |
| | zt_tangent = t[batch_mask] * self.logarithm_map(z0, z1) |
| |
|
| | |
| | return self.exponential_map(z0, zt_tangent) |
| |
|
| | def get_z1_given_zt_and_pred(self, zt, pred, z0, t, batch_mask): |
| | """ Make a best guess on the final state z1 given the current state and |
| | the network prediction. """ |
| |
|
| | |
| | z1_tangent = (1 - t)[batch_mask] * pred |
| |
|
| | |
| | return self.exponential_map(zt, z1_tangent) |
| |
|
| | def sample_zt_given_zs(self, zs, pred, s, t, batch_mask): |
| | """ Perform update, typically using an explicit Euler step. """ |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | step_size = t - s |
| | zt_tangent = step_size[batch_mask] * pred |
| |
|
| | |
| | return self.exponential_map(zs, zt_tangent) |
| |
|
| | def sample_z0(self, batch_mask): |
| | """ Prior. """ |
| | return so3.random_uniform(n_samples=len(batch_mask), device=batch_mask.device) |
| |
|
| | @staticmethod |
| | def d_R_squared_SO3(rot_vec_1, rot_vec_2): |
| | """ |
| | Squared Riemannian metric on SO(3). |
| | Defined as d(R1, R2) = sqrt(0.5) ||log(R1^T R2)||_F |
| | where R1, R2 are rotation matrices. |
| | |
| | The following is equivalent if the difference between the rotations is |
| | expressed as a rotation vector \omega_diff: |
| | d(r1, r2) = ||\omega_diff||_2 |
| | ----- |
| | With the definition of the Frobenius matrix norm ||A||_F^2 = trace(A^H A): |
| | d^2(R1, R2) = 1/2 ||log(R1^T R2)||_F^2 |
| | = 1/2 || hat(R_d) ||_F^2 |
| | = 1/2 tr( hat(R_d)^T hat(R_d) ) |
| | = 1/2 * 2 * ||\omega||_2^2 |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | diff_rot = so3.compose_rotations(-rot_vec_1, rot_vec_2) |
| | return diff_rot.square().sum(dim=-1) |
| |
|
| | def compute_loss(self, pred, z0, z1, zt, t, batch_mask, reduce='mean', eps=5e-2): |
| | """ Compute loss per sample. """ |
| | assert reduce in {'mean', 'sum', 'none'} |
| |
|
| | zt_dot = self.logarithm_map(zt, z1) / torch.clamp(1 - t, min=eps)[batch_mask] |
| |
|
| | |
| | |
| |
|
| | loss = torch.sum((pred - zt_dot)**2, dim=-1) |
| | |
| |
|
| | if reduce == 'mean': |
| | loss = scatter_mean(loss, batch_mask, dim=0) |
| | elif reduce == 'sum': |
| | loss = scatter_add(loss, batch_mask, dim=0) |
| |
|
| | return loss |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CoordICFMPredictFinal(CoordICFM): |
| | def __init__(self, sigma): |
| | self.dim = 3 |
| | super().__init__(sigma) |
| |
|
| | def sample_zt_given_zs(self, zs, z1_minus_zs_pred, s, t, batch_mask): |
| | """ Perform an explicit Euler step. """ |
| |
|
| | |
| | |
| |
|
| | |
| | step_size = (t - s) / (1.0 - s) |
| | assert torch.all(step_size <= 1.0) |
| | |
| | zt = zs + step_size[batch_mask] * z1_minus_zs_pred |
| | return zt |
| |
|
| | def compute_loss(self, z1_minus_zt_pred, z0, z1, t, batch_mask, reduce='mean'): |
| | """ Compute loss per sample. """ |
| | assert reduce in {'mean', 'sum', 'none'} |
| | t = torch.clamp(t, max=0.9) |
| | zt = self.sample_zt(z0, z1, t, batch_mask) |
| | loss = torch.sum((z1_minus_zt_pred + zt - z1) ** 2, dim=-1) / torch.square(1 - t)[batch_mask].squeeze() |
| |
|
| | if reduce == 'mean': |
| | loss = scatter_mean(loss / self.dim, batch_mask, dim=0) |
| | elif reduce == 'sum': |
| | loss = scatter_add(loss, batch_mask, dim=0) |
| |
|
| | return loss |
| |
|
| | def get_z1_given_zt_and_pred(self, zt, z1_minus_zt_pred, z0, t, batch_mask): |
| | return z1_minus_zt_pred + zt |
| |
|
| |
|
| | class TorusICFMPredictFinal(TorusICFM): |
| | """ |
| | Following: |
| | Chen, Ricky TQ, and Yaron Lipman. |
| | "Riemannian flow matching on general geometries." |
| | arXiv preprint arXiv:2302.03660 (2023). |
| | """ |
| | def __init__(self, sigma, dim): |
| | super().__init__(sigma, dim) |
| |
|
| | def get_z1_given_zt_and_pred(self, zt, z1_tangent_pred, z0, t, batch_mask): |
| | """ Make a best guess on the final state z1 given the current state and |
| | the network prediction. """ |
| |
|
| | |
| | return self.exponential_map(zt, z1_tangent_pred) |
| |
|
| | def sample_zt_given_zs(self, zs, z1_tangent_pred, s, t, batch_mask): |
| | """ Perform update, typically using an explicit Euler step. """ |
| |
|
| | |
| | |
| |
|
| | |
| | step_size = (t - s) / (1.0 - s) |
| | assert torch.all(step_size <= 1.0) |
| | |
| | zt_tangent = step_size[batch_mask] * z1_tangent_pred |
| |
|
| | |
| | return self.exponential_map(zs, zt_tangent) |
| |
|
| | def compute_loss(self, z1_tangent_pred, z0, z1, t, batch_mask, reduce='mean'): |
| | """ Compute loss per sample. """ |
| | assert reduce in {'mean', 'sum', 'none'} |
| | zt = self.sample_zt(z0, z1, t, batch_mask) |
| | t = torch.clamp(t, max=0.9) |
| |
|
| | mask = ~torch.isnan(z1) |
| | z1 = torch.nan_to_num(z1, nan=0.0) |
| | loss = mask * (z1_tangent_pred - self.logarithm_map(zt, z1)) ** 2 |
| | loss = torch.sum(loss, dim=-1) / torch.square(1 - t)[batch_mask].squeeze() |
| |
|
| | if reduce == 'mean': |
| | denom = mask.sum(dim=-1) + 1e-6 |
| | loss = scatter_mean(loss / denom, batch_mask, dim=0) |
| | elif reduce == 'sum': |
| | loss = scatter_add(loss, batch_mask, dim=0) |
| |
|
| | return loss |
| |
|