"""Reward helpers for SQLEnv dense shaping.""" from __future__ import annotations import hashlib import math try: from sql_env.models import EpisodeContext except ImportError: # pragma: no cover - Docker fallback import path from models import EpisodeContext # type: ignore[no-redef] _EXEC_OK_REWARD = 0.02 _NEW_INFO_REWARD = 0.01 _NEW_INFO_CAP = 0.10 _REPEAT_PENALTY = 0.01 _STEP_COST = 0.005 _LAYER2_CARDINALITY_WEIGHT = 0.25 _LAYER2_VALUE_OVERLAP_WEIGHT = 0.50 _LAYER2_NUMERIC_RANGE_WEIGHT = 0.25 _LAYER2_IMPROVEMENT_SCALE = 0.15 _STEP_REWARD_FLOOR = -0.2 _STEP_REWARD_CAP = 0.5 def compute_step_reward( ctx: EpisodeContext, action_type: str, sql: str, rows: list[tuple] | None, error: str | None, ) -> float: """Compute one dense step reward and clamp cumulative episode shaping. Combines Layer 1 operational shaping with Layer 2 progress shaping for successful QUERY actions, then clamps cumulative step reward to ``[-0.2, 0.5]`` and returns only the clamped delta for this step. """ step_reward = _layer1_operational(ctx, action_type, sql, rows, error) if action_type.upper() == "QUERY" and rows is not None and error is None: step_reward += _layer2_progress(ctx, rows) unclamped_total = ctx.cumulative_step_reward + step_reward clamped_total = min(_STEP_REWARD_CAP, max(_STEP_REWARD_FLOOR, unclamped_total)) clamped_delta = clamped_total - ctx.cumulative_step_reward ctx.cumulative_step_reward = clamped_total return clamped_delta def _layer1_operational( ctx: EpisodeContext, action_type: str, sql: str, rows: list[tuple] | None, error: str | None, ) -> float: """Compute Layer 1 operational reward signals. Layer 1 applies: - `+0.02` for successful execution (`error is None`) - `+0.01` new-info for first-seen successful QUERY (capped at 0.10 cumulative) - `-0.01` repeat penalty for repeated QUERY SQL - `-0.005` step cost on every call """ reward = -_STEP_COST is_query = action_type.upper() == "QUERY" query_hash: str | None = None is_repeat = False if is_query and sql: query_hash = hashlib.sha256(sql.encode("utf-8")).hexdigest() is_repeat = query_hash in ctx.query_hashes if is_repeat: reward -= _REPEAT_PENALTY elif error is None: reward += _EXEC_OK_REWARD if ( is_query and error is None and rows is not None and query_hash is not None and not is_repeat ): ctx.query_hashes.add(query_hash) if ctx.cumulative_new_info_reward < _NEW_INFO_CAP: remaining = _NEW_INFO_CAP - ctx.cumulative_new_info_reward delta = min(_NEW_INFO_REWARD, remaining) ctx.cumulative_new_info_reward += delta reward += delta return reward def _cardinality_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float: """Compute row-count similarity score in [0.0, 1.0].""" pred_count = len(pred_rows) gold_count = len(gold_rows) denominator = max(pred_count, gold_count, 1) score = 1.0 - (abs(pred_count - gold_count) / denominator) return max(0.0, min(1.0, score)) def _value_overlap_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float: """Compute Jaccard overlap of flattened cell values as strings.""" pred_values = {str(cell) for row in pred_rows for cell in row} gold_values = {str(cell) for row in gold_rows for cell in row} union = pred_values | gold_values if not union: return 0.0 intersection = pred_values & gold_values return len(intersection) / len(union) def _numeric_range_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float: """Compute log-distance proximity for numeric cell values.""" def _is_numeric(value: object) -> bool: return isinstance(value, (int, float)) and not isinstance(value, bool) pred_numerics = [float(cell) for row in pred_rows for cell in row if _is_numeric(cell)] gold_numerics = [float(cell) for row in gold_rows for cell in row if _is_numeric(cell)] if not gold_numerics: return 1.0 if not pred_numerics: return 0.0 total = 0.0 for gold_value in gold_numerics: closest_distance = min(abs(pred_value - gold_value) for pred_value in pred_numerics) total += 1.0 / (1.0 + math.log1p(closest_distance)) return total / len(gold_numerics) def _bin_progress(raw_score: float) -> float: """Bin raw progress to one of {0.0, 0.25, 0.5, 0.75, 1.0}.""" clamped_score = max(0.0, min(1.0, raw_score)) if clamped_score < 0.125: return 0.0 if clamped_score < 0.375: return 0.25 if clamped_score < 0.625: return 0.5 if clamped_score < 0.875: return 0.75 return 1.0 def _layer2_progress(ctx: EpisodeContext, rows: list[tuple]) -> float: """Compute Layer 2 progress reward with improvement-only gating.""" if not ctx.gold_rows: return 0.0 cardinality = _cardinality_score(rows, ctx.gold_rows) value_overlap = _value_overlap_score(rows, ctx.gold_rows) numeric_range = _numeric_range_score(rows, ctx.gold_rows) raw_progress = ( _LAYER2_CARDINALITY_WEIGHT * cardinality + _LAYER2_VALUE_OVERLAP_WEIGHT * value_overlap + _LAYER2_NUMERIC_RANGE_WEIGHT * numeric_range ) binned_progress = _bin_progress(raw_progress) if binned_progress <= ctx.best_progress: return 0.0 progress_delta = binned_progress - ctx.best_progress ctx.best_progress = binned_progress return progress_delta * _LAYER2_IMPROVEMENT_SCALE