| """ |
| MultiStepDemonstrationWrapper: Wraps DemonstrationWrapper to provide waypoint step interface. |
| |
| Each step(action) receives action = waypoint_p(3) + rpy(3) + gripper_action(1), total 7 dimensions. |
| Internally converts RPY to quat then calls move_to_pose_with_screw and close_gripper/open_gripper via planner_denseStep, |
| where PatternLock/RouteStick will force skip close_gripper/open_gripper. |
| Returns obs as dictionary-of-lists, and reward/terminated/truncated as the last step value. |
| Caller must ensure scripts/ is in sys.path to import planner_fail_safe. |
| """ |
| import numpy as np |
| import sapien |
| import torch |
| import gymnasium as gym |
|
|
| from ..robomme_env.utils import planner_denseStep |
| from ..robomme_env.utils.rpy_util import rpy_xyz_to_quat_wxyz_torch |
| from ..robomme_env.utils.planner_fail_safe import ScrewPlanFailure |
|
|
| DATASET_SCREW_MAX_ATTEMPTS = 3 |
| DATASET_RRT_MAX_ATTEMPTS = 3 |
|
|
|
|
| class RRTPlanFailure(RuntimeError): |
| """Raised when move_to_pose_with_RRTStar returns -1 (planning failed).""" |
|
|
|
|
| class MultiStepDemonstrationWrapper(gym.Wrapper): |
| """ |
| Wraps DemonstrationWrapper. step(action) interprets action as |
| (waypoint_p, rpy, gripper_action) total 7 dims, internally converts RPY to quat, |
| executes planning via planner_denseStep, and returns last-step signals. |
| """ |
|
|
| def __init__(self, env, gui_render=True, vis=True, **kwargs): |
| super().__init__(env) |
| self._planner = None |
| self._gui_render = gui_render |
| self._vis = vis |
| self.action_space = gym.spaces.Box( |
| low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64 |
| ) |
|
|
| @staticmethod |
| def _batch_to_steps(batch): |
| obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| n = int(reward_batch.numel()) |
| steps = [] |
| obs_keys = list(obs_batch.keys()) |
| info_keys = list(info_batch.keys()) |
| for idx in range(n): |
| obs = {k: obs_batch[k][idx] for k in obs_keys} |
| info = {k: info_batch[k][idx] for k in info_keys} |
| reward = reward_batch[idx] |
| terminated = terminated_batch[idx] |
| truncated = truncated_batch[idx] |
| steps.append((obs, reward, terminated, truncated, info)) |
| return steps |
|
|
| @staticmethod |
| def _flatten_info_batch(info_batch): |
| return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()} |
|
|
| @staticmethod |
| def _take_last_step_value(value): |
| if isinstance(value, torch.Tensor): |
| if value.numel() == 0 or value.ndim == 0: |
| return value |
| return value.reshape(-1)[-1] |
| if isinstance(value, np.ndarray): |
| if value.size == 0 or value.ndim == 0: |
| return value |
| return value.reshape(-1)[-1] |
| if isinstance(value, (list, tuple)): |
| return value[-1] if value else value |
| return value |
|
|
| def _get_planner(self): |
| if self._planner is not None: |
| return self._planner |
| from ..robomme_env.utils.planner_fail_safe import ( |
| FailAwarePandaArmMotionPlanningSolver, |
| FailAwarePandaStickMotionPlanningSolver, |
| ) |
|
|
| env_id = self.env.unwrapped.spec.id |
| base_pose = self.env.unwrapped.agent.robot.pose |
| if env_id in ("PatternLock", "RouteStick"): |
| self._planner = FailAwarePandaStickMotionPlanningSolver( |
| self.env, |
| debug=False, |
| vis=self._vis, |
| base_pose=base_pose, |
| visualize_target_grasp_pose=False, |
| print_env_info=False, |
| joint_vel_limits=0.3, |
| ) |
| else: |
| self._planner = FailAwarePandaArmMotionPlanningSolver( |
| self.env, |
| debug=False, |
| vis=self._vis, |
| base_pose=base_pose, |
| visualize_target_grasp_pose=True, |
| print_env_info=False, |
| ) |
| return self._planner |
|
|
| def _current_tcp_p(self): |
| current_pose = self.env.unwrapped.agent.tcp.pose |
| p = current_pose.p |
| if hasattr(p, "cpu"): |
| p = p.cpu().numpy() |
| p = np.asarray(p).flatten() |
| return p |
|
|
| def _no_op_step(self): |
| """Execute one step using current qpos + gripper, without moving arm, only to get observation.""" |
| robot = self.env.unwrapped.agent.robot |
| qpos = robot.get_qpos().cpu().numpy().flatten() |
| arm = qpos[:7] |
| gripper = float(qpos[7]) if len(qpos) > 7 else 0.0 |
| action = np.hstack([arm, gripper]) |
| return self.env.step(action) |
|
|
| def step(self, action): |
| """Execute waypoint step and return last-step signals for reward/terminated/truncated.""" |
| action = np.asarray(action, dtype=np.float64).flatten() |
| if action.size < 7: |
| raise ValueError(f"action must have at least 7 elements, got {action.size}") |
| waypoint_p = action[:3] |
| rpy = action[3:6] |
| gripper_action = float(action[6]) |
|
|
| |
| rpy_t = torch.as_tensor(rpy, dtype=torch.float64) |
| waypoint_q = rpy_xyz_to_quat_wxyz_torch(rpy_t).numpy() |
|
|
| pose = sapien.Pose(p=waypoint_p, q=waypoint_q) |
| planner = self._get_planner() |
| is_stick_env = self.env.unwrapped.spec.id in ("PatternLock", "RouteStick") |
|
|
| current_p = self._current_tcp_p() |
| dist = np.linalg.norm(current_p - waypoint_p) |
|
|
| collected_steps = [] |
| |
| |
| move_steps = -1 |
| for attempt in range(1, DATASET_SCREW_MAX_ATTEMPTS + 1): |
| try: |
| result = planner_denseStep._collect_dense_steps( |
| planner, lambda: planner.move_to_pose_with_screw(pose) |
| ) |
| except ScrewPlanFailure as exc: |
| print(f"[MultiStep] screw planning failed (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS}): {exc}") |
| continue |
| |
| if isinstance(result, int) and result == -1: |
| print(f"[MultiStep] screw planning returned -1 (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS})") |
| continue |
|
|
| move_steps = result |
| break |
|
|
| if move_steps == -1: |
| print(f"[MultiStep] screw planning exhausted; fallback to RRT* (max {DATASET_RRT_MAX_ATTEMPTS} attempts)") |
| for attempt in range(1, DATASET_RRT_MAX_ATTEMPTS + 1): |
| try: |
| result = planner_denseStep._collect_dense_steps( |
| planner, lambda: planner.move_to_pose_with_RRTStar(pose) |
| ) |
| except Exception as exc: |
| print(f"[MultiStep] RRT* planning failed (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS}): {exc}") |
| continue |
|
|
| if isinstance(result, int) and result == -1: |
| print(f"[MultiStep] RRT* planning returned -1 (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS})") |
| continue |
|
|
| move_steps = result |
| break |
|
|
| if move_steps == -1: |
| raise RRTPlanFailure("Both screw and RRTStar planning exhausted.") |
| collected_steps.extend(move_steps) |
|
|
| |
| if not is_stick_env: |
| if gripper_action == -1: |
| if hasattr(planner, "close_gripper"): |
| result = planner_denseStep.close_gripper(planner) |
| if result != -1: |
| collected_steps.extend(self._batch_to_steps(result)) |
| elif gripper_action == 1: |
| if hasattr(planner, "open_gripper"): |
| result = planner_denseStep.open_gripper(planner) |
| if result != -1: |
| collected_steps.extend(self._batch_to_steps(result)) |
|
|
| obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = planner_denseStep.to_step_batch( |
| collected_steps |
| ) |
| info_flat = self._flatten_info_batch(info_batch) |
| return ( |
| obs_batch, |
| self._take_last_step_value(reward_batch), |
| self._take_last_step_value(terminated_batch), |
| self._take_last_step_value(truncated_batch), |
| info_flat, |
| ) |
|
|
| def reset(self, **kwargs): |
| self._planner = None |
| return self.env.reset(**kwargs) |
|
|
| def close(self): |
| self._planner = None |
| return self.env.close() |
|
|