| from collections.abc import Iterator, Sequence |
| import multiprocessing |
| import os |
| import typing |
| from typing import Protocol, SupportsIndex, TypeVar |
|
|
| import jax |
| import jax.numpy as jnp |
| import lerobot.common.datasets.lerobot_dataset as lerobot_dataset |
| import numpy as np |
| import torch |
|
|
| import openpi.models.model as _model |
| import openpi.training.config as _config |
| from openpi.training.droid_rlds_dataset import DroidRldsDataset |
| import openpi.transforms as _transforms |
|
|
| T_co = TypeVar("T_co", covariant=True) |
|
|
|
|
| class Dataset(Protocol[T_co]): |
| """Interface for a dataset with random access.""" |
|
|
| def __getitem__(self, index: SupportsIndex) -> T_co: |
| raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") |
|
|
| def __len__(self) -> int: |
| raise NotImplementedError("Subclasses of Dataset should implement __len__.") |
|
|
|
|
| class IterableDataset(Protocol[T_co]): |
| """Interface for an iterable dataset.""" |
|
|
| def __iter__(self) -> Iterator[T_co]: |
| raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.") |
|
|
| def __len__(self) -> int: |
| raise NotImplementedError("Subclasses of Dataset should implement __len__.") |
|
|
|
|
| class DataLoader(Protocol[T_co]): |
| """Interface for a data loader.""" |
|
|
| def data_config(self) -> _config.DataConfig: |
| """Get the data config for this data loader.""" |
| raise NotImplementedError("Subclasses of DataLoader should implement data_config.") |
|
|
| def __iter__(self) -> Iterator[T_co]: |
| raise NotImplementedError("Subclasses of DataLoader should implement __iter__.") |
|
|
|
|
| class TransformedDataset(Dataset[T_co]): |
| def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]): |
| self._dataset = dataset |
| self._transform = _transforms.compose(transforms) |
|
|
| def __getitem__(self, index: SupportsIndex) -> T_co: |
| return self._transform(self._dataset[index]) |
|
|
| def __len__(self) -> int: |
| return len(self._dataset) |
|
|
|
|
| class IterableTransformedDataset(IterableDataset[T_co]): |
| def __init__( |
| self, |
| dataset: IterableDataset, |
| transforms: Sequence[_transforms.DataTransformFn], |
| *, |
| is_batched: bool = False, |
| ): |
| self._dataset = dataset |
| self._transform = _transforms.compose(transforms) |
| self._is_batched = is_batched |
|
|
| def __iter__(self): |
| for sample in self._dataset: |
| if self._is_batched: |
| |
| |
| batch_size = next(v.shape[0] for v in sample.values()) |
|
|
| |
| individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] |
|
|
| |
| transformed = [self._transform(s) for s in individual_samples] |
|
|
| |
| yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed) |
| else: |
| yield self._transform(sample) |
|
|
| def __len__(self) -> int: |
| return len(self._dataset) |
|
|
|
|
| class FakeDataset(Dataset): |
| def __init__(self, model_config: _model.BaseModelConfig, num_samples: int): |
| self._num_samples = num_samples |
| self._observation_spec, self._action_spec = model_config.inputs_spec() |
|
|
| def __getitem__(self, index: SupportsIndex) -> dict: |
| rng = jax.random.key(index.__index__()) |
|
|
| def make_from_spec(spec: jax.ShapeDtypeStruct): |
| nonlocal rng |
| rng, data_rng = jax.random.split(rng) |
| |
| shape = spec.shape[1:] |
| if spec.dtype == jnp.float32: |
| return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0) |
| if spec.dtype == jnp.int32: |
| return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048) |
| return jnp.zeros(shape=shape, dtype=spec.dtype) |
|
|
| observation = jax.tree.map(make_from_spec, self._observation_spec) |
| action = jax.tree.map(make_from_spec, self._action_spec) |
|
|
| return { |
| **observation.to_dict(), |
| "actions": action, |
| } |
|
|
| def __len__(self) -> int: |
| return self._num_samples |
|
|
|
|
| def create_torch_dataset( |
| data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig |
| ) -> Dataset: |
| """Create a dataset for training.""" |
| repo_id = data_config.repo_id |
| if repo_id is None: |
| raise ValueError("Repo ID is not set. Cannot create dataset.") |
| if repo_id == "fake": |
| return FakeDataset(model_config, num_samples=1024) |
|
|
| dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id) |
| dataset = lerobot_dataset.LeRobotDataset( |
| data_config.repo_id, |
| delta_timestamps={ |
| key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys |
| }, |
| ) |
|
|
| if data_config.prompt_from_task: |
| dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) |
|
|
| return dataset |
|
|
|
|
| def create_rlds_dataset( |
| data_config: _config.DataConfig, |
| action_horizon: int, |
| batch_size: int, |
| *, |
| shuffle: bool = False, |
| ) -> Dataset: |
| |
| return DroidRldsDataset( |
| data_dir=data_config.rlds_data_dir, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| action_chunk_size=action_horizon, |
| action_space=data_config.action_space, |
| ) |
|
|
|
|
| def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset: |
| """Transform the dataset by applying the data transforms.""" |
| norm_stats = {} |
| if data_config.repo_id != "fake" and not skip_norm_stats: |
| if data_config.norm_stats is None: |
| raise ValueError( |
| "Normalization stats not found. " |
| "Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`." |
| ) |
| norm_stats = data_config.norm_stats |
|
|
| return TransformedDataset( |
| dataset, |
| [ |
| *data_config.repack_transforms.inputs, |
| *data_config.data_transforms.inputs, |
| _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), |
| *data_config.model_transforms.inputs, |
| ], |
| ) |
|
|
|
|
| def transform_iterable_dataset( |
| dataset: IterableDataset, |
| data_config: _config.DataConfig, |
| *, |
| skip_norm_stats: bool = False, |
| is_batched: bool = False, |
| ) -> IterableDataset: |
| """Transform the dataset by applying the data transforms.""" |
| norm_stats = {} |
| if data_config.repo_id != "fake" and not skip_norm_stats: |
| if data_config.norm_stats is None: |
| raise ValueError( |
| "Normalization stats not found. " |
| "Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`." |
| ) |
| norm_stats = data_config.norm_stats |
|
|
| return IterableTransformedDataset( |
| dataset, |
| [ |
| *data_config.repack_transforms.inputs, |
| *data_config.data_transforms.inputs, |
| _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), |
| *data_config.model_transforms.inputs, |
| ], |
| is_batched=is_batched, |
| ) |
|
|
|
|
| def create_data_loader( |
| config: _config.TrainConfig, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| shuffle: bool = False, |
| num_batches: int | None = None, |
| skip_norm_stats: bool = False, |
| ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: |
| """Create a data loader for training.""" |
| data_config = config.data.create(config.assets_dirs, config.model) |
|
|
| if data_config.rlds_data_dir is not None: |
| return create_rlds_data_loader( |
| data_config, |
| action_horizon=config.model.action_horizon, |
| batch_size=config.batch_size, |
| sharding=sharding, |
| shuffle=shuffle, |
| num_batches=num_batches, |
| skip_norm_stats=skip_norm_stats, |
| ) |
| return create_torch_data_loader( |
| data_config, |
| model_config=config.model, |
| action_horizon=config.model.action_horizon, |
| batch_size=config.batch_size, |
| sharding=sharding, |
| shuffle=shuffle, |
| num_batches=num_batches, |
| num_workers=config.num_workers, |
| seed=config.seed, |
| skip_norm_stats=skip_norm_stats, |
| ) |
|
|
|
|
| def create_torch_data_loader( |
| data_config: _config.DataConfig, |
| model_config: _model.BaseModelConfig, |
| action_horizon: int, |
| batch_size: int, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| skip_norm_stats: bool = False, |
| shuffle: bool = False, |
| num_batches: int | None = None, |
| num_workers: int = 0, |
| seed: int = 0, |
| ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: |
| """Create a data loader for training. |
| |
| Args: |
| data_config: The data configuration. |
| action_horizon: The action horizon. |
| batch_size: The batch size. |
| sharding: The sharding to use for the data loader. If None, the data loader will |
| use a single device sharding. |
| skip_norm_stats: Whether to skip data normalization. |
| shuffle: Whether to shuffle the data. |
| num_batches: Determines the number of batches to return. If the number exceeds the |
| number of batches in the dataset, the data loader will loop over the dataset. |
| If not provided, will iterate over the dataset indefinitely. |
| num_workers: The number of worker processes to use. If zero, the data loader will |
| execute in the main process. |
| seed: The seed to use for shuffling the data. |
| """ |
| dataset = create_torch_dataset(data_config, action_horizon, model_config) |
| dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) |
|
|
| data_loader = TorchDataLoader( |
| dataset, |
| local_batch_size=batch_size // jax.process_count(), |
| sharding=sharding, |
| shuffle=shuffle, |
| num_batches=num_batches, |
| num_workers=num_workers, |
| seed=seed, |
| ) |
|
|
| return DataLoaderImpl(data_config, data_loader) |
|
|
|
|
| def create_rlds_data_loader( |
| data_config: _config.DataConfig, |
| action_horizon: int, |
| batch_size: int, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| skip_norm_stats: bool = False, |
| shuffle: bool = False, |
| num_batches: int | None = None, |
| ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: |
| """Create an RLDS data loader for training. |
| |
| Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md |
| |
| Args: |
| data_config: The data configuration. |
| action_horizon: The action horizon. |
| batch_size: The batch size. |
| sharding: The sharding to use for the data loader. If None, the data loader will |
| use a single device sharding. |
| skip_norm_stats: Whether to skip data normalization. |
| shuffle: Whether to shuffle the data. |
| num_batches: Determines the number of batches to return. If the number exceeds the |
| number of batches in the dataset, the data loader will loop over the dataset. |
| If not provided, will iterate over the dataset indefinitely. |
| """ |
| dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) |
| dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) |
|
|
| data_loader = RLDSDataLoader( |
| dataset, |
| sharding=sharding, |
| num_batches=num_batches, |
| ) |
|
|
| return DataLoaderImpl(data_config, data_loader) |
|
|
|
|
| class TorchDataLoader: |
| def __init__( |
| self, |
| dataset, |
| local_batch_size: int, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| shuffle: bool = False, |
| num_batches: int | None = None, |
| num_workers: int = 0, |
| seed: int = 0, |
| ): |
| """Create a PyTorch data loader. |
| |
| Args: |
| dataset: The dataset to load. |
| local_batch_size: The local batch size for each process. |
| sharding: The sharding to use for the data loader. |
| shuffle: Whether to shuffle the data. |
| num_batches: If provided, determines the number of returned batches. If the |
| number is larger than the number of batches in the dataset, the data loader |
| will loop over the dataset. If not provided, will iterate over the dataset |
| indefinitely. |
| num_workers: The number of worker processes to use. If zero, the data loader will |
| execute in the main process. |
| seed: The seed to use for shuffling the data. |
| """ |
| if jax.process_count() > 1: |
| raise NotImplementedError("Data loading with multiple processes is not supported.") |
|
|
| if len(dataset) < local_batch_size: |
| raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") |
|
|
| if sharding is None: |
| |
| sharding = jax.sharding.NamedSharding( |
| jax.sharding.Mesh(jax.devices(), ("B",)), |
| jax.sharding.PartitionSpec("B"), |
| ) |
|
|
| self._sharding = sharding |
| self._num_batches = num_batches |
|
|
| mp_context = None |
| if num_workers > 0: |
| mp_context = multiprocessing.get_context("spawn") |
|
|
| generator = torch.Generator() |
| generator.manual_seed(seed) |
| self._data_loader = torch.utils.data.DataLoader( |
| typing.cast(torch.utils.data.Dataset, dataset), |
| batch_size=local_batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| multiprocessing_context=mp_context, |
| persistent_workers=num_workers > 0, |
| collate_fn=_collate_fn, |
| worker_init_fn=_worker_init_fn, |
| drop_last=True, |
| generator=generator, |
| ) |
|
|
| @property |
| def torch_loader(self) -> torch.utils.data.DataLoader: |
| return self._data_loader |
|
|
| def __iter__(self): |
| num_items = 0 |
| while True: |
| data_iter = iter(self._data_loader) |
| while True: |
| if self._num_batches is not None and num_items >= self._num_batches: |
| return |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| break |
| num_items += 1 |
| yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) |
|
|
|
|
| def _collate_fn(items): |
| """Collate the batch elements into batched numpy arrays.""" |
| |
| |
| return jax.tree.map(lambda *x: np.stack(np.asarray(x), axis=0), *items) |
|
|
|
|
| def _worker_init_fn(worker_id: int) -> None: |
| """Tell JAX inside the worker process not to preallocate the GPU memory.""" |
| |
| |
| os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" |
| os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" |
|
|
|
|
| class RLDSDataLoader: |
| """Shallow wrapper around the DROID data loader to make it compatible with openpi. |
| |
| All batching already happens in the DROID dataset, so we don't need to do anything here. |
| """ |
|
|
| def __init__( |
| self, |
| dataset: DroidRldsDataset, |
| *, |
| sharding: jax.sharding.Sharding | None = None, |
| num_batches: int | None = None, |
| ): |
| self._dataset = dataset |
| self._num_batches = num_batches |
|
|
| if jax.process_count() > 1: |
| raise NotImplementedError("Data loading with multiple processes is not supported.") |
|
|
| if sharding is None: |
| |
| sharding = jax.sharding.NamedSharding( |
| jax.sharding.Mesh(jax.devices(), ("B",)), |
| jax.sharding.PartitionSpec("B"), |
| ) |
|
|
| self._sharding = sharding |
| self._num_batches = num_batches |
|
|
| def __iter__(self): |
| num_items = 0 |
| while True: |
| data_iter = iter(self._dataset) |
| while True: |
| if self._num_batches is not None and num_items >= self._num_batches: |
| return |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| break |
| num_items += 1 |
| yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) |
|
|
|
|
| class DataLoaderImpl(DataLoader): |
| def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader): |
| self._data_config = data_config |
| self._data_loader = data_loader |
|
|
| def data_config(self) -> _config.DataConfig: |
| return self._data_config |
|
|
| def __iter__(self): |
| for batch in self._data_loader: |
| yield _model.Observation.from_dict(batch), batch["actions"] |
|
|