| import gymnasium as gym |
| import numpy as np |
| import torch |
|
|
| from robomme.robomme_env.utils.vqa_options import get_vqa_options |
| from mani_skill.examples.motionplanning.panda.motionplanner import ( |
| PandaArmMotionPlanningSolver, |
| ) |
| from mani_skill.examples.motionplanning.panda.motionplanner_stick import ( |
| PandaStickMotionPlanningSolver, |
| ) |
| from ..robomme_env.utils import planner_denseStep |
| from ..robomme_env.utils.oracle_action_matcher import ( |
| find_exact_label_option_index, |
| ) |
| from ..robomme_env.utils.choice_action_mapping import select_target_with_pixel |
| from ..logging_utils import logger |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| class OraclePlannerDemonstrationWrapper(gym.Wrapper): |
| """ |
| Wrap Robomme environment with Oracle planning logic into Gym Wrapper for demonstration/evaluation; |
| Input to step is command_dict (containing choice and optional pixel point). |
| step returns obs as dict-of-lists and reward/terminated/truncated as last-step values. |
| """ |
|
|
| def __init__(self, env, env_id, gui_render=True): |
| super().__init__(env) |
| self.env_id = env_id |
| self.gui_render = gui_render |
|
|
| self.planner = None |
| self.language_goal = None |
|
|
| |
| self.available_options = [] |
| self._oracle_screw_max_attempts = 3 |
| self._oracle_rrt_max_attempts = 3 |
| self._front_camera_intrinsic_cv = None |
| self._front_camera_extrinsic_cv = None |
| self._front_rgb_shape = None |
|
|
| |
| self.action_space = gym.spaces.Dict({}) |
| self.observation_space = gym.spaces.Dict({}) |
|
|
| def _wrap_planner_with_screw_then_rrt_retry(self, planner, screw_failure_exc): |
| original_move_to_pose_with_screw = planner.move_to_pose_with_screw |
| original_move_to_pose_with_rrt = planner.move_to_pose_with_RRTStar |
|
|
| def _move_to_pose_with_screw_then_rrt_retry(*args, **kwargs): |
| for attempt in range(1, self._oracle_screw_max_attempts + 1): |
| try: |
| result = original_move_to_pose_with_screw(*args, **kwargs) |
| except screw_failure_exc as exc: |
| logger.debug( |
| f"[OraclePlannerWrapper] screw planning failed " |
| f"(attempt {attempt}/{self._oracle_screw_max_attempts}): {exc}" |
| ) |
| continue |
|
|
| if isinstance(result, int) and result == -1: |
| logger.debug( |
| f"[OraclePlannerWrapper] screw planning returned -1 " |
| f"(attempt {attempt}/{self._oracle_screw_max_attempts})" |
| ) |
| continue |
|
|
| return result |
|
|
| logger.debug( |
| "[OraclePlannerWrapper] screw planning exhausted; " |
| f"fallback to RRT* (max {self._oracle_rrt_max_attempts} attempts)" |
| ) |
| for attempt in range(1, self._oracle_rrt_max_attempts + 1): |
| try: |
| result = original_move_to_pose_with_rrt(*args, **kwargs) |
| except Exception as exc: |
| logger.debug( |
| f"[OraclePlannerWrapper] RRT* planning failed " |
| f"(attempt {attempt}/{self._oracle_rrt_max_attempts}): {exc}" |
| ) |
| continue |
|
|
| if isinstance(result, int) and result == -1: |
| logger.debug( |
| f"[OraclePlannerWrapper] RRT* planning returned -1 " |
| f"(attempt {attempt}/{self._oracle_rrt_max_attempts})" |
| ) |
| continue |
|
|
| return result |
|
|
| raise RuntimeError( |
| "[OraclePlannerWrapper] screw->RRT* planning exhausted; " |
| f"screw_attempts={self._oracle_screw_max_attempts}, " |
| f"rrt_attempts={self._oracle_rrt_max_attempts}" |
| ) |
|
|
| planner.move_to_pose_with_screw = _move_to_pose_with_screw_then_rrt_retry |
| return planner |
|
|
| def reset(self, **kwargs): |
| |
| try: |
| from ..robomme_env.utils.planner_fail_safe import ( |
| FailAwarePandaArmMotionPlanningSolver, |
| FailAwarePandaStickMotionPlanningSolver, |
| ScrewPlanFailure, |
| ) |
| except Exception as exc: |
| logger.debug( |
| "[OraclePlannerWrapper] Warning: failed to import planner_fail_safe, " |
| f"fallback to base planners: {exc}" |
| ) |
| FailAwarePandaArmMotionPlanningSolver = PandaArmMotionPlanningSolver |
| FailAwarePandaStickMotionPlanningSolver = PandaStickMotionPlanningSolver |
|
|
| class ScrewPlanFailure(RuntimeError): |
| """Placeholder exception type when fail-aware planner import is unavailable.""" |
|
|
| |
| if self.env_id in ("PatternLock", "RouteStick"): |
| self.planner = FailAwarePandaStickMotionPlanningSolver( |
| self.env, |
| debug=False, |
| vis=self.gui_render, |
| base_pose=self.env.unwrapped.agent.robot.pose, |
| visualize_target_grasp_pose=False, |
| print_env_info=False, |
| joint_vel_limits=0.3, |
| ) |
| else: |
| self.planner = FailAwarePandaArmMotionPlanningSolver( |
| self.env, |
| debug=False, |
| vis=self.gui_render, |
| base_pose=self.env.unwrapped.agent.robot.pose, |
| visualize_target_grasp_pose=False, |
| print_env_info=False, |
| ) |
| self._wrap_planner_with_screw_then_rrt_retry( |
| self.planner, |
| screw_failure_exc=ScrewPlanFailure, |
| ) |
| ret = self.env.reset(**kwargs) |
| if isinstance(ret, tuple) and len(ret) == 2: |
| obs, info = ret |
| else: |
| obs, info = ret, {} |
| self._update_front_camera_cache(obs_like=obs, info_like=info) |
| self._build_step_options() |
| if isinstance(info, dict): |
| info["available_multi_choices"] = self.available_options |
| return obs, info |
|
|
| @staticmethod |
| def _flatten_info_batch(info_batch: dict) -> dict: |
| return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()} |
|
|
| @staticmethod |
| def _take_last_step_value(value): |
| if isinstance(value, torch.Tensor): |
| if value.numel() == 0 or value.ndim == 0: |
| return value |
| return value.reshape(-1)[-1] |
| if isinstance(value, np.ndarray): |
| if value.size == 0 or value.ndim == 0: |
| return value |
| return value.reshape(-1)[-1] |
| if isinstance(value, (list, tuple)): |
| return value[-1] if value else value |
| return value |
|
|
| @staticmethod |
| def _to_numpy(value): |
| if value is None: |
| return None |
| if isinstance(value, torch.Tensor): |
| value = value.detach().cpu().numpy() |
| return np.asarray(value) |
|
|
| @staticmethod |
| def _take_last_columnar(value): |
| if isinstance(value, list): |
| return value[-1] if value else None |
| return value |
|
|
| @classmethod |
| def _normalize_intrinsic_cv(cls, intrinsic_like): |
| intrinsic = cls._to_numpy(intrinsic_like) |
| if intrinsic is None: |
| return None |
| intrinsic = intrinsic.reshape(-1) |
| if intrinsic.size < 9: |
| return None |
| intrinsic = intrinsic[:9].reshape(3, 3) |
| if not np.all(np.isfinite(intrinsic)): |
| return None |
| return intrinsic.astype(np.float64, copy=False) |
|
|
| @classmethod |
| def _normalize_extrinsic_cv(cls, extrinsic_like): |
| extrinsic = cls._to_numpy(extrinsic_like) |
| if extrinsic is None: |
| return None |
| extrinsic = extrinsic.reshape(-1) |
| if extrinsic.size < 12: |
| return None |
| extrinsic = extrinsic[:12].reshape(3, 4) |
| if not np.all(np.isfinite(extrinsic)): |
| return None |
| return extrinsic.astype(np.float64, copy=False) |
|
|
| def _update_front_camera_cache(self, obs_like=None, info_like=None): |
| obs_dict = obs_like if isinstance(obs_like, dict) else {} |
| info_dict = info_like if isinstance(info_like, dict) else {} |
|
|
| front_rgb = self._take_last_columnar(obs_dict.get("front_rgb_list")) |
| front_rgb_np = self._to_numpy(front_rgb) |
| if front_rgb_np is not None and front_rgb_np.ndim >= 2: |
| self._front_rgb_shape = tuple(front_rgb_np.shape[:2]) |
|
|
| front_extrinsic = self._take_last_columnar( |
| obs_dict.get("front_camera_extrinsic_list") |
| ) |
| front_extrinsic_np = self._normalize_extrinsic_cv(front_extrinsic) |
| if front_extrinsic_np is not None: |
| self._front_camera_extrinsic_cv = front_extrinsic_np |
|
|
| front_intrinsic = self._take_last_columnar(info_dict.get("front_camera_intrinsic")) |
| front_intrinsic_np = self._normalize_intrinsic_cv(front_intrinsic) |
| if front_intrinsic_np is not None: |
| self._front_camera_intrinsic_cv = front_intrinsic_np |
|
|
| @staticmethod |
| def _empty_target(): |
| return { |
| "obj": None, |
| "name": None, |
| "seg_id": None, |
| "position": None, |
| "match_distance": None, |
| "selection_mode": None, |
| } |
|
|
| def _build_step_options(self): |
| selected_target = self._empty_target() |
| solve_options = get_vqa_options(self.env, self.planner, selected_target, self.env_id) |
| self.available_options = [ |
| {"label": opt.get("label"), "action": opt.get("action", "Unknown"), "need_parameter": bool(opt.get("available"))} |
| for opt in solve_options |
| ] |
| return selected_target, solve_options |
|
|
| def _resolve_command(self, command_dict, solve_options): |
| if not isinstance(command_dict, dict): |
| return None, None |
| if "choice" not in command_dict: |
| return None, None |
|
|
| target_choice = command_dict.get("choice") |
| if not isinstance(target_choice, str): |
| return None, None |
| target_choice = target_choice.strip() |
| if not target_choice: |
| return None, None |
| target_label = target_choice.lower() |
|
|
| found_idx = find_exact_label_option_index(target_label, solve_options) |
| if found_idx == -1: |
| logger.debug( |
| f"Error: Choice '{target_choice}' not found in current options by exact label match." |
| ) |
| return None, None |
|
|
| point = command_dict.get("point") |
| if point is None: |
| return found_idx, None |
| if not isinstance(point, (list, tuple, np.ndarray)) or len(point) < 2: |
| return found_idx, None |
| try: |
| y = float(point[0]) |
| x = float(point[1]) |
| except (TypeError, ValueError): |
| return found_idx, None |
| if not np.isfinite(x) or not np.isfinite(y): |
| return found_idx, None |
| |
| return found_idx, [int(np.rint(x)), int(np.rint(y))] |
|
|
| def _apply_position_target(self, selected_target, option, target_pixel): |
| if target_pixel is None: |
| return |
|
|
| best_cand = select_target_with_pixel( |
| available=option.get("available"), |
| pixel_like=target_pixel, |
| intrinsic_cv=self._front_camera_intrinsic_cv, |
| extrinsic_cv=self._front_camera_extrinsic_cv, |
| image_shape=self._front_rgb_shape, |
| ) |
| if best_cand is not None: |
| selected_target.update(best_cand) |
|
|
| def _execute_selected_option(self, option_idx, solve_options): |
| option = solve_options[option_idx] |
| logger.debug(f"Executing option: {option_idx + 1} - {option.get('action')}") |
|
|
| result = planner_denseStep._run_with_dense_collection( |
| self.planner, |
| lambda: option.get("solve")(), |
| ) |
| if result == -1: |
| action_text = option.get("action", "Unknown") |
| raise RuntimeError( |
| f"Oracle solve failed after screw->RRT* retries for env '{self.env_id}', " |
| f"action '{action_text}' (index {option_idx + 1})." |
| ) |
| return result |
|
|
| def _post_eval(self): |
| self.env.unwrapped.evaluate() |
| evaluation = self.env.unwrapped.evaluate(solve_complete_eval=True) |
| logger.debug(f"Evaluation result: {evaluation}") |
|
|
| def _format_step_output(self, batch): |
| obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| self._update_front_camera_cache(obs_like=obs_batch, info_like=info_batch) |
| info_flat = self._flatten_info_batch(info_batch) |
| info_flat["available_multi_choices"] = getattr(self, "available_options", []) |
| return ( |
| obs_batch, |
| self._take_last_step_value(reward_batch), |
| self._take_last_step_value(terminated_batch), |
| self._take_last_step_value(truncated_batch), |
| info_flat, |
| ) |
|
|
| def step(self, action): |
| """ |
| Execute one step: action is command_dict, must contain "choice", optional |
| pixel `point=[y, x]` in front_rgb. |
| Return last-step signals for reward/terminated/truncated while keeping obs as dict-of-lists. |
| """ |
| |
| selected_target, solve_options = self._build_step_options() |
| |
| found_idx, target_pixel = self._resolve_command(action, solve_options) |
|
|
| |
| if found_idx is None: |
| return self._format_step_output(planner_denseStep.empty_step_batch()) |
|
|
| |
| option = solve_options[found_idx] |
| self._apply_position_target( |
| selected_target=selected_target, |
| option=option, |
| target_pixel=target_pixel, |
| ) |
|
|
| requires_target = "available" in option |
| if requires_target: |
| if target_pixel is None: |
| raise ValueError( |
| f"Multi-choice action '{option.get('action', 'Unknown')}' requires " |
| "a target pixel point=[y, x], but command did not provide it." |
| ) |
| if selected_target.get("obj") is None: |
| raise ValueError( |
| f"Multi-choice action '{option.get('action', 'Unknown')}' could not match " |
| f"any available candidate from point={target_pixel}." |
| ) |
|
|
| |
| batch = self._execute_selected_option(found_idx, solve_options) |
| |
| self._post_eval() |
| |
|
|
| print("step!!!!!!") |
| return self._format_step_output(batch) |
|
|