| |
|
|
| class CharTokenizer: |
| def __init__(self): |
| self.chars = set() |
| self.char2idx = {} |
| self.idx2char = {} |
|
|
| def fit(self, texts): |
| for text in texts: |
| self.chars.update(set(text)) |
| self.chars = sorted(list(self.chars)) |
| self.char2idx = {char: idx for idx, char in enumerate(self.chars)} |
| self.idx2char = {idx: char for char, idx in self.char2idx.items()} |
|
|
| def encode(self, text, max_length=None): |
| encoded = [self.char2idx[char] for char in text if char in self.char2idx] |
| if max_length: |
| encoded = encoded[:max_length] + [0] * (max_length - len(encoded)) |
| return encoded |
|
|
| def decode(self, tokens): |
| return ''.join([self.idx2char[token] for token in tokens if token in self.idx2char]) |
|
|