| import torch |
| import pickle |
|
|
| |
| TOKEN = '.' |
|
|
| |
| words = open('data/names.txt','r').read().splitlines() |
|
|
| |
| vocab = sorted(list(set(''.join(words)) | {TOKEN})) |
|
|
| |
| n = len(vocab) |
| N = torch.zeros((n,n), dtype = torch.int32) |
|
|
| |
| char_to_int = {char:i for i,char in enumerate(vocab)} |
| int_to_char = {value:key for key,value in char_to_int.items()} |
|
|
| |
| for word in words: |
| chars = [TOKEN] + list(word) + [TOKEN] |
| for ch1,ch2 in zip(chars,chars[1:]): |
| ix1 = char_to_int[ch1] |
| ix2 = char_to_int[ch2] |
| N[ix1,ix2] += 1 |
|
|
| |
| P = N.float() |
| P /= P.sum(1, keepdim = True) |
|
|
| |
| with open('model/bigrams.pkl', 'wb') as file: |
| pickle.dump([P,char_to_int,int_to_char], file) |