| | |
| | |
| |
|
| | import os |
| | import glob |
| | import json |
| | import re |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | BASE_DIR = "/pfs/lichenyi/work/evaluation" |
| |
|
| | def collect_accuracies(base_dir: str): |
| | """ |
| | 从 base_dir 下面的 valid_score_in_*.json 和 valid_score_ood_*.json 中 |
| | 读取 summary.accuracy,返回两个 dict: |
| | in_acc[step] = accuracy |
| | ood_acc[step] = accuracy |
| | """ |
| | pattern = os.path.join(base_dir, "valid_score_*.json") |
| | files = glob.glob(pattern) |
| |
|
| | in_acc = {} |
| | ood_acc = {} |
| |
|
| | |
| | regex = re.compile(r"valid_score_(in|ood)_(\d+)\.json") |
| |
|
| | for path in sorted(files): |
| | fname = os.path.basename(path) |
| | m = regex.match(fname) |
| | if not m: |
| | continue |
| |
|
| | split = m.group(1) |
| | step = int(m.group(2)) |
| |
|
| | with open(path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| |
|
| | acc = data.get("summary", {}).get("accuracy", None) |
| | if acc is None: |
| | continue |
| |
|
| | if split == "in": |
| | in_acc[step] = acc |
| | else: |
| | ood_acc[step] = acc |
| |
|
| | return in_acc, ood_acc |
| |
|
| |
|
| | def plot_accuracies(in_acc, ood_acc, out_path="valid_accuracy.png"): |
| | """ |
| | 根据 in_acc 和 ood_acc 画图并保存为 out_path。 |
| | in_acc / ood_acc: dict[int, float] |
| | """ |
| | plt.figure(figsize=(8, 5)) |
| |
|
| | |
| | if in_acc: |
| | steps_in = sorted(in_acc.keys()) |
| | vals_in = [in_acc[s] for s in steps_in] |
| | plt.plot(steps_in, vals_in, marker="o", label="in (ID)") |
| |
|
| | |
| | if ood_acc: |
| | steps_ood = sorted(ood_acc.keys()) |
| | vals_ood = [ood_acc[s] for s in steps_ood] |
| | plt.plot(steps_ood, vals_ood, marker="s", linestyle="--", label="ood (OOD)") |
| |
|
| | plt.xlabel("checkpoint / step") |
| | plt.ylabel("accuracy") |
| | plt.title("Validation Accuracy (in vs ood)") |
| | plt.grid(True, linestyle=":") |
| | plt.legend() |
| | plt.tight_layout() |
| | plt.savefig(out_path, dpi=300) |
| | |
| | |
| |
|
| |
|
| | def main(): |
| | in_acc, ood_acc = collect_accuracies(BASE_DIR) |
| | print("in-domain checkpoints and accuracies:", in_acc) |
| | print("ood checkpoints and accuracies:", ood_acc) |
| | plot_accuracies(in_acc, ood_acc, out_path=os.path.join(BASE_DIR, "valid_accuracy.png")) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|