File size: 3,620 Bytes
22a6915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Compare pooled random split vs grouped LOPO for XGBoost."""

import os
import sys

import numpy as np
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from data_preparation.prepare_dataset import get_default_split_config, get_numpy_splits, load_per_person
from models.xgboost.config import build_xgb_classifier, XGB_BASE_PARAMS

MODEL_NAME = "face_orientation"
OUT_PATH = os.path.join(_PROJECT_ROOT, "evaluation", "GROUPED_SPLIT_BENCHMARK.md")


def run_pooled_split():
    split_ratios, seed = get_default_split_config()
    splits, _, _, _ = get_numpy_splits(
        model_name=MODEL_NAME,
        split_ratios=split_ratios,
        seed=seed,
        scale=False,
    )
    model = build_xgb_classifier(seed, verbosity=0, early_stopping_rounds=30)
    model.fit(
        splits["X_train"],
        splits["y_train"],
        eval_set=[(splits["X_val"], splits["y_val"])],
        verbose=False,
    )
    probs = model.predict_proba(splits["X_test"])[:, 1]
    preds = (probs >= 0.5).astype(int)
    y = splits["y_test"]
    return {
        "accuracy": float(accuracy_score(y, preds)),
        "f1": float(f1_score(y, preds, average="weighted")),
        "auc": float(roc_auc_score(y, probs)),
    }


def run_grouped_lopo():
    by_person, _, _ = load_per_person(MODEL_NAME)
    persons = sorted(by_person.keys())
    scores = {"accuracy": [], "f1": [], "auc": []}

    _, seed = get_default_split_config()
    for held_out in persons:
        train_x = np.concatenate([by_person[p][0] for p in persons if p != held_out], axis=0)
        train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out], axis=0)
        test_x, test_y = by_person[held_out]

        model = build_xgb_classifier(seed, verbosity=0)
        model.fit(train_x, train_y, verbose=False)
        probs = model.predict_proba(test_x)[:, 1]
        preds = (probs >= 0.5).astype(int)

        scores["accuracy"].append(float(accuracy_score(test_y, preds)))
        scores["f1"].append(float(f1_score(test_y, preds, average="weighted")))
        scores["auc"].append(float(roc_auc_score(test_y, probs)))

    return {
        "accuracy": float(np.mean(scores["accuracy"])),
        "f1": float(np.mean(scores["f1"])),
        "auc": float(np.mean(scores["auc"])),
        "folds": len(persons),
    }


def write_report(pooled, grouped):
    lines = [
        "# Grouped vs pooled split benchmark",
        "",
        "This compares the same XGBoost config under two evaluation protocols.",
        "",
        f"Config: `{XGB_BASE_PARAMS}`",
        "",
        "| Protocol | Accuracy | F1 (weighted) | ROC-AUC |",
        "|----------|---------:|--------------:|--------:|",
        f"| Pooled random split (70/15/15) | {pooled['accuracy']:.4f} | {pooled['f1']:.4f} | {pooled['auc']:.4f} |",
        f"| Grouped LOPO ({grouped['folds']} folds) | {grouped['accuracy']:.4f} | {grouped['f1']:.4f} | {grouped['auc']:.4f} |",
        "",
        "Use grouped LOPO as the primary generalisation metric when reporting model quality.",
        "",
    ]

    with open(OUT_PATH, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))
    print(f"[LOG] Wrote {OUT_PATH}")


def main():
    pooled = run_pooled_split()
    grouped = run_grouped_lopo()
    write_report(pooled, grouped)
    print(
        "[DONE] pooled_f1={:.4f} grouped_f1={:.4f}".format(
            pooled["f1"], grouped["f1"]
        )
    )


if __name__ == "__main__":
    main()