| |
| |
| @@ -21,7 +21,7 @@ |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Inference-only Deepseek model.""" |
| -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
| +from typing import Any, Dict, Iterable, List, Optional, Tuple |
| |
| import torch |
| from torch import nn |
| @@ -29,18 +29,19 @@ |
| |
| from vllm.attention import Attention, AttentionMetadata |
| from vllm.config import CacheConfig |
| -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, |
| +from vllm.distributed import (get_tensor_model_parallel_rank, |
| get_tensor_model_parallel_world_size, |
| tensor_model_parallel_all_reduce) |
| from vllm.model_executor.layers.activation import SiluAndMul |
| -from vllm.model_executor.layers.fused_moe import fused_moe |
| +from vllm.model_executor.layers.fused_moe import FusedMoE |
| from vllm.model_executor.layers.layernorm import RMSNorm |
| from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, |
| QKVParallelLinear, |
| ReplicatedLinear, |
| RowParallelLinear) |
| from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| -from vllm.model_executor.layers.quantization import QuantizationConfig |
| +from vllm.model_executor.layers.quantization.base_config import ( |
| + QuantizationConfig) |
| from vllm.model_executor.layers.rotary_embedding import get_rope |
| from vllm.model_executor.layers.sampler import Sampler, SamplerOutput |
| from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| @@ -49,10 +50,6 @@ |
| from vllm.model_executor.sampling_metadata import SamplingMetadata |
| from vllm.sequence import IntermediateTensors |
| |
| -from .interfaces import SupportsPP |
| -from .utils import (is_pp_missing_parameter, |
| - make_empty_intermediate_tensors_factory, make_layers) |
| - |
| |
| class DeepseekMLP(nn.Module): |
| |
| @@ -91,6 +88,7 @@ |
| def __init__( |
| self, |
| config: PretrainedConfig, |
| + layer_idx: int, |
| quant_config: Optional[QuantizationConfig] = None, |
| ): |
| super().__init__() |
| @@ -104,15 +102,17 @@ |
| f"Tensor parallel size {self.tp_size} is greater than " |
| f"the number of experts {self.n_routed_experts}.") |
| |
| - self.experts = nn.ModuleList([ |
| - DeepseekMLP(hidden_size=config.hidden_size, |
| - intermediate_size=config.moe_intermediate_size, |
| - hidden_act=config.hidden_act, |
| - quant_config=quant_config, |
| - reduce_results=False) |
| - for idx in range(self.n_routed_experts) |
| - ]) |
| - self.pack_params() |
| + self.experts = FusedMoE( |
| + num_experts=config.n_routed_experts, |
| + top_k=config.num_experts_per_tok, |
| + hidden_size=config.hidden_size, |
| + intermediate_size=config.moe_intermediate_size, |
| + reduce_results=False, |
| + renormalize=config.norm_topk_prob, |
| + quant_config=quant_config, |
| + use_grouped_topk=False, |
| + prefix=f"model.layers.{layer_idx}.mlp.experts" |
| + ) |
| |
| self.gate = ReplicatedLinear(config.hidden_size, |
| self.n_routed_experts, |
| @@ -130,25 +130,6 @@ |
| reduce_results=False, |
| ) |
| |
| - def pack_params(self): |
| - w1 = [] |
| - w2 = [] |
| - for expert in self.experts: |
| - w1.append(expert.gate_up_proj.weight) |
| - w2.append(expert.down_proj.weight) |
| - self.w1 = torch._utils._flatten_dense_tensors(w1) |
| - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) |
| - for data, param in zip(w1s, w1): |
| - param.data = data |
| - self.w1 = self.w1.view(len(w1), *w1s[0].shape) |
| - |
| - self.w2 = torch._utils._flatten_dense_tensors(w2) |
| - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) |
| - for data, param in zip(w2s, w2): |
| - param.data = data |
| - |
| - self.w2 = self.w2.view(len(w2), *w2s[0].shape) |
| - |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| num_tokens, hidden_dim = hidden_states.shape |
| hidden_states = hidden_states.view(-1, hidden_dim) |
| @@ -156,18 +137,14 @@ |
| shared_output = self.shared_experts(hidden_states) |
| # router_logits: (num_tokens, n_experts) |
| router_logits, _ = self.gate(hidden_states) |
| - final_hidden_states = fused_moe(hidden_states, |
| - self.w1, |
| - self.w2, |
| - router_logits, |
| - self.top_k, |
| - renormalize=self.config.norm_topk_prob, |
| - inplace=True) |
| + final_hidden_states = self.experts(hidden_states=hidden_states, |
| + router_logits=router_logits) |
| |
| - if self.config.n_shared_experts is not None: |
| + if shared_output is not None: |
| final_hidden_states = final_hidden_states + shared_output |
| - final_hidden_states = tensor_model_parallel_all_reduce( |
| - final_hidden_states) |
| + if self.tp_size > 1: |
| + final_hidden_states = tensor_model_parallel_all_reduce( |
| + final_hidden_states) |
| |
| return final_hidden_states.view(num_tokens, hidden_dim) |
| |
| @@ -179,6 +156,7 @@ |
| hidden_size: int, |
| num_heads: int, |
| num_kv_heads: int, |
| + head_dim: int, |
| rope_theta: float = 10000, |
| rope_scaling: Optional[Dict[str, Any]] = None, |
| max_position_embeddings: int = 8192, |
| @@ -201,7 +179,8 @@ |
| # the KV heads across multiple tensor parallel GPUs. |
| assert tp_size % self.total_num_kv_heads == 0 |
| self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) |
| - self.head_dim = hidden_size // self.total_num_heads |
| + # self.head_dim = hidden_size // self.total_num_heads |
| + self.head_dim = hidden_size // self.total_num_heads if head_dim is None else head_dim |
| self.q_size = self.num_heads * self.head_dim |
| self.kv_size = self.num_kv_heads * self.head_dim |
| self.scaling = self.head_dim**-0.5 |
| @@ -268,10 +247,12 @@ |
| rope_scaling = getattr(config, "rope_scaling", None) |
| max_position_embeddings = getattr(config, "max_position_embeddings", |
| 8192) |
| + head_dim = getattr(config, "head_dim", None) |
| self.self_attn = DeepseekAttention( |
| hidden_size=self.hidden_size, |
| num_heads=config.num_attention_heads, |
| num_kv_heads=config.num_key_value_heads, |
| + head_dim=head_dim, |
| rope_theta=rope_theta, |
| rope_scaling=rope_scaling, |
| max_position_embeddings=max_position_embeddings, |
| @@ -281,7 +262,7 @@ |
| if (config.n_routed_experts is not None |
| and layer_idx >= config.first_k_dense_replace |
| and layer_idx % config.moe_layer_freq == 0): |
| - self.mlp = DeepseekMoE(config=config, quant_config=quant_config) |
| + self.mlp = DeepseekMoE(config=config, quant_config=quant_config, layer_idx=layer_idx) |
| else: |
| self.mlp = DeepseekMLP( |
| hidden_size=config.hidden_size, |
| @@ -332,7 +313,6 @@ |
| config: PretrainedConfig, |
| cache_config: Optional[CacheConfig] = None, |
| quant_config: Optional[QuantizationConfig] = None, |
| - prefix: str = "", |
| ) -> None: |
| super().__init__() |
| self.padding_idx = config.pad_token_id |
| @@ -342,17 +322,14 @@ |
| config.vocab_size, |
| config.hidden_size, |
| ) |
| - self.start_layer, self.end_layer, self.layers = make_layers( |
| - config.num_hidden_layers, |
| - lambda prefix: DeepseekDecoderLayer(config, |
| - int(prefix.split(".")[-1]), |
| - cache_config, |
| - quant_config=quant_config), |
| - prefix=f"{prefix}.layers") |
| + self.layers = nn.ModuleList([ |
| + DeepseekDecoderLayer(config, |
| + layer_idx, |
| + cache_config, |
| + quant_config=quant_config) |
| + for layer_idx in range(config.num_hidden_layers) |
| + ]) |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| - self.make_empty_intermediate_tensors = ( |
| - make_empty_intermediate_tensors_factory( |
| - ["hidden_states", "residual"], config.hidden_size)) |
| |
| def forward( |
| self, |
| @@ -360,29 +337,19 @@ |
| positions: torch.Tensor, |
| kv_caches: List[torch.Tensor], |
| attn_metadata: AttentionMetadata, |
| - intermediate_tensors: Optional[IntermediateTensors], |
| - ) -> Union[torch.Tensor, IntermediateTensors]: |
| - if get_pp_group().is_first_rank: |
| - hidden_states = self.embed_tokens(input_ids) |
| - residual = None |
| - else: |
| - hidden_states = intermediate_tensors["hidden_states"] |
| - residual = intermediate_tensors["residual"] |
| - for i in range(self.start_layer, self.end_layer): |
| + ) -> torch.Tensor: |
| + hidden_states = self.embed_tokens(input_ids) |
| + residual = None |
| + for i in range(len(self.layers)): |
| layer = self.layers[i] |
| hidden_states, residual = layer(positions, hidden_states, |
| - kv_caches[i - self.start_layer], |
| - attn_metadata, residual) |
| - if not get_pp_group().is_last_rank: |
| - return IntermediateTensors({ |
| - "hidden_states": hidden_states, |
| - "residual": residual |
| - }) |
| + kv_caches[i], attn_metadata, |
| + residual) |
| hidden_states, _ = self.norm(hidden_states, residual) |
| return hidden_states |
| |
| |
| -class DeepseekForCausalLM(nn.Module, SupportsPP): |
| +class DeepseekForCausalLM(nn.Module): |
| |
| def __init__( |
| self, |
| @@ -401,8 +368,6 @@ |
| self.lm_head.weight = self.model.embed_tokens.weight |
| self.logits_processor = LogitsProcessor(config.vocab_size) |
| self.sampler = Sampler() |
| - self.make_empty_intermediate_tensors = ( |
| - self.model.make_empty_intermediate_tensors) |
| |
| def forward( |
| self, |
| @@ -411,9 +376,9 @@ |
| kv_caches: List[torch.Tensor], |
| attn_metadata: AttentionMetadata, |
| intermediate_tensors: Optional[IntermediateTensors] = None, |
| - ) -> Union[torch.Tensor, IntermediateTensors]: |
| + ) -> torch.Tensor: |
| hidden_states = self.model(input_ids, positions, kv_caches, |
| - attn_metadata, intermediate_tensors) |
| + attn_metadata) |
| return hidden_states |
| |
| def compute_logits( |
| @@ -443,6 +408,15 @@ |
| ("gate_up_proj", "up_proj", 1), |
| ] |
| |
| + # Params for weights, fp8 weight scales, fp8 activation scales |
| + # (param_name, weight_name, expert_id, shard_id) |
| + expert_params_mapping = FusedMoE.make_expert_params_mapping( |
| + ckpt_gate_proj_name="gate_proj", |
| + ckpt_down_proj_name="down_proj", |
| + ckpt_up_proj_name="up_proj", |
| + num_experts=self.config.n_routed_experts, |
| + ) |
| + |
| params_dict = dict(self.named_parameters()) |
| for name, loaded_weight in weights: |
| if "rotary_emb.inv_freq" in name: |
| @@ -450,31 +424,41 @@ |
| for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| if weight_name not in name: |
| continue |
| + if ("mlp.experts." in name) and name not in params_dict: |
| + continue |
| name = name.replace(weight_name, param_name) |
| # Skip loading extra bias for GPTQ models. |
| if name.endswith(".bias") and name not in params_dict: |
| continue |
| - # Skip experts that are not assigned to this worker. |
| - if (("mlp.experts." in name or "mlp.shared_experts." in name) |
| - and name not in params_dict): |
| - continue |
| - if is_pp_missing_parameter(name, self): |
| - continue |
| param = params_dict[name] |
| weight_loader = param.weight_loader |
| weight_loader(param, loaded_weight, shard_id) |
| break |
| else: |
| - # Skip loading extra bias for GPTQ models. |
| - if name.endswith(".bias") and name not in params_dict: |
| - continue |
| - # Skip experts that are not assigned to this worker. |
| - if (("mlp.experts." in name or "mlp.shared_experts." in name) |
| - and name not in params_dict): |
| - continue |
| - if is_pp_missing_parameter(name, self): |
| - continue |
| - param = params_dict[name] |
| - weight_loader = getattr(param, "weight_loader", |
| - default_weight_loader) |
| - weight_loader(param, loaded_weight) |
| + for mapping in expert_params_mapping: |
| + param_name, weight_name, expert_id, shard_id = mapping |
| + if weight_name not in name: |
| + continue |
| + name = name.replace(weight_name, param_name) |
| + param = params_dict[name] |
| + weight_loader = param.weight_loader |
| + weight_loader( |
| + param, |
| + loaded_weight, |
| + name, |
| + shard_id=shard_id, |
| + expert_id=expert_id, |
| + ) |
| + break |
| + else: |
| + # Skip loading extra bias for GPTQ models. |
| + if name.endswith(".bias") and name not in params_dict: |
| + continue |
| + # Skip experts that are not assigned to this worker. |
| + if ("mlp.experts." in name or "mlp.shared_experts." |
| + in name) and name not in params_dict: |
| + continue |
| + param = params_dict[name] |
| + weight_loader = getattr(param, "weight_loader", |
| + default_weight_loader) |
| + weight_loader(param, loaded_weight) |
| \ No newline at end of file |
| |
| |
| @@ -245,7 +245,7 @@ |
| config = self.quant_config.target_scheme_map["Linear"].get("weights") |
| self.num_bits = config.num_bits |
| self.packed_factor = 32 // config.num_bits |
| - self.strategy = config.strategy.value |
| + self.strategy = config.strategy |
| self.group_size = config.group_size |
| assert config.symmetric, ( |
| "Only symmetric quantization is supported for MoE") |
|
|