| import torch
|
| import torch.nn as nn
|
| from typing import TYPE_CHECKING
|
| from torch.nn import functional as F
|
|
|
| from .layers import layer_norm, mlp
|
| from .config import TextConfig
|
|
|
|
|
| if TYPE_CHECKING:
|
| from .rope import RotaryEmbedding
|
|
|
|
|
| def text_encoder(input_ids: torch.Tensor, w: nn.Module):
|
| return F.embedding(input_ids, w.wte)
|
|
|
| def attn(
|
| x: torch.Tensor,
|
| w: nn.Module,
|
| attn_mask: torch.Tensor,
|
| n_heads: int,
|
| rope: "RotaryEmbedding",
|
| kv_cache: nn.Module,
|
| pos_ids: torch.Tensor,
|
| ):
|
| bsz, q_len, d_model = x.shape
|
| head_dim = d_model // n_heads
|
|
|
| qkv_out = w.qkv(x)
|
|
|
| qkv_reshaped = qkv_out.view(bsz, q_len, 3, n_heads, head_dim)
|
|
|
|
|
|
|
|
|
| qkv_permuted = qkv_reshaped.permute(2, 0, 3, 1, 4)
|
|
|
|
|
| q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]
|
|
|
| q = rope.apply(q, pos_ids)
|
| k = rope.apply(k, pos_ids)
|
|
|
| k, v = kv_cache.update(pos_ids, k, v)
|
|
|
|
|
| out = F.scaled_dot_product_attention(
|
| q, k, v, attn_mask=attn_mask
|
| )
|
| out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
| out = w.proj(out)
|
| return out
|
|
|
| def text_decoder(
|
| x: torch.Tensor,
|
| w: nn.Module,
|
| attn_mask: torch.Tensor,
|
| config: TextConfig,
|
| rope: "RotaryEmbedding",
|
| pos_ids: torch.Tensor,
|
| ):
|
|
|
| for i, block in enumerate(w.blocks):
|
| l_in = layer_norm(x, block.ln)
|
| l_attn = attn(
|
| l_in,
|
| block.attn,
|
| attn_mask=attn_mask,
|
| n_heads=config.n_heads,
|
| rope=rope,
|
| kv_cache=block.kv_cache,
|
| pos_ids=pos_ids,
|
| )
|
| l_mlp = mlp(l_in, block.mlp)
|
| x = x + l_attn + l_mlp
|
|
|
| return x
|
|
|
|
|
| def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
| hidden_BC = hidden_BTC[:, -1, :]
|
| hidden_BC = layer_norm(hidden_BC, w.post_ln)
|
| logits = w.lm_head(hidden_BC)
|
| return logits
|
|
|
|
|
| def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
| qkv_dim = int(config.dim * 3)
|
|
|
| text = nn.ModuleDict(
|
| {
|
| "blocks": nn.ModuleList(
|
| [
|
| nn.ModuleDict(
|
| {
|
| "ln": nn.LayerNorm(config.dim, dtype=dtype),
|
| "attn": nn.ModuleDict(
|
| {
|
| "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
|
| "proj": nn.Linear(
|
| config.dim, config.dim, dtype=dtype
|
| ),
|
| }
|
| ),
|
| "mlp": nn.ModuleDict(
|
| {
|
| "fc1": nn.Linear(
|
| config.dim, config.ff_dim, dtype=dtype
|
| ),
|
| "fc2": nn.Linear(
|
| config.ff_dim, config.dim, dtype=dtype
|
| ),
|
| }
|
| ),
|
| }
|
| )
|
| for _ in range(config.n_layers)
|
| ]
|
| ),
|
| "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
|
| "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
|
| }
|
| )
|
| text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
|
|
|
|
|
| return text
|
|
|