| import dataclasses |
| import enum |
| import logging |
| import socket |
|
|
| import tyro |
|
|
| from openpi.policies import policy as _policy |
| from openpi.policies import policy_config as _policy_config |
| from openpi.serving import websocket_policy_server |
| from openpi.training import config as _config |
|
|
|
|
| class EnvMode(enum.Enum): |
| """Supported environments.""" |
|
|
| ALOHA = "aloha" |
| ALOHA_SIM = "aloha_sim" |
| DROID = "droid" |
| LIBERO = "libero" |
|
|
|
|
| @dataclasses.dataclass |
| class Checkpoint: |
| """Load a policy from a trained checkpoint.""" |
|
|
| |
| config: str |
| |
| dir: str |
|
|
|
|
| @dataclasses.dataclass |
| class Default: |
| """Use the default policy for the given environment.""" |
|
|
|
|
| @dataclasses.dataclass |
| class Args: |
| """Arguments for the serve_policy script.""" |
|
|
| |
| env: EnvMode = EnvMode.ALOHA_SIM |
|
|
| |
| |
| default_prompt: str | None = None |
|
|
| |
| port: int = 8000 |
| |
| record: bool = False |
|
|
| |
| policy: Checkpoint | Default = dataclasses.field(default_factory=Default) |
|
|
|
|
| |
| DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = { |
| EnvMode.ALOHA: Checkpoint( |
| config="pi0_aloha", |
| dir="gs://openpi-assets/checkpoints/pi0_base", |
| ), |
| EnvMode.ALOHA_SIM: Checkpoint( |
| config="pi0_aloha_sim", |
| dir="gs://openpi-assets/checkpoints/pi0_aloha_sim", |
| ), |
| EnvMode.DROID: Checkpoint( |
| config="pi0_fast_droid", |
| dir="gs://openpi-assets/checkpoints/pi0_fast_droid", |
| ), |
| EnvMode.LIBERO: Checkpoint( |
| config="pi0_fast_libero", |
| dir="gs://openpi-assets/checkpoints/pi0_fast_libero", |
| ), |
| } |
|
|
|
|
| def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy: |
| """Create a default policy for the given environment.""" |
| if checkpoint := DEFAULT_CHECKPOINT.get(env): |
| return _policy_config.create_trained_policy( |
| _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt |
| ) |
| raise ValueError(f"Unsupported environment mode: {env}") |
|
|
|
|
| def create_policy(args: Args) -> _policy.Policy: |
| """Create a policy from the given arguments.""" |
| match args.policy: |
| case Checkpoint(): |
| return _policy_config.create_trained_policy( |
| _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt |
| ) |
| case Default(): |
| return create_default_policy(args.env, default_prompt=args.default_prompt) |
|
|
|
|
| def main(args: Args) -> None: |
| policy = create_policy(args) |
| policy_metadata = policy.metadata |
|
|
| |
| if args.record: |
| policy = _policy.PolicyRecorder(policy, "policy_records") |
|
|
| hostname = socket.gethostname() |
| local_ip = socket.gethostbyname(hostname) |
| logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) |
|
|
| server = websocket_policy_server.WebsocketPolicyServer( |
| policy=policy, |
| host="0.0.0.0", |
| port=args.port, |
| metadata=policy_metadata, |
| ) |
| server.serve_forever() |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO, force=True) |
| main(tyro.cli(Args)) |
|
|