| | import os
|
| | import tqdm
|
| | import torch
|
| | import datetime
|
| | import itertools
|
| |
|
| | from multiprocessing import Pool
|
| | from collections import OrderedDict, defaultdict
|
| |
|
| |
|
| | def print_message(*s, condition=True):
|
| | s = ' '.join([str(x) for x in s])
|
| | msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)
|
| |
|
| | if condition:
|
| | print(msg, flush=True)
|
| |
|
| | return msg
|
| |
|
| |
|
| | def timestamp():
|
| | format_str = "%Y-%m-%d_%H.%M.%S"
|
| | result = datetime.datetime.now().strftime(format_str)
|
| | return result
|
| |
|
| |
|
| | def file_tqdm(file):
|
| | print(f"#> Reading {file.name}")
|
| |
|
| | with tqdm.tqdm(total=os.path.getsize(file.name) / 1024.0 / 1024.0, unit="MiB") as pbar:
|
| | for line in file:
|
| | yield line
|
| | pbar.update(len(line) / 1024.0 / 1024.0)
|
| |
|
| | pbar.close()
|
| |
|
| |
|
| | def save_checkpoint(path, epoch_idx, mb_idx, model, optimizer, arguments=None):
|
| | print(f"#> Saving a checkpoint to {path} ..")
|
| |
|
| | if hasattr(model, 'module'):
|
| | model = model.module
|
| |
|
| | checkpoint = {}
|
| | checkpoint['epoch'] = epoch_idx
|
| | checkpoint['batch'] = mb_idx
|
| | checkpoint['model_state_dict'] = model.state_dict()
|
| | checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
| | checkpoint['arguments'] = arguments
|
| |
|
| | torch.save(checkpoint, path)
|
| |
|
| |
|
| | def load_checkpoint(path, model, optimizer=None, do_print=True):
|
| | if do_print:
|
| | print_message("#> Loading checkpoint", path, "..")
|
| |
|
| | if path.startswith("http:") or path.startswith("https:"):
|
| | checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu')
|
| | else:
|
| | checkpoint = torch.load(path, map_location='cpu')
|
| |
|
| | state_dict = checkpoint['model_state_dict']
|
| | new_state_dict = OrderedDict()
|
| | for k, v in state_dict.items():
|
| | name = k
|
| | if k[:7] == 'module.':
|
| | name = k[7:]
|
| | new_state_dict[name] = v
|
| |
|
| | checkpoint['model_state_dict'] = new_state_dict
|
| |
|
| | try:
|
| | model.load_state_dict(checkpoint['model_state_dict'])
|
| | except:
|
| | print_message("[WARNING] Loading checkpoint with strict=False")
|
| | model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| |
|
| | if optimizer:
|
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| |
|
| | if do_print:
|
| | print_message("#> checkpoint['epoch'] =", checkpoint['epoch'])
|
| | print_message("#> checkpoint['batch'] =", checkpoint['batch'])
|
| |
|
| | return checkpoint
|
| |
|
| |
|
| | def create_directory(path):
|
| | if os.path.exists(path):
|
| | print('\n')
|
| | print_message("#> Note: Output directory", path, 'already exists\n\n')
|
| | else:
|
| | print('\n')
|
| | print_message("#> Creating directory", path, '\n\n')
|
| | os.makedirs(path)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def f7(seq):
|
| | """
|
| | Source: https://stackoverflow.com/a/480227/1493011
|
| | """
|
| |
|
| | seen = set()
|
| | return [x for x in seq if not (x in seen or seen.add(x))]
|
| |
|
| |
|
| | def batch(group, bsize, provide_offset=False):
|
| | offset = 0
|
| | while offset < len(group):
|
| | L = group[offset: offset + bsize]
|
| | yield ((offset, L) if provide_offset else L)
|
| | offset += len(L)
|
| | return
|
| |
|
| |
|
| | class dotdict(dict):
|
| | """
|
| | dot.notation access to dictionary attributes
|
| | Credit: derek73 @ https://stackoverflow.com/questions/2352181
|
| | """
|
| | __getattr__ = dict.__getitem__
|
| | __setattr__ = dict.__setitem__
|
| | __delattr__ = dict.__delitem__
|
| |
|
| |
|
| | def flatten(L):
|
| | return [x for y in L for x in y]
|
| |
|
| |
|
| | def zipstar(L, lazy=False):
|
| | """
|
| | A much faster A, B, C = zip(*[(a, b, c), (a, b, c), ...])
|
| | May return lists or tuples.
|
| | """
|
| |
|
| | if len(L) == 0:
|
| | return L
|
| |
|
| | width = len(L[0])
|
| |
|
| | if width < 100:
|
| | return [[elem[idx] for elem in L] for idx in range(width)]
|
| |
|
| | L = zip(*L)
|
| |
|
| | return L if lazy else list(L)
|
| |
|
| |
|
| | def zip_first(L1, L2):
|
| | length = len(L1) if type(L1) in [tuple, list] else None
|
| |
|
| | L3 = list(zip(L1, L2))
|
| |
|
| | assert length in [None, len(L3)], "zip_first() failure: length differs!"
|
| |
|
| | return L3
|
| |
|
| |
|
| | def int_or_float(val):
|
| | if '.' in val:
|
| | return float(val)
|
| |
|
| | return int(val)
|
| |
|
| | def load_ranking(path, types=None, lazy=False):
|
| | print_message(f"#> Loading the ranked lists from {path} ..")
|
| |
|
| | try:
|
| | lists = torch.load(path)
|
| | lists = zipstar([l.tolist() for l in tqdm.tqdm(lists)], lazy=lazy)
|
| | except:
|
| | if types is None:
|
| | types = itertools.cycle([int_or_float])
|
| |
|
| | with open(path) as f:
|
| | lists = [[typ(x) for typ, x in zip_first(types, line.strip().split('\t'))]
|
| | for line in file_tqdm(f)]
|
| |
|
| | return lists
|
| |
|
| |
|
| | def save_ranking(ranking, path):
|
| | lists = zipstar(ranking)
|
| | lists = [torch.tensor(l) for l in lists]
|
| |
|
| | torch.save(lists, path)
|
| |
|
| | return lists
|
| |
|
| |
|
| | def groupby_first_item(lst):
|
| | groups = defaultdict(list)
|
| |
|
| | for first, *rest in lst:
|
| | rest = rest[0] if len(rest) == 1 else rest
|
| | groups[first].append(rest)
|
| |
|
| | return groups
|
| |
|
| |
|
| | def process_grouped_by_first_item(lst):
|
| | """
|
| | Requires items in list to already be grouped by first item.
|
| | """
|
| |
|
| | groups = defaultdict(list)
|
| |
|
| | started = False
|
| | last_group = None
|
| |
|
| | for first, *rest in lst:
|
| | rest = rest[0] if len(rest) == 1 else rest
|
| |
|
| | if started and first != last_group:
|
| | yield (last_group, groups[last_group])
|
| | assert first not in groups, f"{first} seen earlier --- violates precondition."
|
| |
|
| | groups[first].append(rest)
|
| |
|
| | last_group = first
|
| | started = True
|
| |
|
| | return groups
|
| |
|
| |
|
| | def grouper(iterable, n, fillvalue=None):
|
| | """
|
| | Collect data into fixed-length chunks or blocks
|
| | Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
|
| | Source: https://docs.python.org/3/library/itertools.html#itertools-recipes
|
| | """
|
| |
|
| | args = [iter(iterable)] * n
|
| | return itertools.zip_longest(*args, fillvalue=fillvalue)
|
| |
|
| |
|
| |
|
| | class NullContextManager(object):
|
| | def __init__(self, dummy_resource=None):
|
| | self.dummy_resource = dummy_resource
|
| | def __enter__(self):
|
| | return self.dummy_resource
|
| | def __exit__(self, *args):
|
| | pass
|
| |
|
| |
|
| | def load_batch_backgrounds(args, qids):
|
| | if args.qid2backgrounds is None:
|
| | return None
|
| |
|
| | qbackgrounds = []
|
| |
|
| | for qid in qids:
|
| | back = args.qid2backgrounds[qid]
|
| |
|
| | if len(back) and type(back[0]) == int:
|
| | x = [args.collection[pid] for pid in back]
|
| | else:
|
| | x = [args.collectionX.get(pid, '') for pid in back]
|
| |
|
| | x = ' [SEP] '.join(x)
|
| | qbackgrounds.append(x)
|
| |
|
| | return qbackgrounds
|
| |
|