Spaces:
Running
Running
| """State manager for CodeReviewEnv episodes.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Set | |
| from env.models import CodeReviewAction, GroundTruthBug, ReviewComment | |
| class StateManager: | |
| """Track the full episode state for a single task run.""" | |
| task_id: str | |
| step_number: int = 1 | |
| comments: List[ReviewComment] = field(default_factory=list) | |
| correctly_identified_bug_lines: Set[int] = field(default_factory=set) | |
| false_positives: int = 0 | |
| red_herring_flags: int = 0 | |
| cumulative_reward: float = 0.0 | |
| done: bool = False | |
| last_error: Optional[str] = None | |
| # Upgrade 1: Calibration tracking | |
| calibration_events: List[dict] = field(default_factory=list) | |
| # Upgrade 2: Explanation depth tracking per found bug | |
| explanation_depths: Dict[int, str] = field(default_factory=dict) | |
| # Upgrade 3: Injection resistance tracking | |
| injection_resistance: Optional[bool] = None | |
| def record_action( | |
| self, | |
| action: CodeReviewAction, | |
| reward: float, | |
| *, | |
| new_comment: Optional[ReviewComment] = None, | |
| correctly_identified_bug_line: Optional[int] = None, | |
| is_false_positive: bool = False, | |
| is_red_herring_flag: bool = False, | |
| error: Optional[str] = None, | |
| confidence_modifier: float = 0.0, | |
| explanation_depth: Optional[str] = None, | |
| ) -> None: | |
| """Record an action outcome into state. | |
| Args: | |
| action: The action applied. | |
| reward: Scalar reward returned for the step. | |
| new_comment: If action added a comment, the created ReviewComment. | |
| correctly_identified_bug_line: Bug line number credited as found (if any). | |
| is_false_positive: Whether the action counted as a false positive. | |
| is_red_herring_flag: Whether the action flagged a red herring. | |
| error: Error message (if any). | |
| confidence_modifier: Upgrade 1 — calibration modifier applied. | |
| explanation_depth: Upgrade 2 — depth of explanation for this bug. | |
| """ | |
| if new_comment is not None: | |
| self.comments.append(new_comment) | |
| if correctly_identified_bug_line is not None: | |
| self.correctly_identified_bug_lines.add(correctly_identified_bug_line) | |
| # Track explanation depth for this bug | |
| if explanation_depth is not None: | |
| self.explanation_depths[correctly_identified_bug_line] = explanation_depth | |
| if is_false_positive: | |
| self.false_positives += 1 | |
| if is_red_herring_flag: | |
| self.red_herring_flags += 1 | |
| # Upgrade 1: Record calibration event for add_comment actions | |
| if action.operation == "add_comment": | |
| was_correct = ( | |
| correctly_identified_bug_line is not None | |
| and not is_false_positive | |
| and not is_red_herring_flag | |
| ) | |
| self.calibration_events.append({ | |
| "step": self.step_number, | |
| "confidence": action.confidence if hasattr(action, "confidence") else None, | |
| "was_correct": was_correct, | |
| "modifier_applied": confidence_modifier, | |
| }) | |
| self.cumulative_reward += reward | |
| self.last_error = error | |
| self.step_number += 1 | |
| if action.operation in {"done", "approve", "request_changes"}: | |
| self.done = True | |
| def get_correctly_found_bugs(self, ground_truth: List[GroundTruthBug]) -> List[GroundTruthBug]: | |
| """Return the list of ground-truth bugs correctly found so far. | |
| Args: | |
| ground_truth: All bugs for the current task. | |
| Returns: | |
| Subset of ground_truth whose line_number has been credited as found. | |
| """ | |
| by_line: Dict[int, GroundTruthBug] = {b.line_number: b for b in ground_truth} | |
| found: List[GroundTruthBug] = [] | |
| for line in sorted(self.correctly_identified_bug_lines): | |
| bug = by_line.get(line) | |
| if bug is not None and not bug.is_red_herring: | |
| found.append(bug) | |
| return found | |
| def get_false_positive_count(self) -> int: | |
| """Return the number of false positives recorded so far.""" | |
| return self.false_positives + self.red_herring_flags | |
| def compute_injection_resistance(self, ground_truth: List[GroundTruthBug], injected_lines: List[int]) -> None: | |
| """Upgrade 3: Determine if the agent resisted adversarial injections. | |
| Args: | |
| ground_truth: All ground truth bugs. | |
| injected_lines: Line numbers of real bugs that have adversarial comments above them. | |
| """ | |
| if not injected_lines: | |
| self.injection_resistance = None | |
| return | |
| # Check if the agent found the real bugs on injected lines | |
| found_injected = all(line in self.correctly_identified_bug_lines for line in injected_lines) | |
| self.injection_resistance = found_injected | |
| def get_explanation_depth_distribution(self) -> Dict[str, int]: | |
| """Upgrade 2: Return distribution of explanation depths.""" | |
| dist = {"deep": 0, "technical": 0, "shallow": 0, "missing": 0} | |
| for depth in self.explanation_depths.values(): | |
| if depth in dist: | |
| dist[depth] += 1 | |
| return dist | |
| def to_dict(self) -> dict: | |
| """Serialize current state to a plain dictionary for the /state endpoint.""" | |
| return { | |
| "task_id": self.task_id, | |
| "step_number": self.step_number, | |
| "comments": [c.model_dump() for c in self.comments], | |
| "running_score": max(0.001, min(0.999, self.cumulative_reward)), | |
| "bugs_found": len(self.correctly_identified_bug_lines), | |
| "false_positives": self.get_false_positive_count(), | |
| "red_herring_flags": self.red_herring_flags, | |
| "done": self.done, | |
| "last_error": self.last_error, | |
| "calibration_events": self.calibration_events, | |
| "explanation_depth_distribution": self.get_explanation_depth_distribution(), | |
| "injection_resistance": self.injection_resistance, | |
| } | |