File size: 11,944 Bytes
d403233 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | # Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Simple implementation of Phi model."""
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.phi.configuration_phi import PhiConfig
from diffnext.models.flash_attention import apply_rotary_emb
def maybe_apply_ckpt(function, x, enable=False) -> torch.Tensor:
"""Apply gradient checkpointing if possible."""
if enable and (x[0] if isinstance(x, (tuple, list)) else x).requires_grad:
return torch.utils.checkpoint.checkpoint(function, x, use_reentrant=False)
return function(x)
class PhiRotaryEmbedding(nn.Module):
"""Rotary embedding layer."""
class PEFunc(object):
"""Apply RoPE weight to Q/K tensor."""
def __init__(self, weight):
self.cos, self.sin = weight
def __call__(self, x: torch.Tensor) -> torch.Tensor:
self.cos, self.sin = self.cos.to(x), self.sin.to(x)
return apply_rotary_emb(x, self.cos, self.sin, inplace=True)
@staticmethod
def from_config(config):
head_dim = config.hidden_size // config.num_attention_heads
rotary_dim = int(config.partial_rotary_factor * head_dim)
return PhiRotaryEmbedding(rotary_dim, config.max_position_embeddings, config.rope_theta)
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.dim, self.base = dim, base
self.max_position_embeddings = max_position_embeddings
freq = self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)
self.register_buffer("inv_freq", freq.reciprocal_(), persistent=False)
self.set_cos_sin_cache(max_position_embeddings, dtype=torch.get_default_dtype())
def set_cos_sin_cache(self, seqlen, dtype):
self.max_seqlen_cached, device = seqlen, self.inv_freq.device
t = torch.arange(self.max_seqlen_cached, device=device, dtype=torch.int64)
freq = torch.outer(t.float(), self.inv_freq.float())
emb = torch.cat((freq, freq), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype), persistent=False)
def get_func(self, pos=0, seqlen=1) -> PEFunc:
return self.PEFunc(_[pos : pos + seqlen].chunk(2, -1)[0] for _ in (self.cos, self.sin))
class PhiMLP(nn.Module):
"""Phi MLP."""
def __init__(self, config: PhiConfig):
super().__init__()
self.gradient_checkpointing = False
self.activation = ACT2FN[config.hidden_act]
self.config, self.hidden_size = config, config.hidden_size
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, x) -> torch.Tensor:
return self.fc2(self.activation(self.fc1(x)))
class PhiAttention(nn.Module):
"""Phi attention."""
def __init__(self, config: PhiConfig, layer_idx=None):
super().__init__()
self.layer_idx, hidden_size = layer_idx, config.hidden_size
self.config, self.is_causal, self.gradient_checkpointing = config, True, False
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.q_proj = nn.Linear(hidden_size, config.num_attention_heads * self.head_dim)
self.k_proj = nn.Linear(hidden_size, config.num_key_value_heads * self.head_dim)
self.v_proj = nn.Linear(hidden_size, config.num_key_value_heads * self.head_dim)
self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size)
self.attn_mask = self.past_key_value = self.pe_func = self.flex_attn = None
def forward_qkv(self, x) -> torch.Tensor:
x = x[1](x[0]) if isinstance(x, (tuple, list)) else x # PreNorm.
q, k, v = [m(x) for m in (self.q_proj, self.k_proj, self.v_proj)]
return [_.unflatten(-1, (-1, self.head_dim)) for _ in (q, k, v)]
def repeat_kv(self, x) -> torch.Tensor:
return x.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
class PhiSdpaAttention(PhiAttention):
"""Phi SDPA attention."""
def forward(self, x) -> torch.Tensor:
q, k, v = maybe_apply_ckpt(self.forward_qkv, x, self.gradient_checkpointing)
q, k = [self.pe_func(_) for _ in (q, k)]
q, k, v = [_.transpose(1, 2) for _ in (q, k, v)]
if self.past_key_value is not None and getattr(self.past_key_value, "is_frozen", False):
k, v = [torch.cat(_, -2) for _ in zip(self.past_key_value[self.layer_idx], (k, v))]
elif self.past_key_value is not None: # Fallback to legacy NTP caching.
k, v = self.past_key_value.update(k, v, self.layer_idx)
self.past_key_value = None # Release cache reference.
if self.flex_attn and self.flex_attn.offsets:
return self.dense(self.flex_attn(q, k, v).transpose(1, 2).flatten(2))
is_causal = self.is_causal and self.attn_mask is None and x.size(1) > 1
sdpa_args = {"is_causal": is_causal, "enable_gqa": True}
o = nn.functional.scaled_dot_product_attention(q, k, v, self.attn_mask, **sdpa_args)
return self.dense(o.transpose(1, 2).flatten(2))
class PhiDecoderLayer(nn.Module):
"""Phi decoder layer."""
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
self.self_attn = PhiSdpaAttention(config, layer_idx)
self.mlp, self.gradient_checkpointing = PhiMLP(config), False
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.resid_pdrop, inplace=True)
self.mlp_checkpointing = False
def forward(self, x) -> torch.Tensor:
shortcut, x = x, self.input_layernorm(x)
x = self.self_attn(x).add_(maybe_apply_ckpt(self.mlp, x, self.mlp.gradient_checkpointing))
return x.add_(shortcut)
class PhiPreTrainedModel(PreTrainedModel):
"""Phi pre-trained model."""
config_class = PhiConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class PhiModel(PhiPreTrainedModel):
"""Phi transformer model."""
def __init__(self, config: PhiConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = [PhiDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
self.layers = nn.ModuleList(self.layers)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.rotary_emb, _ = PhiRotaryEmbedding.from_config(config), self.post_init()
def forward(
self,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
past_key_values: torch.Tensor = None,
**kwargs,
) -> BaseModelOutputWithPast:
x = inputs_embeds if input_ids is None else self.embed_tokens(input_ids)
pe_pos = kwargs.get("rope_pos", past_key_values.get_seq_length() if past_key_values else 0)
pe_embedder = self.flex_rope if isinstance(pe_pos, torch.Tensor) else self.rotary_emb
pe_func = pe_embedder.get_func(pe_pos, x.size(1))
for layer in self.layers:
layer.self_attn.pe_func = pe_func
layer.self_attn.attn_mask = attention_mask
layer.self_attn.past_key_value = past_key_values
x = maybe_apply_ckpt(layer.__call__, x, layer.gradient_checkpointing)
x = self.final_layernorm(x)
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values)
class PhiEncoderModel(PhiPreTrainedModel):
"""Phi encoder model."""
def __init__(self, config):
super().__init__(config)
self.model = PhiModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.vocab_size, _ = config.vocab_size, self.post_init()
def forward(self, input_ids, attention_mask=None, **kwargs) -> BaseModelOutputWithPast:
return self.model(input_ids, attention_mask, **kwargs)
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
"""Phi causal language model."""
def __init__(self, config):
super().__init__(config)
self.model = PhiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.lm_shift, _ = 0, self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self) -> PhiModel:
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
logits_to_keep=None,
**kwargs,
) -> CausalLMOutputWithPast:
outputs = self.model(input_ids, attention_mask, inputs_embeds, **kwargs)
keep = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
head_w = self.lm_head.weight[self.lm_shift :] if self.lm_shift else self.lm_head.weight
logits = nn.functional.linear(outputs[0] if keep is None else outputs[0][:, keep], head_w)
return CausalLMOutputWithPast(logits=logits, past_key_values=outputs.past_key_values)
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
past_key_values, _ = kwargs.get("past_key_values", None), kwargs.pop("attention_mask", None)
past_pos = past_key_values.get_seq_length() if past_key_values else 0
inputs = {"input_ids": input_ids[:, past_pos:] if past_pos else input_ids, **kwargs}
if inputs_embeds is not None and not past_pos:
inputs["inputs_embeds"] = inputs_embeds
return inputs
|