| | import gzip |
| | import random |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.optim as optim |
| | import tqdm |
| | from torch.utils.data import DataLoader, Dataset |
| |
|
| | from Andromeda.model import Andromeda |
| |
|
| | from Andromeda.core.transformer import Decoder, AndromedaEmbedding, Transformer |
| | from Andromeda.core.autoregressive_wrapper import AutoregressiveWrapper |
| | |
| |
|
| | NUM_BATCHES = int(1e5) |
| | BATCH_SIZE = 4 |
| | GRADIENT_ACCUMULATE_EVERY = 1 |
| | LEARNING_RATE = 1e-4 |
| | VALIDATE_EVERY = 100 |
| | GENERATE_EVERY = 500 |
| | GENERATE_LENGTH = 1024 |
| | SEQ_LEN = 1024 |
| |
|
| | |
| |
|
| | def cycle(loader): |
| | while True: |
| | for data in loader: |
| | yield data |
| |
|
| | def decode_token(token): |
| | return str(chr(max(32, token))) |
| |
|
| | def decode_tokens(tokens): |
| | return ''.join(list(map(decode_token, tokens))) |
| |
|
| | |
| |
|
| | model = Transformer( |
| | num_tokens=50432, |
| | max_seq_len=8192, |
| | use_abs_pos_emb=False, |
| | embedding_provider=AndromedaEmbedding(), |
| | attn_layers=Decoder( |
| | dim=2560, |
| | depth=32, |
| | dim_head=128, |
| | heads=24, |
| | alibi_pos_bias=True, |
| | alibi_num_heads=12, |
| | rotary_xpos=True, |
| | attn_flash=True, |
| | |
| | |
| | attn_one_kv_head=True, |
| | qk_norm=True, |
| | attn_qk_norm=True, |
| | attn_qk_norm_dim_scale=True |
| | ) |
| | ) |
| |
|
| | model = AutoregressiveWrapper(model) |
| |
|
| | model.cuda() |
| |
|
| | |
| |
|
| | with gzip.open('./data/enwik8.gz') as file: |
| | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() |
| | train_x, valid_x = np.split(data, [int(90e6)]) |
| | data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) |
| |
|
| | class TextSamplerDataset(Dataset): |
| | def __init__(self, data, seq_len): |
| | super().__init__() |
| | self.data = data |
| | self.seq_len = seq_len |
| |
|
| | def __getitem__(self, index): |
| | rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,)) |
| | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() |
| | return full_seq.cuda() |
| |
|
| | def __len__(self): |
| | return self.data.size(0) // self.seq_len |
| |
|
| | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) |
| | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) |
| | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)) |
| | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True)) |
| |
|
| | |
| |
|
| | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) |
| |
|
| | |
| |
|
| | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): |
| | model.train() |
| |
|
| | for __ in range(GRADIENT_ACCUMULATE_EVERY): |
| | loss = model(next(train_loader)) |
| | (loss / GRADIENT_ACCUMULATE_EVERY).backward() |
| |
|
| | print(f'training loss: {loss.item()}') |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
| | optim.step() |
| | optim.zero_grad() |
| |
|
| | if i % VALIDATE_EVERY == 0: |
| | model.eval() |
| | with torch.no_grad(): |
| | loss = model(next(val_loader)) |
| | print(f'validation loss: {loss.item()}') |
| | |
| | |
| | torch.save(model.state_dict(), f"./model_{i}.pth") |
| |
|
| | if i % GENERATE_EVERY == 0: |
| | model.eval() |
| | inp = random.choice(val_dataset)[:-1] |
| | prime = decode_tokens(inp) |
| | print('%s \n\n %s', (prime, '*' * 100)) |
| |
|
| | sample = model.generate(inp, GENERATE_LENGTH) |
| | output_str = decode_tokens(sample) |
| | print(output_str) |