| | import os |
| | import tempfile |
| | import unittest |
| |
|
| | from finetrainers.data import ( |
| | InMemoryDistributedDataPreprocessor, |
| | PrecomputedDistributedDataPreprocessor, |
| | VideoCaptionFilePairDataset, |
| | initialize_preprocessor, |
| | wrap_iterable_dataset_for_preprocessing, |
| | ) |
| | from finetrainers.data.precomputation import PRECOMPUTED_DATA_DIR |
| | from finetrainers.utils import find_files |
| |
|
| | from .utils import create_dummy_directory_structure |
| |
|
| |
|
| | class PreprocessorFastTests(unittest.TestCase): |
| | def setUp(self): |
| | self.rank = 0 |
| | self.world_size = 1 |
| | self.num_items = 3 |
| | self.processor_fn = { |
| | "latent": self._latent_processor_fn, |
| | "condition": self._condition_processor_fn, |
| | } |
| | self.save_dir = tempfile.TemporaryDirectory() |
| |
|
| | directory_structure = [ |
| | "0.mp4", |
| | "1.mp4", |
| | "2.mp4", |
| | "0.txt", |
| | "1.txt", |
| | "2.txt", |
| | ] |
| | create_dummy_directory_structure( |
| | directory_structure, self.save_dir, self.num_items, "a cat ruling the world", "mp4" |
| | ) |
| |
|
| | dataset = VideoCaptionFilePairDataset(self.save_dir.name, infinite=True) |
| | dataset = wrap_iterable_dataset_for_preprocessing( |
| | dataset, |
| | dataset_type="video", |
| | config={ |
| | "video_resolution_buckets": [[2, 32, 32]], |
| | "reshape_mode": "bicubic", |
| | }, |
| | ) |
| | self.dataset = dataset |
| |
|
| | def tearDown(self): |
| | self.save_dir.cleanup() |
| |
|
| | @staticmethod |
| | def _latent_processor_fn(**data): |
| | video = data["video"] |
| | video = video[:, :, :16, :16] |
| | data["video"] = video |
| | return data |
| |
|
| | @staticmethod |
| | def _condition_processor_fn(**data): |
| | caption = data["caption"] |
| | caption = caption + " surrounded by mystical aura" |
| | data["caption"] = caption |
| | return data |
| |
|
| | def test_initialize_preprocessor(self): |
| | preprocessor = initialize_preprocessor( |
| | self.rank, |
| | self.world_size, |
| | self.num_items, |
| | self.processor_fn, |
| | self.save_dir.name, |
| | enable_precomputation=False, |
| | ) |
| | self.assertIsInstance(preprocessor, InMemoryDistributedDataPreprocessor) |
| |
|
| | preprocessor = initialize_preprocessor( |
| | self.rank, |
| | self.world_size, |
| | self.num_items, |
| | self.processor_fn, |
| | self.save_dir.name, |
| | enable_precomputation=True, |
| | ) |
| | self.assertIsInstance(preprocessor, PrecomputedDistributedDataPreprocessor) |
| |
|
| | def test_in_memory_preprocessor_consume(self): |
| | data_iterator = iter(self.dataset) |
| | preprocessor = initialize_preprocessor( |
| | self.rank, |
| | self.world_size, |
| | self.num_items, |
| | self.processor_fn, |
| | self.save_dir.name, |
| | enable_precomputation=False, |
| | ) |
| |
|
| | condition_iterator = preprocessor.consume( |
| | "condition", components={}, data_iterator=data_iterator, cache_samples=True |
| | ) |
| | latent_iterator = preprocessor.consume( |
| | "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
| | ) |
| |
|
| | self.assertFalse(preprocessor.requires_data) |
| | for _ in range(self.num_items): |
| | condition_item = next(condition_iterator) |
| | latent_item = next(latent_iterator) |
| | self.assertIn("caption", condition_item) |
| | self.assertIn("video", latent_item) |
| | self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
| | self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
| | self.assertTrue(preprocessor.requires_data) |
| |
|
| | def test_in_memory_preprocessor_consume_once(self): |
| | data_iterator = iter(self.dataset) |
| | preprocessor = initialize_preprocessor( |
| | self.rank, |
| | self.world_size, |
| | self.num_items, |
| | self.processor_fn, |
| | self.save_dir.name, |
| | enable_precomputation=False, |
| | ) |
| |
|
| | condition_iterator = preprocessor.consume_once( |
| | "condition", components={}, data_iterator=data_iterator, cache_samples=True |
| | ) |
| | latent_iterator = preprocessor.consume_once( |
| | "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
| | ) |
| |
|
| | self.assertFalse(preprocessor.requires_data) |
| | for _ in range(self.num_items): |
| | condition_item = next(condition_iterator) |
| | latent_item = next(latent_iterator) |
| | self.assertIn("caption", condition_item) |
| | self.assertIn("video", latent_item) |
| | self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
| | self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
| | self.assertFalse(preprocessor.requires_data) |
| |
|
| | def test_precomputed_preprocessor_consume(self): |
| | data_iterator = iter(self.dataset) |
| | preprocessor = initialize_preprocessor( |
| | self.rank, |
| | self.world_size, |
| | self.num_items, |
| | self.processor_fn, |
| | self.save_dir.name, |
| | enable_precomputation=True, |
| | ) |
| |
|
| | condition_iterator = preprocessor.consume( |
| | "condition", components={}, data_iterator=data_iterator, cache_samples=True |
| | ) |
| | latent_iterator = preprocessor.consume( |
| | "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
| | ) |
| |
|
| | precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) |
| | condition_file_list = find_files(precomputed_data_dir, "condition-*") |
| | latent_file_list = find_files(precomputed_data_dir, "latent-*") |
| | self.assertEqual(len(condition_file_list), 3) |
| | self.assertEqual(len(latent_file_list), 3) |
| |
|
| | self.assertFalse(preprocessor.requires_data) |
| | for _ in range(self.num_items): |
| | condition_item = next(condition_iterator) |
| | latent_item = next(latent_iterator) |
| | self.assertIn("caption", condition_item) |
| | self.assertIn("video", latent_item) |
| | self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
| | self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
| | self.assertTrue(preprocessor.requires_data) |
| |
|
| | def test_precomputed_preprocessor_consume_once(self): |
| | data_iterator = iter(self.dataset) |
| | preprocessor = initialize_preprocessor( |
| | self.rank, |
| | self.world_size, |
| | self.num_items, |
| | self.processor_fn, |
| | self.save_dir.name, |
| | enable_precomputation=True, |
| | ) |
| |
|
| | condition_iterator = preprocessor.consume_once( |
| | "condition", components={}, data_iterator=data_iterator, cache_samples=True |
| | ) |
| | latent_iterator = preprocessor.consume_once( |
| | "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True |
| | ) |
| |
|
| | precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR) |
| | condition_file_list = find_files(precomputed_data_dir, "condition-*") |
| | latent_file_list = find_files(precomputed_data_dir, "latent-*") |
| | self.assertEqual(len(condition_file_list), 3) |
| | self.assertEqual(len(latent_file_list), 3) |
| |
|
| | self.assertFalse(preprocessor.requires_data) |
| | for _ in range(self.num_items): |
| | condition_item = next(condition_iterator) |
| | latent_item = next(latent_iterator) |
| | self.assertIn("caption", condition_item) |
| | self.assertIn("video", latent_item) |
| | self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura") |
| | self.assertEqual(latent_item["video"].shape[-2:], (16, 16)) |
| | self.assertFalse(preprocessor.requires_data) |
| |
|