File size: 5,205 Bytes
5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """Unit tests for type-aware answer verification helpers."""
import sqlite3
from sql_env.models import EpisodeContext, QuestionRecord
from sql_env.server.verifier import (
_compare_float,
_compare_integer,
_compare_list,
_compare_string,
verify_answer,
)
def _build_question_record() -> QuestionRecord:
return QuestionRecord(
question_id="q-1",
question_text="How many students?",
database_name="test_db",
gold_sql="SELECT 1",
gold_answer="1",
answer_type="integer",
difficulty="easy",
tables_involved=["students"],
)
def _build_episode_context(gold_rows: list[tuple] | None = None) -> EpisodeContext:
return EpisodeContext(
episode_id="ep-1",
db_connection=sqlite3.connect(":memory:"),
question_record=_build_question_record(),
gold_rows=gold_rows,
)
def test_verify_integer_exact_match() -> None:
assert verify_answer(predicted="42", gold="42", answer_type="integer") is True
def test_verify_float_within_tolerance() -> None:
assert verify_answer(predicted="3.14", gold="3.15", answer_type="float") is True
def test_verify_string_case_insensitive() -> None:
assert verify_answer(predicted="Alice", gold="alice", answer_type="string") is True
def test_verify_list_order_insensitive() -> None:
assert verify_answer(predicted="a, b", gold="b, a", answer_type="list") is True
def test_verify_none_type_falls_back_to_string() -> None:
assert verify_answer(predicted=" hello ", gold="hello", answer_type=None) is True
def test_verify_unknown_type_falls_back_to_string() -> None:
assert verify_answer(predicted="foo", gold="foo", answer_type="table") is True
def test_verify_empty_predicted_returns_false() -> None:
assert verify_answer(predicted=" ", gold="42", answer_type="integer") is False
def test_verify_none_like_predicted_returns_false() -> None:
assert verify_answer(predicted="", gold="42", answer_type=None) is False
def test_compare_integer_exact_match() -> None:
assert _compare_integer("25", "25") is True
def test_compare_integer_from_float_string() -> None:
assert _compare_integer("25.0", "25") is True
def test_compare_integer_mismatch() -> None:
assert _compare_integer("24", "25") is False
def test_compare_integer_non_numeric_returns_false() -> None:
assert _compare_integer("abc", "25") is False
def test_compare_integer_whitespace_only_returns_false() -> None:
assert _compare_integer(" ", "25") is False
def test_compare_integer_float_truncation() -> None:
assert _compare_integer("25.9", "25") is True
def test_compare_float_exact_match() -> None:
assert _compare_float("3.14", "3.14") is True
def test_compare_float_within_1pct_tolerance() -> None:
assert _compare_float("100.5", "100.0") is True
def test_compare_float_outside_1pct_tolerance() -> None:
assert _compare_float("102.0", "100.0") is False
def test_compare_float_boundary_exactly_1pct() -> None:
assert _compare_float("101.0", "100.0") is True
def test_compare_float_just_over_1pct() -> None:
assert _compare_float("101.01", "100.0") is False
def test_compare_float_gold_zero_uses_absolute_tolerance() -> None:
assert _compare_float("0.0000000001", "0") is True
def test_compare_float_gold_zero_fails_large_diff() -> None:
assert _compare_float("0.001", "0") is False
def test_compare_float_non_numeric_returns_false() -> None:
assert _compare_float("abc", "3.14") is False
def test_compare_string_case_insensitive() -> None:
assert _compare_string("ALICE", "alice") is True
def test_compare_string_whitespace_normalized() -> None:
assert _compare_string(" Alice Bob ", "Alice Bob") is True
def test_compare_string_mismatch() -> None:
assert _compare_string("Alice", "Bob") is False
def test_compare_list_same_order() -> None:
assert _compare_list("a, b, c", "a, b, c") is True
def test_compare_list_different_order() -> None:
assert _compare_list("c, a, b", "a, b, c") is True
def test_compare_list_mismatch() -> None:
assert _compare_list("a, b, d", "a, b, c") is False
def test_compare_list_with_gold_rows() -> None:
gold_rows = [("a",), ("b",)]
assert _compare_list("a, b", "ignored", gold_rows=gold_rows) is True
def test_compare_list_gold_rows_none_fallback() -> None:
assert _compare_list("a, b", "a, b", gold_rows=None) is True
def test_compare_list_whitespace_and_case_normalized() -> None:
assert _compare_list(" Alice , Bob ", "alice,bob") is True
def test_episode_context_gold_rows_default() -> None:
context = _build_episode_context()
try:
assert context.gold_rows is None
finally:
context.db_connection.close()
def test_episode_context_gold_rows_set() -> None:
context = _build_episode_context(gold_rows=[(1,), (2,)])
try:
assert context.gold_rows == [(1,), (2,)]
finally:
context.db_connection.close()
def test_episode_context_gold_rows_empty_list() -> None:
context = _build_episode_context(gold_rows=[])
try:
assert context.gold_rows == []
finally:
context.db_connection.close()
|