mansaripo commited on
Commit
1a13422
·
verified ·
1 Parent(s): d0aacb0

Update vllm_plugin/quartet2_quant.py

Browse files
Files changed (1) hide show
  1. vllm_plugin/quartet2_quant.py +28 -32
vllm_plugin/quartet2_quant.py CHANGED
@@ -72,15 +72,28 @@ class QuartetIILinearMethod(LinearMethodBase):
72
  layer.register_parameter("weight", weight)
73
 
74
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
75
- from scipy.linalg import hadamard as scipy_hadamard
76
- device = layer.weight.device
77
- had_np = scipy_hadamard(128) * 128 ** -0.5
78
- layer.had = torch.tensor(
79
- had_np, dtype=torch.bfloat16, device=device, requires_grad=False,
80
- )
81
- layer.scratch_amax = torch.empty(
82
- (), dtype=torch.uint32, device=device,
83
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def apply(
86
  self,
@@ -91,9 +104,8 @@ class QuartetIILinearMethod(LinearMethodBase):
91
  from quartet2.quant import quant_fp4, NVFP4QuantMode
92
  from quartet2.linear import abs_max, _fp4_mm
93
 
94
- weight = layer.weight
95
  orig_shape = x.shape
96
- out_features = weight.shape[0]
97
  flat_x = x.reshape(-1, x.shape[-1])
98
 
99
  num_rows = flat_x.shape[0]
@@ -104,38 +116,22 @@ class QuartetIILinearMethod(LinearMethodBase):
104
  else:
105
  pad_rows = 0
106
 
107
- w_remainder = out_features % 128
108
- if w_remainder != 0:
109
- w_pad = 128 - w_remainder
110
- weight = F.pad(weight, (0, 0, 0, w_pad))
111
- else:
112
- w_pad = 0
113
-
114
  input_amax = abs_max(flat_x)
115
- weight_amax = abs_max(weight)
116
-
117
- mode = NVFP4QuantMode.FOUR_SIX
118
- scale_override = 1.0
119
-
120
  input_fp4 = quant_fp4(
121
  flat_x, amax=input_amax,
122
- scale_override=scale_override, mode=mode,
123
- )
124
- weight_fp4 = quant_fp4(
125
- weight, amax=weight_amax,
126
- scale_override=scale_override, mode=mode,
127
  )
128
 
129
- alpha = input_fp4.tensor_scale * weight_fp4.tensor_scale
130
  output = _fp4_mm(
131
- input_fp4.fp4, weight_fp4.fp4,
132
- input_fp4.micro_scales, weight_fp4.micro_scales,
133
  alpha,
134
  )
135
 
136
  if pad_rows > 0:
137
  output = output[:num_rows]
138
- if w_pad > 0:
139
  output = output[:, :out_features]
140
 
141
  output = output.reshape(*orig_shape[:-1], output.shape[-1])
 
72
  layer.register_parameter("weight", weight)
73
 
74
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
75
+ from quartet2.quant import quant_fp4, NVFP4QuantMode
76
+ from quartet2.linear import abs_max
77
+
78
+ weight = layer.weight.data
79
+ device = weight.device
80
+ out_features = weight.shape[0]
81
+
82
+ w_remainder = out_features % 128
83
+ if w_remainder != 0:
84
+ w_pad = 128 - w_remainder
85
+ weight = F.pad(weight, (0, 0, 0, w_pad))
86
+ else:
87
+ w_pad = 0
88
+
89
+ mode = NVFP4QuantMode.FOUR_SIX
90
+ weight_amax = abs_max(weight)
91
+ wq = quant_fp4(weight, amax=weight_amax, scale_override=1.0, mode=mode)
92
+
93
+ layer.weight_fp4 = wq.fp4
94
+ layer.weight_micro_scales = wq.micro_scales
95
+ layer.weight_tensor_scale = wq.tensor_scale
96
+ layer.w_pad = w_pad
97
 
98
  def apply(
99
  self,
 
104
  from quartet2.quant import quant_fp4, NVFP4QuantMode
105
  from quartet2.linear import abs_max, _fp4_mm
106
 
 
107
  orig_shape = x.shape
108
+ out_features = layer.weight.shape[0]
109
  flat_x = x.reshape(-1, x.shape[-1])
110
 
111
  num_rows = flat_x.shape[0]
 
116
  else:
117
  pad_rows = 0
118
 
 
 
 
 
 
 
 
119
  input_amax = abs_max(flat_x)
 
 
 
 
 
120
  input_fp4 = quant_fp4(
121
  flat_x, amax=input_amax,
122
+ scale_override=1.0, mode=NVFP4QuantMode.FOUR_SIX,
 
 
 
 
123
  )
124
 
125
+ alpha = input_fp4.tensor_scale * layer.weight_tensor_scale
126
  output = _fp4_mm(
127
+ input_fp4.fp4, layer.weight_fp4,
128
+ input_fp4.micro_scales, layer.weight_micro_scales,
129
  alpha,
130
  )
131
 
132
  if pad_rows > 0:
133
  output = output[:num_rows]
134
+ if layer.w_pad > 0:
135
  output = output[:, :out_features]
136
 
137
  output = output.reshape(*orig_shape[:-1], output.shape[-1])