"""Integration tests for type-aware answer verification in SQLEnvironment.""" import json import sqlite3 import pytest from sql_env.models import QuestionRecord, SQLAction from sql_env.server.sql_environment import SQLEnvironment from sql_env.server.test_sql_env import MockTokenizer @pytest.fixture def env(tmp_path): db_id = "integration_db" db_root = tmp_path / "databases" db_dir = db_root / db_id db_dir.mkdir(parents=True) db_path = db_dir / f"{db_id}.sqlite" connection = sqlite3.connect(db_path) cursor = connection.cursor() cursor.execute( "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary REAL)" ) cursor.execute("CREATE TABLE departments (name TEXT)") cursor.executemany( "INSERT INTO employees (id, name, dept, salary) VALUES (?, ?, ?, ?)", [ (1, "Alice", "Engineering", 99.5), (2, "Bob", "Engineering", 100.0), (3, "Cara", "Sales", 100.5), ], ) cursor.executemany( "INSERT INTO departments (name) VALUES (?)", [("Alice",), ("Bob",)], ) connection.commit() connection.close() questions_path = tmp_path / "questions.json" questions_path.write_text( json.dumps( [ { "question": "Placeholder", "db_id": db_id, "query": "SELECT 1", } ] ), encoding="utf-8", ) return SQLEnvironment( questions_path=str(questions_path), db_dir=str(db_root), tokenizer=MockTokenizer(), ) def _set_single_question(env: SQLEnvironment, *, sql: str, answer_type: str | None) -> None: env.questions = [ QuestionRecord( question_id="q-0", question_text="Integration check", database_name="integration_db", gold_sql=sql, gold_answer="", answer_type=answer_type if answer_type is not None else "string", difficulty="easy", tables_involved=[], ) ] if answer_type is None: env.questions[0].answer_type = None def test_integer_answer_flow(env): _set_single_question( env, sql="SELECT COUNT(*) FROM employees", answer_type="integer", ) env.reset(seed=1) observation = env.step(SQLAction(action_type="ANSWER", argument="3.0")) assert observation.done is True assert observation.reward == 1.0 def test_float_answer_flow(env): _set_single_question( env, sql="SELECT AVG(salary) FROM employees", answer_type="float", ) env.reset(seed=1) observation = env.step(SQLAction(action_type="ANSWER", argument="100.0")) assert observation.done is True assert observation.reward == 1.0 def test_string_answer_flow(env): _set_single_question( env, sql="SELECT dept FROM employees WHERE id = 1", answer_type="string", ) env.reset(seed=1) observation = env.step(SQLAction(action_type="ANSWER", argument=" engineering ")) assert observation.done is True assert observation.reward == 1.0 def test_list_answer_flow(env): _set_single_question( env, sql="SELECT name FROM departments ORDER BY name", answer_type="list", ) env.reset(seed=1) observation = env.step(SQLAction(action_type="ANSWER", argument="Bob, Alice")) assert observation.done is True assert observation.reward == 1.0 def test_fallback_when_answer_type_missing(env): _set_single_question( env, sql="SELECT dept FROM employees WHERE id = 1", answer_type=None, ) env.reset(seed=1) observation = env.step(SQLAction(action_type="ANSWER", argument="engineering")) assert observation.done is True assert observation.reward == 1.0 def test_type_coercion_failure_returns_zero_reward(env): _set_single_question( env, sql="SELECT COUNT(*) FROM employees", answer_type="integer", ) env.reset(seed=1) observation = env.step(SQLAction(action_type="ANSWER", argument="not-a-number")) assert observation.done is True assert observation.reward == 0.0