| """
|
| Performance benchmarking for BitLinear vs nn.Linear.
|
|
|
| This script benchmarks forward pass time for various layer sizes and batch sizes,
|
| comparing BitLinear (Python implementation) with standard nn.Linear.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import time
|
| from bitlinear import BitLinear, MultiTernaryLinear
|
| import sys
|
|
|
|
|
| def benchmark_forward_pass(layer, x, n_warmup=10, n_runs=100):
|
| """
|
| Benchmark forward pass time for a layer.
|
|
|
| Args:
|
| layer: PyTorch module to benchmark
|
| x: Input tensor
|
| n_warmup: Number of warmup iterations
|
| n_runs: Number of benchmark iterations
|
|
|
| Returns:
|
| Average time per forward pass in milliseconds
|
| """
|
|
|
| with torch.no_grad():
|
| for _ in range(n_warmup):
|
| _ = layer(x)
|
|
|
|
|
| start_time = time.time()
|
| with torch.no_grad():
|
| for _ in range(n_runs):
|
| _ = layer(x)
|
| end_time = time.time()
|
|
|
| avg_time_ms = (end_time - start_time) / n_runs * 1000
|
| return avg_time_ms
|
|
|
|
|
| def run_benchmarks():
|
| """Run comprehensive benchmarks."""
|
|
|
| print("=" * 100)
|
| print("BitLinear Performance Benchmarks")
|
| print("=" * 100)
|
| print(f"\nPyTorch version: {torch.__version__}")
|
| print(f"Device: CPU")
|
| print(f"Number of warmup runs: 10")
|
| print(f"Number of benchmark runs: 100")
|
|
|
|
|
| layer_sizes = [
|
| (512, 512),
|
| (1024, 1024),
|
| (2048, 2048),
|
| (4096, 4096),
|
| ]
|
|
|
| batch_configs = [
|
| (1, 1),
|
| (16, 128),
|
| (32, 128),
|
| (64, 128),
|
| ]
|
|
|
| results = []
|
|
|
| for in_features, out_features in layer_sizes:
|
| print(f"\n{'=' * 100}")
|
| print(f"Layer Size: {in_features} → {out_features}")
|
| print(f"{'=' * 100}")
|
|
|
| for batch_size, seq_len in batch_configs:
|
| print(f"\nBatch: {batch_size}, Seq Length: {seq_len}")
|
| print("-" * 100)
|
|
|
|
|
| x = torch.randn(batch_size, seq_len, in_features)
|
|
|
|
|
| linear = nn.Linear(in_features, out_features)
|
| bitlinear = BitLinear.from_linear(linear)
|
| multi_ternary = MultiTernaryLinear.from_linear(linear, k=2)
|
|
|
|
|
| time_linear = benchmark_forward_pass(linear, x)
|
|
|
|
|
| time_bitlinear = benchmark_forward_pass(bitlinear, x)
|
|
|
|
|
| time_multi = benchmark_forward_pass(multi_ternary, x)
|
|
|
|
|
| speedup_bit = time_linear / time_bitlinear
|
| speedup_multi = time_linear / time_multi
|
|
|
|
|
| print(f"nn.Linear: {time_linear:8.3f} ms")
|
| print(f"BitLinear: {time_bitlinear:8.3f} ms (speedup: {speedup_bit:5.2f}x)")
|
| print(f"MultiTernaryLinear: {time_multi:8.3f} ms (speedup: {speedup_multi:5.2f}x)")
|
|
|
|
|
| results.append({
|
| 'in_features': in_features,
|
| 'out_features': out_features,
|
| 'batch_size': batch_size,
|
| 'seq_len': seq_len,
|
| 'time_linear': time_linear,
|
| 'time_bitlinear': time_bitlinear,
|
| 'time_multi': time_multi,
|
| 'speedup_bit': speedup_bit,
|
| 'speedup_multi': speedup_multi,
|
| })
|
|
|
|
|
| print(f"\n\n{'=' * 100}")
|
| print("Summary Table (Markdown Format)")
|
| print(f"{'=' * 100}\n")
|
|
|
| print("| Layer Size | Batch | Seq Len | nn.Linear (ms) | BitLinear (ms) | Speedup | Multi-Ternary (ms) | Speedup |")
|
| print("|------------|-------|---------|----------------|----------------|---------|--------------------|---------| ")
|
|
|
| for r in results:
|
| print(f"| {r['in_features']}×{r['out_features']:<4} | {r['batch_size']:5} | {r['seq_len']:7} | "
|
| f"{r['time_linear']:14.3f} | {r['time_bitlinear']:14.3f} | {r['speedup_bit']:7.2f} | "
|
| f"{r['time_multi']:18.3f} | {r['speedup_multi']:7.2f} |")
|
|
|
|
|
| print(f"\n{'=' * 100}")
|
| print("Summary Statistics")
|
| print(f"{'=' * 100}\n")
|
|
|
| avg_speedup_bit = sum(r['speedup_bit'] for r in results) / len(results)
|
| avg_speedup_multi = sum(r['speedup_multi'] for r in results) / len(results)
|
|
|
| print(f"Average BitLinear speedup: {avg_speedup_bit:.2f}x")
|
| print(f"Average Multi-Ternary speedup: {avg_speedup_multi:.2f}x")
|
|
|
| if avg_speedup_bit < 1.0:
|
| print(f"\nNote: BitLinear is slower than nn.Linear by {1/avg_speedup_bit:.2f}x on average.")
|
| print("This is expected for the Python implementation. C++/CUDA extensions would be faster.")
|
| else:
|
| print(f"\nNote: BitLinear is faster than nn.Linear by {avg_speedup_bit:.2f}x on average!")
|
|
|
| print(f"\n{'=' * 100}")
|
| print("Benchmark Complete!")
|
| print(f"{'=' * 100}")
|
|
|
|
|
| if __name__ == "__main__":
|
| run_benchmarks()
|
|
|