| import abc |
| from collections.abc import Sequence |
| import dataclasses |
| import enum |
| import logging |
| import pathlib |
| from typing import Generic, TypeVar |
|
|
| import augmax |
| from flax import nnx |
| from flax import struct |
| from flax import traverse_util |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import orbax.checkpoint as ocp |
|
|
| from openpi.shared import image_tools |
| import openpi.shared.array_typing as at |
|
|
| logger = logging.getLogger("openpi") |
|
|
| ArrayT = TypeVar("ArrayT", at.Array, jax.ShapeDtypeStruct) |
|
|
|
|
| class ModelType(enum.Enum): |
| """Supported model types.""" |
|
|
| PI0 = "pi0" |
| PI0_FAST = "pi0_fast" |
|
|
|
|
| |
| IMAGE_KEYS = ( |
| "base_0_rgb", |
| "left_wrist_0_rgb", |
| "right_wrist_0_rgb", |
| ) |
|
|
|
|
| |
| IMAGE_RESOLUTION = (224, 224) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @at.typecheck |
| @struct.dataclass |
| class Observation(Generic[ArrayT]): |
| """Holds observations, i.e., inputs to the model. |
| |
| See `Observation.from_dict` to see the expected dictionary form. This is the format |
| that should be produced by the data transforms. |
| """ |
|
|
| |
| images: dict[str, at.Float[ArrayT, "*b h w c"]] |
| |
| image_masks: dict[str, at.Bool[ArrayT, "*b"]] |
| |
| state: at.Float[ArrayT, "*b s"] |
|
|
| |
| tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None |
| |
| tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None |
|
|
| |
|
|
| |
| token_ar_mask: at.Int[ArrayT, "*b l"] | None = None |
| |
| token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None |
|
|
| @classmethod |
| def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]": |
| """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format.""" |
| |
| if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data): |
| raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.") |
| |
| for key in data["image"]: |
| if data["image"][key].dtype == np.uint8: |
| data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0 |
| return cls( |
| images=data["image"], |
| image_masks=data["image_mask"], |
| state=data["state"], |
| tokenized_prompt=data.get("tokenized_prompt"), |
| tokenized_prompt_mask=data.get("tokenized_prompt_mask"), |
| token_ar_mask=data.get("token_ar_mask"), |
| token_loss_mask=data.get("token_loss_mask"), |
| ) |
|
|
| def to_dict(self) -> at.PyTree[ArrayT]: |
| """Convert the Observation to a nested dict.""" |
| result = dataclasses.asdict(self) |
| result["image"] = result.pop("images") |
| result["image_mask"] = result.pop("image_masks") |
| return result |
|
|
|
|
| |
| |
| Actions = at.Float[ArrayT, "*b ah ad"] |
|
|
|
|
| def preprocess_observation( |
| rng: at.KeyArrayLike | None, |
| observation: Observation, |
| *, |
| train: bool = False, |
| image_keys: Sequence[str] = IMAGE_KEYS, |
| image_resolution: tuple[int, int] = IMAGE_RESOLUTION, |
| ) -> Observation: |
| """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and |
| filling in a default image mask (if necessary). |
| """ |
|
|
| if not set(image_keys).issubset(observation.images): |
| raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") |
|
|
| batch_shape = observation.state.shape[:-1] |
|
|
| out_images = {} |
| for key in image_keys: |
| image = observation.images[key] |
| if image.shape[1:3] != image_resolution: |
| logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") |
| image = image_tools.resize_with_pad(image, *image_resolution) |
|
|
| if train: |
| |
| image = image / 2.0 + 0.5 |
|
|
| transforms = [] |
| if "wrist" not in key: |
| height, width = image.shape[1:3] |
| transforms += [ |
| augmax.RandomCrop(int(width * 0.95), int(height * 0.95)), |
| augmax.Resize(width, height), |
| augmax.Rotate((-5, 5)), |
| ] |
| transforms += [ |
| augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), |
| ] |
| sub_rngs = jax.random.split(rng, image.shape[0]) |
| image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image) |
|
|
| |
| image = image * 2.0 - 1.0 |
|
|
| out_images[key] = image |
|
|
| |
| out_masks = {} |
| for key in out_images: |
| if key not in observation.image_masks: |
| |
| out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool) |
| else: |
| out_masks[key] = jnp.asarray(observation.image_masks[key]) |
|
|
| return Observation( |
| images=out_images, |
| image_masks=out_masks, |
| state=observation.state, |
| tokenized_prompt=observation.tokenized_prompt, |
| tokenized_prompt_mask=observation.tokenized_prompt_mask, |
| token_ar_mask=observation.token_ar_mask, |
| token_loss_mask=observation.token_loss_mask, |
| ) |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class BaseModelConfig(abc.ABC): |
| """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` |
| method to create the corresponding model. |
| """ |
|
|
| |
| action_dim: int |
| |
| action_horizon: int |
| |
| max_token_len: int |
|
|
| @property |
| @abc.abstractmethod |
| def model_type(self) -> ModelType: |
| """The model type.""" |
|
|
| @abc.abstractmethod |
| def create(self, rng: at.KeyArrayLike) -> "BaseModel": |
| """Create a new model, initializing parameters.""" |
|
|
| def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel": |
| """Create a model with the given parameters.""" |
| model = nnx.eval_shape(self.create, jax.random.key(0)) |
| graphdef, state = nnx.split(model) |
| if remove_extra_params: |
| params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params) |
| at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False) |
| state.replace_by_pure_dict(params) |
| return nnx.merge(graphdef, state) |
|
|
| @abc.abstractmethod |
| def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]: |
| """Returns the input specification for the model. Values are jax.ShapeDtypeStruct.""" |
|
|
| def fake_obs(self, batch_size: int = 1) -> Observation: |
| observation_spec, _ = self.inputs_spec(batch_size=batch_size) |
| return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec) |
|
|
| def fake_act(self, batch_size: int = 1) -> Actions: |
| _, action_spec = self.inputs_spec(batch_size=batch_size) |
| return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec) |
|
|
|
|
| @dataclasses.dataclass |
| class BaseModel(nnx.Module, abc.ABC): |
| """Base class for all model implementations. Specific models should inherit from this class. They should call |
| super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len). |
| """ |
|
|
| action_dim: int |
| action_horizon: int |
| max_token_len: int |
|
|
| @abc.abstractmethod |
| def compute_loss( |
| self, |
| rng: at.KeyArrayLike, |
| observation: Observation, |
| actions: Actions, |
| *, |
| train: bool = False, |
| ) -> at.Float[at.Array, "*b ah"]: ... |
|
|
| @abc.abstractmethod |
| def sample_actions(self, rng: at.KeyArrayLike, observation: Observation) -> Actions: ... |
|
|
|
|
| def restore_params( |
| params_path: pathlib.Path | str, |
| *, |
| restore_type: type[np.ndarray] | type[jax.Array] = jax.Array, |
| dtype: jnp.dtype | None = None, |
| sharding: jax.sharding.Sharding | None = None, |
| ) -> at.Params: |
| """Restores unstructured params PyTree from a checkpoint. |
| |
| This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as |
| well as pre-trained checkpoints released for openpi. |
| |
| Args: |
| params_path: The local path to the checkpoint directory. |
| restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array. |
| dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint. |
| sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices. |
| |
| Returns: |
| The restored params. |
| """ |
| params_path = pathlib.Path(params_path).resolve() |
| if not params_path.exists(): |
| raise FileNotFoundError(f"Model params not found at: {params_path}") |
|
|
| if restore_type is jax.Array and sharding is None: |
| mesh = jax.sharding.Mesh(jax.devices(), ("x",)) |
| sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) |
|
|
| with ocp.PyTreeCheckpointer() as ckptr: |
| metadata = ckptr.metadata(params_path) |
| item = {"params": metadata["params"]} |
|
|
| params = ckptr.restore( |
| params_path, |
| ocp.args.PyTreeRestore( |
| item=item, |
| restore_args=jax.tree.map( |
| lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item |
| ), |
| ), |
| )["params"] |
|
|
| |
| |
| flat_params = traverse_util.flatten_dict(params) |
| if all(kp[-1] == "value" for kp in flat_params): |
| flat_params = {kp[:-1]: v for kp, v in flat_params.items()} |
| return traverse_util.unflatten_dict(flat_params) |
|
|