| | import copy |
| | import json |
| | import random |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional, List |
| |
|
| |
|
| | @dataclass |
| | class StepResult: |
| | obs: Dict[str, Any] |
| | reward: float |
| | done: bool |
| | info: Dict[str, Any] |
| |
|
| |
|
| | class FlowDebugEnv: |
| | """ |
| | This is a simple environment made for OpenEnv. |
| | Here's what it does: |
| | - It gives you information in text or JSON. |
| | - You can only do one thing: fix the 'inputs.expression' in a 'Condition_Check' step. |
| | - If your fix is exactly right, you win! |
| | """ |
| | def __init__(self, cases: List[Dict[str, Any]], max_attempts: int = 3, seed: Optional[int] = None): |
| | self.cases = cases |
| | self.max_attempts = max_attempts |
| | self.rng = random.Random(seed) |
| | self.current_case: Optional[Dict[str, Any]] = None |
| | self.attempts_left = max_attempts |
| |
|
| | @classmethod |
| | def from_json(cls, cases_json_path: str, max_attempts: int = 3, seed: Optional[int] = None): |
| | path = Path(cases_json_path) |
| | with open(path, "r", encoding="utf-8") as f: |
| | cases = json.load(f) |
| | return cls(cases=cases, max_attempts=max_attempts, seed=seed) |
| |
|
| | def reset(self) -> Dict[str, Any]: |
| | self.current_case = copy.deepcopy(self.rng.choice(self.cases)) |
| | self.attempts_left = self.max_attempts |
| | return self._make_observation() |
| |
|
| | def step(self, action: Dict[str, Any]) -> StepResult: |
| | if self.current_case is None: |
| | raise RuntimeError("Call reset() before step().") |
| |
|
| | self.attempts_left -= 1 |
| |
|
| | if action.get("action") != "patch_step": |
| | return self._invalid_action("Unsupported action type") |
| |
|
| | step_name = action.get("step") |
| | field = action.get("field") |
| | value = action.get("value") |
| |
|
| | patched_ok = self._apply_patch(step_name, field, value) |
| | if not patched_ok: |
| | return self._invalid_action("Patch failed (step/field not found)") |
| |
|
| | gold = self.current_case["gold_fix"] |
| | solved = (step_name == gold["step"] and field == gold["field"] and value == gold["value"]) |
| |
|
| | if solved: |
| | self._mark_success() |
| | obs = self._make_observation(run_status="Succeeded", error=None, failed_step=None) |
| | return StepResult(obs=obs, reward=1.0, done=True, |
| | info={"result": "success", "case_id": self.current_case["case_id"]}) |
| |
|
| | if self.attempts_left <= 0: |
| | obs = self._make_observation() |
| | return StepResult(obs=obs, reward=-0.2, done=True, |
| | info={"result": "out_of_attempts", "case_id": self.current_case["case_id"]}) |
| |
|
| | obs = self._make_observation() |
| | return StepResult(obs=obs, reward=-0.1, done=False, |
| | info={"result": "still_failed", "case_id": self.current_case["case_id"]}) |
| |
|
| | |
| | def _apply_patch(self, step_name: str, field: str, value: str) -> bool: |
| | for step in self.current_case["steps"]: |
| | if step["name"] == step_name: |
| | if field == "inputs.expression": |
| | step.setdefault("inputs", {}) |
| | step["inputs"]["expression"] = value |
| | return True |
| | return False |
| |
|
| | def _mark_success(self): |
| | for step in self.current_case["steps"]: |
| | step["status"] = "Succeeded" |
| |
|
| | def _make_observation(self, run_status="Failed", error="keep", failed_step="keep"): |
| | if error == "keep": |
| | err_obj = self.current_case["error"] |
| | else: |
| | err_obj = error |
| |
|
| | if failed_step == "keep": |
| | failed = self.current_case["failed_step"] |
| | else: |
| | failed = failed_step |
| |
|
| | return { |
| | "case_id": self.current_case["case_id"], |
| | "run_status": run_status, |
| | "failed_step": failed, |
| | "error": err_obj, |
| | "steps": self.current_case["steps"], |
| | "attempts_left": self.attempts_left |
| | } |
| |
|
| | def _invalid_action(self, msg: str) -> StepResult: |
| | obs = self._make_observation() |
| | done = (self.attempts_left <= 0) |
| | return StepResult(obs=obs, reward=-0.1, done=done, |
| | info={"result": "invalid_action", "message": msg, "case_id": self.current_case["case_id"]}) |
| |
|