| |
| """Compare how v3 and v4 replay pipelines read multi_choice actions. |
| |
| v3 source: |
| - EpisodeDatasetResolver.get_step("multi_choice", step) |
| |
| v4-noresolver source: |
| - scripts.dataset_replay._build_action_sequence(..., "multi_choice") |
| - then _parse_oracle_command() in replay loop |
| """ |
|
|
| import argparse |
| import importlib.util |
| import json |
| import re |
| import sys |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
| import h5py |
| import numpy as np |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
| SRC_ROOT = REPO_ROOT / "src" |
| if str(SRC_ROOT) not in sys.path: |
| sys.path.insert(0, str(SRC_ROOT)) |
|
|
|
|
| def _load_episode_dataset_resolver_cls(): |
| resolver_path = SRC_ROOT / "robomme" / "env_record_wrapper" / "episode_dataset_resolver.py" |
| spec = importlib.util.spec_from_file_location( |
| "episode_dataset_resolver_direct", |
| resolver_path, |
| ) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"Failed to load resolver module from {resolver_path}") |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| resolver_cls = getattr(module, "EpisodeDatasetResolver", None) |
| if resolver_cls is None: |
| raise RuntimeError(f"EpisodeDatasetResolver not found in {resolver_path}") |
| return resolver_cls |
|
|
|
|
| EpisodeDatasetResolver = _load_episode_dataset_resolver_cls() |
|
|
| DEFAULT_ENV_ID = "PatternLock" |
| DEFAULT_DATASET_ROOT = "/data/hongzefu/data_0226-test" |
|
|
|
|
| def _parse_oracle_command_v4(choice_action: Optional[Any]) -> Optional[dict[str, Any]]: |
| """Exact validation logic used in evaluate_dataset_replay-parallelv4-noresolver.py.""" |
| if not isinstance(choice_action, dict): |
| return None |
| choice = choice_action.get("choice") |
| if not isinstance(choice, str) or not choice.strip(): |
| return None |
| point = choice_action.get("point") |
| if not isinstance(point, (list, tuple, np.ndarray)) or len(point) != 2: |
| return None |
| return choice_action |
|
|
|
|
| def _is_video_demo_v4(ts: h5py.Group) -> bool: |
| info = ts.get("info") |
| if info is None or "is_video_demo" not in info: |
| return False |
| return bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0]) |
|
|
|
|
| def _is_subgoal_boundary_v4(ts: h5py.Group) -> bool: |
| info = ts.get("info") |
| if info is None or "is_subgoal_boundary" not in info: |
| return False |
| return bool(np.reshape(np.asarray(info["is_subgoal_boundary"][()]), -1)[0]) |
|
|
|
|
| def _decode_h5_str_v4(raw: Any) -> str: |
| if isinstance(raw, np.ndarray): |
| raw = raw.flatten()[0] |
| if isinstance(raw, (bytes, np.bytes_)): |
| raw = raw.decode("utf-8") |
| return raw |
|
|
|
|
| def _build_multi_choice_sequence_v4(episode_data: h5py.Group) -> list[Any]: |
| """ |
| Re-implementation of dataset_replay._build_action_sequence(..., \"multi_choice\") |
| without importing cv2/imageio/torch dependencies. |
| """ |
| timestep_keys = sorted( |
| (k for k in episode_data.keys() if k.startswith("timestep_")), |
| key=lambda k: int(k.split("_")[1]), |
| ) |
|
|
| out: list[Any] = [] |
| for key in timestep_keys: |
| ts = episode_data[key] |
| if _is_video_demo_v4(ts): |
| continue |
|
|
| action_grp = ts.get("action") |
| if action_grp is None: |
| continue |
| if not _is_subgoal_boundary_v4(ts): |
| continue |
| if "choice_action" not in action_grp: |
| continue |
|
|
| raw = _decode_h5_str_v4(action_grp["choice_action"][()]) |
| try: |
| out.append(json.loads(raw)) |
| except (TypeError, ValueError, json.JSONDecodeError): |
| continue |
| return out |
|
|
|
|
| def _resolve_h5_path(env_id: str, dataset_root: Optional[str], h5_path: Optional[str]) -> Path: |
| if h5_path: |
| return Path(h5_path) |
| if not dataset_root: |
| raise ValueError("Either --h5_path or --dataset_root must be provided") |
| return Path(dataset_root) / f"record_dataset_{env_id}.h5" |
|
|
|
|
| def _episode_indices(data: h5py.File) -> list[int]: |
| return sorted( |
| int(m.group(1)) |
| for key in data.keys() |
| for m in [re.match(r"episode_(\d+)$", key)] |
| if m |
| ) |
|
|
|
|
| def _parse_episode_filter(raw: Optional[str], all_eps: list[int]) -> list[int]: |
| if not raw: |
| return all_eps |
|
|
| selected: set[int] = set() |
| for token in [x.strip() for x in raw.split(",") if x.strip()]: |
| if "-" in token: |
| lo_s, hi_s = token.split("-", 1) |
| lo = int(lo_s) |
| hi = int(hi_s) |
| if lo > hi: |
| lo, hi = hi, lo |
| selected.update(range(lo, hi + 1)) |
| else: |
| selected.add(int(token)) |
|
|
| return [ep for ep in all_eps if ep in selected] |
|
|
|
|
| def _canonical_command(cmd: Any) -> str: |
| """Stable string form for diffing and readable output.""" |
| try: |
| return json.dumps(cmd, ensure_ascii=False, sort_keys=True) |
| except TypeError: |
| if isinstance(cmd, dict): |
| safe = { |
| str(k): (v.tolist() if isinstance(v, np.ndarray) else v) |
| for k, v in cmd.items() |
| } |
| return json.dumps(safe, ensure_ascii=False, sort_keys=True) |
| return repr(cmd) |
|
|
|
|
| def _read_v4_commands(episode_group: h5py.Group) -> tuple[list[Any], list[dict[str, Any]], int]: |
| raw_list = _build_multi_choice_sequence_v4(episode_group) |
| parsed_list: list[dict[str, Any]] = [] |
| skipped = 0 |
|
|
| for item in raw_list: |
| parsed = _parse_oracle_command_v4(item) |
| if parsed is None: |
| skipped += 1 |
| continue |
| parsed_list.append(parsed) |
|
|
| return raw_list, parsed_list, skipped |
|
|
|
|
| def _read_v3_commands(env_id: str, episode: int, dataset_ref: str) -> list[dict[str, Any]]: |
| out: list[dict[str, Any]] = [] |
| with EpisodeDatasetResolver( |
| env_id=env_id, |
| episode=episode, |
| dataset_directory=dataset_ref, |
| ) as resolver: |
| step = 0 |
| while True: |
| cmd = resolver.get_step("multi_choice", step) |
| if cmd is None: |
| break |
| if isinstance(cmd, dict): |
| out.append(cmd) |
| step += 1 |
| return out |
|
|
|
|
| def compare_episode( |
| env_id: str, |
| episode: int, |
| episode_group: h5py.Group, |
| dataset_ref: str, |
| max_show: int, |
| ) -> None: |
| v4_raw, v4_effective, v4_skipped = _read_v4_commands(episode_group) |
| v3_resolver = _read_v3_commands(env_id=env_id, episode=episode, dataset_ref=dataset_ref) |
|
|
| print(f"\n=== episode_{episode} ===") |
| print( |
| "counts: " |
| f"v4_raw={len(v4_raw)}, " |
| f"v4_effective={len(v4_effective)} (skipped_by_parse={v4_skipped}), " |
| f"v3_resolver={len(v3_resolver)}" |
| ) |
|
|
| v4_effective_c = [_canonical_command(x) for x in v4_effective] |
| v3_c = [_canonical_command(x) for x in v3_resolver] |
|
|
| if v4_effective_c == v3_c: |
| print("effective sequence compare: SAME") |
| else: |
| print("effective sequence compare: DIFFERENT") |
| max_len = max(len(v4_effective_c), len(v3_c)) |
| shown = 0 |
| for idx in range(max_len): |
| left = v4_effective_c[idx] if idx < len(v4_effective_c) else "<MISSING>" |
| right = v3_c[idx] if idx < len(v3_c) else "<MISSING>" |
| if left == right: |
| continue |
| print(f" idx={idx}") |
| print(f" v4_effective: {left}") |
| print(f" v3_resolver : {right}") |
| shown += 1 |
| if shown >= max_show: |
| remaining = max_len - idx - 1 |
| if remaining > 0: |
| print(f" ... more differences omitted ({remaining} remaining positions)") |
| break |
|
|
| print(f"sample v4_raw (first {max_show}):") |
| for i, item in enumerate(v4_raw[:max_show]): |
| print(f" [{i}] {_canonical_command(item)}") |
|
|
| print(f"sample v4_effective (first {max_show}):") |
| for i, item in enumerate(v4_effective[:max_show]): |
| print(f" [{i}] {_canonical_command(item)}") |
|
|
| print(f"sample v3_resolver (first {max_show}):") |
| for i, item in enumerate(v3_resolver[:max_show]): |
| print(f" [{i}] {_canonical_command(item)}") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Compare multi_choice read results between " |
| "evaluate_dataset_replay-parallelv3 and parallelv4-noresolver." |
| ) |
| ) |
| parser.add_argument( |
| "--env_id", |
| type=str, |
| default=DEFAULT_ENV_ID, |
| help=f"Task/env id. Default: {DEFAULT_ENV_ID}", |
| ) |
| parser.add_argument( |
| "--dataset_root", |
| type=str, |
| default=DEFAULT_DATASET_ROOT, |
| help=( |
| "Directory that contains record_dataset_<env_id>.h5. " |
| f"Default: {DEFAULT_DATASET_ROOT}" |
| ), |
| ) |
| parser.add_argument( |
| "--h5_path", |
| type=str, |
| default=None, |
| help="Direct path to .h5 file (overrides --dataset_root)", |
| ) |
| parser.add_argument( |
| "--episodes", |
| type=str, |
| default=0, |
| help="Episode filter, e.g. '0,3,8-10'. Default: all episodes in h5", |
| ) |
| parser.add_argument( |
| "--max_show", |
| type=int, |
| default=50, |
| help="Max number of diff/sample rows per episode", |
| ) |
| args = parser.parse_args() |
|
|
| h5_file = _resolve_h5_path(args.env_id, args.dataset_root, args.h5_path) |
| if not h5_file.exists(): |
| raise FileNotFoundError(f"h5 file not found: {h5_file}") |
|
|
| dataset_ref = str(h5_file) if h5_file.suffix == ".h5" else str(h5_file.parent) |
|
|
| print(f"env_id={args.env_id}") |
| print(f"h5={h5_file}") |
|
|
| with h5py.File(h5_file, "r") as data: |
| all_eps = _episode_indices(data) |
| selected_eps = _parse_episode_filter(args.episodes, all_eps) |
|
|
| if not selected_eps: |
| print("No episodes selected.") |
| return |
|
|
| print(f"episodes={selected_eps}") |
| for ep in selected_eps: |
| key = f"episode_{ep}" |
| if key not in data: |
| print(f"\n=== episode_{ep} ===") |
| print("missing in h5, skip") |
| continue |
| compare_episode( |
| env_id=args.env_id, |
| episode=ep, |
| episode_group=data[key], |
| dataset_ref=dataset_ref, |
| max_show=args.max_show, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|