File size: 4,138 Bytes
cb20efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# std lib
import os
from pathlib import Path

# 3rd party imports
import pandas as pd

# local imports
from huggingface_hub import snapshot_download

DATA_DIR = Path(__file__).resolve().parent / "data"


def get_full_gaia_level1_data():
    """
    Download the full GAIA level 1 data and save it locally in data/.
    """
    os.makedirs(DATA_DIR, exist_ok=True)

    # Fallback to common HF token names; token can be optional for public datasets.
    token = os.getenv("HF_FINAL_ASSIGNMENT_DRAFT") or os.getenv("HF_TOKEN")

    kwargs = {
        "repo_id": "gaia-benchmark/GAIA",
        "repo_type": "dataset",
        "local_dir": DATA_DIR,
    }
    if token:
        kwargs["token"] = token

    snapshot_download(**kwargs)


def get_file_from_gaia_level1_data(task_id: str):
    """
    Given a GAIA level 1 task ID, return the corresponding file path in the local directory data/gaia_level1_data.
    les fichiers servant au test HF sont dans data/gaia_level1_data/2023_level1/validation/
    """
    # lire le contenu du dossier data/gaia_level1_data/2023_level1/validation et afficher les chemins des fichiers présents
    validation_dir = DATA_DIR / "2023" / "validation"
    for root, _, files in os.walk(validation_dir):
        for file in files:
            if task_id in file:
                return os.path.join(root, file)
    return None


def get_question(task_id: str) -> str:
    """
    Given a GAIA level 1 task ID, return the corresponding question.
    """
    metadata_file = os.path.join(
        DATA_DIR, "2023", "validation", "metadata.level1.parquet"
    )
    if os.path.exists(metadata_file):
        metadata_df = pd.read_parquet(metadata_file)
        return metadata_df.loc[metadata_df["task_id"] == task_id, "Question"].values[0]
    else:
        print(f"Metadata file not found: {metadata_file}")
        return ""


def ensure_validation_data(base_dir: Path):
    """
    Ensure GAIA 2023 level1 validation files are present.
    Returns (ok: bool, error_message: str | None).
    """
    validation_dir = base_dir / "data" / "2023" / "validation"
    metadata_file = validation_dir / "metadata.level1.parquet"
    need_download_reason = None

    # Fast fail: missing or empty validation directory
    if not validation_dir.exists() or not any(validation_dir.iterdir()):
        need_download_reason = f"Validation data not found in {validation_dir}"
    # Metadata is required to validate expected files
    elif not metadata_file.exists():
        need_download_reason = (
            f"Metadata file not found: {metadata_file}. Cannot verify expected files"
        )
    else:
        try:
            # Read only the needed column for speed/memory
            metadata_df = pd.read_parquet(metadata_file, columns=["file_name"])
            expected_files = {
                str(name) for name in metadata_df["file_name"].dropna().unique()
            }

            present_files = {p.name for p in validation_dir.iterdir() if p.is_file()}
            missing_files = expected_files - present_files

            if missing_files:
                need_download_reason = (
                    f"Missing {len(missing_files)} expected validation files"
                )
        except Exception as e:
            need_download_reason = (
                f"Error reading metadata ({metadata_file}): {e}. "
                "Cannot verify expected files"
            )

    if need_download_reason is not None:
        print(f"{need_download_reason}. Downloading full GAIA level 1 data...")
        try:
            get_full_gaia_level1_data()
            print("Data download completed.")
        except Exception as e:
            error_message = f"Error downloading GAIA level 1 data: {e}"
            print(error_message)
            return False, error_message
    else:
        print("All expected validation files are present. Skipping data download.")

    return True, None


if __name__ == "__main__":
    # get_full_gaia_level1_data()
    print(get_file_from_gaia_level1_data("cca530fc-4052-43b2-b130-b30968d8aa44"))
    # print(get_question("cca530fc-4052-43b2-b130-b30968d8aa44"))