| |
|
| | """
|
| | Verification script to demonstrate all implemented functionality.
|
| | Run this to see layers.py and packing.py in action!
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
| | from bitlinear.packing import (
|
| | pack_ternary_base3,
|
| | unpack_ternary_base3,
|
| | estimate_memory_savings,
|
| | )
|
| |
|
| |
|
| | def demo_bitlinear():
|
| | """Demonstrate BitLinear layer."""
|
| | print("=" * 70)
|
| | print("1. BitLinear Layer Demo")
|
| | print("=" * 70)
|
| |
|
| |
|
| | layer = BitLinear(512, 256, bias=True)
|
| | print(f"β Created BitLinear(512 β 256)")
|
| | print(f" - W_ternary shape: {layer.W_ternary.shape}")
|
| | print(f" - Gamma shape: {layer.gamma.shape}")
|
| | print(f" - Unique weight values: {sorted(layer.W_ternary.unique().tolist())}")
|
| |
|
| |
|
| | x = torch.randn(16, 512)
|
| | y = layer(x)
|
| | print(f"\nβ Forward pass: {x.shape} β {y.shape}")
|
| |
|
| |
|
| | linear = nn.Linear(512, 256)
|
| | bitlinear = BitLinear.from_linear(linear)
|
| | print(f"β Converted nn.Linear to BitLinear")
|
| | print()
|
| |
|
| |
|
| | def demo_multi_ternary():
|
| | """Demonstrate MultiTernaryLinear layer."""
|
| | print("=" * 70)
|
| | print("2. MultiTernaryLinear Layer Demo")
|
| | print("=" * 70)
|
| |
|
| |
|
| | for k in [1, 2, 4]:
|
| | layer = MultiTernaryLinear(256, 128, k=k, bias=True)
|
| | print(f"β MultiTernaryLinear(256 β 128, k={k})")
|
| | print(f" - W_ternary shape: {layer.W_ternary.shape}")
|
| | print(f" - Gammas shape: {layer.gammas.shape}")
|
| |
|
| |
|
| | print("\nβ Approximation quality test:")
|
| | linear = nn.Linear(128, 128)
|
| | x = torch.randn(8, 128)
|
| | dense_out = linear(x)
|
| |
|
| | errors = []
|
| | for k in [1, 2, 4]:
|
| | multi = MultiTernaryLinear.from_linear(linear, k=k)
|
| | ternary_out = multi(x)
|
| | error = torch.norm(dense_out - ternary_out).item()
|
| | errors.append(error)
|
| | print(f" - k={k}: reconstruction error = {error:.4f}")
|
| |
|
| | print(f" - Error decreases with k: {errors[0] > errors[1] > errors[2]}")
|
| | print()
|
| |
|
| |
|
| | def demo_model_conversion():
|
| | """Demonstrate model conversion utility."""
|
| | print("=" * 70)
|
| | print("3. Model Conversion Utility Demo")
|
| | print("=" * 70)
|
| |
|
| |
|
| | class SimpleModel(nn.Module):
|
| | def __init__(self):
|
| | super().__init__()
|
| | self.fc1 = nn.Linear(128, 256)
|
| | self.relu = nn.ReLU()
|
| | self.fc2 = nn.Linear(256, 128)
|
| | self.fc3 = nn.Linear(128, 10)
|
| |
|
| | def forward(self, x):
|
| | x = self.relu(self.fc1(x))
|
| | x = self.relu(self.fc2(x))
|
| | return self.fc3(x)
|
| |
|
| | model = SimpleModel()
|
| |
|
| |
|
| | linear_count = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
| | print(f"β Original model: {linear_count} Linear layers")
|
| |
|
| |
|
| | model_converted = convert_linear_to_bitlinear(model, inplace=False)
|
| | bitlinear_count = sum(1 for m in model_converted.modules() if isinstance(m, BitLinear))
|
| | print(f"β Converted model: {bitlinear_count} BitLinear layers")
|
| |
|
| |
|
| | x = torch.randn(4, 128)
|
| | y = model_converted(x)
|
| | print(f"β Forward pass works: {x.shape} β {y.shape}")
|
| | print()
|
| |
|
| |
|
| | def demo_packing():
|
| | """Demonstrate base-3 packing."""
|
| | print("=" * 70)
|
| | print("4. Base-3 Packing Demo")
|
| | print("=" * 70)
|
| |
|
| |
|
| | W = torch.tensor([
|
| | [-1, 0, 1, -1, 0],
|
| | [1, 1, -1, 0, 1],
|
| | [0, -1, 1, 1, -1],
|
| | ], dtype=torch.float32)
|
| |
|
| | print(f"β Original ternary weights shape: {W.shape}")
|
| | print(f" - Float32 memory: {W.numel() * 4} bytes")
|
| |
|
| |
|
| | packed, original_shape = pack_ternary_base3(W)
|
| | print(f"\nβ Packed into uint8 tensor")
|
| | print(f" - Packed shape: {packed.shape}")
|
| | print(f" - Packed memory: {packed.numel()} bytes")
|
| | print(f" - Compression: {W.numel() * 4 / packed.numel():.2f}x")
|
| |
|
| |
|
| | W_unpacked = unpack_ternary_base3(packed, original_shape)
|
| | print(f"\nβ Unpacked back to ternary")
|
| | print(f" - Unpacked shape: {W_unpacked.shape}")
|
| | print(f" - Perfect round-trip: {torch.allclose(W, W_unpacked)}")
|
| | print()
|
| |
|
| |
|
| | def demo_memory_estimation():
|
| | """Demonstrate memory savings estimation."""
|
| | print("=" * 70)
|
| | print("5. Memory Savings Estimation")
|
| | print("=" * 70)
|
| |
|
| | configs = [
|
| | (768, 3072, 1, "Single Transformer FFN layer"),
|
| | (768, 3072, 12, "BERT-base (12 layers)"),
|
| | (1024, 4096, 24, "BERT-large (24 layers)"),
|
| | ]
|
| |
|
| | for in_dim, out_dim, num_layers, description in configs:
|
| | stats = estimate_memory_savings(in_dim, out_dim, num_layers)
|
| | print(f"\nβ {description}")
|
| | print(f" Configuration: {in_dim} β {out_dim} Γ {num_layers} layers")
|
| | print(f" Float32 memory: {stats['float32_bytes'] / 1e6:.2f} MB")
|
| | print(f" Packed memory: {stats['packed_bytes'] / 1e6:.2f} MB")
|
| | print(f" Savings: {stats['savings_bytes'] / 1e6:.2f} MB")
|
| | print(f" Compression: {stats['compression_ratio']:.2f}x")
|
| | print()
|
| |
|
| |
|
| | def main():
|
| | """Run all demos."""
|
| | print("\n" + "=" * 70)
|
| | print(" BitLinear Implementation Verification")
|
| | print(" All functionality implemented and working!")
|
| | print("=" * 70)
|
| | print()
|
| |
|
| | demo_bitlinear()
|
| | demo_multi_ternary()
|
| | demo_model_conversion()
|
| | demo_packing()
|
| | demo_memory_estimation()
|
| |
|
| | print("=" * 70)
|
| | print(" β All implementations verified!")
|
| | print(" β Ready for C++/CUDA optimization")
|
| | print("=" * 70)
|
| | print()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|