import torch import torch.nn.functional as F from torch.nn import Parameter from vllm.model_executor.layers.quantization import ( register_quantization_config, ) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.parameter import ModelWeightParameter @register_quantization_config("quartet2") class QuartetIIConfig(QuantizationConfig): def get_name(self) -> str: return "quartet2" def get_supported_act_dtypes(self) -> list: return [torch.bfloat16] @classmethod def get_min_capability(cls) -> int: return 100 # Blackwell (SM 10.0) @staticmethod def get_config_filenames() -> list[str]: return [] @classmethod def from_config(cls, config: dict) -> "QuartetIIConfig": return cls() def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> QuantizeMethodBase | None: if isinstance(layer, LinearBase): return QuartetIILinearMethod(self) return None class QuartetIILinearMethod(LinearMethodBase): def __init__(self, config: QuartetIIConfig): self.config = config def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=params_dtype, ), input_dim=1, output_dim=0, weight_loader=extra_weight_attrs.get("weight_loader"), ) layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: from quartet2.quant import quant_fp4, NVFP4QuantMode from quartet2.linear import abs_max weight = layer.weight.data device = weight.device out_features = weight.shape[0] w_remainder = out_features % 128 if w_remainder != 0: w_pad = 128 - w_remainder weight = F.pad(weight, (0, 0, 0, w_pad)) else: w_pad = 0 mode = NVFP4QuantMode.FOUR_SIX weight_amax = abs_max(weight) wq = quant_fp4(weight, amax=weight_amax, scale_override=1.0, mode=mode) layer.weight_fp4 = wq.fp4 layer.weight_micro_scales = wq.micro_scales layer.weight_tensor_scale = wq.tensor_scale layer.w_pad = w_pad def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: from quartet2.quant import quant_fp4, NVFP4QuantMode from quartet2.linear import abs_max, _fp4_mm orig_shape = x.shape out_features = layer.weight.shape[0] flat_x = x.reshape(-1, x.shape[-1]) num_rows = flat_x.shape[0] remainder = num_rows % 128 if remainder != 0: pad_rows = 128 - remainder flat_x = F.pad(flat_x, (0, 0, 0, pad_rows)) else: pad_rows = 0 input_amax = abs_max(flat_x) input_fp4 = quant_fp4( flat_x, amax=input_amax, scale_override=1.0, mode=NVFP4QuantMode.FOUR_SIX, ) alpha = input_fp4.tensor_scale * layer.weight_tensor_scale output = _fp4_mm( input_fp4.fp4, layer.weight_fp4, input_fp4.micro_scales, layer.weight_micro_scales, alpha, ) if pad_rows > 0: output = output[:num_rows] if layer.w_pad > 0: output = output[:, :out_features] output = output.reshape(*orig_shape[:-1], output.shape[-1]) if bias is not None: output = output + bias return output