| """ |
| Lloyd-Max optimal scalar quantizer for the Beta distribution |
| arising from random rotation of unit-norm vectors. |
| |
| After rotating a d-dimensional unit vector by a random orthogonal matrix, |
| each coordinate follows: f(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1-x^2)^((d-3)/2) |
| supported on [-1, 1]. |
| |
| For practical dimensions (d >= 64), this is well-approximated by N(0, 1/d). |
| We solve the Lloyd-Max conditions (continuous 1-D k-means) to find optimal centroids. |
| """ |
|
|
| import torch |
| import math |
| from scipy import integrate, special |
|
|
|
|
| def beta_pdf(x: float, d: int) -> float: |
| """PDF of a single coordinate after random rotation of a d-dim unit vector.""" |
| if abs(x) >= 1.0: |
| return 0.0 |
| coeff = math.gamma(d / 2) / (math.sqrt(math.pi) * math.gamma((d - 1) / 2)) |
| return coeff * (1 - x * x) ** ((d - 3) / 2) |
|
|
|
|
| def gaussian_approx_pdf(x: float, d: int) -> float: |
| """Gaussian approximation N(0, 1/d) -- accurate for d >= 64.""" |
| sigma2 = 1.0 / d |
| return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2)) |
|
|
|
|
| def solve_lloyd_max(d: int, bits: int, use_exact: bool = False, max_iter: int = 200, tol: float = 1e-10): |
| """ |
| Solve Lloyd-Max optimal quantizer for the coordinate distribution. |
| |
| Args: |
| d: vector dimension |
| bits: number of quantization bits |
| use_exact: if True, use exact Beta PDF; if False, use Gaussian approx |
| max_iter: maximum Lloyd-Max iterations |
| tol: convergence tolerance |
| |
| Returns: |
| centroids: sorted tensor of 2^bits optimal centroids |
| boundaries: sorted tensor of 2^bits - 1 boundaries between centroids |
| """ |
| n_levels = 2 ** bits |
| pdf = (lambda x: beta_pdf(x, d)) if use_exact else (lambda x: gaussian_approx_pdf(x, d)) |
| sigma = 1.0 / math.sqrt(d) |
|
|
| |
| lo, hi = -3.5 * sigma, 3.5 * sigma |
| centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)] |
|
|
| for iteration in range(max_iter): |
| |
| boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)] |
|
|
| |
| edges = [lo * 3] + boundaries + [hi * 3] |
| new_centroids = [] |
| for i in range(n_levels): |
| a, b = edges[i], edges[i + 1] |
|
|
| numerator, _ = integrate.quad(lambda x: x * pdf(x), a, b) |
| denominator, _ = integrate.quad(pdf, a, b) |
|
|
| if denominator > 1e-15: |
| new_centroids.append(numerator / denominator) |
| else: |
| new_centroids.append(centroids[i]) |
|
|
| |
| max_shift = max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) |
| centroids = new_centroids |
|
|
| if max_shift < tol: |
| break |
|
|
| |
| boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)] |
|
|
| return ( |
| torch.tensor(centroids, dtype=torch.float32), |
| torch.tensor(boundaries, dtype=torch.float32), |
| ) |
|
|
|
|
| def compute_expected_distortion(d: int, bits: int, centroids: torch.Tensor, boundaries: torch.Tensor, use_exact: bool = False) -> float: |
| """Compute the expected MSE distortion per coordinate for the given quantizer.""" |
| pdf = (lambda x: beta_pdf(x, d)) if use_exact else (lambda x: gaussian_approx_pdf(x, d)) |
| sigma = 1.0 / math.sqrt(d) |
| n_levels = len(centroids) |
|
|
| edges = [-3.5 * sigma * 3] + boundaries.tolist() + [3.5 * sigma * 3] |
| total_distortion = 0.0 |
|
|
| for i in range(n_levels): |
| a, b = edges[i], edges[i + 1] |
| c = centroids[i].item() |
| dist, _ = integrate.quad(lambda x: (x - c) ** 2 * pdf(x), a, b) |
| total_distortion += dist |
|
|
| return total_distortion |
|
|
|
|
| class LloydMaxCodebook: |
| """Precomputed Lloyd-Max codebook for a given dimension and bit-width.""" |
|
|
| def __init__(self, d: int, bits: int, use_exact: bool = False): |
| self.d = d |
| self.bits = bits |
| self.n_levels = 2 ** bits |
| self.centroids, self.boundaries = solve_lloyd_max(d, bits, use_exact) |
| self.distortion = compute_expected_distortion(d, bits, self.centroids, self.boundaries, use_exact) |
|
|
| def quantize(self, x: torch.Tensor) -> torch.Tensor: |
| """Quantize values to nearest centroid indices.""" |
| |
| diffs = (x.unsqueeze(-1) - self.centroids.to(x.device)) |
| return diffs.abs().argmin(dim=-1) |
|
|
| def dequantize(self, indices: torch.Tensor) -> torch.Tensor: |
| """Map indices back to centroid values.""" |
| return self.centroids.to(indices.device)[indices] |
|
|
| def __repr__(self): |
| return ( |
| f"LloydMaxCodebook(d={self.d}, bits={self.bits}, " |
| f"levels={self.n_levels}, distortion_per_coord={self.distortion:.6f})" |
| ) |