| import abc |
| import torch |
| from lib.utils import utils |
|
|
| class RaySampler(metaclass=abc.ABCMeta): |
| def __init__(self,near, far): |
| self.near = near |
| self.far = far |
|
|
| @abc.abstractmethod |
| def get_z_vals(self, ray_dirs, cam_loc, model): |
| pass |
|
|
| class UniformSampler(RaySampler): |
| """Samples uniformly in the range [near, far] |
| """ |
| def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1): |
| super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) |
| self.N_samples = N_samples |
| self.scene_bounding_sphere = scene_bounding_sphere |
| self.take_sphere_intersection = take_sphere_intersection |
|
|
| def get_z_vals(self, ray_dirs, cam_loc, model): |
| if not self.take_sphere_intersection: |
| near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda() |
| else: |
| sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere) |
| near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda() |
| far = sphere_intersections[:,1:] |
|
|
| t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda() |
| z_vals = near * (1. - t_vals) + far * (t_vals) |
|
|
| if model.training: |
| |
| mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) |
| upper = torch.cat([mids, z_vals[..., -1:]], -1) |
| lower = torch.cat([z_vals[..., :1], mids], -1) |
| |
| t_rand = torch.rand(z_vals.shape).cuda() |
|
|
| z_vals = lower + (upper - lower) * t_rand |
|
|
| return z_vals |
|
|
| class ErrorBoundSampler(RaySampler): |
| def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra, |
| eps, beta_iters, max_total_iters, |
| inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0): |
| super().__init__(near, 2.0 * scene_bounding_sphere) |
| self.N_samples = N_samples |
| self.N_samples_eval = N_samples_eval |
| self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) |
|
|
| self.N_samples_extra = N_samples_extra |
|
|
| self.eps = eps |
| self.beta_iters = beta_iters |
| self.max_total_iters = max_total_iters |
| self.scene_bounding_sphere = scene_bounding_sphere |
| self.add_tiny = add_tiny |
|
|
| self.inverse_sphere_bg = inverse_sphere_bg |
| if inverse_sphere_bg: |
| N_samples_inverse_sphere = 32 |
| self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0) |
|
|
| def get_z_vals(self, ray_dirs, cam_loc, model, cond, smpl_tfs, eval_mode, smpl_verts): |
| beta0 = model.density.get_beta().detach() |
|
|
| |
| z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model) |
| samples, samples_idx = z_vals, None |
|
|
| |
| dists = z_vals[:, 1:] - z_vals[:, :-1] |
| bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1) |
| beta = torch.sqrt(bound) |
|
|
| total_iters, not_converge = 0, True |
|
|
| |
| while not_converge and total_iters < self.max_total_iters: |
| points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1) |
| points_flat = points.reshape(-1, 3) |
| |
| model.implicit_network.eval() |
| with torch.no_grad(): |
| samples_sdf = model.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_verts=smpl_verts)[0] |
| model.implicit_network.train() |
| if samples_idx is not None: |
| sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]), |
| samples_sdf.reshape(-1, samples.shape[1])], -1) |
| sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1) |
| else: |
| sdf = samples_sdf |
|
|
|
|
| |
| d = sdf.reshape(z_vals.shape) |
| dists = z_vals[:, 1:] - z_vals[:, :-1] |
| a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs() |
| first_cond = a.pow(2) + b.pow(2) <= c.pow(2) |
| second_cond = a.pow(2) + c.pow(2) <= b.pow(2) |
| d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda() |
| d_star[first_cond] = b[first_cond] |
| d_star[second_cond] = c[second_cond] |
| s = (a + b + c) / 2.0 |
| area_before_sqrt = s * (s - a) * (s - b) * (s - c) |
| mask = ~first_cond & ~second_cond & (b + c - a > 0) |
| d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask]) |
| d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star |
|
|
|
|
| |
| curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star) |
| beta[curr_error <= self.eps] = beta0 |
| beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta |
| for j in range(self.beta_iters): |
| beta_mid = (beta_min + beta_max) / 2. |
| curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star) |
| beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps] |
| beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps] |
| beta = beta_max |
|
|
|
|
| |
| density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1)) |
|
|
| dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) |
| free_energy = dists * density |
| shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) |
| alpha = 1 - torch.exp(-free_energy) |
| transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) |
| weights = alpha * transmittance |
|
|
| |
| total_iters += 1 |
| not_converge = beta.max() > beta0 |
|
|
| if not_converge and total_iters < self.max_total_iters: |
| ''' Sample more points proportional to the current error bound''' |
|
|
| N = self.N_samples_eval |
|
|
| bins = z_vals |
| error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2) |
| error_integral = torch.cumsum(error_per_section, dim=-1) |
| bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1] |
|
|
| pdf = bound_opacity + self.add_tiny |
| pdf = pdf / torch.sum(pdf, -1, keepdim=True) |
| cdf = torch.cumsum(pdf, -1) |
| cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) |
|
|
| else: |
| ''' Sample the final sample set to be used in the volume rendering integral ''' |
|
|
| N = self.N_samples |
|
|
| bins = z_vals |
| pdf = weights[..., :-1] |
| pdf = pdf + 1e-5 |
| pdf = pdf / torch.sum(pdf, -1, keepdim=True) |
| cdf = torch.cumsum(pdf, -1) |
| cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) |
|
|
|
|
| |
| if (not_converge and total_iters < self.max_total_iters) or (not model.training): |
| u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1) |
| else: |
| u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda() |
| u = u.contiguous() |
|
|
| inds = torch.searchsorted(cdf, u, right=True) |
| below = torch.max(torch.zeros_like(inds - 1), inds - 1) |
| above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) |
| inds_g = torch.stack([below, above], -1) |
|
|
| matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] |
| cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) |
| bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) |
|
|
| denom = (cdf_g[..., 1] - cdf_g[..., 0]) |
| denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) |
| t = (u - cdf_g[..., 0]) / denom |
| samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) |
|
|
|
|
| |
| if not_converge and total_iters < self.max_total_iters: |
| z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1) |
|
|
|
|
| z_samples = samples |
|
|
| near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda() |
| if self.inverse_sphere_bg: |
| far = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:] |
|
|
| if self.N_samples_extra > 0: |
| if model.training: |
| sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra] |
| else: |
| sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long() |
| z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1) |
| else: |
| z_vals_extra = torch.cat([near, far], -1) |
|
|
| z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1) |
|
|
| |
| idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda() |
| z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1)) |
|
|
| if self.inverse_sphere_bg: |
| z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model) |
| z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere) |
| z_vals = (z_vals, z_vals_inverse_sphere) |
|
|
| return z_vals, z_samples_eik |
|
|
| def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star): |
| density = model.density(sdf.reshape(z_vals.shape), beta=beta) |
| shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1) |
| integral_estimation = torch.cumsum(shifted_free_energy, dim=-1) |
| error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2) |
| error_integral = torch.cumsum(error_per_section, dim=-1) |
| bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1]) |
|
|
| return bound_opacity.max(-1)[0] |
|
|
|
|
|
|