| import json |
| from collections import Counter |
| import jieba |
| import os |
| abs_path = os.path.dirname(os.path.abspath(__file__)) |
| jieba.load_userdict(os.path.join(abs_path, '../../utils/key_technical_words.txt')) |
|
|
| class Tokenizer(object): |
| def __init__(self, args): |
| self.ann_path = args.ann_path |
| self.threshold = args.threshold |
| self.ann = json.loads(open(self.ann_path, 'r', encoding='utf-8-sig').read()) |
| self.dict_pth = args.dict_pth |
| self.token2idx, self.idx2token = self.create_vocabulary() |
|
|
| def create_vocabulary(self): |
| if self.dict_pth != ' ': |
| word_dict = json.loads(open(self.dict_pth, 'r', encoding="utf_8_sig").read()) |
| word_dict[1] = {int(k): v for k, v in word_dict[1].items()} |
| return word_dict[0], word_dict[1] |
| else: |
| total_tokens = [] |
| split_list = ['train', 'test', 'val'] |
| for split in split_list: |
| for example in self.ann[split]: |
| tokens = list(jieba.lcut(example['finding'])) |
| for token in tokens: |
| total_tokens.append(token) |
| counter = Counter(total_tokens) |
| vocab = [k for k, v in counter.items()] + ['<unk>'] |
| token2idx, idx2token = {}, {} |
| for idx, token in enumerate(vocab): |
| token2idx[token] = idx + 1 |
| idx2token[idx + 1] = token |
| with open('E:/Captionv0/Code/SGF/utils/breast_dict.txt', 'w', encoding='utf-8-sig') as f: |
| f.write(json.dumps([token2idx, idx2token])) |
| return token2idx, idx2token |
|
|
|
|
| def get_token_by_id(self, id): |
| return self.idx2token[id] |
|
|
| def get_id_by_token(self, token): |
| if token not in self.token2idx: |
| return self.token2idx['<unk>'] |
| return self.token2idx[token] |
|
|
| def get_vocab_size(self): |
| return len(self.token2idx) |
|
|
|
|
| def __call__(self, report): |
| tokens = list(jieba.cut(report)) |
| ids = [] |
| for token in tokens: |
| ids.append(self.get_id_by_token(token)) |
| ids = [0] + ids + [0] |
| return ids |
|
|
| def decode(self, ids): |
| txt = '' |
| for i, idx in enumerate(ids): |
| if idx > 0: |
| if i >= 1: |
| txt += ' ' |
| txt += self.idx2token[idx] |
| else: |
| break |
| return txt |
|
|
| def decode_list(self, ids): |
| txt = [] |
| for i, idx in enumerate(ids): |
| if idx > 0: |
| txt.append(self.idx2token[idx]) |
| else:txt.append('<start/end>') |
|
|
| return txt |
|
|
| def decode_batch(self, ids_batch): |
| out = [] |
| for ids in ids_batch: |
| out.append(self.decode(ids)) |
| return out |
|
|
| def decode_batch_list(self, ids_batch): |
| out = [] |
| for ids in ids_batch: |
| out.append(self.decode_list(ids)) |
| return out |
|
|
|
|
|
|
|
|