| | from typing import Dict, List, Optional, NoReturn |
| | import torch |
| | import lightning.pytorch as pl |
| | from torch.utils.data import DataLoader |
| | from data.audiotext_dataset import AudioTextDataset |
| |
|
| |
|
| | class DataModule(pl.LightningDataModule): |
| | def __init__( |
| | self, |
| | train_dataset: object, |
| | batch_size: int, |
| | num_workers: int |
| | ): |
| | r"""Data module. To get one batch of data: |
| | |
| | code-block:: python |
| | |
| | data_module.setup() |
| | |
| | for batch_data_dict in data_module.train_dataloader(): |
| | print(batch_data_dict.keys()) |
| | break |
| | |
| | Args: |
| | train_sampler: Sampler object |
| | train_dataset: Dataset object |
| | num_workers: int |
| | distributed: bool |
| | """ |
| | super().__init__() |
| | self._train_dataset = train_dataset |
| | self.num_workers = num_workers |
| | self.batch_size = batch_size |
| | self.collate_fn = collate_fn |
| |
|
| |
|
| | def prepare_data(self): |
| | |
| | |
| | pass |
| |
|
| | def setup(self, stage: Optional[str] = None) -> NoReturn: |
| | r"""called on every device.""" |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | self.train_dataset = self._train_dataset |
| | |
| | |
| | def train_dataloader(self) -> torch.utils.data.DataLoader: |
| | r"""Get train loader.""" |
| | train_loader = DataLoader( |
| | dataset=self.train_dataset, |
| | batch_size=self.batch_size, |
| | collate_fn=self.collate_fn, |
| | num_workers=self.num_workers, |
| | pin_memory=True, |
| | persistent_workers=False, |
| | shuffle=True |
| | ) |
| |
|
| | return train_loader |
| |
|
| | def val_dataloader(self): |
| | |
| | |
| | pass |
| |
|
| | def test_dataloader(self): |
| | |
| | |
| | pass |
| |
|
| | def teardown(self): |
| | |
| | |
| | pass |
| |
|
| |
|
| | def collate_fn(list_data_dict): |
| | r"""Collate mini-batch data to inputs and targets for training. |
| | |
| | Args: |
| | list_data_dict: e.g., [ |
| | { |
| | 'text': 'a sound of dog', |
| | 'waveform': (1, samples), |
| | 'modality': 'audio_text' |
| | } |
| | ... |
| | ] |
| | Returns: |
| | data_dict: e.g. |
| | 'audio_text': { |
| | 'text': ['a sound of dog', ...] |
| | 'waveform': (batch_size, 1, samples) |
| | } |
| | """ |
| | |
| | at_list_data_dict = [data_dict for data_dict in list_data_dict if data_dict['modality']=='audio_text'] |
| |
|
| | at_data_dict = {} |
| | |
| | if len(at_list_data_dict) > 0: |
| | for key in at_list_data_dict[0].keys(): |
| | at_data_dict[key] = [at_data_dict[key] for at_data_dict in at_list_data_dict] |
| | if key == 'waveform': |
| | at_data_dict[key] = torch.stack(at_data_dict[key]) |
| | elif key == 'text': |
| | at_data_dict[key] = [text for text in at_data_dict[key]] |
| |
|
| | |
| | data_dict = { |
| | 'audio_text': at_data_dict |
| | } |
| | |
| | return data_dict |