File size: 5,205 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Unit tests for type-aware answer verification helpers."""

import sqlite3

from sql_env.models import EpisodeContext, QuestionRecord

from sql_env.server.verifier import (
    _compare_float,
    _compare_integer,
    _compare_list,
    _compare_string,
    verify_answer,
)


def _build_question_record() -> QuestionRecord:
    return QuestionRecord(
        question_id="q-1",
        question_text="How many students?",
        database_name="test_db",
        gold_sql="SELECT 1",
        gold_answer="1",
        answer_type="integer",
        difficulty="easy",
        tables_involved=["students"],
    )


def _build_episode_context(gold_rows: list[tuple] | None = None) -> EpisodeContext:
    return EpisodeContext(
        episode_id="ep-1",
        db_connection=sqlite3.connect(":memory:"),
        question_record=_build_question_record(),
        gold_rows=gold_rows,
    )


def test_verify_integer_exact_match() -> None:
    assert verify_answer(predicted="42", gold="42", answer_type="integer") is True


def test_verify_float_within_tolerance() -> None:
    assert verify_answer(predicted="3.14", gold="3.15", answer_type="float") is True


def test_verify_string_case_insensitive() -> None:
    assert verify_answer(predicted="Alice", gold="alice", answer_type="string") is True


def test_verify_list_order_insensitive() -> None:
    assert verify_answer(predicted="a, b", gold="b, a", answer_type="list") is True


def test_verify_none_type_falls_back_to_string() -> None:
    assert verify_answer(predicted=" hello ", gold="hello", answer_type=None) is True


def test_verify_unknown_type_falls_back_to_string() -> None:
    assert verify_answer(predicted="foo", gold="foo", answer_type="table") is True


def test_verify_empty_predicted_returns_false() -> None:
    assert verify_answer(predicted="   ", gold="42", answer_type="integer") is False


def test_verify_none_like_predicted_returns_false() -> None:
    assert verify_answer(predicted="", gold="42", answer_type=None) is False


def test_compare_integer_exact_match() -> None:
    assert _compare_integer("25", "25") is True


def test_compare_integer_from_float_string() -> None:
    assert _compare_integer("25.0", "25") is True


def test_compare_integer_mismatch() -> None:
    assert _compare_integer("24", "25") is False


def test_compare_integer_non_numeric_returns_false() -> None:
    assert _compare_integer("abc", "25") is False


def test_compare_integer_whitespace_only_returns_false() -> None:
    assert _compare_integer(" ", "25") is False


def test_compare_integer_float_truncation() -> None:
    assert _compare_integer("25.9", "25") is True


def test_compare_float_exact_match() -> None:
    assert _compare_float("3.14", "3.14") is True


def test_compare_float_within_1pct_tolerance() -> None:
    assert _compare_float("100.5", "100.0") is True


def test_compare_float_outside_1pct_tolerance() -> None:
    assert _compare_float("102.0", "100.0") is False


def test_compare_float_boundary_exactly_1pct() -> None:
    assert _compare_float("101.0", "100.0") is True


def test_compare_float_just_over_1pct() -> None:
    assert _compare_float("101.01", "100.0") is False


def test_compare_float_gold_zero_uses_absolute_tolerance() -> None:
    assert _compare_float("0.0000000001", "0") is True


def test_compare_float_gold_zero_fails_large_diff() -> None:
    assert _compare_float("0.001", "0") is False


def test_compare_float_non_numeric_returns_false() -> None:
    assert _compare_float("abc", "3.14") is False


def test_compare_string_case_insensitive() -> None:
    assert _compare_string("ALICE", "alice") is True


def test_compare_string_whitespace_normalized() -> None:
    assert _compare_string("  Alice  Bob ", "Alice Bob") is True


def test_compare_string_mismatch() -> None:
    assert _compare_string("Alice", "Bob") is False


def test_compare_list_same_order() -> None:
    assert _compare_list("a, b, c", "a, b, c") is True


def test_compare_list_different_order() -> None:
    assert _compare_list("c, a, b", "a, b, c") is True


def test_compare_list_mismatch() -> None:
    assert _compare_list("a, b, d", "a, b, c") is False


def test_compare_list_with_gold_rows() -> None:
    gold_rows = [("a",), ("b",)]
    assert _compare_list("a, b", "ignored", gold_rows=gold_rows) is True


def test_compare_list_gold_rows_none_fallback() -> None:
    assert _compare_list("a, b", "a, b", gold_rows=None) is True


def test_compare_list_whitespace_and_case_normalized() -> None:
    assert _compare_list(" Alice ,  Bob ", "alice,bob") is True


def test_episode_context_gold_rows_default() -> None:
    context = _build_episode_context()
    try:
        assert context.gold_rows is None
    finally:
        context.db_connection.close()


def test_episode_context_gold_rows_set() -> None:
    context = _build_episode_context(gold_rows=[(1,), (2,)])
    try:
        assert context.gold_rows == [(1,), (2,)]
    finally:
        context.db_connection.close()


def test_episode_context_gold_rows_empty_list() -> None:
    context = _build_episode_context(gold_rows=[])
    try:
        assert context.gold_rows == []
    finally:
        context.db_connection.close()