| | import matplotlib.pyplot as plt |
| | import time |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from torchvision import datasets, transforms |
| | import numpy as np |
| | import tracemalloc |
| |
|
| | |
| | from Andromeda.model import Andromeda |
| | from Andromeda.utils.stable_adamw import StableAdamWUnfused |
| |
|
| | torch.manual_seed(0) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(0) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | class AndromedaModelTest: |
| | def __init__(self): |
| | self.model = Andromeda |
| | self.optimizer = StableAdamWUnfused() |
| | self.loss_function = torch.nn.CrossEntropyLoss() |
| | self.test_input = torch.randint(0, 256, (1, 1024)).cuda() |
| |
|
| | def test_forward_pass(self): |
| | output = self.model(self.test_input) |
| | assert output.shape == (1, 1024, 64007), "Forward pass output shape mismatch" |
| |
|
| | def test_backward_pass(self): |
| | self.optimizer.zero_grad() |
| | output = self.model(self.test_input) |
| | loss = self.loss_function(output, self.test_input) |
| |
|
| | loss.backward() |
| | for name, parameter in self.model.named_parameters(): |
| | assert not torch.isnan(parameter.grad().any()), f"Gradient for {name} contains NaNs" |
| | assert not torch.isinf(parameter.grad().any()), f"Gradient for {name} contains Infs" |
| |
|
| |
|
| | def test_optimizer_step(self): |
| | initial_params = [param.clone() for param in self.model_parameters()] |
| | output = self.model(self.test_input) |
| | loss = self.loss_function(output, self.test_input) |
| |
|
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | self.optimizer.step() |
| | for initial_param, param in zip(initial_params, self.model.parameters()): |
| | assert not torch.equal(initial_param, param), "Model Parameters did not change after an optimizer step" |
| |
|
| |
|
| |
|
| |
|
| |
|
| | class SpeedMetrics: |
| | def __init__(self, model): |
| | self.model = model.to(device) |
| |
|
| | def forward_pass_time(self): |
| | start_time = time.time() |
| | self.model.decoder.forward(torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long))[0] |
| | end_time = time.time() |
| | return end_time - start_time |
| | |
| | def backward_pass_time(self): |
| | model_input = self.model.decoder.forward(torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long))[0] |
| | start_time = time.time() |
| | loss = torch.nn.CrossEntropyLoss()(model_input, torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long)) |
| | loss.backward() |
| | end_time = time.time() |
| | return end_time - start_time |
| | |
| | def end_to_end_latency(self): |
| | start_time = time.time() |
| | self.model.forward(torch.randint(0, 50304, (1, 8192), device=device, dtype=torch.long)) |
| | end_time = time.time() |
| | return end_time - start_time |
| | |
| |
|
| |
|
| | class ScalabilityMetrics: |
| | def __init__(self, model, dataset): |
| | self.model = model |
| | self.dataset = dataset |
| | self.dataloader = DataLoader(dataset, batch_size=32) |
| |
|
| | def throughput(self): |
| | start_time = time.time() |
| | for i, data in enumerate(self.dataloader, 0): |
| | self.model.forward(data) |
| | end_time = time.time() |
| | return len(self.dataset) / (end_time - start_time) |
| |
|
| |
|
| | class ConsistencyMetrics: |
| | def __init__(self, model): |
| | self.model = model |
| |
|
| | def consistency_over_time(self): |
| | consistency_times = [] |
| | outputs_list = [] |
| | for _ in range(10): |
| | start_time = time.time() |
| | outputs = self.model.forward(torch.randint(0, 50304, (1, 8192))) |
| | end_time = time.time() |
| | consistency_times.append(end_time - start_time) |
| | outputs_list.append(outputs.detach().numpy()) |
| |
|
| | initial_output = outputs_list[0] |
| | consistency_score = 0 |
| | for output in outputs_list[1:]: |
| | if np.array_equal(initial_output, output): |
| | consistency_score += 1 |
| | consistency_score = consistency_score / len(outputs_list) * 100 |
| |
|
| | return consistency_times, consistency_score |
| |
|
| |
|
| | class MemoryMetrics: |
| | def __init__(self, model): |
| | self.model = model |
| |
|
| | def memory_footprint(self): |
| | tracemalloc.start() |
| | self.model.forward(torch.randint(0, 50304, (1, 8192))) |
| | current, peak = tracemalloc.get_traced_memory() |
| | tracemalloc.stop() |
| | return current, peak |
| |
|
| |
|
| | class SequenceMetrics: |
| | def __init__(self, model): |
| | self.model = model |
| |
|
| | def sequence_length_impact(self): |
| | seq_lengths = [1024, 2048, 4096, 8192] |
| | seq_impact_times = [] |
| | for length in seq_lengths: |
| | start_time = time.time() |
| | self.model.forward(torch.randint(0, 50304, (1, length))) |
| | end_time = time.time() |
| | seq_impact_times.append(end_time - start_time) |
| | return seq_lengths, seq_impact_times |
| |
|
| |
|
| |
|
| |
|
| | class FlopsBenchmark: |
| | def __init__(self, model, bsz=32, d_model=1024, num_heads=8, sequence_lengths=list(range(500, 32001, 500))): |
| | self.bsz = bsz |
| | self.d_model = d_model |
| | self.num_heads = num_heads |
| | self.sequence_lengths = sequence_lengths |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.dtype=torch.float32 |
| | self.model = model.to(self.device) |
| |
|
| | def benchmark(self): |
| | time_taken = [] |
| | tflops_per_s = [] |
| |
|
| | for seq_len in self.sequence_lengths: |
| | x = torch.randn(self.bsz, seq_len, self.d_model).to(self.device).type(self.dtype) |
| | torch.cuda.synchronize() |
| |
|
| | start = time.time() |
| | self.model(x) |
| | torch.cuda.synchronize() |
| | elapsed = time.time() - start |
| |
|
| | time_taken.append(elapsed) |
| | total_flops = 4 * seq_len **2 * (self.d_model // self.num_heads) * self.num_heads |
| | tflops_per_s.append(total_flops / elapsed / 1e12) |
| |
|
| | for seq_len, elapsed, tflops in zip(self.sequence_lengths, time_taken, tflops_per_s): |
| | print(f"Sequence length: {seq_len}, Time elapsed: {elapsed} s, TFLOPs/s: {tflops}") |
| |
|
| |
|
| | |
| | test_dataset = datasets.FakeData(size=1000, transform=transforms.ToTensor()) |
| |
|
| | |
| | model = Andromeda( |
| | num_tokens=50304, |
| | dim=1024, |
| | depth=24, |
| | dim_head=128, |
| | heads=8, |
| | alibi_num_heads=4 |
| | ) |
| |
|
| |
|
| | |
| | |
| | speed_metrics = SpeedMetrics(model) |
| | forward_pass_time = speed_metrics.forward_pass_time() |
| | backward_pass_time = speed_metrics.backward_pass_time() |
| | end_to_end_latency = speed_metrics.end_to_end_latency() |
| |
|
| |
|
| | |
| | scalability_metrics = ScalabilityMetrics(model, test_dataset) |
| | throughput = scalability_metrics.throughput() |
| |
|
| |
|
| | |
| | consistency_metrics = ConsistencyMetrics(model) |
| | consistency_times, consistency_score = consistency_metrics.consistency_over_time() |
| |
|
| |
|
| | |
| | memory_metrics = MemoryMetrics(model) |
| | current, peak = memory_metrics.memory_footprint() |
| |
|
| | |
| | sequence_metrics = SequenceMetrics(model) |
| | seq_lengths, seq_impact_times = sequence_metrics.sequence_length_impact() |
| |
|
| |
|
| |
|
| | |
| |
|
| | flops_benchmark = FlopsBenchmark(model) |
| | flops_benchmark.benchmark() |
| |
|
| | |
| | fig, axs = plt.subplots(3) |
| |
|
| | axs[0].bar(["Forward Pass Time", "Backward Pass Time", "End-to-End Latency"], [forward_pass_time, backward_pass_time, end_to_end_latency]) |
| | axs[0].set_title('Speed Metrics') |
| | axs[0].set_xlabel('Metrics') |
| | axs[0].set_ylabel('Time (seconds)') |
| |
|
| | axs[1].bar(seq_lengths, seq_impact_times) |
| | axs[1].set_title('Sequence Length Impact') |
| | axs[1].set_xlabel('Sequence Length') |
| | axs[1].set_ylabel('Time (seconds)') |
| |
|
| | axs[2].plot(list(range(1, 11)), consistency_times) |
| | axs[2].set_title('Consistency Over Time') |
| | axs[2].set_xlabel('Run Number') |
| | axs[2].set_ylabel('Time (seconds)') |
| |
|
| | plt.tight_layout() |
| | plt.show() |
| |
|
| | print(f"Throughput: {throughput} instances/second") |
| | print(f"Memory used: {current / 10**6}MB; Peak: {peak / 10**6}MB") |
| |
|
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | model_test = AndromedaModelTest() |
| | model_test.test_forward_pass() |
| | model_test.test_backward_pass() |
| | model_test.test_optimizer_step() |