| from typing import List |
| import torch |
|
|
| from ._ops import ops |
|
|
|
|
| def w8_a16_gemm( |
| input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor |
| ) -> torch.Tensor: |
| return ops.w8_a16_gemm(input, weight, scale) |
|
|
|
|
| def w8_a16_gemm_( |
| input: torch.Tensor, |
| weight: torch.Tensor, |
| scale: torch.Tensor, |
| output: torch.Tensor, |
| m: int, |
| n: int, |
| k: int, |
| ) -> torch.Tensor: |
| return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k) |
|
|
|
|
| def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor: |
| return ops.preprocess_weights(origin_weight, is_int4) |
|
|
|
|
| def quant_weights( |
| origin_weight: torch.Tensor, |
| quant_type: torch.dtype, |
| return_unprocessed_quantized_tensor: bool, |
| ) -> List[torch.Tensor]: |
| return ops.quant_weights( |
| origin_weight, quant_type, return_unprocessed_quantized_tensor |
| ) |
|
|