| | """
|
| | Divide a document collection into N-word/token passage spans (with wrap-around for last passage).
|
| | """
|
| |
|
| | import os
|
| | import math
|
| | import ujson
|
| | import random
|
| |
|
| | from multiprocessing import Pool
|
| | from argparse import ArgumentParser
|
| | from colbert.utils.utils import print_message
|
| |
|
| | Format1 = 'docid,text'
|
| | Format2 = 'docid,text,title'
|
| | Format3 = 'docid,url,title,text'
|
| |
|
| |
|
| | def process_page(inp):
|
| | """
|
| | Wraps around if we split: make sure last passage isn't too short.
|
| | This is meant to be similar to the DPR preprocessing.
|
| | """
|
| |
|
| | (nwords, overlap, tokenizer), (title_idx, docid, title, url, content) = inp
|
| |
|
| | if tokenizer is None:
|
| | words = content.split()
|
| | else:
|
| | words = tokenizer.tokenize(content)
|
| |
|
| | words_ = (words + words) if len(words) > nwords else words
|
| | passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)]
|
| |
|
| | assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words))
|
| |
|
| | if tokenizer is None:
|
| | passages = [' '.join(psg) for psg in passages]
|
| | else:
|
| | passages = [' '.join(psg).replace(' ##', '') for psg in passages]
|
| |
|
| | if title_idx % 100000 == 0:
|
| | print("#> ", title_idx, '\t\t\t', title)
|
| |
|
| | for p in passages:
|
| | print("$$$ ", '\t\t', p)
|
| | print()
|
| |
|
| | print()
|
| | print()
|
| | print()
|
| |
|
| | return (docid, title, url, passages)
|
| |
|
| |
|
| | def main(args):
|
| | random.seed(12345)
|
| | print_message("#> Starting...")
|
| |
|
| | letter = 'w' if not args.use_wordpiece else 't'
|
| | output_path = f'{args.input}.{letter}{args.nwords}_{args.overlap}'
|
| | assert not os.path.exists(output_path)
|
| |
|
| | RawCollection = []
|
| | Collection = []
|
| |
|
| | NumIllFormattedLines = 0
|
| |
|
| | with open(args.input) as f:
|
| | for line_idx, line in enumerate(f):
|
| | if line_idx % (100*1000) == 0:
|
| | print(line_idx, end=' ')
|
| |
|
| | title, url = None, None
|
| |
|
| | try:
|
| | line = line.strip().split('\t')
|
| |
|
| | if args.format == Format1:
|
| | docid, doc = line
|
| | elif args.format == Format2:
|
| | docid, doc, title = line
|
| | elif args.format == Format3:
|
| | docid, url, title, doc = line
|
| |
|
| | RawCollection.append((line_idx, docid, title, url, doc))
|
| | except:
|
| | NumIllFormattedLines += 1
|
| |
|
| | if NumIllFormattedLines % 1000 == 0:
|
| | print(f'\n[{line_idx}] NumIllFormattedLines = {NumIllFormattedLines}\n')
|
| |
|
| | print()
|
| | print_message("# of documents is", len(RawCollection), '\n')
|
| |
|
| | p = Pool(args.nthreads)
|
| |
|
| | print_message("#> Starting parallel processing...")
|
| |
|
| | tokenizer = None
|
| | if args.use_wordpiece:
|
| | from transformers import BertTokenizerFast
|
| | tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
|
| |
|
| | process_page_params = [(args.nwords, args.overlap, tokenizer)] * len(RawCollection)
|
| | Collection = p.map(process_page, zip(process_page_params, RawCollection))
|
| |
|
| | print_message(f"#> Writing to {output_path} ...")
|
| | with open(output_path, 'w') as f:
|
| | line_idx = 1
|
| |
|
| | if args.format == Format1:
|
| | f.write('\t'.join(['id', 'text']) + '\n')
|
| | elif args.format == Format2:
|
| | f.write('\t'.join(['id', 'text', 'title']) + '\n')
|
| | elif args.format == Format3:
|
| | f.write('\t'.join(['id', 'text', 'title', 'docid']) + '\n')
|
| |
|
| | for docid, title, url, passages in Collection:
|
| | for passage in passages:
|
| | if args.format == Format1:
|
| | f.write('\t'.join([str(line_idx), passage]) + '\n')
|
| | elif args.format == Format2:
|
| | f.write('\t'.join([str(line_idx), passage, title]) + '\n')
|
| | elif args.format == Format3:
|
| | f.write('\t'.join([str(line_idx), passage, title, docid]) + '\n')
|
| |
|
| | line_idx += 1
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = ArgumentParser(description="docs2passages.")
|
| |
|
| |
|
| | parser.add_argument('--input', dest='input', required=True)
|
| | parser.add_argument('--format', dest='format', required=True, choices=[Format1, Format2, Format3])
|
| |
|
| |
|
| | parser.add_argument('--use-wordpiece', dest='use_wordpiece', default=False, action='store_true')
|
| | parser.add_argument('--nwords', dest='nwords', default=100, type=int)
|
| | parser.add_argument('--overlap', dest='overlap', default=0, type=int)
|
| |
|
| |
|
| | parser.add_argument('--nthreads', dest='nthreads', default=28, type=int)
|
| |
|
| | args = parser.parse_args()
|
| | assert args.nwords in range(50, 500)
|
| |
|
| | main(args)
|
| |
|