| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Simple implementation of Phi model.""" |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
|
|
| from transformers.activations import ACT2FN |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.models.phi.configuration_phi import PhiConfig |
|
|
| from diffnext.models.flash_attention import apply_rotary_emb |
|
|
|
|
| def maybe_apply_ckpt(function, x, enable=False) -> torch.Tensor: |
| """Apply gradient checkpointing if possible.""" |
| if enable and (x[0] if isinstance(x, (tuple, list)) else x).requires_grad: |
| return torch.utils.checkpoint.checkpoint(function, x, use_reentrant=False) |
| return function(x) |
|
|
|
|
| class PhiRotaryEmbedding(nn.Module): |
| """Rotary embedding layer.""" |
|
|
| class PEFunc(object): |
| """Apply RoPE weight to Q/K tensor.""" |
|
|
| def __init__(self, weight): |
| self.cos, self.sin = weight |
|
|
| def __call__(self, x: torch.Tensor) -> torch.Tensor: |
| self.cos, self.sin = self.cos.to(x), self.sin.to(x) |
| return apply_rotary_emb(x, self.cos, self.sin, inplace=True) |
|
|
| @staticmethod |
| def from_config(config): |
| head_dim = config.hidden_size // config.num_attention_heads |
| rotary_dim = int(config.partial_rotary_factor * head_dim) |
| return PhiRotaryEmbedding(rotary_dim, config.max_position_embeddings, config.rope_theta) |
|
|
| def __init__(self, dim, max_position_embeddings=2048, base=10000): |
| super().__init__() |
| self.dim, self.base = dim, base |
| self.max_position_embeddings = max_position_embeddings |
| freq = self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim) |
| self.register_buffer("inv_freq", freq.reciprocal_(), persistent=False) |
| self.set_cos_sin_cache(max_position_embeddings, dtype=torch.get_default_dtype()) |
|
|
| def set_cos_sin_cache(self, seqlen, dtype): |
| self.max_seqlen_cached, device = seqlen, self.inv_freq.device |
| t = torch.arange(self.max_seqlen_cached, device=device, dtype=torch.int64) |
| freq = torch.outer(t.float(), self.inv_freq.float()) |
| emb = torch.cat((freq, freq), dim=-1) |
| self.register_buffer("cos", emb.cos().to(dtype), persistent=False) |
| self.register_buffer("sin", emb.sin().to(dtype), persistent=False) |
|
|
| def get_func(self, pos=0, seqlen=1) -> PEFunc: |
| return self.PEFunc(_[pos : pos + seqlen].chunk(2, -1)[0] for _ in (self.cos, self.sin)) |
|
|
|
|
| class PhiMLP(nn.Module): |
| """Phi MLP.""" |
|
|
| def __init__(self, config: PhiConfig): |
| super().__init__() |
| self.gradient_checkpointing = False |
| self.activation = ACT2FN[config.hidden_act] |
| self.config, self.hidden_size = config, config.hidden_size |
| self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
| self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
| def forward(self, x) -> torch.Tensor: |
| return self.fc2(self.activation(self.fc1(x))) |
|
|
|
|
| class PhiAttention(nn.Module): |
| """Phi attention.""" |
|
|
| def __init__(self, config: PhiConfig, layer_idx=None): |
| super().__init__() |
| self.layer_idx, hidden_size = layer_idx, config.hidden_size |
| self.config, self.is_causal, self.gradient_checkpointing = config, True, False |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.q_proj = nn.Linear(hidden_size, config.num_attention_heads * self.head_dim) |
| self.k_proj = nn.Linear(hidden_size, config.num_key_value_heads * self.head_dim) |
| self.v_proj = nn.Linear(hidden_size, config.num_key_value_heads * self.head_dim) |
| self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) |
| self.attn_mask = self.past_key_value = self.pe_func = self.flex_attn = None |
|
|
| def forward_qkv(self, x) -> torch.Tensor: |
| x = x[1](x[0]) if isinstance(x, (tuple, list)) else x |
| q, k, v = [m(x) for m in (self.q_proj, self.k_proj, self.v_proj)] |
| return [_.unflatten(-1, (-1, self.head_dim)) for _ in (q, k, v)] |
|
|
| def repeat_kv(self, x) -> torch.Tensor: |
| return x.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2) |
|
|
|
|
| class PhiSdpaAttention(PhiAttention): |
| """Phi SDPA attention.""" |
|
|
| def forward(self, x) -> torch.Tensor: |
| q, k, v = maybe_apply_ckpt(self.forward_qkv, x, self.gradient_checkpointing) |
| q, k = [self.pe_func(_) for _ in (q, k)] |
| q, k, v = [_.transpose(1, 2) for _ in (q, k, v)] |
| if self.past_key_value is not None and getattr(self.past_key_value, "is_frozen", False): |
| k, v = [torch.cat(_, -2) for _ in zip(self.past_key_value[self.layer_idx], (k, v))] |
| elif self.past_key_value is not None: |
| k, v = self.past_key_value.update(k, v, self.layer_idx) |
| self.past_key_value = None |
| if self.flex_attn and self.flex_attn.offsets: |
| return self.dense(self.flex_attn(q, k, v).transpose(1, 2).flatten(2)) |
| is_causal = self.is_causal and self.attn_mask is None and x.size(1) > 1 |
| sdpa_args = {"is_causal": is_causal, "enable_gqa": True} |
| o = nn.functional.scaled_dot_product_attention(q, k, v, self.attn_mask, **sdpa_args) |
| return self.dense(o.transpose(1, 2).flatten(2)) |
|
|
|
|
| class PhiDecoderLayer(nn.Module): |
| """Phi decoder layer.""" |
|
|
| def __init__(self, config: PhiConfig, layer_idx: int): |
| super().__init__() |
| self.self_attn = PhiSdpaAttention(config, layer_idx) |
| self.mlp, self.gradient_checkpointing = PhiMLP(config), False |
| self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.resid_pdrop, inplace=True) |
| self.mlp_checkpointing = False |
|
|
| def forward(self, x) -> torch.Tensor: |
| shortcut, x = x, self.input_layernorm(x) |
| x = self.self_attn(x).add_(maybe_apply_ckpt(self.mlp, x, self.mlp.gradient_checkpointing)) |
| return x.add_(shortcut) |
|
|
|
|
| class PhiPreTrainedModel(PreTrainedModel): |
| """Phi pre-trained model.""" |
|
|
| config_class = PhiConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["PhiDecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| class PhiModel(PhiPreTrainedModel): |
| """Phi transformer model.""" |
|
|
| def __init__(self, config: PhiConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = [PhiDecoderLayer(config, i) for i in range(config.num_hidden_layers)] |
| self.layers = nn.ModuleList(self.layers) |
| self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.rotary_emb, _ = PhiRotaryEmbedding.from_config(config), self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor = None, |
| attention_mask: torch.Tensor = None, |
| inputs_embeds: torch.Tensor = None, |
| past_key_values: torch.Tensor = None, |
| **kwargs, |
| ) -> BaseModelOutputWithPast: |
| x = inputs_embeds if input_ids is None else self.embed_tokens(input_ids) |
| pe_pos = kwargs.get("rope_pos", past_key_values.get_seq_length() if past_key_values else 0) |
| pe_embedder = self.flex_rope if isinstance(pe_pos, torch.Tensor) else self.rotary_emb |
| pe_func = pe_embedder.get_func(pe_pos, x.size(1)) |
| for layer in self.layers: |
| layer.self_attn.pe_func = pe_func |
| layer.self_attn.attn_mask = attention_mask |
| layer.self_attn.past_key_value = past_key_values |
| x = maybe_apply_ckpt(layer.__call__, x, layer.gradient_checkpointing) |
| x = self.final_layernorm(x) |
| return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values) |
|
|
|
|
| class PhiEncoderModel(PhiPreTrainedModel): |
| """Phi encoder model.""" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = PhiModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| self.vocab_size, _ = config.vocab_size, self.post_init() |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs) -> BaseModelOutputWithPast: |
| return self.model(input_ids, attention_mask, **kwargs) |
|
|
|
|
| class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): |
| """Phi causal language model.""" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = PhiModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| self.lm_shift, _ = 0, self.post_init() |
|
|
| def get_input_embeddings(self) -> nn.Embedding: |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self) -> nn.Linear: |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self) -> PhiModel: |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: torch.Tensor = None, |
| inputs_embeds: torch.Tensor = None, |
| logits_to_keep=None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| outputs = self.model(input_ids, attention_mask, inputs_embeds, **kwargs) |
| keep = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| head_w = self.lm_head.weight[self.lm_shift :] if self.lm_shift else self.lm_head.weight |
| logits = nn.functional.linear(outputs[0] if keep is None else outputs[0][:, keep], head_w) |
| return CausalLMOutputWithPast(logits=logits, past_key_values=outputs.past_key_values) |
|
|
| def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): |
| past_key_values, _ = kwargs.get("past_key_values", None), kwargs.pop("attention_mask", None) |
| past_pos = past_key_values.get_seq_length() if past_key_values else 0 |
| inputs = {"input_ids": input_ids[:, past_pos:] if past_pos else input_ids, **kwargs} |
| if inputs_embeds is not None and not past_pos: |
| inputs["inputs_embeds"] = inputs_embeds |
| return inputs |
|
|