| |
| |
| |
| |
| |
|
|
|
|
| import unittest |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from itertools import product |
|
|
| import numpy as np |
|
|
| import torch |
| from pytorch3d.implicitron.dataset.data_loader_map_provider import ( |
| DoublePoolBatchSampler, |
| ) |
|
|
| from pytorch3d.implicitron.dataset.dataset_base import DatasetBase |
| from pytorch3d.implicitron.dataset.frame_data import FrameData |
| from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler |
|
|
|
|
| @dataclass |
| class MockFrameAnnotation: |
| frame_number: int |
| sequence_name: str = "sequence" |
| frame_timestamp: float = 0.0 |
|
|
|
|
| class MockDataset(DatasetBase): |
| def __init__(self, num_seq, max_frame_gap=1): |
| """ |
| Makes a gap of max_frame_gap frame numbers in the middle of each sequence |
| """ |
| self.seq_annots = {f"seq_{i}": None for i in range(num_seq)} |
| self._seq_to_idx = { |
| f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq) |
| } |
|
|
| |
| |
| frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap)) |
| self.frame_annots = [ |
| {"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq |
| ] |
| for seq_name, idx in self._seq_to_idx.items(): |
| for i in idx: |
| self.frame_annots[i]["frame_annotation"].sequence_name = seq_name |
|
|
| def get_frame_numbers_and_timestamps(self, idxs, subset_filter=None): |
| assert subset_filter is None |
| out = [] |
| for idx in idxs: |
| frame_annotation = self.frame_annots[idx]["frame_annotation"] |
| out.append( |
| (frame_annotation.frame_number, frame_annotation.frame_timestamp) |
| ) |
| return out |
|
|
| def __getitem__(self, index: int): |
| fa = self.frame_annots[index]["frame_annotation"] |
| fd = FrameData( |
| sequence_name=fa.sequence_name, |
| sequence_category="default_category", |
| frame_number=torch.LongTensor([fa.frame_number]), |
| frame_timestamp=torch.LongTensor([fa.frame_timestamp]), |
| ) |
| return fd |
|
|
|
|
| class TestSceneBatchSampler(unittest.TestCase): |
| def setUp(self): |
| np.random.seed(42) |
| self.dataset_overfit = MockDataset(1) |
|
|
| def test_overfit(self): |
| num_batches = 3 |
| batch_size = 10 |
| sampler = SceneBatchSampler( |
| self.dataset_overfit, |
| batch_size=batch_size, |
| num_batches=num_batches, |
| images_per_seq_options=[10], |
| ) |
|
|
| self.assertEqual(len(sampler), num_batches) |
|
|
| it = iter(sampler) |
| for _ in range(num_batches): |
| batch = next(it) |
| self.assertIsNotNone(batch) |
| self.assertEqual(len(batch), batch_size) |
| self.assertTrue(all(idx // 10 == 0 for idx in batch)) |
|
|
| with self.assertRaises(StopIteration): |
| batch = next(it) |
|
|
| def test_multiseq(self): |
| for ips_options in [[10], [2], [3], [2, 3, 4]]: |
| for sample_consecutive_frames in [True, False]: |
| for consecutive_frames_max_gap in [0, 1, 3]: |
| self._test_multiseq_flavour( |
| ips_options, |
| sample_consecutive_frames, |
| consecutive_frames_max_gap, |
| ) |
|
|
| def test_multiseq_gaps(self): |
| num_batches = 16 |
| batch_size = 10 |
| dataset_multiseq = MockDataset(5, max_frame_gap=3) |
| for ips_options in [[10], [2], [3], [2, 3, 4]]: |
| debug_info = f" Images per sequence: {ips_options}." |
|
|
| sampler = SceneBatchSampler( |
| dataset_multiseq, |
| batch_size=batch_size, |
| num_batches=num_batches, |
| images_per_seq_options=ips_options, |
| sample_consecutive_frames=True, |
| consecutive_frames_max_gap=1, |
| ) |
|
|
| self.assertEqual(len(sampler), num_batches, msg=debug_info) |
|
|
| it = iter(sampler) |
| for _ in range(num_batches): |
| batch = next(it) |
| self.assertIsNotNone(batch, "batch is None in" + debug_info) |
| if max(ips_options) > 5: |
| |
| self.assertEqual(len(batch), 5, msg=debug_info) |
| else: |
| |
| self.assertEqual(len(batch), batch_size, msg=debug_info) |
|
|
| self._check_frames_are_consecutive( |
| batch, dataset_multiseq.frame_annots, debug_info |
| ) |
|
|
| def _test_multiseq_flavour( |
| self, |
| ips_options, |
| sample_consecutive_frames, |
| consecutive_frames_max_gap, |
| num_batches=16, |
| batch_size=10, |
| ): |
| debug_info = ( |
| f" Images per sequence: {ips_options}, " |
| f"sample_consecutive_frames: {sample_consecutive_frames}, " |
| f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, " |
| ) |
| |
| |
| frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3 |
| dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap) |
| sampler = SceneBatchSampler( |
| dataset_multiseq, |
| batch_size=batch_size, |
| num_batches=num_batches, |
| images_per_seq_options=ips_options, |
| sample_consecutive_frames=sample_consecutive_frames, |
| consecutive_frames_max_gap=consecutive_frames_max_gap, |
| ) |
|
|
| self.assertEqual(len(sampler), num_batches, msg=debug_info) |
|
|
| it = iter(sampler) |
| typical_counts = set() |
| for _ in range(num_batches): |
| batch = next(it) |
| self.assertIsNotNone(batch, "batch is None in" + debug_info) |
| |
| self.assertEqual(len(batch), batch_size, msg=debug_info) |
| |
| counts = _count_by_quotient(batch, 10) |
| freqs = _count_by_quotient(counts.values(), 1) |
| self.assertLessEqual( |
| len(freqs), |
| 2, |
| msg="We should have maximum of 2 different " |
| "frequences of sequences in the batch." + debug_info, |
| ) |
| if len(freqs) == 2: |
| most_seq_count = max(*freqs.keys()) |
| last_seq = min(*freqs.keys()) |
| self.assertEqual( |
| freqs[last_seq], |
| 1, |
| msg="Only one odd sequence allowed." + debug_info, |
| ) |
| else: |
| self.assertEqual(len(freqs), 1) |
| most_seq_count = next(iter(freqs)) |
|
|
| self.assertIn(most_seq_count, ips_options) |
| typical_counts.add(most_seq_count) |
|
|
| if sample_consecutive_frames: |
| self._check_frames_are_consecutive( |
| batch, |
| dataset_multiseq.frame_annots, |
| debug_info, |
| max_gap=consecutive_frames_max_gap, |
| ) |
|
|
| self.assertTrue( |
| all(i in typical_counts for i in ips_options), |
| "Some of the frequency options did not occur among " |
| f"the {num_batches} batches (could be just bad luck)." + debug_info, |
| ) |
|
|
| with self.assertRaises(StopIteration): |
| batch = next(it) |
|
|
| def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1): |
| |
| for i in range(len(batch) - 1): |
| curr_idx, next_idx = batch[i : i + 2] |
| if curr_idx // 10 == next_idx // 10: |
| if max_gap > 0: |
| curr_idx, next_idx = [ |
| annots[idx]["frame_annotation"].frame_number |
| for idx in (curr_idx, next_idx) |
| ] |
| gap = max_gap |
| else: |
| gap = 1 |
|
|
| self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info) |
|
|
|
|
| def _count_by_quotient(indices, divisor): |
| counter = defaultdict(int) |
| for i in indices: |
| counter[i // divisor] += 1 |
|
|
| return counter |
|
|
|
|
| class TestRandomSampling(unittest.TestCase): |
| def test_double_pool_batch_sampler(self): |
| unknown_idxs = [2, 3, 4, 5, 8] |
| known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17] |
| for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]): |
| with self.subTest(f"{replacement}, {num_batches}"): |
| sampler = DoublePoolBatchSampler( |
| first_indices=unknown_idxs, |
| rest_indices=known_idxs, |
| batch_size=4, |
| replacement=replacement, |
| num_batches=num_batches, |
| ) |
| for _ in range(6): |
| epoch = list(sampler) |
| self.assertEqual(len(epoch), num_batches or len(unknown_idxs)) |
| for batch in epoch: |
| self.assertEqual(len(batch), 4) |
| self.assertIn(batch[0], unknown_idxs) |
| for i in batch[1:]: |
| self.assertIn(i, known_idxs) |
| if not replacement and 4 != num_batches: |
| self.assertEqual( |
| {batch[0] for batch in epoch}, set(unknown_idxs) |
| ) |
|
|