| | |
| |
|
| | import collections |
| | import os |
| |
|
| | import torch |
| | from torch.utils.data import get_worker_info |
| | from torch.utils.data._utils.collate import ( |
| | default_collate_err_msg_format, |
| | np_str_obj_array_pattern, |
| | ) |
| | from lightning_fabric.utilities.seed import pl_worker_init_function |
| |
|
| | def collate(batch): |
| | """Difference with PyTorch default_collate: it can stack other tensor-like objects. |
| | Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich |
| | https://github.com/cvg/pixloc |
| | Released under the Apache License 2.0 |
| | """ |
| | if not isinstance(batch, list): |
| | return batch |
| |
|
| | |
| | batch = [elem for elem in batch if elem is not None] |
| | elem = batch[0] |
| | elem_type = type(elem) |
| | if isinstance(elem, torch.Tensor): |
| | out = None |
| | if torch.utils.data.get_worker_info() is not None: |
| | |
| | |
| | numel = sum(x.numel() for x in batch) |
| | storage = elem.storage()._new_shared(numel, device=elem.device) |
| | out = elem.new(storage).resize_(len(batch), *list(elem.size())) |
| | return torch.stack(batch, 0, out=out) |
| | elif ( |
| | elem_type.__module__ == "numpy" |
| | and elem_type.__name__ != "str_" |
| | and elem_type.__name__ != "string_" |
| | ): |
| | if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": |
| | |
| | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
| | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) |
| |
|
| | return collate([torch.as_tensor(b) for b in batch]) |
| | elif elem.shape == (): |
| | return torch.as_tensor(batch) |
| | elif isinstance(elem, float): |
| | return torch.tensor(batch, dtype=torch.float64) |
| | elif isinstance(elem, int): |
| | return torch.tensor(batch) |
| | elif isinstance(elem, (str, bytes)): |
| | return batch |
| | elif isinstance(elem, collections.abc.Mapping): |
| | return {key: collate([d[key] for d in batch]) for key in elem} |
| | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): |
| | return elem_type(*(collate(samples) for samples in zip(*batch))) |
| | elif isinstance(elem, collections.abc.Sequence): |
| | |
| | it = iter(batch) |
| | elem_size = len(next(it)) |
| | if not all(len(elem) == elem_size for elem in it): |
| | raise RuntimeError("each element in list of batch should be of equal size") |
| | transposed = zip(*batch) |
| | return [collate(samples) for samples in transposed] |
| | else: |
| | |
| | try: |
| | return torch.stack(batch, 0) |
| | except TypeError as e: |
| | if "expected Tensor as element" in str(e): |
| | return batch |
| | else: |
| | raise e |
| |
|
| |
|
| | def set_num_threads(nt): |
| | """Force numpy and other libraries to use a limited number of threads.""" |
| | try: |
| | import mkl |
| | except ImportError: |
| | pass |
| | else: |
| | mkl.set_num_threads(nt) |
| | torch.set_num_threads(1) |
| | os.environ["IPC_ENABLE"] = "1" |
| | for o in [ |
| | "OPENBLAS_NUM_THREADS", |
| | "NUMEXPR_NUM_THREADS", |
| | "OMP_NUM_THREADS", |
| | "MKL_NUM_THREADS", |
| | ]: |
| | os.environ[o] = str(nt) |
| |
|
| |
|
| | def worker_init_fn(i): |
| | info = get_worker_info() |
| | pl_worker_init_function(info.id) |
| | num_threads = info.dataset.cfg.get("num_threads") |
| | if num_threads is not None: |
| | set_num_threads(num_threads) |