| """Integration tests for type-aware answer verification in SQLEnvironment.""" |
|
|
| import json |
| import sqlite3 |
|
|
| import pytest |
|
|
| from sql_env.models import QuestionRecord, SQLAction |
| from sql_env.server.sql_environment import SQLEnvironment |
| from sql_env.server.test_sql_env import MockTokenizer |
|
|
|
|
| @pytest.fixture |
| def env(tmp_path): |
| db_id = "integration_db" |
| db_root = tmp_path / "databases" |
| db_dir = db_root / db_id |
| db_dir.mkdir(parents=True) |
| db_path = db_dir / f"{db_id}.sqlite" |
|
|
| connection = sqlite3.connect(db_path) |
| cursor = connection.cursor() |
| cursor.execute( |
| "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary REAL)" |
| ) |
| cursor.execute("CREATE TABLE departments (name TEXT)") |
| cursor.executemany( |
| "INSERT INTO employees (id, name, dept, salary) VALUES (?, ?, ?, ?)", |
| [ |
| (1, "Alice", "Engineering", 99.5), |
| (2, "Bob", "Engineering", 100.0), |
| (3, "Cara", "Sales", 100.5), |
| ], |
| ) |
| cursor.executemany( |
| "INSERT INTO departments (name) VALUES (?)", |
| [("Alice",), ("Bob",)], |
| ) |
| connection.commit() |
| connection.close() |
|
|
| questions_path = tmp_path / "questions.json" |
| questions_path.write_text( |
| json.dumps( |
| [ |
| { |
| "question": "Placeholder", |
| "db_id": db_id, |
| "query": "SELECT 1", |
| } |
| ] |
| ), |
| encoding="utf-8", |
| ) |
|
|
| return SQLEnvironment( |
| questions_path=str(questions_path), |
| db_dir=str(db_root), |
| tokenizer=MockTokenizer(), |
| ) |
|
|
|
|
| def _set_single_question(env: SQLEnvironment, *, sql: str, answer_type: str | None) -> None: |
| env.questions = [ |
| QuestionRecord( |
| question_id="q-0", |
| question_text="Integration check", |
| database_name="integration_db", |
| gold_sql=sql, |
| gold_answer="", |
| answer_type=answer_type if answer_type is not None else "string", |
| difficulty="easy", |
| tables_involved=[], |
| ) |
| ] |
| if answer_type is None: |
| env.questions[0].answer_type = None |
|
|
|
|
| def test_integer_answer_flow(env): |
| _set_single_question( |
| env, |
| sql="SELECT COUNT(*) FROM employees", |
| answer_type="integer", |
| ) |
|
|
| env.reset(seed=1) |
| observation = env.step(SQLAction(action_type="ANSWER", argument="3.0")) |
|
|
| assert observation.done is True |
| assert observation.reward == 1.0 |
|
|
|
|
| def test_float_answer_flow(env): |
| _set_single_question( |
| env, |
| sql="SELECT AVG(salary) FROM employees", |
| answer_type="float", |
| ) |
|
|
| env.reset(seed=1) |
| observation = env.step(SQLAction(action_type="ANSWER", argument="100.0")) |
|
|
| assert observation.done is True |
| assert observation.reward == 1.0 |
|
|
|
|
| def test_string_answer_flow(env): |
| _set_single_question( |
| env, |
| sql="SELECT dept FROM employees WHERE id = 1", |
| answer_type="string", |
| ) |
|
|
| env.reset(seed=1) |
| observation = env.step(SQLAction(action_type="ANSWER", argument=" engineering ")) |
|
|
| assert observation.done is True |
| assert observation.reward == 1.0 |
|
|
|
|
| def test_list_answer_flow(env): |
| _set_single_question( |
| env, |
| sql="SELECT name FROM departments ORDER BY name", |
| answer_type="list", |
| ) |
|
|
| env.reset(seed=1) |
| observation = env.step(SQLAction(action_type="ANSWER", argument="Bob, Alice")) |
|
|
| assert observation.done is True |
| assert observation.reward == 1.0 |
|
|
|
|
| def test_fallback_when_answer_type_missing(env): |
| _set_single_question( |
| env, |
| sql="SELECT dept FROM employees WHERE id = 1", |
| answer_type=None, |
| ) |
|
|
| env.reset(seed=1) |
| observation = env.step(SQLAction(action_type="ANSWER", argument="engineering")) |
|
|
| assert observation.done is True |
| assert observation.reward == 1.0 |
|
|
|
|
| def test_type_coercion_failure_returns_zero_reward(env): |
| _set_single_question( |
| env, |
| sql="SELECT COUNT(*) FROM employees", |
| answer_type="integer", |
| ) |
|
|
| env.reset(seed=1) |
| observation = env.step(SQLAction(action_type="ANSWER", argument="not-a-number")) |
|
|
| assert observation.done is True |
| assert observation.reward == 0.0 |
|
|