File size: 4,136 Bytes
954e44f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a13422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
954e44f
 
 
 
 
 
 
 
 
 
 
1a13422
954e44f
 
 
 
 
 
 
 
 
 
 
 
 
1a13422
954e44f
 
1a13422
954e44f
1a13422
 
954e44f
 
 
 
 
1a13422
d0aacb0
954e44f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

import torch
import torch.nn.functional as F
from torch.nn import Parameter

from vllm.model_executor.layers.quantization import (
    register_quantization_config,
)
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig,
    QuantizeMethodBase,
)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.parameter import ModelWeightParameter


@register_quantization_config("quartet2")
class QuartetIIConfig(QuantizationConfig):

    def get_name(self) -> str:
        return "quartet2"

    def get_supported_act_dtypes(self) -> list:
        return [torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        return 100  # Blackwell (SM 10.0)

    @staticmethod
    def get_config_filenames() -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict) -> "QuartetIIConfig":
        return cls()

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> QuantizeMethodBase | None:
        if isinstance(layer, LinearBase):
            return QuartetIILinearMethod(self)
        return None


class QuartetIILinearMethod(LinearMethodBase):

    def __init__(self, config: QuartetIIConfig):
        self.config = config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=extra_weight_attrs.get("weight_loader"),
        )
        layer.register_parameter("weight", weight)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        from quartet2.quant import quant_fp4, NVFP4QuantMode
        from quartet2.linear import abs_max

        weight = layer.weight.data
        device = weight.device
        out_features = weight.shape[0]

        w_remainder = out_features % 128
        if w_remainder != 0:
            w_pad = 128 - w_remainder
            weight = F.pad(weight, (0, 0, 0, w_pad))
        else:
            w_pad = 0

        mode = NVFP4QuantMode.FOUR_SIX
        weight_amax = abs_max(weight)
        wq = quant_fp4(weight, amax=weight_amax, scale_override=1.0, mode=mode)

        layer.weight_fp4 = wq.fp4
        layer.weight_micro_scales = wq.micro_scales
        layer.weight_tensor_scale = wq.tensor_scale
        layer.w_pad = w_pad

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from quartet2.quant import quant_fp4, NVFP4QuantMode
        from quartet2.linear import abs_max, _fp4_mm

        orig_shape = x.shape
        out_features = layer.weight.shape[0]
        flat_x = x.reshape(-1, x.shape[-1])

        num_rows = flat_x.shape[0]
        remainder = num_rows % 128
        if remainder != 0:
            pad_rows = 128 - remainder
            flat_x = F.pad(flat_x, (0, 0, 0, pad_rows))
        else:
            pad_rows = 0

        input_amax = abs_max(flat_x)
        input_fp4 = quant_fp4(
            flat_x, amax=input_amax,
            scale_override=1.0, mode=NVFP4QuantMode.FOUR_SIX,
        )

        alpha = input_fp4.tensor_scale * layer.weight_tensor_scale
        output = _fp4_mm(
            input_fp4.fp4, layer.weight_fp4,
            input_fp4.micro_scales, layer.weight_micro_scales,
            alpha,
        )

        if pad_rows > 0:
            output = output[:num_rows]
        if layer.w_pad > 0:
            output = output[:, :out_features]

        output = output.reshape(*orig_shape[:-1], output.shape[-1])
        if bias is not None:
            output = output + bias
        return output