| | import json |
| | import sys |
| | import random |
| | from collections import defaultdict |
| |
|
| | def collect_dataset_info(file_path): |
| | """收集数据集信息,包括每个数据集的行号列表和首次出现顺序""" |
| | dataset_lines = defaultdict(list) |
| | order = [] |
| | seen = set() |
| | |
| | with open(file_path, 'r') as f: |
| | for line_num, line in enumerate(f, 1): |
| | try: |
| | data = json.loads(line.strip()) |
| | custom_id = data['custom_id'] |
| | dataset = custom_id.split('-')[0] |
| | |
| | if dataset not in seen: |
| | order.append(dataset) |
| | seen.add(dataset) |
| | |
| | dataset_lines[dataset].append(line_num) |
| | except json.JSONDecodeError: |
| | print(f"Error: Invalid JSON at line {line_num}", file=sys.stderr) |
| | except KeyError: |
| | print(f"Error: Missing 'custom_id' at line {line_num}", file=sys.stderr) |
| | except IndexError: |
| | print(f"Error: Invalid custom_id format at line {line_num}", file=sys.stderr) |
| | |
| | return dataset_lines, order |
| |
|
| | def main(): |
| | if len(sys.argv) != 4: |
| | print("Usage: python sample_datasets.py <input.jsonl> <output.jsonl> <N>") |
| | sys.exit(1) |
| | |
| | input_file = sys.argv[1] |
| | output_file = sys.argv[2] |
| | try: |
| | N = int(sys.argv[3]) |
| | except ValueError: |
| | print("Error: N must be an integer.") |
| | sys.exit(1) |
| | |
| | |
| | dataset_info, dataset_order = collect_dataset_info(input_file) |
| | k = len(dataset_info) |
| | |
| | if k == 0: |
| | print("Error: No datasets found in the input file.") |
| | sys.exit(1) |
| | |
| | |
| | for dataset, lines in dataset_info.items(): |
| | if len(lines) < 5: |
| | print(f"Error: Dataset '{dataset}' has fewer than 5 samples.") |
| | sys.exit(1) |
| | |
| | total_samples = sum(len(lines) for lines in dataset_info.values()) |
| | min_samples = 5 * k |
| | |
| | if N < min_samples or N > total_samples: |
| | print(f"Error: N must be between {min_samples} and {total_samples}.") |
| | sys.exit(1) |
| | |
| | |
| | available = {dataset: len(lines) - 5 for dataset, lines in dataset_info.items()} |
| | total_available = sum(available.values()) |
| | R = N - 5 * k |
| | |
| | if R > total_available: |
| | print(f"Error: Cannot allocate {R} samples from available {total_available}.") |
| | sys.exit(1) |
| | |
| | |
| | allocations = [] |
| | sum_avail = total_available if total_available != 0 else 1 |
| | |
| | for dataset in dataset_order: |
| | avail = available[dataset] |
| | alloc_float = R * avail / sum_avail |
| | allocations.append(alloc_float) |
| | |
| | integer_part = [int(alloc) for alloc in allocations] |
| | remainders = [alloc - int_part for alloc, int_part in zip(allocations, integer_part)] |
| | remainder_total = R - sum(integer_part) |
| | |
| | |
| | remainder_indices = sorted(enumerate(remainders), key=lambda x: (-x[1], x[0])) |
| | for i in range(remainder_total): |
| | idx = remainder_indices[i][0] |
| | integer_part[idx] += 1 |
| | |
| | |
| | sample_counts = {} |
| | for i, dataset in enumerate(dataset_order): |
| | alloc = integer_part[i] |
| | if alloc > available[dataset]: |
| | print(f"Error: Allocation for dataset '{dataset}' exceeds available samples.") |
| | sys.exit(1) |
| | sample_counts[dataset] = 5 + alloc |
| | |
| | |
| | print("\nSampling Distribution:") |
| | total_sampled = 0 |
| | for dataset in dataset_order: |
| | count = sample_counts[dataset] |
| | total_sampled += count |
| | print(f" - {dataset}: {count} samples") |
| | print(f"Total samples: {total_sampled} (target: {N})") |
| | |
| | |
| | if total_sampled != N: |
| | print(f"Error: Total sampled count mismatch ({total_sampled} vs {N})") |
| | sys.exit(1) |
| | |
| | |
| | selected_lines = [] |
| | for dataset in dataset_order: |
| | lines = dataset_info[dataset] |
| | count = sample_counts[dataset] |
| | selected = random.sample(lines, count) |
| | selected_lines.extend(selected) |
| | |
| | selected_lines.sort() |
| | |
| | |
| | current_idx = 0 |
| | total_selected = len(selected_lines) |
| | |
| | with open(input_file, 'r') as infile, open(output_file, 'w') as outfile: |
| | for line_num, line in enumerate(infile, 1): |
| | if current_idx >= total_selected: |
| | break |
| | if line_num == selected_lines[current_idx]: |
| | outfile.write(line) |
| | current_idx += 1 |
| | |
| | print(f"\nSuccessfully sampled {N} records to {output_file}.") |
| |
|
| | if __name__ == "__main__": |
| | main() |