| import torch | |
| class PointInSpace: | |
| def __init__(self, global_sigma=0.5, local_sigma=0.01): | |
| self.global_sigma = global_sigma | |
| self.local_sigma = local_sigma | |
| def get_points(self, pc_input=None, local_sigma=None, global_ratio=0.125): | |
| """Sample one point near each of the given point + 1/8 uniformly. | |
| Args: | |
| pc_input (tensor): sampling centers. shape: [B, N, D] | |
| Returns: | |
| samples (tensor): sampled points. shape: [B, N + N / 8, D] | |
| """ | |
| batch_size, sample_size, dim = pc_input.shape | |
| if local_sigma is None: | |
| sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma) | |
| else: | |
| sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma) | |
| sample_global = ( | |
| torch.rand(batch_size, int(sample_size * global_ratio), dim, device=pc_input.device) | |
| * (self.global_sigma * 2) | |
| ) - self.global_sigma | |
| sample = torch.cat([sample_local, sample_global], dim=1) | |
| return sample |