| import gzip |
| import random |
|
|
| import numpy as np |
| import torch |
| import torch.optim as optim |
| import tqdm |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from Andromeda.model import Andromeda |
|
|
| from Andromeda.core.transformer import Decoder, AndromedaEmbedding, Transformer |
| from Andromeda.core.autoregressive_wrapper import AutoregressiveWrapper |
| |
|
|
| NUM_BATCHES = int(1e5) |
| BATCH_SIZE = 4 |
| GRADIENT_ACCUMULATE_EVERY = 1 |
| LEARNING_RATE = 1e-4 |
| VALIDATE_EVERY = 100 |
| GENERATE_EVERY = 500 |
| GENERATE_LENGTH = 1024 |
| SEQ_LEN = 1024 |
|
|
| |
|
|
| def cycle(loader): |
| while True: |
| for data in loader: |
| yield data |
|
|
| def decode_token(token): |
| return str(chr(max(32, token))) |
|
|
| def decode_tokens(tokens): |
| return ''.join(list(map(decode_token, tokens))) |
|
|
| |
|
|
| model = Transformer( |
| num_tokens=50432, |
| max_seq_len=8192, |
| use_abs_pos_emb=False, |
| embedding_provider=AndromedaEmbedding(), |
| attn_layers=Decoder( |
| dim=2560, |
| depth=32, |
| dim_head=128, |
| heads=24, |
| alibi_pos_bias=True, |
| alibi_num_heads=12, |
| rotary_xpos=True, |
| attn_flash=True, |
| |
| |
| attn_one_kv_head=True, |
| qk_norm=True, |
| attn_qk_norm=True, |
| attn_qk_norm_dim_scale=True |
| ) |
| ) |
|
|
| model = AutoregressiveWrapper(model) |
|
|
| model.cuda() |
|
|
| |
|
|
| with gzip.open('./data/enwik8.gz') as file: |
| data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() |
| train_x, valid_x = np.split(data, [int(90e6)]) |
| data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) |
|
|
| class TextSamplerDataset(Dataset): |
| def __init__(self, data, seq_len): |
| super().__init__() |
| self.data = data |
| self.seq_len = seq_len |
|
|
| def __getitem__(self, index): |
| rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) |
| full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() |
| return full_seq.cuda() |
|
|
| def __len__(self): |
| return self.data.size(0) // self.seq_len |
|
|
| train_dataset = TextSamplerDataset(data_train, SEQ_LEN) |
| val_dataset = TextSamplerDataset(data_val, SEQ_LEN) |
| train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)) |
| val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True)) |
|
|
| |
|
|
| optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) |
|
|
| |
|
|
| for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): |
| model.train() |
|
|
| for __ in range(GRADIENT_ACCUMULATE_EVERY): |
| loss = model(next(train_loader)) |
| (loss / GRADIENT_ACCUMULATE_EVERY).backward() |
|
|
| print(f'training loss: {loss.item()}') |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
| optim.step() |
| optim.zero_grad() |
|
|
| if i % VALIDATE_EVERY == 0: |
| model.eval() |
| with torch.no_grad(): |
| loss = model(next(val_loader)) |
| print(f'validation loss: {loss.item()}') |
| |
| |
| torch.save(model.state_dict(), f"./model_{i}.pth") |
|
|
| if i % GENERATE_EVERY == 0: |
| model.eval() |
| inp = random.choice(val_dataset)[:-1] |
| prime = decode_tokens(inp) |
| print('%s \n\n %s', (prime, '*' * 100)) |
|
|
| sample = model.generate(inp, GENERATE_LENGTH) |
| output_str = decode_tokens(sample) |
| print(output_str) |