| | #pragma once |
| |
|
| | #include <stdio.h> |
| | #include <stdlib.h> |
| | #include <string.h> |
| | #include <unistd.h> |
| | #include <fcntl.h> |
| |
|
| |
|
| | #include "utils.h" |
| |
|
| |
|
| |
|
| |
|
| | #ifndef MAX_VOCAB_SIZE |
| | #define MAX_VOCAB_SIZE 32000 |
| | #endif |
| |
|
| | #ifndef MAX_WORD_LEN |
| | #define MAX_WORD_LEN 16 |
| | #endif |
| |
|
| |
|
| |
|
| |
|
| | STRUCTURE(tokenizer_t, |
| | int vocab_size; |
| | char vocab[MAX_VOCAB_SIZE][MAX_WORD_LEN]; |
| |
|
| | int (*add_word) (tokenizer_t *, char *); |
| | |
| | int (*encode_word) (tokenizer_t *, char *); |
| | int (*encode_stream) (tokenizer_t *, char **); |
| | int (*encode_file) (tokenizer_t *, int); |
| |
|
| | char *(*decode) (tokenizer_t *, int); |
| | char *(*decode_file) (tokenizer_t *, int); |
| |
|
| | void (*save_vocab) (tokenizer_t *, char *); |
| | void (*load_vocab) (tokenizer_t *, char *); |
| |
|
| | void (*save_tokenizer) (tokenizer_t *, char *); |
| | void (*load_tokenizer) (tokenizer_t *, char *); |
| | ); |
| |
|
| |
|
| |
|
| |
|
| | char rdchar(int fd) { char c; return read(fd, &c, sizeof(c)) < sizeof(c) ? EOF : c; } |
| | int rdint(int fd) { int d; return read(fd, &d, sizeof(d)) < sizeof(d) ? EOF : d; } |
| | void seekback(int fd, int n) { lseek(fd, -n, SEEK_CUR); } |
| |
|
| | char *strcpy(char *dst, const char *src) { for (; (*dst++ = *src++); ); return dst; } |
| | int strcmp(const char *a, const char *b) { |
| | for (; *a && *a == *b; ++a, ++b); |
| | return *a - *b; |
| | } |
| |
|
| |
|
| |
|
| |
|
| | int tokenizer_add_word(tokenizer_t *t, char *word) { |
| | if (t->vocab_size >= MAX_VOCAB_SIZE) return -1; |
| | strcpy(t->vocab[t->vocab_size], word); |
| | return t->vocab_size++; |
| | } |
| |
|
| |
|
| | int tokenizer_encode_word(tokenizer_t *t, char* word) { |
| | int left = 0, |
| | right = t->vocab_size - 1; |
| |
|
| | for (; left <= right; ) { |
| | int mid = left + (right - left) / 2, |
| | cmp = strcmp(word, t->vocab[mid]); |
| | |
| | if (cmp == 0) return mid; |
| | else if (cmp < 0) right = mid - 1; |
| | else left = mid + 1; |
| | } |
| |
|
| | return -1; |
| | } |
| |
|
| | int tokenizer_encode_stream(tokenizer_t *t, char **stream) { |
| | char word[MAX_WORD_LEN] = {}; |
| | int id = -1, i = 0, j = 0; |
| |
|
| | for (; (*stream)[i] && i < MAX_WORD_LEN; ++i) { |
| | word[i] = (*stream)[i]; |
| | |
| | int tmp = t->encode_word(t, word); |
| | if (tmp != -1) id = tmp, j = i + 1; |
| | } |
| |
|
| | *stream += j; |
| | return id; |
| | } |
| |
|
| |
|
| | int tokenizer_encode_file(tokenizer_t *t, int fd) { |
| | char c, word[MAX_WORD_LEN] = {}; |
| | int id = -1, i = 0, j = 0; |
| |
|
| | for (; (c = rdchar(fd)) != EOF && i < MAX_WORD_LEN; ++i) { |
| | word[i] = c; |
| |
|
| | int tmp = t->encode_word(t, word); |
| | if (tmp != -1) id = tmp, j = i + 1; |
| | } |
| |
|
| | seekback(fd, MAX_WORD_LEN - j + 1); |
| | return id; |
| | } |
| |
|
| |
|
| | char *tokenizer_decode(tokenizer_t *t, int id) { return t->vocab[id]; } |
| |
|
| | char *tokenizer_decode_file(tokenizer_t *t, int fd) { |
| | int id = rdint(fd); |
| | if (id == EOF) ERROR("read EOF from file\n"); |
| |
|
| | return t->decode(t, id); |
| | } |
| |
|
| |
|
| | void tokenizer_save_vocab(tokenizer_t *t, char *fname) { |
| | int fd = open(fname, O_WRONLY | O_CREAT | O_TRUNC, 0644); |
| | if (fd < 0) ERROR("failed to open \"%s\"\n", fname); |
| |
|
| | int max_len = 0; |
| | for (int i = 0; i < t->vocab_size; ++i) { |
| | char *str = t->vocab[i]; |
| | |
| | int len = strlen(str), |
| | n = write(fd, str, len); |
| |
|
| | max_len = len > max_len ? len : max_len; |
| |
|
| | if (n != len) ERROR("failed to write to %s, only wrote %d bytes out of %d\n", fname, n, len); |
| | } |
| |
|
| | printf("wrote %d tokens to file \"%s\"\nMax token length was %d\n", t->vocab_size, fname, max_len); |
| | close(fd); |
| | } |
| |
|
| | void tokenizer_load_vocab(tokenizer_t *t, char *fname) { |
| | int fd = open(fname, O_RDONLY); |
| | if (fd < 0) ERROR("failed to open \"%s\"\n", fname); |
| |
|
| | char c, word[MAX_WORD_LEN]; |
| | for (; (c = rdchar(fd)) != EOF; ) { |
| | for (int i = 0; i < MAX_WORD_LEN; ++i, c = rdchar(fd)) { |
| | word[i] = c; |
| |
|
| | if (word[i] == EOF || word[i] == '\n') { |
| | word[i] = '\0'; |
| | break; |
| | } |
| | } |
| |
|
| | t->add_word(t, word); |
| | } |
| | } |
| |
|
| |
|
| | void tokenizer_save_tokenizer(tokenizer_t *t, char *fname) { |
| | int fd = open(fname, O_WRONLY | O_CREAT | O_TRUNC, 0644); |
| | if (fd < 0) ERROR("failed to open \"%s\"\n", fname); |
| |
|
| | int n = write(fd, t->vocab, sizeof(t->vocab)); |
| | if (n != sizeof(t->vocab)) ERROR("failed to write to %s, only wrote %d bytes out of %ld\n", fname, n, sizeof(t->vocab)); |
| |
|
| | printf("wrote %d bytes (%d tokens) to \"%s\"\n", n, t->vocab_size, fname); |
| | close(fd); |
| | } |
| |
|
| | void tokenizer_load_tokenizer(tokenizer_t *t, char *fname) { |
| | int fd = open(fname, O_RDONLY); |
| | if (fd < 0) ERROR("failed to open \"%s\"\n", fname); |
| |
|
| | int n = read(fd, t->vocab, sizeof(t->vocab)); |
| | if (n != sizeof(t->vocab)) ERROR("failed to read from %s, only read %d bytes out of %ld\n", fname, n, sizeof(t->vocab)); |
| |
|
| | t->vocab_size = n / MAX_WORD_LEN; |
| |
|
| | printf("read %d bytes (%d tokens) from \"%s\"\n", n, t->vocab_size, fname); |
| | close(fd); |
| | } |
| |
|
| |
|
| |
|
| |
|
| | tokenizer_t _tokenizer; |
| | tokenizer_t *Tokenizer(char *fname) { |
| | _tokenizer = (tokenizer_t) { |
| | .vocab_size = 0, |
| |
|
| | .add_word = tokenizer_add_word, |
| | |
| | .encode_word = tokenizer_encode_word, |
| | .encode_stream = tokenizer_encode_stream, |
| | .encode_file = tokenizer_encode_file, |
| |
|
| | .decode = tokenizer_decode, |
| | .decode_file = tokenizer_decode_file, |
| |
|
| | .save_vocab = tokenizer_save_vocab, |
| | .load_vocab = tokenizer_load_vocab, |
| | |
| | .save_tokenizer = tokenizer_save_tokenizer, |
| | .load_tokenizer = tokenizer_load_tokenizer |
| | }; |
| |
|
| | if (fname) _tokenizer.load_tokenizer(&_tokenizer, fname); |
| |
|
| | INFO("vocabulary size: %d (%d max)\n", _tokenizer.vocab_size, MAX_VOCAB_SIZE); |
| | INFO("max token length: %d\n", MAX_WORD_LEN); |
| | INFO("size of structure: %ld bytes\n", sizeof(tokenizer_t)); |
| |
|
| | return &_tokenizer; |
| | } |