| | import torch |
| | from torch.autograd import gradcheck |
| |
|
| | import kornia |
| | import kornia.testing as utils |
| |
|
| |
|
| | class TestBatchedForward: |
| | def test_runbatch(self, device): |
| | patches = torch.rand(34, 1, 32, 32) |
| | sift = kornia.feature.SIFTDescriptor(32) |
| | desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32) |
| | desc = sift(patches) |
| | assert torch.allclose(desc, desc_batched) |
| |
|
| | def test_runone(self, device): |
| | patches = torch.rand(16, 1, 32, 32) |
| | sift = kornia.feature.SIFTDescriptor(32) |
| | desc_batched = kornia.utils.memory.batched_forward(sift, patches, device, 32) |
| | desc = sift(patches) |
| | assert torch.allclose(desc, desc_batched) |
| |
|
| | def test_gradcheck(self, device): |
| | batch_size, channels, height, width = 3, 2, 5, 4 |
| | img = torch.rand(batch_size, channels, height, width, device=device) |
| | img = utils.tensor_to_gradcheck_var(img) |
| | assert gradcheck( |
| | kornia.utils.memory.batched_forward, |
| | (kornia.feature.BlobHessian(), img, device, 2), |
| | raise_exception=True, |
| | nondet_tol=1e-4, |
| | ) |
| |
|