| | """
|
| | Divide a query set into two.
|
| | """
|
| |
|
| | import os
|
| | import math
|
| | import ujson
|
| | import random
|
| |
|
| | from argparse import ArgumentParser
|
| | from collections import OrderedDict
|
| | from colbert.utils.utils import print_message
|
| |
|
| |
|
| | def main(args):
|
| | random.seed(12345)
|
| |
|
| | """
|
| | Load the queries
|
| | """
|
| | Queries = OrderedDict()
|
| |
|
| | print_message(f"#> Loading queries from {args.input}..")
|
| | with open(args.input) as f:
|
| | for line in f:
|
| | qid, query = line.strip().split('\t')
|
| |
|
| | assert qid not in Queries
|
| | Queries[qid] = query
|
| |
|
| | """
|
| | Apply the splitting
|
| | """
|
| | size_a = len(Queries) - args.holdout
|
| | size_b = args.holdout
|
| | size_a, size_b = max(size_a, size_b), min(size_a, size_b)
|
| |
|
| | assert size_a > 0 and size_b > 0, (len(Queries), size_a, size_b)
|
| |
|
| | print_message(f"#> Deterministically splitting the queries into ({size_a}, {size_b})-sized splits.")
|
| |
|
| | keys = list(Queries.keys())
|
| | sample_b_indices = sorted(list(random.sample(range(len(keys)), size_b)))
|
| | sample_a_indices = sorted(list(set.difference(set(list(range(len(keys)))), set(sample_b_indices))))
|
| |
|
| | assert len(sample_a_indices) == size_a
|
| | assert len(sample_b_indices) == size_b
|
| |
|
| | sample_a = [keys[idx] for idx in sample_a_indices]
|
| | sample_b = [keys[idx] for idx in sample_b_indices]
|
| |
|
| | """
|
| | Write the output
|
| | """
|
| |
|
| | output_path_a = f'{args.input}.a'
|
| | output_path_b = f'{args.input}.b'
|
| |
|
| | assert not os.path.exists(output_path_a), output_path_a
|
| | assert not os.path.exists(output_path_b), output_path_b
|
| |
|
| | print_message(f"#> Writing the splits out to {output_path_a} and {output_path_b} ...")
|
| |
|
| | for output_path, sample in [(output_path_a, sample_a), (output_path_b, sample_b)]:
|
| | with open(output_path, 'w') as f:
|
| | for qid in sample:
|
| | query = Queries[qid]
|
| | line = '\t'.join([qid, query]) + '\n'
|
| | f.write(line)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = ArgumentParser(description="queries_split.")
|
| |
|
| |
|
| | parser.add_argument('--input', dest='input', required=True)
|
| | parser.add_argument('--holdout', dest='holdout', required=True, type=int)
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | main(args)
|
| |
|