Update vllm_plugin/quartet2_quant.py
Browse files- vllm_plugin/quartet2_quant.py +28 -32
vllm_plugin/quartet2_quant.py
CHANGED
|
@@ -72,15 +72,28 @@ class QuartetIILinearMethod(LinearMethodBase):
|
|
| 72 |
layer.register_parameter("weight", weight)
|
| 73 |
|
| 74 |
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
| 75 |
-
from
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
def apply(
|
| 86 |
self,
|
|
@@ -91,9 +104,8 @@ class QuartetIILinearMethod(LinearMethodBase):
|
|
| 91 |
from quartet2.quant import quant_fp4, NVFP4QuantMode
|
| 92 |
from quartet2.linear import abs_max, _fp4_mm
|
| 93 |
|
| 94 |
-
weight = layer.weight
|
| 95 |
orig_shape = x.shape
|
| 96 |
-
out_features = weight.shape[0]
|
| 97 |
flat_x = x.reshape(-1, x.shape[-1])
|
| 98 |
|
| 99 |
num_rows = flat_x.shape[0]
|
|
@@ -104,38 +116,22 @@ class QuartetIILinearMethod(LinearMethodBase):
|
|
| 104 |
else:
|
| 105 |
pad_rows = 0
|
| 106 |
|
| 107 |
-
w_remainder = out_features % 128
|
| 108 |
-
if w_remainder != 0:
|
| 109 |
-
w_pad = 128 - w_remainder
|
| 110 |
-
weight = F.pad(weight, (0, 0, 0, w_pad))
|
| 111 |
-
else:
|
| 112 |
-
w_pad = 0
|
| 113 |
-
|
| 114 |
input_amax = abs_max(flat_x)
|
| 115 |
-
weight_amax = abs_max(weight)
|
| 116 |
-
|
| 117 |
-
mode = NVFP4QuantMode.FOUR_SIX
|
| 118 |
-
scale_override = 1.0
|
| 119 |
-
|
| 120 |
input_fp4 = quant_fp4(
|
| 121 |
flat_x, amax=input_amax,
|
| 122 |
-
scale_override=
|
| 123 |
-
)
|
| 124 |
-
weight_fp4 = quant_fp4(
|
| 125 |
-
weight, amax=weight_amax,
|
| 126 |
-
scale_override=scale_override, mode=mode,
|
| 127 |
)
|
| 128 |
|
| 129 |
-
alpha = input_fp4.tensor_scale *
|
| 130 |
output = _fp4_mm(
|
| 131 |
-
input_fp4.fp4,
|
| 132 |
-
input_fp4.micro_scales,
|
| 133 |
alpha,
|
| 134 |
)
|
| 135 |
|
| 136 |
if pad_rows > 0:
|
| 137 |
output = output[:num_rows]
|
| 138 |
-
if w_pad > 0:
|
| 139 |
output = output[:, :out_features]
|
| 140 |
|
| 141 |
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
|
|
|
| 72 |
layer.register_parameter("weight", weight)
|
| 73 |
|
| 74 |
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
| 75 |
+
from quartet2.quant import quant_fp4, NVFP4QuantMode
|
| 76 |
+
from quartet2.linear import abs_max
|
| 77 |
+
|
| 78 |
+
weight = layer.weight.data
|
| 79 |
+
device = weight.device
|
| 80 |
+
out_features = weight.shape[0]
|
| 81 |
+
|
| 82 |
+
w_remainder = out_features % 128
|
| 83 |
+
if w_remainder != 0:
|
| 84 |
+
w_pad = 128 - w_remainder
|
| 85 |
+
weight = F.pad(weight, (0, 0, 0, w_pad))
|
| 86 |
+
else:
|
| 87 |
+
w_pad = 0
|
| 88 |
+
|
| 89 |
+
mode = NVFP4QuantMode.FOUR_SIX
|
| 90 |
+
weight_amax = abs_max(weight)
|
| 91 |
+
wq = quant_fp4(weight, amax=weight_amax, scale_override=1.0, mode=mode)
|
| 92 |
+
|
| 93 |
+
layer.weight_fp4 = wq.fp4
|
| 94 |
+
layer.weight_micro_scales = wq.micro_scales
|
| 95 |
+
layer.weight_tensor_scale = wq.tensor_scale
|
| 96 |
+
layer.w_pad = w_pad
|
| 97 |
|
| 98 |
def apply(
|
| 99 |
self,
|
|
|
|
| 104 |
from quartet2.quant import quant_fp4, NVFP4QuantMode
|
| 105 |
from quartet2.linear import abs_max, _fp4_mm
|
| 106 |
|
|
|
|
| 107 |
orig_shape = x.shape
|
| 108 |
+
out_features = layer.weight.shape[0]
|
| 109 |
flat_x = x.reshape(-1, x.shape[-1])
|
| 110 |
|
| 111 |
num_rows = flat_x.shape[0]
|
|
|
|
| 116 |
else:
|
| 117 |
pad_rows = 0
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
input_amax = abs_max(flat_x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
input_fp4 = quant_fp4(
|
| 121 |
flat_x, amax=input_amax,
|
| 122 |
+
scale_override=1.0, mode=NVFP4QuantMode.FOUR_SIX,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
+
alpha = input_fp4.tensor_scale * layer.weight_tensor_scale
|
| 126 |
output = _fp4_mm(
|
| 127 |
+
input_fp4.fp4, layer.weight_fp4,
|
| 128 |
+
input_fp4.micro_scales, layer.weight_micro_scales,
|
| 129 |
alpha,
|
| 130 |
)
|
| 131 |
|
| 132 |
if pad_rows > 0:
|
| 133 |
output = output[:num_rows]
|
| 134 |
+
if layer.w_pad > 0:
|
| 135 |
output = output[:, :out_features]
|
| 136 |
|
| 137 |
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|