Quentin Mace commited on
Commit
0f22e6b
·
1 Parent(s): 15bd321

initial pipeline

Browse files
Files changed (1) hide show
  1. data/pipeline_handler.py +232 -0
data/pipeline_handler.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from typing import Dict, List, Optional
4
+
5
+ import pandas as pd
6
+
7
+
8
+ class PipelineHandler:
9
+ """Handler for ViDoRe v3 pipeline evaluation results from GitHub."""
10
+
11
+ def __init__(self):
12
+ self.pipeline_infos = {}
13
+ self.github_base_url = "https://raw.githubusercontent.com/illuin-tech/vidore-benchmark/vidore_v3_pipeline/results"
14
+ self.available_datasets = []
15
+ self.available_languages = ["overall"] # Default languages available
16
+
17
+ # Setup GitHub authentication if token is available
18
+ self.github_token = os.environ.get("GITHUB_TOKEN")
19
+ self.headers = {}
20
+ if self.github_token:
21
+ self.headers["Authorization"] = f"token {self.github_token}"
22
+ print("GitHub token detected - using authenticated requests")
23
+
24
+ def get_pipeline_folders_from_github(self) -> List[str]:
25
+ """Get list of pipeline folders from GitHub API."""
26
+ api_url = "https://api.github.com/repos/illuin-tech/vidore-benchmark/contents/results?ref=vidore_v3_pipeline"
27
+
28
+ try:
29
+ response = requests.get(api_url, headers=self.headers)
30
+ response.raise_for_status()
31
+ contents = response.json()
32
+
33
+ # Filter for directories only
34
+ folders = [item["name"] for item in contents if item["type"] == "dir"]
35
+ return sorted(folders)
36
+ except Exception as e:
37
+ print(f"Error fetching pipeline folders from GitHub: {e}")
38
+ return []
39
+
40
+ def get_dataset_files_from_github(self, pipeline_name: str) -> List[str]:
41
+ """Get list of dataset JSON files for a specific pipeline from GitHub API."""
42
+ api_url = f"https://api.github.com/repos/illuin-tech/vidore-benchmark/contents/results/{pipeline_name}?ref=vidore_v3_pipeline"
43
+
44
+ try:
45
+ response = requests.get(api_url, headers=self.headers)
46
+ response.raise_for_status()
47
+ contents = response.json()
48
+
49
+ # Filter for JSON files that start with 'vidore_v3'
50
+ files = [
51
+ item["name"]
52
+ for item in contents
53
+ if item["type"] == "file"
54
+ and item["name"].startswith("vidore_v3")
55
+ and item["name"].endswith(".json")
56
+ ]
57
+ return sorted(files)
58
+ except Exception as e:
59
+ print(f"Error fetching dataset files from {pipeline_name}: {e}")
60
+ return []
61
+
62
+ def fetch_json_from_github(self, pipeline_name: str, filename: str) -> Optional[Dict]:
63
+ """Fetch a JSON file from GitHub raw content."""
64
+ url = f"{self.github_base_url}/{pipeline_name}/{filename}"
65
+
66
+ try:
67
+ response = requests.get(url, headers=self.headers)
68
+ response.raise_for_status()
69
+ return response.json()
70
+ except Exception as e:
71
+ print(f"Error fetching {filename} from {pipeline_name}: {e}")
72
+ return None
73
+
74
+ def get_pipeline_data(self):
75
+ """Fetch all pipeline data from GitHub."""
76
+ pipeline_folders = self.get_pipeline_folders_from_github()
77
+ datasets_set = set()
78
+ languages_set = set(["overall"])
79
+
80
+ for pipeline_name in pipeline_folders:
81
+ # Get all dataset files for this pipeline
82
+ dataset_files = self.get_dataset_files_from_github(pipeline_name)
83
+
84
+ if not dataset_files:
85
+ continue
86
+
87
+ pipeline_data = {}
88
+ for filename in dataset_files:
89
+ results = self.fetch_json_from_github(pipeline_name, filename)
90
+ if results:
91
+ # Extract dataset name from filename (remove vidore_v3_ prefix and .json suffix)
92
+ dataset_name = filename.replace("vidore_v3_", "").replace(".json", "")
93
+ datasets_set.add(dataset_name)
94
+ pipeline_data[dataset_name] = results
95
+
96
+ # Collect available languages
97
+ if "aggregated_metrics" in results and "by_language" in results["aggregated_metrics"]:
98
+ languages_set.update(results["aggregated_metrics"]["by_language"].keys())
99
+
100
+ if pipeline_data:
101
+ self.pipeline_infos[pipeline_name] = pipeline_data
102
+
103
+ self.available_datasets = sorted(list(datasets_set))
104
+ self.available_languages = sorted(list(languages_set))
105
+
106
+ def calculate_cost_metric(self, pipeline_datasets: Dict) -> float:
107
+ """
108
+ Calculate a compute cost metric based on retrieval time across all datasets.
109
+ Returns cost in arbitrary units (could be refined based on actual compute costs).
110
+ """
111
+ total_time_s = 0
112
+
113
+ for dataset_name, dataset_data in pipeline_datasets.items():
114
+ if "aggregated_metrics" not in dataset_data:
115
+ continue
116
+
117
+ timing = dataset_data["aggregated_metrics"].get("timing", {})
118
+ total_time_ms = timing.get("total_retrieval_time_milliseconds", 0)
119
+ total_time_s += total_time_ms / 1000.0
120
+
121
+ # Simple cost model: assume $0.01 per second of compute (adjustable)
122
+ cost = total_time_s * 0.01
123
+
124
+ return round(cost, 4)
125
+
126
+ def extract_dataset_metrics(
127
+ self, pipeline_datasets: Dict, metric: str = "ndcg_cut_5", language: str = "overall"
128
+ ) -> Dict[str, float]:
129
+ """
130
+ Extract metrics for individual datasets from the aggregated results.
131
+
132
+ Args:
133
+ pipeline_datasets: Dictionary mapping dataset names to their data
134
+ metric: The metric to extract (e.g., 'ndcg_at_5')
135
+ language: The language to filter by ('overall' for all languages, or specific language)
136
+
137
+ Returns:
138
+ Dictionary mapping dataset names to metric values
139
+ """
140
+ # Map metric names from UI format to API format
141
+ metric_mapping = {
142
+ "ndcg_at_1": "ndcg_cut_5", # Using cut_5 as approximation
143
+ "ndcg_at_5": "ndcg_cut_5",
144
+ "ndcg_at_10": "ndcg_cut_10",
145
+ "ndcg_at_100": "ndcg_cut_100",
146
+ "recall_at_1": "recall_5",
147
+ "recall_at_5": "recall_5",
148
+ "recall_at_10": "recall_10",
149
+ "recall_at_100": "recall_100",
150
+ }
151
+
152
+ actual_metric = metric_mapping.get(metric, metric)
153
+ dataset_metrics = {}
154
+
155
+ for dataset_name, dataset_data in pipeline_datasets.items():
156
+ if "aggregated_metrics" not in dataset_data:
157
+ continue
158
+
159
+ aggregated = dataset_data["aggregated_metrics"]
160
+
161
+ # Get metrics for the specified language
162
+ if language == "overall":
163
+ metrics_data = aggregated.get("overall", {})
164
+ else:
165
+ metrics_data = aggregated.get("by_language", {}).get(language, {})
166
+
167
+ if metrics_data:
168
+ # Format dataset name for display
169
+ display_name = dataset_name.replace("_", " ").title()
170
+ dataset_metrics[display_name] = metrics_data.get(actual_metric, 0.0)
171
+
172
+ return dataset_metrics
173
+
174
+ def render_df(self, metric: str = "ndcg_at_5", language: str = "overall") -> pd.DataFrame:
175
+ """
176
+ Render a DataFrame with pipeline results.
177
+
178
+ Args:
179
+ metric: The metric to display (e.g., 'ndcg_at_5')
180
+ language: The language to filter by ('overall' for all languages, or specific language)
181
+
182
+ Returns:
183
+ DataFrame with columns: Pipeline Name, Compute Cost, Timing metrics, Dataset metrics
184
+ """
185
+ pipeline_res = {}
186
+
187
+ for pipeline_name, pipeline_datasets in self.pipeline_infos.items():
188
+ row_data = {}
189
+
190
+ # Aggregate time metrics across all datasets
191
+ total_time_ms = 0
192
+ total_queries = 0
193
+
194
+ for dataset_name, dataset_data in pipeline_datasets.items():
195
+ if "aggregated_metrics" in dataset_data:
196
+ timing = dataset_data["aggregated_metrics"].get("timing", {})
197
+ total_time_ms += timing.get("total_retrieval_time_milliseconds", 0)
198
+ total_queries += timing.get("num_queries", 0)
199
+
200
+ if total_queries > 0:
201
+ if total_time_ms > 0:
202
+ row_data["Queries per Second"] = round(
203
+ total_queries / (total_time_ms / 1000.0), 2
204
+ )
205
+ else:
206
+ row_data["Queries per Second"] = 0
207
+ else:
208
+ row_data["Queries per Second"] = -1
209
+
210
+ # Add dataset metrics
211
+ dataset_metrics = self.extract_dataset_metrics(pipeline_datasets, metric, language)
212
+ row_data.update(dataset_metrics)
213
+
214
+ # Calculate average across datasets if there are multiple
215
+ if dataset_metrics:
216
+ row_data["Average"] = round(sum(dataset_metrics.values()) / len(dataset_metrics), 4)
217
+
218
+ pipeline_res[pipeline_name] = row_data
219
+
220
+ if pipeline_res:
221
+ df = pd.DataFrame(pipeline_res).T
222
+ # Reorder columns to have Average right after timing metrics
223
+ cols = list(df.columns)
224
+ if "Average" in cols:
225
+ cols.remove("Average")
226
+ # Insert Average after Queries per Second
227
+ insert_pos = cols.index("Queries per Second") + 1 if "Queries per Second" in cols else 2
228
+ cols.insert(insert_pos, "Average")
229
+ df = df[cols]
230
+ return df
231
+
232
+ return pd.DataFrame()