| |
| |
| |
| |
| |
|
|
| import unittest |
|
|
| import torch |
| from pytorch3d.renderer.blending import ( |
| BlendParams, |
| hard_rgb_blend, |
| sigmoid_alpha_blend, |
| softmax_rgb_blend, |
| ) |
| from pytorch3d.renderer.cameras import FoVPerspectiveCameras |
| from pytorch3d.renderer.mesh.rasterizer import Fragments |
| from pytorch3d.renderer.splatter_blend import SplatterBlender |
|
|
| from .common_testing import TestCaseMixin |
|
|
|
|
| def sigmoid_blend_naive_loop(colors, fragments, blend_params): |
| """ |
| Naive for loop based implementation of distance based alpha calculation. |
| Only for test purposes. |
| """ |
| pix_to_face = fragments.pix_to_face |
| dists = fragments.dists |
| sigma = blend_params.sigma |
|
|
| N, H, W, K = pix_to_face.shape |
| device = pix_to_face.device |
| pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device) |
|
|
| for n in range(N): |
| for h in range(H): |
| for w in range(W): |
| alpha = 1.0 |
|
|
| |
| |
| for k in range(K): |
| if pix_to_face[n, h, w, k] >= 0: |
| prob = torch.sigmoid(-dists[n, h, w, k] / sigma) |
| alpha *= 1.0 - prob |
| pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :] |
| pixel_colors[n, h, w, 3] = 1.0 - alpha |
|
|
| return pixel_colors |
|
|
|
|
| def sigmoid_alpha_blend_vectorized(colors, fragments, blend_params) -> torch.Tensor: |
| N, H, W, K = fragments.pix_to_face.shape |
| pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) |
| mask = fragments.pix_to_face >= 0 |
| prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask |
| pixel_colors[..., :3] = colors[..., 0, :] |
| pixel_colors[..., 3] = 1.0 - torch.prod((1.0 - prob), dim=-1) |
| return pixel_colors |
|
|
|
|
| def sigmoid_blend_naive_loop_backward(grad_images, images, fragments, blend_params): |
| pix_to_face = fragments.pix_to_face |
| dists = fragments.dists |
| sigma = blend_params.sigma |
|
|
| N, H, W, K = pix_to_face.shape |
| device = pix_to_face.device |
| grad_distances = torch.zeros((N, H, W, K), dtype=dists.dtype, device=device) |
|
|
| for n in range(N): |
| for h in range(H): |
| for w in range(W): |
| alpha = 1.0 - images[n, h, w, 3] |
| grad_alpha = grad_images[n, h, w, 3] |
| |
| |
| for k in range(K): |
| if pix_to_face[n, h, w, k] >= 0: |
| prob = torch.sigmoid(-dists[n, h, w, k] / sigma) |
| grad_distances[n, h, w, k] = ( |
| grad_alpha * (-1.0 / sigma) * prob * alpha |
| ) |
| return grad_distances |
|
|
|
|
| def softmax_blend_naive(colors, fragments, blend_params): |
| """ |
| Naive for loop based implementation of softmax blending. |
| Only for test purposes. |
| """ |
| pix_to_face = fragments.pix_to_face |
| dists = fragments.dists |
| zbuf = fragments.zbuf |
| sigma = blend_params.sigma |
| gamma = blend_params.gamma |
|
|
| N, H, W, K = pix_to_face.shape |
| device = pix_to_face.device |
| pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device) |
|
|
| |
| zfar = 100.0 |
| znear = 1.0 |
| eps = 1e-10 |
|
|
| bk_color = blend_params.background_color |
| if not torch.is_tensor(bk_color): |
| bk_color = torch.tensor(bk_color, dtype=colors.dtype, device=device) |
|
|
| for n in range(N): |
| for h in range(H): |
| for w in range(W): |
| alpha = 1.0 |
| weights_k = torch.zeros(K, device=device) |
| zmax = torch.tensor(0.0, device=device) |
|
|
| |
| for k in range(K): |
| if pix_to_face[n, h, w, k] >= 0: |
| zinv = (zfar - zbuf[n, h, w, k]) / (zfar - znear) |
| if zinv > zmax: |
| zmax = zinv |
|
|
| |
| |
| for k in range(K): |
| if pix_to_face[n, h, w, k] >= 0: |
| zinv = (zfar - zbuf[n, h, w, k]) / (zfar - znear) |
| prob = torch.sigmoid(-dists[n, h, w, k] / sigma) |
| alpha *= 1.0 - prob |
| weights_k[k] = prob * torch.exp((zinv - zmax) / gamma) |
|
|
| |
| delta = torch.exp((eps - zmax) / blend_params.gamma).clamp(min=eps) |
| delta = delta.to(device) |
| denom = weights_k.sum() + delta |
| cols = (weights_k[..., None] * colors[n, h, w, :, :]).sum(dim=0) |
| pixel_colors[n, h, w, :3] = cols + delta * bk_color |
| pixel_colors[n, h, w, :3] /= denom |
| pixel_colors[n, h, w, 3] = 1.0 - alpha |
|
|
| return pixel_colors |
|
|
|
|
| class TestBlending(TestCaseMixin, unittest.TestCase): |
| def setUp(self) -> None: |
| torch.manual_seed(42) |
|
|
| def _compare_impls( |
| self, fn1, fn2, args1, args2, grad_var1=None, grad_var2=None, compare_grads=True |
| ): |
| out1 = fn1(*args1) |
| out2 = fn2(*args2) |
| self.assertClose(out1.cpu()[..., 3], out2.cpu()[..., 3], atol=1e-7) |
|
|
| |
| if not compare_grads: |
| return |
|
|
| grad_out = torch.randn_like(out1) |
| (out1 * grad_out).sum().backward() |
| self.assertTrue(hasattr(grad_var1, "grad")) |
|
|
| (out2 * grad_out).sum().backward() |
| self.assertTrue(hasattr(grad_var2, "grad")) |
|
|
| self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5) |
|
|
| def test_hard_rgb_blend(self): |
| N, H, W, K = 5, 10, 10, 20 |
| pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K)) |
| bary_coords = torch.ones((N, H, W, K, 3)) |
| fragments = Fragments( |
| pix_to_face=pix_to_face, |
| bary_coords=bary_coords, |
| zbuf=pix_to_face, |
| dists=pix_to_face, |
| ) |
| colors = torch.randn((N, H, W, K, 3)) |
| blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1)) |
| images = hard_rgb_blend(colors, fragments, blend_params) |
|
|
| |
| is_foreground = pix_to_face[..., 0] >= 0 |
| self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :]) |
|
|
| |
| for i in range(3): |
| channel_color = blend_params.background_color[i] |
| self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all()) |
|
|
| |
| self.assertClose(images[..., 3], (pix_to_face[..., 0] >= 0).float()) |
|
|
| def test_sigmoid_alpha_blend_manual_gradients(self): |
| |
| torch.manual_seed(231) |
| F = 32 |
| |
| N, S, K = 2, 3, 2 |
| device = torch.device("cuda") |
| pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 |
| colors = torch.randn((N, S, S, K, 3), device=device) |
| empty = torch.tensor([], device=device) |
|
|
| |
| |
| random_sign_flip = torch.rand((N, S, S, K)) |
| random_sign_flip[random_sign_flip > 0.5] *= -1.0 |
| dists = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) |
| fragments = Fragments( |
| pix_to_face=pix_to_face, |
| bary_coords=empty, |
| zbuf=empty, |
| dists=dists, |
| ) |
| blend_params = BlendParams(sigma=1e-3) |
| pix_cols = sigmoid_blend_naive_loop(colors, fragments, blend_params) |
| grad_out = torch.randn_like(pix_cols) |
|
|
| |
| pix_cols.backward(grad_out) |
| grad_dists = sigmoid_blend_naive_loop_backward( |
| grad_out, pix_cols, fragments, blend_params |
| ) |
| self.assertTrue(torch.allclose(dists.grad, grad_dists, atol=1e-7)) |
|
|
| def test_sigmoid_alpha_blend_python(self): |
| """ |
| Test outputs of python tensorised function and python loop |
| """ |
|
|
| |
| torch.manual_seed(231) |
| F = 32 |
| |
| N, S, K = 1, 4, 1 |
| device = torch.device("cuda") |
| pix_to_face = torch.randint(low=-1, high=F, size=(N, S, S, K), device=device) |
| colors = torch.randn((N, S, S, K, 3), device=device) |
| empty = torch.tensor([], device=device) |
|
|
| dists1 = torch.randn(size=(N, S, S, K), device=device) |
| dists2 = dists1.clone() |
| dists1.requires_grad = True |
| dists2.requires_grad = True |
|
|
| fragments1 = Fragments( |
| pix_to_face=pix_to_face, |
| bary_coords=empty, |
| zbuf=empty, |
| dists=dists1, |
| ) |
| fragments2 = Fragments( |
| pix_to_face=pix_to_face, |
| bary_coords=empty, |
| zbuf=empty, |
| dists=dists2, |
| ) |
|
|
| blend_params = BlendParams(sigma=1e-2) |
| args1 = (colors, fragments1, blend_params) |
| args2 = (colors, fragments2, blend_params) |
|
|
| self._compare_impls( |
| sigmoid_alpha_blend, |
| sigmoid_alpha_blend_vectorized, |
| args1, |
| args2, |
| dists1, |
| dists2, |
| compare_grads=True, |
| ) |
|
|
| def test_softmax_rgb_blend(self): |
| |
| |
| N, S, K = 1, 8, 2 |
| device = torch.device("cuda") |
| pix_to_face = torch.full( |
| (N, S, S, K), fill_value=-1, dtype=torch.int64, device=device |
| ) |
| h = int(S / 2) |
| pix_to_face_full = torch.randint( |
| size=(N, h, h, K), low=0, high=100, device=device |
| ) |
| s = int(S / 4) |
| e = int(0.75 * S) |
| pix_to_face[:, s:e, s:e, :] = pix_to_face_full |
| empty = torch.tensor([], device=device) |
|
|
| random_sign_flip = torch.rand((N, S, S, K), device=device) |
| random_sign_flip[random_sign_flip > 0.5] *= -1.0 |
| zbuf1 = torch.randn(size=(N, S, S, K), device=device) |
|
|
| |
| |
| dists1 = torch.randn(size=(N, S, S, K), device=device) * random_sign_flip |
| dists2 = dists1.clone() |
| zbuf2 = zbuf1.clone() |
| dists1.requires_grad = True |
| dists2.requires_grad = True |
| colors = torch.randn((N, S, S, K, 3), device=device) |
| fragments1 = Fragments( |
| pix_to_face=pix_to_face, |
| bary_coords=empty, |
| zbuf=zbuf1, |
| dists=dists1, |
| ) |
| fragments2 = Fragments( |
| pix_to_face=pix_to_face, |
| bary_coords=empty, |
| zbuf=zbuf2, |
| dists=dists2, |
| ) |
|
|
| blend_params = BlendParams(sigma=1e-3) |
| args1 = (colors, fragments1, blend_params) |
| args2 = (colors, fragments2, blend_params) |
| self._compare_impls( |
| softmax_rgb_blend, |
| softmax_blend_naive, |
| args1, |
| args2, |
| dists1, |
| dists2, |
| compare_grads=True, |
| ) |
|
|
| @staticmethod |
| def bm_sigmoid_alpha_blending( |
| num_meshes: int = 16, |
| image_size: int = 128, |
| faces_per_pixel: int = 100, |
| device="cuda", |
| backend: str = "pytorch", |
| ): |
| device = torch.device(device) |
| torch.manual_seed(231) |
|
|
| |
| N, S, K = num_meshes, image_size, faces_per_pixel |
| F = 32 |
| pix_to_face = torch.randint( |
| low=-1, high=F + 1, size=(N, S, S, K), device=device |
| ) |
| colors = torch.randn((N, S, S, K, 3), device=device) |
| empty = torch.tensor([], device=device) |
|
|
| dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) |
| fragments = Fragments( |
| pix_to_face=pix_to_face, |
| bary_coords=empty, |
| zbuf=empty, |
| dists=dists1, |
| ) |
| blend_params = BlendParams(sigma=1e-3) |
|
|
| blend_fn = ( |
| sigmoid_alpha_blend_vectorized |
| if backend == "pytorch" |
| else sigmoid_alpha_blend |
| ) |
|
|
| torch.cuda.synchronize() |
|
|
| def fn(): |
| |
| images = blend_fn(colors, fragments, blend_params) |
| images.sum().backward() |
| torch.cuda.synchronize() |
|
|
| return fn |
|
|
| @staticmethod |
| def bm_softmax_blending( |
| num_meshes: int = 16, |
| image_size: int = 128, |
| faces_per_pixel: int = 100, |
| device: str = "cpu", |
| backend: str = "pytorch", |
| ): |
| if torch.cuda.is_available() and "cuda:" in device: |
| |
| torch.cuda.set_device(device) |
|
|
| device = torch.device(device) |
| torch.manual_seed(231) |
|
|
| |
| N, S, K = num_meshes, image_size, faces_per_pixel |
| F = 32 |
| pix_to_face = torch.randint( |
| low=-1, high=F + 1, size=(N, S, S, K), device=device |
| ) |
| colors = torch.randn((N, S, S, K, 3), device=device) |
| empty = torch.tensor([], device=device) |
|
|
| dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) |
| zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) |
| fragments = Fragments( |
| pix_to_face=pix_to_face, bary_coords=empty, zbuf=zbuf, dists=dists1 |
| ) |
| blend_params = BlendParams(sigma=1e-3) |
|
|
| torch.cuda.synchronize() |
|
|
| def fn(): |
| |
| images = softmax_rgb_blend(colors, fragments, blend_params) |
| images.sum().backward() |
| torch.cuda.synchronize() |
|
|
| return fn |
|
|
| @staticmethod |
| def bm_splatter_blending( |
| num_meshes: int = 16, |
| image_size: int = 128, |
| faces_per_pixel: int = 2, |
| use_jit: bool = False, |
| device: str = "cpu", |
| backend: str = "pytorch", |
| ): |
| if torch.cuda.is_available() and "cuda:" in device: |
| |
| torch.cuda.set_device(device) |
|
|
| device = torch.device(device) |
| torch.manual_seed(231) |
|
|
| |
| N, S, K = num_meshes, image_size, faces_per_pixel |
| F = 32 |
|
|
| pixel_coords_camera = torch.randn( |
| (N, S, S, K, 3), device=device, requires_grad=True |
| ) |
| cameras = FoVPerspectiveCameras(device=device) |
| colors = torch.randn((N, S, S, K, 3), device=device) |
| background_mask = torch.randint( |
| low=-1, high=F + 1, size=(N, S, S, K), device=device |
| ) |
| background_mask = torch.full((N, S, S, K), False, dtype=bool, device=device) |
| blend_params = BlendParams(sigma=0.5) |
|
|
| torch.cuda.synchronize() |
| splatter_blender = SplatterBlender((N, S, S, K), colors.device) |
|
|
| def fn(): |
| |
| images = splatter_blender( |
| colors, |
| pixel_coords_camera, |
| cameras, |
| background_mask, |
| blend_params, |
| ) |
| images.sum().backward() |
| torch.cuda.synchronize() |
|
|
| return fn |
|
|
| def test_blend_params(self): |
| """Test color parameter of BlendParams(). |
| Assert passed value overrides default value. |
| """ |
| bp_default = BlendParams() |
| bp_new = BlendParams(background_color=(0.5, 0.5, 0.5)) |
| self.assertEqual(bp_new.background_color, (0.5, 0.5, 0.5)) |
| self.assertEqual(bp_default.background_color, (1.0, 1.0, 1.0)) |
|
|