File size: 9,346 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""Unit tests for evaluation package random policy and evaluate()."""

import json
import sqlite3

import pytest

from sql_env.evaluation import RandomPolicy, evaluate
from sql_env.models import SQLAction, SQLObservation
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.server.test_sql_env import MockTokenizer


def _build_sql_environment(tmp_path, *, db_id: str) -> SQLEnvironment:
    db_root = tmp_path / "databases"
    db_dir = db_root / db_id
    db_dir.mkdir(parents=True)
    db_path = db_dir / f"{db_id}.sqlite"

    connection = sqlite3.connect(db_path)
    cursor = connection.cursor()
    cursor.execute(
        "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)"
    )
    cursor.executemany(
        "INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)",
        [
            (1, "Alice", "engineering"),
            (2, "Bob", "engineering"),
            (3, "Cara", "sales"),
        ],
    )
    connection.commit()
    connection.close()

    questions_path = tmp_path / "questions.json"
    questions_path.write_text(
        json.dumps(
            [
                {
                    "question": "How many employees are there?",
                    "db_id": db_id,
                    "query": "SELECT COUNT(*) FROM employees",
                }
            ]
        ),
        encoding="utf-8",
    )

    return SQLEnvironment(
        questions_path=str(questions_path),
        db_dir=str(db_root),
        tokenizer=MockTokenizer(),
    )


def _build_observation(*, budget_remaining: int, result: str = "") -> SQLObservation:
    return SQLObservation(
        question="How many rows?",
        schema_info="Available tables:\n- employees\n- departments",
        result=result,
        error="",
        step_count=0,
        budget_remaining=budget_remaining,
        action_history=[],
        done=False,
        reward=None,
    )


def _terminal_observation(*, reward: float) -> SQLObservation:
    return SQLObservation(
        question="How many rows?",
        schema_info="Available tables:\n- employees\n- departments",
        result="",
        error="",
        step_count=1,
        budget_remaining=0,
        action_history=[],
        done=True,
        reward=reward,
    )


class _FixedPolicy:
    def select_action(self, observation: SQLObservation) -> SQLAction:
        return SQLAction(action_type="QUERY", argument="SELECT 1")


class _RaisingPolicy:
    def __init__(self, fail_on_episode: int) -> None:
        self._fail_on_episode = fail_on_episode
        self._episode_index = -1

    def select_action(self, observation: SQLObservation) -> SQLAction:
        if observation.step_count == 0:
            self._episode_index += 1
        if self._episode_index == self._fail_on_episode:
            raise RuntimeError("policy failed")
        return SQLAction(action_type="QUERY", argument="SELECT 1")


class _SeedTrackingEnv:
    def __init__(self, rewards: list[float]) -> None:
        self._rewards = rewards
        self._episode_index = -1
        self.reset_seeds: list[int | None] = []

    def reset(self, *, seed: int | None = None) -> SQLObservation:
        self.reset_seeds.append(seed)
        self._episode_index += 1
        return _build_observation(budget_remaining=2)

    def step(self, action: SQLAction) -> SQLObservation:
        del action
        reward = self._rewards[self._episode_index]
        return _terminal_observation(reward=reward)


class _FlakyEnv(_SeedTrackingEnv):
    def __init__(self, rewards: list[float], fail_on_episode: int) -> None:
        super().__init__(rewards)
        self._fail_on_episode = fail_on_episode

    def step(self, action: SQLAction) -> SQLObservation:
        if self._episode_index == self._fail_on_episode:
            raise RuntimeError("step failed")
        return super().step(action)


def test_random_policy_explores_when_budget_gt_one() -> None:
    policy = RandomPolicy(seed=42)
    observation = _build_observation(budget_remaining=10)

    action = policy.select_action(observation)

    assert action.action_type in {"DESCRIBE", "SAMPLE", "QUERY"}


def test_random_policy_answers_when_budget_eq_one() -> None:
    policy = RandomPolicy(seed=42)
    observation = _build_observation(budget_remaining=1)

    action = policy.select_action(observation)

    assert action.action_type == "ANSWER"


def test_random_policy_returns_sql_action() -> None:
    policy = RandomPolicy(seed=7)
    observation = _build_observation(budget_remaining=10)

    action = policy.select_action(observation)

    assert isinstance(action, SQLAction)


def test_random_policy_deterministic_with_seed() -> None:
    observation = _build_observation(budget_remaining=10)
    first = RandomPolicy(seed=123)
    second = RandomPolicy(seed=123)

    first_actions = [first.select_action(observation) for _ in range(25)]
    second_actions = [second.select_action(observation) for _ in range(25)]

    assert first_actions == second_actions


def test_random_policy_explores_all_action_types() -> None:
    policy = RandomPolicy(seed=1)
    observation = _build_observation(budget_remaining=10)

    action_types = {policy.select_action(observation).action_type for _ in range(200)}

    assert action_types == {"DESCRIBE", "SAMPLE", "QUERY"}


def test_random_policy_uses_result_rows_for_answer_candidates() -> None:
    policy = RandomPolicy(seed=0)
    observation = _build_observation(
        budget_remaining=1,
        result="1. engineering | 25\n2. sales | 10",
    )

    action = policy.select_action(observation)

    assert action.action_type == "ANSWER"
    assert action.argument in {
        "engineering",
        "25",
        "sales",
        "10",
        "engineering | 25",
        "sales | 10",
    }


def test_evaluate_happy_path() -> None:
    env = _SeedTrackingEnv([1.0, 0.0, 1.0])
    result = evaluate(env, _FixedPolicy(), n_episodes=3)

    assert result.n_episodes == 3
    assert result.n_completed == 3
    assert len(result.episodes) == 3
    assert result.success_rate == 2 / 3
    assert result.avg_reward == 2 / 3
    assert result.avg_steps == 1.0


def test_evaluate_zero_episodes_returns_zero_values() -> None:
    env = _SeedTrackingEnv([])
    result = evaluate(env, _FixedPolicy(), n_episodes=0)

    assert result == result.__class__(
        success_rate=0.0,
        avg_reward=0.0,
        avg_steps=0.0,
        n_episodes=0,
        n_completed=0,
        episodes=[],
    )
    assert env.reset_seeds == []


def test_evaluate_negative_episodes_raises() -> None:
    env = _SeedTrackingEnv([])

    try:
        evaluate(env, _FixedPolicy(), n_episodes=-1)
    except ValueError as exc:
        assert str(exc) == "n_episodes must be >= 0"
    else:
        raise AssertionError("Expected ValueError for negative n_episodes")


def test_evaluate_uses_seed_plus_episode_index() -> None:
    env = _SeedTrackingEnv([1.0, 1.0, 1.0])
    evaluate(env, _FixedPolicy(), n_episodes=3, seed=100)

    assert env.reset_seeds == [100, 101, 102]


def test_evaluate_records_episode_errors_and_continues() -> None:
    env = _FlakyEnv([1.0, 1.0, 1.0], fail_on_episode=1)
    result = evaluate(env, _FixedPolicy(), n_episodes=3)

    assert result.n_episodes == 3
    assert len(result.episodes) == 3
    assert result.n_completed == 2
    assert result.episodes[1].error == "step failed"
    assert result.episodes[2].error is None


def test_evaluate_averages_exclude_failed_episodes() -> None:
    env = _FlakyEnv([1.0, 0.0, 0.0], fail_on_episode=1)
    result = evaluate(env, _FixedPolicy(), n_episodes=3)

    assert result.n_completed == 2
    assert result.avg_reward == 0.5
    assert result.avg_steps == 1.0
    assert result.success_rate == 0.5


def test_evaluate_policy_exception_recorded() -> None:
    env = _SeedTrackingEnv([1.0, 1.0, 1.0])
    result = evaluate(env, _RaisingPolicy(fail_on_episode=1), n_episodes=3)

    assert result.n_completed == 2
    assert result.episodes[1].error == "policy failed"


def test_evaluate_progress_callback_receives_episode_progress() -> None:
    env = _SeedTrackingEnv([1.0, 1.0, 1.0])
    calls: list[tuple[int, int]] = []

    evaluate(
        env,
        _FixedPolicy(),
        n_episodes=3,
        progress_callback=lambda current, total: calls.append((current, total)),
    )

    assert calls == [(1, 3), (2, 3), (3, 3)]


def test_evaluate_integration_with_sql_environment(tmp_path) -> None:
    env = _build_sql_environment(tmp_path, db_id="integration_eval")

    result = evaluate(env, RandomPolicy(seed=42), n_episodes=10, seed=0)

    assert result.n_episodes == 10
    assert result.n_completed == 10
    assert len(result.episodes) == 10
    assert result.success_rate == sum(int(e.correct) for e in result.episodes) / 10
    assert result.avg_reward == pytest.approx(
        sum(e.total_reward for e in result.episodes) / 10
    )


def test_evaluate_integration_is_deterministic_with_seeds(tmp_path) -> None:
    env_a = _build_sql_environment(tmp_path / "run_a", db_id="integration_eval")
    env_b = _build_sql_environment(tmp_path / "run_b", db_id="integration_eval")

    result_a = evaluate(env_a, RandomPolicy(seed=42), n_episodes=10, seed=0)
    result_b = evaluate(env_b, RandomPolicy(seed=42), n_episodes=10, seed=0)

    assert result_a == result_b