| """
|
| Base-3 packing utilities for memory-efficient ternary weight storage.
|
|
|
| Ternary weights ({-1, 0, +1}) can be represented in base-3, allowing
|
| multiple ternary values to be packed into a single byte or integer.
|
| This provides significant memory savings over storing each value as a float32.
|
|
|
| Theoretical packing:
|
| - 1 ternary value requires log2(3) ≈ 1.58 bits
|
| - 5 ternary values fit in 1 byte (3^5 = 243 < 256)
|
| - Compression ratio: 32 bits (float) → ~1.6 bits (packed) = 20x compression
|
| """
|
|
|
| import torch
|
| from typing import Tuple
|
|
|
|
|
| def pack_ternary_base3(W_ternary: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]:
|
| """
|
| Pack ternary weights into base-3 representation for memory efficiency.
|
|
|
| Packs multiple ternary values ({-1, 0, +1}) into uint8 storage using base-3
|
| encoding. This achieves near-optimal compression for ternary data.
|
|
|
| Encoding scheme:
|
| -1 → 0 (base 3)
|
| 0 → 1 (base 3)
|
| +1 → 2 (base 3)
|
|
|
| Then pack 5 base-3 digits into one byte:
|
| packed_byte = d0 + d1*3 + d2*9 + d3*27 + d4*81
|
|
|
| Args:
|
| W_ternary: Ternary weight tensor with values in {-1, 0, +1}
|
| Shape: [out_features, in_features] or [k, out_features, in_features]
|
|
|
| Returns:
|
| packed: Packed weights as uint8 tensor (5 values per byte)
|
| original_shape: Shape of original tensor for unpacking
|
|
|
| Notes:
|
| - 5 ternary values per byte (3^5 = 243 < 256)
|
| - Pad with zeros if dimensions not divisible by 5
|
| - This is the primary memory optimization for ternary weights
|
| """
|
| original_shape = tuple(W_ternary.shape)
|
|
|
|
|
| base3 = (W_ternary + 1).flatten().to(torch.uint8)
|
|
|
|
|
| numel = base3.numel()
|
| pad_size = (5 - numel % 5) % 5
|
| if pad_size > 0:
|
| base3 = torch.cat([base3, torch.zeros(pad_size, dtype=torch.uint8, device=base3.device)])
|
|
|
|
|
| base3 = base3.view(-1, 5)
|
|
|
|
|
| powers_of_3 = torch.tensor([1, 3, 9, 27, 81], dtype=torch.uint8, device=base3.device)
|
| packed = (base3 * powers_of_3).sum(dim=1)
|
|
|
| return packed, original_shape
|
|
|
|
|
| def unpack_ternary_base3(
|
| packed: torch.Tensor,
|
| original_shape: Tuple[int, ...],
|
| ) -> torch.Tensor:
|
| """
|
| Unpack base-3 encoded ternary weights back to full representation.
|
|
|
| Reverses the packing operation to recover ternary weights.
|
|
|
| Args:
|
| packed: Packed uint8 tensor (5 values per byte)
|
| original_shape: Original shape of the ternary tensor
|
|
|
| Returns:
|
| W_ternary: Ternary weight tensor with values in {-1, 0, +1}
|
| """
|
|
|
| d0 = packed % 3
|
| d1 = (packed // 3) % 3
|
| d2 = (packed // 9) % 3
|
| d3 = (packed // 27) % 3
|
| d4 = (packed // 81) % 3
|
|
|
|
|
| base3 = torch.stack([d0, d1, d2, d3, d4], dim=1).flatten()
|
|
|
|
|
| numel = 1
|
| for dim in original_shape:
|
| numel *= dim
|
|
|
|
|
| base3 = base3[:numel]
|
|
|
|
|
| W_ternary = base3.to(torch.float32) - 1
|
|
|
|
|
| W_ternary = W_ternary.view(original_shape)
|
|
|
| return W_ternary
|
|
|
|
|
| def compute_compression_ratio(
|
| original_size: int,
|
| packed_size: int,
|
| ) -> float:
|
| """
|
| Compute compression ratio for packed ternary weights.
|
|
|
| Args:
|
| original_size: Size in bytes of original float32 weights
|
| packed_size: Size in bytes of packed ternary weights
|
|
|
| Returns:
|
| Compression ratio (e.g., 20.0 means 20x compression)
|
|
|
| Examples:
|
| >>> # 512 x 512 float32 weights = 512*512*4 bytes = 1,048,576 bytes
|
| >>> # Packed: 512*512 ternary values / 5 per byte ≈ 52,429 bytes
|
| >>> ratio = compute_compression_ratio(1048576, 52429)
|
| >>> print(f"Compression: {ratio:.1f}x")
|
| Compression: 20.0x
|
| """
|
| return original_size / packed_size if packed_size > 0 else 0.0
|
|
|
|
|
| def estimate_memory_savings(
|
| in_features: int,
|
| out_features: int,
|
| num_layers: int = 1,
|
| ) -> dict:
|
| """
|
| Estimate memory savings from ternary packing for a given layer configuration.
|
|
|
| Args:
|
| in_features: Input dimension
|
| out_features: Output dimension
|
| num_layers: Number of layers (for cumulative savings)
|
|
|
| Returns:
|
| Dictionary with memory statistics:
|
| - float32_bytes: Memory for float32 weights
|
| - packed_bytes: Memory for packed ternary weights
|
| - savings_bytes: Absolute memory saved
|
| - compression_ratio: Ratio of compression
|
|
|
| Examples:
|
| >>> stats = estimate_memory_savings(768, 3072, num_layers=12)
|
| >>> print(f"Total savings: {stats['savings_bytes'] / 1e6:.1f} MB")
|
| """
|
|
|
| weights_per_layer = in_features * out_features
|
| float32_bytes_per_layer = weights_per_layer * 4
|
|
|
|
|
| packed_bytes_per_layer = (weights_per_layer + 4) // 5
|
|
|
|
|
| float32_bytes = float32_bytes_per_layer * num_layers
|
| packed_bytes = packed_bytes_per_layer * num_layers
|
|
|
|
|
| savings_bytes = float32_bytes - packed_bytes
|
| compression_ratio = compute_compression_ratio(float32_bytes, packed_bytes)
|
|
|
| return {
|
| 'float32_bytes': float32_bytes,
|
| 'packed_bytes': packed_bytes,
|
| 'savings_bytes': savings_bytes,
|
| 'compression_ratio': compression_ratio,
|
| }
|
|
|
|
|
|
|
|
|
| def pack_ternary_bitwise(W_ternary: torch.Tensor) -> torch.Tensor:
|
| """
|
| Alternative packing using 2 bits per ternary value.
|
|
|
| Simpler but less efficient than base-3 packing:
|
| -1 → 00
|
| 0 → 01
|
| +1 → 10
|
|
|
| This uses 2 bits per value (4 values per byte) instead of optimal 1.58 bits.
|
| Easier to implement but 20% less efficient than base-3 packing.
|
|
|
| TODO:
|
| - Implement 2-bit packing scheme
|
| - Compare with base-3 for speed vs. compression trade-off
|
| """
|
|
|
| raise NotImplementedError("pack_ternary_bitwise not yet implemented")
|
|
|
|
|
| def unpack_ternary_bitwise(packed: torch.Tensor, original_shape: Tuple[int, ...]) -> torch.Tensor:
|
| """
|
| Unpack 2-bit encoded ternary weights.
|
|
|
| TODO:
|
| - Implement bitwise unpacking
|
| """
|
|
|
| raise NotImplementedError("unpack_ternary_bitwise not yet implemented")
|
|
|