| #include <torch/all.h> |
| #include <vector> |
|
|
| #define SMALL_M_FAST_PATH 4 |
| std::vector<torch::Tensor> |
| symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight, |
| at::ScalarType quant_type, |
| bool return_unprocessed_quantized_tensor); |
|
|
| torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight, |
| bool is_int4); |
|
|
| torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input, |
| torch::Tensor const &weight, |
| torch::Tensor const &scale); |
|
|
| torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input, |
| torch::Tensor const &weight, |
| torch::Tensor const &scale, |
| torch::Tensor &output, |
| const int64_t m, |
| const int64_t n, |
| const int64_t k); |
|
|