| import re |
| import torch |
|
|
| vocabulary = {} |
| token_vocabulary = {} |
| |
|
|
| with open('cl100k_base_vocab_list.txt', 'r', encoding='utf-8') as file: |
| for line_count, line in enumerate(file): |
| line = line.rstrip('\n') |
| if (line.startswith('\'') and line.endswith('\'')) or (line.startswith('\"') and line.endswith('\"')): |
| line = line[1:-1] |
| vocabulary[line] = line_count |
| else: |
| vocabulary[line] = line_count |
| token_vocabulary = {v: k for k, v in vocabulary.items()} |
|
|
| def get_vocabulary(): |
| return vocabulary |
|
|
|
|
| def get_token_vocabulary(): |
| return token_vocabulary |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def tokenize_sequence(sentence): |
| |
| tokenized_seq = [] |
| regex = r'(\s+\w+|\S+)' |
| words = re.split(regex, sentence) |
| for word in words: |
| if word in vocabulary: |
| tokenized_seq.append(vocabulary.get(word, vocabulary.get('<UNK>'))) |
| else: |
| i = 0 |
| while i < len(word): |
| subword_len = 1 |
| for j in range(len(word), i - 1, -1): |
| subword = word[i:j] |
| if subword in vocabulary: |
| tokenized_seq.append(vocabulary.get(subword, vocabulary.get('<UNK>'))) |
| subword_len = len(subword) |
| break |
| if j - i == 1: |
| tokenized_seq.append(vocabulary.get('<UNK>')) |
| break |
| i += subword_len |
| tokenized_seq.append(vocabulary.get('<EOS>')) |
| return tokenized_seq |
|
|
|
|
| def detokenize_sequence(tokenized_seq): |
| decoded_sentence = '' |
| for token in tokenized_seq: |
| decoded_sentence += token_vocabulary[token] |
| return decoded_sentence |
|
|
|
|
| def pad_to_length(seq, length): |
| padded_seq = torch.full((length,), fill_value=0, dtype=torch.long) |
| padded_seq[:len(seq)] = torch.tensor(seq, dtype=torch.long) |
| return padded_seq |
|
|