| |
| ''' |
| @File : inference_cogview.py |
| @Time : 2021/10/09 19:41:58 |
| @Author : Ming Ding |
| @Contact : dm18@mails.tsinghua.edu.cn |
| ''' |
|
|
| |
| import os |
| import sys |
| import math |
| import random |
| import torch |
| import argparse |
| import stat |
|
|
| from SwissArmyTransformer import mpu, get_args, get_tokenizer |
| from SwissArmyTransformer.model import CachedAutoregressiveModel |
| from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy |
| from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence |
| from SwissArmyTransformer.generation.utils import timed_name, generate_continually |
| from SwissArmyTransformer.training import set_random_seed |
|
|
| import json |
|
|
| def main(args): |
|
|
| ''' |
| 2022/06/17 |
| Modify load_checkpoint to from_pretraind |
| ''' |
| |
| |
| |
| model_path = '/path/to/checkpoints/' |
| |
| model, args = CachedAutoregressiveModel.from_pretrained(args, model_path) |
|
|
| if args.fp16: |
| model = model.half() |
| model = model.to(args.device) |
| set_random_seed(args.seed) |
| model.eval() |
| |
| tokenizer = get_tokenizer(args) |
| |
| |
| end_tokens = [tokenizer.get_command('eos').Id] |
| strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens) |
| |
| def process(raw_text): |
| if args.with_id: |
| query_id, raw_text = raw_text.split('\t') |
| raw_text = json.loads(raw_text) |
| question=raw_text["question"] + "答:" |
| raw_text = question |
| seq = tokenizer._encode(raw_text) |
| if len(seq) != 0 and seq[0] == 20005: |
| seq = seq[1:] |
| seq = [tokenizer.get_command('ENC').Id] + seq |
| seq += [-1] * (args.max_sequence_length - len(seq)) |
| if len(seq) > args.max_sequence_length: |
| raise ValueError('text too long.') |
| |
| seq = torch.cuda.LongTensor(seq, device=args.device) |
| mbz = args.max_inference_batch_size |
| assert args.batch_size < mbz or args.batch_size % mbz == 0 |
| output_list = [] |
| for tim in range(max(args.batch_size // mbz, 1)): |
| output = filling_sequence(model, seq.clone(), |
| batch_size=min(args.batch_size, mbz), |
| strategy=strategy, |
| log_attention_weights=None |
| )[0] |
| if isinstance(output, torch.Tensor): |
| output = list(output) |
|
|
| output_list.extend(output) |
| |
| for i in range(len(output_list)): |
| output = output_list[i].tolist() |
| try: |
| unfinished = output.index(-1) |
| except ValueError: |
| unfinished = len(output) |
| if output[unfinished - 1] in end_tokens: |
| unfinished -= 1 |
| output_list[i] = output[1:unfinished] |
| bog = output.index(tokenizer.get_command('eos').Id) |
| output_list[i] = output[1:bog] + output[bog+1:unfinished] |
| |
| |
| txts = [] |
| for seq in output_list: |
| decode_tokens = tokenizer.DecodeIds(seq) |
| txts.append(decode_tokens) |
| |
| |
| if args.with_id: |
| full_path = os.path.join(args.output_path, query_id + '.txt') |
| else: |
| prefix = raw_text.replace('/', '')[:20] |
| full_path = timed_name(prefix, '.txt', args.output_path) |
| print(txts[0]) |
| test_eval_path = os.path.join(args.output_path, 'test_eval.txt') |
| with open(test_eval_path, 'a', encoding='utf-8') as fout: |
| fout.write(txts[0] + '\n') |
| os.chmod(test_eval_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU) |
|
|
| os.makedirs(args.output_path, exist_ok=True) |
| generate_continually(process, args.input_source) |
|
|
|
|
| if __name__ == "__main__": |
| py_parser = argparse.ArgumentParser(add_help=False) |
| |
| known, args_list = py_parser.parse_known_args() |
| args = get_args(args_list) |
| args = argparse.Namespace(**vars(args), **vars(known)) |
| args.do_train = False |
| |
| with torch.no_grad(): |
| main(args) |
|
|
|
|