sql_env / server /reward.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Reward helpers for SQLEnv dense shaping."""
from __future__ import annotations
import hashlib
import math
try:
from sql_env.models import EpisodeContext
except ImportError: # pragma: no cover - Docker fallback import path
from models import EpisodeContext # type: ignore[no-redef]
_EXEC_OK_REWARD = 0.02
_NEW_INFO_REWARD = 0.01
_NEW_INFO_CAP = 0.10
_REPEAT_PENALTY = 0.01
_STEP_COST = 0.005
_LAYER2_CARDINALITY_WEIGHT = 0.25
_LAYER2_VALUE_OVERLAP_WEIGHT = 0.50
_LAYER2_NUMERIC_RANGE_WEIGHT = 0.25
_LAYER2_IMPROVEMENT_SCALE = 0.15
_STEP_REWARD_FLOOR = -0.2
_STEP_REWARD_CAP = 0.5
def compute_step_reward(
ctx: EpisodeContext,
action_type: str,
sql: str,
rows: list[tuple] | None,
error: str | None,
) -> float:
"""Compute one dense step reward and clamp cumulative episode shaping.
Combines Layer 1 operational shaping with Layer 2 progress shaping for
successful QUERY actions, then clamps cumulative step reward to
``[-0.2, 0.5]`` and returns only the clamped delta for this step.
"""
step_reward = _layer1_operational(ctx, action_type, sql, rows, error)
if action_type.upper() == "QUERY" and rows is not None and error is None:
step_reward += _layer2_progress(ctx, rows)
unclamped_total = ctx.cumulative_step_reward + step_reward
clamped_total = min(_STEP_REWARD_CAP, max(_STEP_REWARD_FLOOR, unclamped_total))
clamped_delta = clamped_total - ctx.cumulative_step_reward
ctx.cumulative_step_reward = clamped_total
return clamped_delta
def _layer1_operational(
ctx: EpisodeContext,
action_type: str,
sql: str,
rows: list[tuple] | None,
error: str | None,
) -> float:
"""Compute Layer 1 operational reward signals.
Layer 1 applies:
- `+0.02` for successful execution (`error is None`)
- `+0.01` new-info for first-seen successful QUERY (capped at 0.10 cumulative)
- `-0.01` repeat penalty for repeated QUERY SQL
- `-0.005` step cost on every call
"""
reward = -_STEP_COST
is_query = action_type.upper() == "QUERY"
query_hash: str | None = None
is_repeat = False
if is_query and sql:
query_hash = hashlib.sha256(sql.encode("utf-8")).hexdigest()
is_repeat = query_hash in ctx.query_hashes
if is_repeat:
reward -= _REPEAT_PENALTY
elif error is None:
reward += _EXEC_OK_REWARD
if (
is_query
and error is None
and rows is not None
and query_hash is not None
and not is_repeat
):
ctx.query_hashes.add(query_hash)
if ctx.cumulative_new_info_reward < _NEW_INFO_CAP:
remaining = _NEW_INFO_CAP - ctx.cumulative_new_info_reward
delta = min(_NEW_INFO_REWARD, remaining)
ctx.cumulative_new_info_reward += delta
reward += delta
return reward
def _cardinality_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float:
"""Compute row-count similarity score in [0.0, 1.0]."""
pred_count = len(pred_rows)
gold_count = len(gold_rows)
denominator = max(pred_count, gold_count, 1)
score = 1.0 - (abs(pred_count - gold_count) / denominator)
return max(0.0, min(1.0, score))
def _value_overlap_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float:
"""Compute Jaccard overlap of flattened cell values as strings."""
pred_values = {str(cell) for row in pred_rows for cell in row}
gold_values = {str(cell) for row in gold_rows for cell in row}
union = pred_values | gold_values
if not union:
return 0.0
intersection = pred_values & gold_values
return len(intersection) / len(union)
def _numeric_range_score(pred_rows: list[tuple], gold_rows: list[tuple]) -> float:
"""Compute log-distance proximity for numeric cell values."""
def _is_numeric(value: object) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)
pred_numerics = [float(cell) for row in pred_rows for cell in row if _is_numeric(cell)]
gold_numerics = [float(cell) for row in gold_rows for cell in row if _is_numeric(cell)]
if not gold_numerics:
return 1.0
if not pred_numerics:
return 0.0
total = 0.0
for gold_value in gold_numerics:
closest_distance = min(abs(pred_value - gold_value) for pred_value in pred_numerics)
total += 1.0 / (1.0 + math.log1p(closest_distance))
return total / len(gold_numerics)
def _bin_progress(raw_score: float) -> float:
"""Bin raw progress to one of {0.0, 0.25, 0.5, 0.75, 1.0}."""
clamped_score = max(0.0, min(1.0, raw_score))
if clamped_score < 0.125:
return 0.0
if clamped_score < 0.375:
return 0.25
if clamped_score < 0.625:
return 0.5
if clamped_score < 0.875:
return 0.75
return 1.0
def _layer2_progress(ctx: EpisodeContext, rows: list[tuple]) -> float:
"""Compute Layer 2 progress reward with improvement-only gating."""
if not ctx.gold_rows:
return 0.0
cardinality = _cardinality_score(rows, ctx.gold_rows)
value_overlap = _value_overlap_score(rows, ctx.gold_rows)
numeric_range = _numeric_range_score(rows, ctx.gold_rows)
raw_progress = (
_LAYER2_CARDINALITY_WEIGHT * cardinality
+ _LAYER2_VALUE_OVERLAP_WEIGHT * value_overlap
+ _LAYER2_NUMERIC_RANGE_WEIGHT * numeric_range
)
binned_progress = _bin_progress(raw_progress)
if binned_progress <= ctx.best_progress:
return 0.0
progress_delta = binned_progress - ctx.best_progress
ctx.best_progress = binned_progress
return progress_delta * _LAYER2_IMPROVEMENT_SCALE