| import torch |
| import triton |
| import triton.testing |
| import sys |
| from torch.utils.benchmark import Timer |
|
|
| sys.path.append("/models/blitz/crates/blitz-kernels/src/cuda") |
| from ghost_quant import ghost_quant_fp8_kernel |
|
|
| def run_rigorous_quant(): |
| N = 1024 * 1024 * 16 |
| X = torch.randn(N, device="cuda", dtype=torch.float32) |
| Y_blitz = torch.empty(N, device="cuda", dtype=torch.int8) |
| seed = 42 |
|
|
| |
| def ref_fn(x): |
| return x.to(torch.float8_e4m3fn).view(torch.int8) |
|
|
| |
| ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024) |
| y_ref = ref_fn(X) |
| |
| |
| diff = (Y_blitz.float() - y_ref.float()).abs().mean() |
| print(f"Correctness (Mean Diff): {diff:.6f}") |
|
|
| |
| |
| ms_blitz = triton.testing.do_bench(lambda: ghost_quant_fp8_kernel[(triton.cdiv(N, 1024),)](X, Y_blitz, seed, N, BLOCK_SIZE=1024)) |
| |
| |
| compiled_ref = torch.compile(ref_fn, mode="max-autotune") |
| compiled_ref(X) |
| ms_inductor = triton.testing.do_bench(lambda: compiled_ref(X)) |
|
|
| print(f"--- RIGOROUS RECEIPT: GHOST QUANT (16M Tokens) ---") |
| print(f"H200 Inductor Latency: {ms_inductor:.4f} ms") |
| print(f"Blitz Artisan Latency: {ms_blitz:.4f} ms") |
| print(f"REAL SPEEDUP: {ms_inductor/ms_blitz:.2f}x") |
|
|
| if __name__ == "__main__": |
| run_rigorous_quant() |
|
|