| |
| |
|
|
| |
| #include "utils.h" |
| #include "gemm/gemm.h" |
| #include "quantized_utils.h" |
| #include "bnb_quantized.h" |
|
|
| |
| |
| |
|
|
| #define instantiate_bnb_kernel(name, type, blocksize, quant_type) \ |
| template [[host_name( \ |
| #name "_" #type "_bs_" #blocksize "_qt_" #quant_type \ |
| )]] [[kernel]] decltype(name<type, blocksize, quant_type>) \ |
| name<type, blocksize, quant_type>; |
|
|
| |
|
|
| #define instantiate_bnb_all_kernels(type, blocksize, quant_type) \ |
| instantiate_bnb_kernel(bnb_quantize_blockwise, type, blocksize, quant_type) \ |
| instantiate_bnb_kernel(bnb_dequantize_blockwise, type, blocksize, quant_type) \ |
| instantiate_bnb_kernel(bnb_qmv, type, blocksize, quant_type) \ |
| instantiate_bnb_kernel(bnb_qmm_t, type, blocksize, quant_type) |
|
|
| |
|
|
| #define instantiate_bnb_quant_types(type, blocksize) \ |
| instantiate_bnb_all_kernels(type, blocksize, 1) \ |
| instantiate_bnb_all_kernels(type, blocksize, 2) |
|
|
| |
|
|
| #define instantiate_bnb_blocksizes(type) \ |
| instantiate_bnb_quant_types(type, 64) \ |
| instantiate_bnb_quant_types(type, 128) \ |
| instantiate_bnb_quant_types(type, 256) \ |
| instantiate_bnb_quant_types(type, 512) |
|
|
| |
|
|
| instantiate_bnb_blocksizes(half) |
| instantiate_bnb_blocksizes(bfloat16_t) |
| instantiate_bnb_blocksizes(float) |
|
|
| |
|
|