sql_env / tests /test_verifier.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Unit tests for type-aware answer verification helpers."""
import sqlite3
from sql_env.models import EpisodeContext, QuestionRecord
from sql_env.server.verifier import (
_compare_float,
_compare_integer,
_compare_list,
_compare_string,
verify_answer,
)
def _build_question_record() -> QuestionRecord:
return QuestionRecord(
question_id="q-1",
question_text="How many students?",
database_name="test_db",
gold_sql="SELECT 1",
gold_answer="1",
answer_type="integer",
difficulty="easy",
tables_involved=["students"],
)
def _build_episode_context(gold_rows: list[tuple] | None = None) -> EpisodeContext:
return EpisodeContext(
episode_id="ep-1",
db_connection=sqlite3.connect(":memory:"),
question_record=_build_question_record(),
gold_rows=gold_rows,
)
def test_verify_integer_exact_match() -> None:
assert verify_answer(predicted="42", gold="42", answer_type="integer") is True
def test_verify_float_within_tolerance() -> None:
assert verify_answer(predicted="3.14", gold="3.15", answer_type="float") is True
def test_verify_string_case_insensitive() -> None:
assert verify_answer(predicted="Alice", gold="alice", answer_type="string") is True
def test_verify_list_order_insensitive() -> None:
assert verify_answer(predicted="a, b", gold="b, a", answer_type="list") is True
def test_verify_none_type_falls_back_to_string() -> None:
assert verify_answer(predicted=" hello ", gold="hello", answer_type=None) is True
def test_verify_unknown_type_falls_back_to_string() -> None:
assert verify_answer(predicted="foo", gold="foo", answer_type="table") is True
def test_verify_empty_predicted_returns_false() -> None:
assert verify_answer(predicted=" ", gold="42", answer_type="integer") is False
def test_verify_none_like_predicted_returns_false() -> None:
assert verify_answer(predicted="", gold="42", answer_type=None) is False
def test_compare_integer_exact_match() -> None:
assert _compare_integer("25", "25") is True
def test_compare_integer_from_float_string() -> None:
assert _compare_integer("25.0", "25") is True
def test_compare_integer_mismatch() -> None:
assert _compare_integer("24", "25") is False
def test_compare_integer_non_numeric_returns_false() -> None:
assert _compare_integer("abc", "25") is False
def test_compare_integer_whitespace_only_returns_false() -> None:
assert _compare_integer(" ", "25") is False
def test_compare_integer_float_truncation() -> None:
assert _compare_integer("25.9", "25") is True
def test_compare_float_exact_match() -> None:
assert _compare_float("3.14", "3.14") is True
def test_compare_float_within_1pct_tolerance() -> None:
assert _compare_float("100.5", "100.0") is True
def test_compare_float_outside_1pct_tolerance() -> None:
assert _compare_float("102.0", "100.0") is False
def test_compare_float_boundary_exactly_1pct() -> None:
assert _compare_float("101.0", "100.0") is True
def test_compare_float_just_over_1pct() -> None:
assert _compare_float("101.01", "100.0") is False
def test_compare_float_gold_zero_uses_absolute_tolerance() -> None:
assert _compare_float("0.0000000001", "0") is True
def test_compare_float_gold_zero_fails_large_diff() -> None:
assert _compare_float("0.001", "0") is False
def test_compare_float_non_numeric_returns_false() -> None:
assert _compare_float("abc", "3.14") is False
def test_compare_string_case_insensitive() -> None:
assert _compare_string("ALICE", "alice") is True
def test_compare_string_whitespace_normalized() -> None:
assert _compare_string(" Alice Bob ", "Alice Bob") is True
def test_compare_string_mismatch() -> None:
assert _compare_string("Alice", "Bob") is False
def test_compare_list_same_order() -> None:
assert _compare_list("a, b, c", "a, b, c") is True
def test_compare_list_different_order() -> None:
assert _compare_list("c, a, b", "a, b, c") is True
def test_compare_list_mismatch() -> None:
assert _compare_list("a, b, d", "a, b, c") is False
def test_compare_list_with_gold_rows() -> None:
gold_rows = [("a",), ("b",)]
assert _compare_list("a, b", "ignored", gold_rows=gold_rows) is True
def test_compare_list_gold_rows_none_fallback() -> None:
assert _compare_list("a, b", "a, b", gold_rows=None) is True
def test_compare_list_whitespace_and_case_normalized() -> None:
assert _compare_list(" Alice , Bob ", "alice,bob") is True
def test_episode_context_gold_rows_default() -> None:
context = _build_episode_context()
try:
assert context.gold_rows is None
finally:
context.db_connection.close()
def test_episode_context_gold_rows_set() -> None:
context = _build_episode_context(gold_rows=[(1,), (2,)])
try:
assert context.gold_rows == [(1,), (2,)]
finally:
context.db_connection.close()
def test_episode_context_gold_rows_empty_list() -> None:
context = _build_episode_context(gold_rows=[])
try:
assert context.gold_rows == []
finally:
context.db_connection.close()