| | """
|
| | Example: Using BitLinear as a drop-in replacement for nn.Linear in a Transformer.
|
| |
|
| | This example demonstrates:
|
| | 1. Creating a simple Transformer block with standard nn.Linear
|
| | 2. Converting it to use BitLinear layers
|
| | 3. Running forward passes to verify compatibility
|
| | 4. Comparing memory usage and output similarity
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional
|
| |
|
| | from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
| |
|
| |
|
| | class TransformerBlock(nn.Module):
|
| | """
|
| | Simplified Transformer block for demonstration.
|
| |
|
| | Contains:
|
| | - Multi-head self-attention with linear projections
|
| | - Feed-forward network with two linear layers
|
| | - Layer normalization and residual connections
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | d_model: int = 512,
|
| | nhead: int = 8,
|
| | dim_feedforward: int = 2048,
|
| | dropout: float = 0.1,
|
| | ):
|
| | super().__init__()
|
| |
|
| |
|
| | self.d_model = d_model
|
| | self.nhead = nhead
|
| | self.d_k = d_model // nhead
|
| |
|
| |
|
| | self.q_proj = nn.Linear(d_model, d_model)
|
| | self.k_proj = nn.Linear(d_model, d_model)
|
| | self.v_proj = nn.Linear(d_model, d_model)
|
| | self.out_proj = nn.Linear(d_model, d_model)
|
| |
|
| |
|
| | self.ffn = nn.Sequential(
|
| | nn.Linear(d_model, dim_feedforward),
|
| | nn.ReLU(),
|
| | nn.Dropout(dropout),
|
| | nn.Linear(dim_feedforward, d_model),
|
| | )
|
| |
|
| |
|
| | self.norm1 = nn.LayerNorm(d_model)
|
| | self.norm2 = nn.LayerNorm(d_model)
|
| |
|
| |
|
| | self.dropout1 = nn.Dropout(dropout)
|
| | self.dropout2 = nn.Dropout(dropout)
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through Transformer block.
|
| |
|
| | Args:
|
| | x: Input tensor [batch_size, seq_len, d_model]
|
| | mask: Optional attention mask
|
| |
|
| | Returns:
|
| | Output tensor [batch_size, seq_len, d_model]
|
| | """
|
| |
|
| | residual = x
|
| | x = self.norm1(x)
|
| |
|
| |
|
| | q = self.q_proj(x)
|
| | k = self.k_proj(x)
|
| | v = self.v_proj(x)
|
| |
|
| |
|
| | batch_size, seq_len, _ = x.shape
|
| | q = q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
| | k = k.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
| | v = v.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)
|
| |
|
| |
|
| | scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
|
| | if mask is not None:
|
| | scores = scores.masked_fill(mask == 0, -1e9)
|
| | attn_weights = F.softmax(scores, dim=-1)
|
| | attn_output = torch.matmul(attn_weights, v)
|
| |
|
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous().view(
|
| | batch_size, seq_len, self.d_model
|
| | )
|
| | attn_output = self.out_proj(attn_output)
|
| | attn_output = self.dropout1(attn_output)
|
| |
|
| |
|
| | x = residual + attn_output
|
| |
|
| |
|
| | residual = x
|
| | x = self.norm2(x)
|
| | x = self.ffn(x)
|
| | x = self.dropout2(x)
|
| |
|
| |
|
| | x = residual + x
|
| |
|
| | return x
|
| |
|
| |
|
| | def count_parameters(model: nn.Module) -> int:
|
| | """Count total trainable parameters in a model."""
|
| | return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| |
|
| |
|
| | def estimate_memory_mb(model: nn.Module) -> float:
|
| | """Estimate memory usage of model parameters in MB."""
|
| | total_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
|
| | return total_bytes / (1024 ** 2)
|
| |
|
| |
|
| | def compare_outputs(
|
| | output1: torch.Tensor,
|
| | output2: torch.Tensor,
|
| | ) -> dict:
|
| | """
|
| | Compare two output tensors and compute similarity metrics.
|
| |
|
| | Returns:
|
| | Dictionary with comparison metrics
|
| | """
|
| | mse = F.mse_loss(output1, output2).item()
|
| | cosine_sim = F.cosine_similarity(
|
| | output1.flatten(), output2.flatten(), dim=0
|
| | ).item()
|
| | relative_error = (
|
| | torch.norm(output1 - output2) / torch.norm(output1)
|
| | ).item()
|
| |
|
| | return {
|
| | "mse": mse,
|
| | "cosine_similarity": cosine_sim,
|
| | "relative_error": relative_error,
|
| | }
|
| |
|
| |
|
| | def main():
|
| | """Main example demonstrating BitLinear usage in Transformer."""
|
| |
|
| | print("=" * 80)
|
| | print("BitLinear Transformer Example")
|
| | print("=" * 80)
|
| |
|
| |
|
| | batch_size = 32
|
| | seq_len = 128
|
| | d_model = 512
|
| | nhead = 8
|
| | dim_feedforward = 2048
|
| |
|
| |
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | print(f"\nInput shape: {x.shape}")
|
| |
|
| |
|
| | print("\n" + "-" * 80)
|
| | print("1. Standard Transformer with nn.Linear")
|
| | print("-" * 80)
|
| |
|
| | model_standard = TransformerBlock(
|
| | d_model=d_model,
|
| | nhead=nhead,
|
| | dim_feedforward=dim_feedforward,
|
| | )
|
| |
|
| | print(f"Parameters: {count_parameters(model_standard):,}")
|
| | print(f"Memory: {estimate_memory_mb(model_standard):.2f} MB")
|
| |
|
| |
|
| | with torch.no_grad():
|
| | output_standard = model_standard(x)
|
| | print(f"Output shape: {output_standard.shape}")
|
| |
|
| |
|
| | print("\n" + "-" * 80)
|
| | print("2. Transformer with BitLinear")
|
| | print("-" * 80)
|
| |
|
| | model_bitlinear = convert_linear_to_bitlinear(model_standard, inplace=False)
|
| |
|
| | print(f"Parameters: {count_parameters(model_bitlinear):,}")
|
| | print(f"Memory: {estimate_memory_mb(model_bitlinear):.2f} MB")
|
| |
|
| |
|
| | with torch.no_grad():
|
| | output_bitlinear = model_bitlinear(x)
|
| | print(f"Output shape: {output_bitlinear.shape}")
|
| |
|
| |
|
| | print("\n" + "-" * 80)
|
| | print("3. Output Comparison")
|
| | print("-" * 80)
|
| |
|
| | metrics = compare_outputs(output_standard, output_bitlinear)
|
| | print(f"MSE: {metrics['mse']:.6f}")
|
| | print(f"Cosine similarity: {metrics['cosine_similarity']:.6f}")
|
| | print(f"Relative error: {metrics['relative_error']:.6f}")
|
| |
|
| |
|
| | print("\n" + "-" * 80)
|
| | print("4. Memory Savings")
|
| | print("-" * 80)
|
| |
|
| | mem_standard = estimate_memory_mb(model_standard)
|
| | mem_bitlinear = estimate_memory_mb(model_bitlinear)
|
| | savings = (mem_standard - mem_bitlinear) / mem_standard * 100
|
| |
|
| | print(f"Standard model: {mem_standard:.2f} MB")
|
| | print(f"BitLinear model: {mem_bitlinear:.2f} MB")
|
| | print(f"Memory savings: {savings:.1f}%")
|
| | print(f"Compression ratio: {mem_standard / mem_bitlinear:.1f}x")
|
| |
|
| |
|
| | print("\n" + "-" * 80)
|
| | print("5. Conversion Details")
|
| | print("-" * 80)
|
| |
|
| | def count_linear_layers(model):
|
| | count = 0
|
| | for module in model.modules():
|
| | if isinstance(module, nn.Linear):
|
| | count += 1
|
| | return count
|
| |
|
| | def count_bitlinear_layers(model):
|
| | count = 0
|
| | for module in model.modules():
|
| | if isinstance(module, BitLinear):
|
| | count += 1
|
| | return count
|
| |
|
| | print(f"Original Linear layers: {count_linear_layers(model_standard)}")
|
| | print(f"Converted BitLinear layers: {count_bitlinear_layers(model_bitlinear)}")
|
| |
|
| | print("\n" + "=" * 80)
|
| | print("Example complete!")
|
| | print("=" * 80)
|
| | print("\nKey Takeaways:")
|
| | print("- BitLinear is a drop-in replacement for nn.Linear")
|
| | print("- Significant memory savings (~20x for weights)")
|
| | print("- Output similarity is high (cosine sim > 0.99 typically)")
|
| | print("- Slight accuracy trade-off due to ternary quantization")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|