| | import os |
| | import struct |
| |
|
| | |
| | MAX_VOCAB_SIZE = 32000 |
| | MAX_WORD_LEN = 16 |
| |
|
| | def ERROR(message, *args): |
| | """Prints an error message to stderr and exits.""" |
| | import sys |
| | sys.stderr.write(message % args) |
| | sys.exit(1) |
| |
|
| | def INFO(message, *args): |
| | """Prints an informational message to stdout.""" |
| | print(message % args) |
| |
|
| | class Tokenizer: |
| | def __init__(self, fname=None): |
| | self.vocab_size = 0 |
| | self.vocab = [''] * MAX_VOCAB_SIZE |
| |
|
| | if fname: |
| | self.load_tokenizer(fname) |
| |
|
| | INFO("vocabulary size: %d (%d max)", self.vocab_size, MAX_VOCAB_SIZE) |
| | INFO("max token length: %d", MAX_WORD_LEN) |
| | |
| | structure_size = self.vocab_size * MAX_WORD_LEN |
| | INFO("size of structure: %d bytes", structure_size) |
| |
|
| | def add_word(self, word): |
| | """Adds a word to the vocabulary.""" |
| | if self.vocab_size >= MAX_VOCAB_SIZE: |
| | return -1 |
| | |
| | if len(word) >= MAX_WORD_LEN: |
| | word = word[:MAX_WORD_LEN - 1] |
| | self.vocab[self.vocab_size] = word |
| | self.vocab_size += 1 |
| | return self.vocab_size - 1 |
| |
|
| | def encode_word(self, word): |
| | """Encodes a word into its corresponding ID using binary search.""" |
| | left = 0 |
| | right = self.vocab_size - 1 |
| |
|
| | while left <= right: |
| | mid = left + (right - left) // 2 |
| | cmp = self._compare(word, self.vocab[mid]) |
| |
|
| | if cmp == 0: |
| | return mid |
| | elif cmp < 0: |
| | right = mid - 1 |
| | else: |
| | left = mid + 1 |
| |
|
| | return -1 |
| |
|
| | def encode_stream(self, stream): |
| | """ |
| | Encodes a word from a stream. |
| | |
| | Args: |
| | stream (list of str): A list containing the characters of the stream. |
| | |
| | Returns: |
| | int: The ID of the encoded word. |
| | """ |
| | word = '' |
| | id = -1 |
| | j = 0 |
| |
|
| | for i in range(min(MAX_WORD_LEN, len(stream))): |
| | word += stream[i] |
| | tmp = self.encode_word(word) |
| | if tmp != -1: |
| | id = tmp |
| | j = i + 1 |
| |
|
| | |
| | del stream[:j] |
| |
|
| | return id |
| |
|
| | def encode_file(self, fd): |
| | """ |
| | Encodes a word from a file descriptor. |
| | |
| | Args: |
| | fd (file object): The file to encode from. |
| | |
| | Returns: |
| | int: The ID of the encoded word. |
| | """ |
| | word = '' |
| | id = -1 |
| | j = 0 |
| |
|
| | for _ in range(MAX_WORD_LEN): |
| | c = fd.read(1) |
| | if not c: |
| | break |
| | char = c.decode('utf-8', errors='ignore') |
| | word += char |
| | tmp = self.encode_word(word) |
| | if tmp != -1: |
| | id = tmp |
| | j = len(word) |
| |
|
| | |
| | to_seek = MAX_WORD_LEN - j |
| | if to_seek > 0: |
| | fd.seek(-to_seek, os.SEEK_CUR) |
| |
|
| | return id |
| |
|
| | def decode(self, id): |
| | """Decodes an ID back into its corresponding word.""" |
| | if 0 <= id < self.vocab_size: |
| | return self.vocab[id] |
| | return None |
| |
|
| | def decode_file(self, fd): |
| | """ |
| | Decodes an ID read from a file descriptor back into its corresponding word. |
| | |
| | Args: |
| | fd (file object): The file to decode from. |
| | |
| | Returns: |
| | str: The decoded word. |
| | """ |
| | data = fd.read(4) |
| | if len(data) < 4: |
| | ERROR("read EOF from file\n") |
| | |
| | id = struct.unpack('i', data)[0] |
| | return self.decode(id) |
| |
|
| | def save_vocab(self, fname): |
| | """Saves the vocabulary to a text file, one word per line.""" |
| | try: |
| | with open(fname, 'w', encoding='utf-8') as f: |
| | max_len = 0 |
| | for i in range(self.vocab_size): |
| | word = self.vocab[i] |
| | f.write(word + '\n') |
| | if len(word) > max_len: |
| | max_len = len(word) |
| | INFO("wrote %d tokens to file \"%s\"\nMax token length was %d", |
| | self.vocab_size, fname, max_len) |
| | except IOError as e: |
| | ERROR("failed to write to \"%s\": %s\n", fname, str(e)) |
| |
|
| | def load_vocab(self, fname): |
| | """Loads the vocabulary from a text file, expecting one word per line.""" |
| | try: |
| | with open(fname, 'r', encoding='utf-8') as f: |
| | for line in f: |
| | word = line.strip() |
| | if word: |
| | self.add_word(word) |
| | except IOError as e: |
| | ERROR("failed to open \"%s\": %s\n", fname, str(e)) |
| |
|
| | def save_tokenizer(self, fname): |
| | """Saves the tokenizer's vocabulary to a binary file.""" |
| | try: |
| | with open(fname, 'wb') as f: |
| | for i in range(MAX_VOCAB_SIZE): |
| | if i < self.vocab_size: |
| | word = self.vocab[i].encode('utf-8') |
| | if len(word) >= MAX_WORD_LEN: |
| | word = word[:MAX_WORD_LEN - 1] |
| | word += b'\0' * (MAX_WORD_LEN - len(word)) |
| | else: |
| | word = b'\0' * MAX_WORD_LEN |
| | f.write(word) |
| | INFO("wrote %d bytes (%d tokens) to \"%s\"", |
| | MAX_VOCAB_SIZE * MAX_WORD_LEN, self.vocab_size, fname) |
| | except IOError as e: |
| | ERROR("failed to write to \"%s\": %s\n", fname, str(e)) |
| |
|
| | def load_tokenizer(self, fname): |
| | """Loads the tokenizer's vocabulary from a binary file.""" |
| | try: |
| | with open(fname, 'rb') as f: |
| | for i in range(MAX_VOCAB_SIZE): |
| | bytes_word = f.read(MAX_WORD_LEN) |
| | if not bytes_word or len(bytes_word) < MAX_WORD_LEN: |
| | break |
| | |
| | word = bytes_word.split(b'\0', 1)[0].decode('utf-8', errors='ignore') |
| | if word: |
| | self.vocab[i] = word |
| | self.vocab_size += 1 |
| | INFO("read %d bytes (%d tokens) from \"%s\"", |
| | self.vocab_size * MAX_WORD_LEN, self.vocab_size, fname) |
| | except IOError as e: |
| | ERROR("failed to read from \"%s\": %s\n", fname, str(e)) |
| |
|
| | @staticmethod |
| | def _compare(a, b): |
| | """Helper method to compare two strings similar to strcmp in C.""" |
| | return (a > b) - (a < b) |
| |
|