| """Reward helpers for SQLEnv dense shaping.""" |
|
|
| from __future__ import annotations |
|
|
| import hashlib |
| import math |
|
|
| try: |
| from sql_env.models import EpisodeContext |
| except ImportError: |
| from models import EpisodeContext |
|
|
|
|
| _EXEC_OK_REWARD = 0.02 |
| _NEW_INFO_REWARD = 0.01 |
| _NEW_INFO_CAP = 0.10 |
| _REPEAT_PENALTY = 0.01 |
| _STEP_COST = 0.005 |
| _LAYER2_CARDINALITY_WEIGHT = 0.25 |
| _LAYER2_VALUE_OVERLAP_WEIGHT = 0.50 |
| _LAYER2_NUMERIC_RANGE_WEIGHT = 0.25 |
| _LAYER2_IMPROVEMENT_SCALE = 0.15 |
| _STEP_REWARD_FLOOR = -0.2 |
| _STEP_REWARD_CAP = 0.5 |
|
|
|
|
| def compute_step_reward( |
| ctx: EpisodeContext, |
| action_type: str, |
| sql: str, |
| rows: list[tuple] | None, |
| error: str | None, |
| ) -> float: |
| """Compute one dense step reward and clamp cumulative episode shaping. |
| |
| Combines Layer 1 operational shaping with Layer 2 progress shaping for |
| successful QUERY actions, then clamps cumulative step reward to |
| ``[-0.2, 0.5]`` and returns only the clamped delta for this step. |
| """ |
|
|
| step_reward = _layer1_operational(ctx, action_type, sql, rows, error) |
|
|
| if action_type.upper() == "QUERY" and rows is not None and error is None: |
| step_reward += _layer2_progress(ctx, rows) |
|
|
| unclamped_total = ctx.cumulative_step_reward + step_reward |
| clamped_total = min(_STEP_REWARD_CAP, max(_STEP_REWARD_FLOOR, unclamped_total)) |
| clamped_delta = clamped_total - ctx.cumulative_step_reward |
| ctx.cumulative_step_reward = clamped_total |
|
|
| return clamped_delta |
|
|
|
|
| def _layer1_operational( |
| ctx: EpisodeContext, |
| action_type: str, |
| sql: str, |
| rows: list[tuple] | None, |
| error: str | None, |
| ) -> float: |
| """Compute Layer 1 operational reward signals. |
| |
| Layer 1 applies: |
| - `+0.02` for successful execution (`error is None`) |
| - `+0.01` new-info for first-seen successful QUERY (capped at 0.10 cumulative) |
| - `-0.01` repeat penalty for repeated QUERY SQL |
| - `-0.005` step cost on every call |
| """ |
|
|
| reward = -_STEP_COST |
|
|
| is_query = action_type.upper() == "QUERY" |
| query_hash: str | None = None |
| is_repeat = False |
|
|
| if is_query and sql: |
| query_hash = hashlib.sha256(sql.encode("utf-8")).hexdigest() |
| is_repeat = query_hash in ctx.query_hashes |
|
|
| if is_repeat: |
| reward -= _REPEAT_PENALTY |
| elif error is None: |
| reward += _EXEC_OK_REWARD |
|
|
| if ( |
| is_query |
| and error is None |
| and rows is not None |
| and query_hash is not None |
| and not is_repeat |
| ): |
| ctx.query_hashes.add(query_hash) |
| if ctx.cumulative_new_info_reward < _NEW_INFO_CAP: |
| remaining = _NEW_INFO_CAP - ctx.cumulative_new_info_reward |
| delta = min(_NEW_INFO_REWARD, remaining) |
| ctx.cumulative_new_info_reward += delta |
| reward += delta |
|
|
| return reward |
|
|
|
|
| def _cardinality_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float: |
| """Compute row-count similarity score in [0.0, 1.0].""" |
|
|
| pred_count = len(pred_rows) |
| gold_count = len(gold_rows) |
| denominator = max(pred_count, gold_count, 1) |
| score = 1.0 - (abs(pred_count - gold_count) / denominator) |
| return max(0.0, min(1.0, score)) |
|
|
|
|
| def _value_overlap_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float: |
| """Compute Jaccard overlap of flattened cell values as strings.""" |
|
|
| pred_values = {str(cell) for row in pred_rows for cell in row} |
| gold_values = {str(cell) for row in gold_rows for cell in row} |
|
|
| union = pred_values | gold_values |
| if not union: |
| return 0.0 |
|
|
| intersection = pred_values & gold_values |
| return len(intersection) / len(union) |
|
|
|
|
| def _numeric_range_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float: |
| """Compute log-distance proximity for numeric cell values.""" |
|
|
| def _is_numeric(value: object) -> bool: |
| return isinstance(value, (int, float)) and not isinstance(value, bool) |
|
|
| pred_numerics = [float(cell) for row in pred_rows for cell in row if _is_numeric(cell)] |
| gold_numerics = [float(cell) for row in gold_rows for cell in row if _is_numeric(cell)] |
|
|
| if not gold_numerics: |
| return 1.0 |
| if not pred_numerics: |
| return 0.0 |
|
|
| total = 0.0 |
| for gold_value in gold_numerics: |
| closest_distance = min(abs(pred_value - gold_value) for pred_value in pred_numerics) |
| total += 1.0 / (1.0 + math.log1p(closest_distance)) |
|
|
| return total / len(gold_numerics) |
|
|
|
|
| def _bin_progress(raw_score: float) -> float: |
| """Bin raw progress to one of {0.0, 0.25, 0.5, 0.75, 1.0}.""" |
|
|
| clamped_score = max(0.0, min(1.0, raw_score)) |
| if clamped_score < 0.125: |
| return 0.0 |
| if clamped_score < 0.375: |
| return 0.25 |
| if clamped_score < 0.625: |
| return 0.5 |
| if clamped_score < 0.875: |
| return 0.75 |
| return 1.0 |
|
|
|
|
| def _layer2_progress(ctx: EpisodeContext, rows: list[tuple]) -> float: |
| """Compute Layer 2 progress reward with improvement-only gating.""" |
|
|
| if not ctx.gold_rows: |
| return 0.0 |
|
|
| cardinality = _cardinality_score(rows, ctx.gold_rows) |
| value_overlap = _value_overlap_score(rows, ctx.gold_rows) |
| numeric_range = _numeric_range_score(rows, ctx.gold_rows) |
|
|
| raw_progress = ( |
| _LAYER2_CARDINALITY_WEIGHT * cardinality |
| + _LAYER2_VALUE_OVERLAP_WEIGHT * value_overlap |
| + _LAYER2_NUMERIC_RANGE_WEIGHT * numeric_range |
| ) |
| binned_progress = _bin_progress(raw_progress) |
|
|
| if binned_progress <= ctx.best_progress: |
| return 0.0 |
|
|
| progress_delta = binned_progress - ctx.best_progress |
| ctx.best_progress = binned_progress |
| return progress_delta * _LAYER2_IMPROVEMENT_SCALE |
|
|