File size: 2,826 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Answer verification for SQLEnv using type-aware comparisons."""

from __future__ import annotations

import re


def verify_answer(
    predicted: str,
    gold: str,
    answer_type: str | None = None,
    gold_rows: list[tuple] | None = None,
) -> bool:
    """Compare submitted and gold answers with type-aware dispatch."""
    predicted_text = "" if predicted is None else str(predicted)
    gold_text = "" if gold is None else str(gold)

    if not predicted_text.strip():
        return False

    match answer_type:
        case "integer":
            return _compare_integer(predicted_text, gold_text)
        case "float":
            return _compare_float(predicted_text, gold_text)
        case "list":
            return _compare_list(predicted_text, gold_text, gold_rows)
        case "string":
            return _compare_string(predicted_text, gold_text)
        case _:
            return _compare_string(predicted_text, gold_text)


def _normalize_value(value: str) -> str:
    """Normalize strings for case-insensitive, whitespace-stable comparison."""
    text = "" if value is None else str(value)
    return " ".join(text.strip().lower().split())


def _compare_integer(predicted: str, gold: str) -> bool:
    """Compare integer values after coercing with ``int(float(x))``."""
    try:
        return int(float(predicted)) == int(float(gold))
    except (TypeError, ValueError):
        return False


def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool:
    """Compare float values using a relative tolerance."""
    try:
        predicted_value = float(predicted)
        gold_value = float(gold)
    except (TypeError, ValueError):
        return False

    if gold_value == 0.0:
        return abs(predicted_value - gold_value) <= 1e-9

    return abs(predicted_value - gold_value) <= tolerance * abs(gold_value)


def _compare_string(predicted: str, gold: str) -> bool:
    """Compare two strings with normalization."""
    return _normalize_value(predicted) == _normalize_value(gold)


def _parse_list_values(raw: str) -> set[str]:
    """Parse comma/newline/pipe-separated values into a normalized set."""
    tokens = re.split(r"\s*(?:,|\n|\|)\s*", raw)
    normalized = {_normalize_value(token) for token in tokens if token.strip()}
    return normalized


def _compare_list(
    predicted: str,
    gold: str,
    gold_rows: list[tuple] | None = None,
) -> bool:
    """Compare list-like answers as order-insensitive sets."""
    predicted_set = _parse_list_values(predicted)

    if gold_rows is not None:
        gold_set = {
            _normalize_value(str(cell))
            for row in gold_rows
            for cell in row
            if str(cell).strip()
        }
    else:
        gold_set = _parse_list_values(gold)

    return predicted_set == gold_set