"""Unit tests for type-aware answer verification helpers.""" import sqlite3 from sql_env.models import EpisodeContext, QuestionRecord from sql_env.server.verifier import ( _compare_float, _compare_integer, _compare_list, _compare_string, verify_answer, ) def _build_question_record() -> QuestionRecord: return QuestionRecord( question_id="q-1", question_text="How many students?", database_name="test_db", gold_sql="SELECT 1", gold_answer="1", answer_type="integer", difficulty="easy", tables_involved=["students"], ) def _build_episode_context(gold_rows: list[tuple] | None = None) -> EpisodeContext: return EpisodeContext( episode_id="ep-1", db_connection=sqlite3.connect(":memory:"), question_record=_build_question_record(), gold_rows=gold_rows, ) def test_verify_integer_exact_match() -> None: assert verify_answer(predicted="42", gold="42", answer_type="integer") is True def test_verify_float_within_tolerance() -> None: assert verify_answer(predicted="3.14", gold="3.15", answer_type="float") is True def test_verify_string_case_insensitive() -> None: assert verify_answer(predicted="Alice", gold="alice", answer_type="string") is True def test_verify_list_order_insensitive() -> None: assert verify_answer(predicted="a, b", gold="b, a", answer_type="list") is True def test_verify_none_type_falls_back_to_string() -> None: assert verify_answer(predicted=" hello ", gold="hello", answer_type=None) is True def test_verify_unknown_type_falls_back_to_string() -> None: assert verify_answer(predicted="foo", gold="foo", answer_type="table") is True def test_verify_empty_predicted_returns_false() -> None: assert verify_answer(predicted=" ", gold="42", answer_type="integer") is False def test_verify_none_like_predicted_returns_false() -> None: assert verify_answer(predicted="", gold="42", answer_type=None) is False def test_compare_integer_exact_match() -> None: assert _compare_integer("25", "25") is True def test_compare_integer_from_float_string() -> None: assert _compare_integer("25.0", "25") is True def test_compare_integer_mismatch() -> None: assert _compare_integer("24", "25") is False def test_compare_integer_non_numeric_returns_false() -> None: assert _compare_integer("abc", "25") is False def test_compare_integer_whitespace_only_returns_false() -> None: assert _compare_integer(" ", "25") is False def test_compare_integer_float_truncation() -> None: assert _compare_integer("25.9", "25") is True def test_compare_float_exact_match() -> None: assert _compare_float("3.14", "3.14") is True def test_compare_float_within_1pct_tolerance() -> None: assert _compare_float("100.5", "100.0") is True def test_compare_float_outside_1pct_tolerance() -> None: assert _compare_float("102.0", "100.0") is False def test_compare_float_boundary_exactly_1pct() -> None: assert _compare_float("101.0", "100.0") is True def test_compare_float_just_over_1pct() -> None: assert _compare_float("101.01", "100.0") is False def test_compare_float_gold_zero_uses_absolute_tolerance() -> None: assert _compare_float("0.0000000001", "0") is True def test_compare_float_gold_zero_fails_large_diff() -> None: assert _compare_float("0.001", "0") is False def test_compare_float_non_numeric_returns_false() -> None: assert _compare_float("abc", "3.14") is False def test_compare_string_case_insensitive() -> None: assert _compare_string("ALICE", "alice") is True def test_compare_string_whitespace_normalized() -> None: assert _compare_string(" Alice Bob ", "Alice Bob") is True def test_compare_string_mismatch() -> None: assert _compare_string("Alice", "Bob") is False def test_compare_list_same_order() -> None: assert _compare_list("a, b, c", "a, b, c") is True def test_compare_list_different_order() -> None: assert _compare_list("c, a, b", "a, b, c") is True def test_compare_list_mismatch() -> None: assert _compare_list("a, b, d", "a, b, c") is False def test_compare_list_with_gold_rows() -> None: gold_rows = [("a",), ("b",)] assert _compare_list("a, b", "ignored", gold_rows=gold_rows) is True def test_compare_list_gold_rows_none_fallback() -> None: assert _compare_list("a, b", "a, b", gold_rows=None) is True def test_compare_list_whitespace_and_case_normalized() -> None: assert _compare_list(" Alice , Bob ", "alice,bob") is True def test_episode_context_gold_rows_default() -> None: context = _build_episode_context() try: assert context.gold_rows is None finally: context.db_connection.close() def test_episode_context_gold_rows_set() -> None: context = _build_episode_context(gold_rows=[(1,), (2,)]) try: assert context.gold_rows == [(1,), (2,)] finally: context.db_connection.close() def test_episode_context_gold_rows_empty_list() -> None: context = _build_episode_context(gold_rows=[]) try: assert context.gold_rows == [] finally: context.db_connection.close()