|
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| __all__ = ['XLMRoberta', 'xlm_roberta_large']
|
|
|
|
|
| class SelfAttention(nn.Module):
|
|
|
| def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
| assert dim % num_heads == 0
|
| super().__init__()
|
| self.dim = dim
|
| self.num_heads = num_heads
|
| self.head_dim = dim // num_heads
|
| self.eps = eps
|
|
|
|
|
| self.q = nn.Linear(dim, dim)
|
| self.k = nn.Linear(dim, dim)
|
| self.v = nn.Linear(dim, dim)
|
| self.o = nn.Linear(dim, dim)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, x, mask):
|
| """
|
| x: [B, L, C].
|
| """
|
| b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
|
|
|
|
| q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
|
|
|
|
| p = self.dropout.p if self.training else 0.0
|
| x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
| x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
|
|
|
|
| x = self.o(x)
|
| x = self.dropout(x)
|
| return x
|
|
|
|
|
| class AttentionBlock(nn.Module):
|
|
|
| def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
| super().__init__()
|
| self.dim = dim
|
| self.num_heads = num_heads
|
| self.post_norm = post_norm
|
| self.eps = eps
|
|
|
|
|
| self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
| self.norm1 = nn.LayerNorm(dim, eps=eps)
|
| self.ffn = nn.Sequential(
|
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
| nn.Dropout(dropout))
|
| self.norm2 = nn.LayerNorm(dim, eps=eps)
|
|
|
| def forward(self, x, mask):
|
| if self.post_norm:
|
| x = self.norm1(x + self.attn(x, mask))
|
| x = self.norm2(x + self.ffn(x))
|
| else:
|
| x = x + self.attn(self.norm1(x), mask)
|
| x = x + self.ffn(self.norm2(x))
|
| return x
|
|
|
|
|
| class XLMRoberta(nn.Module):
|
| """
|
| XLMRobertaModel with no pooler and no LM head.
|
| """
|
|
|
| def __init__(self,
|
| vocab_size=250002,
|
| max_seq_len=514,
|
| type_size=1,
|
| pad_id=1,
|
| dim=1024,
|
| num_heads=16,
|
| num_layers=24,
|
| post_norm=True,
|
| dropout=0.1,
|
| eps=1e-5):
|
| super().__init__()
|
| self.vocab_size = vocab_size
|
| self.max_seq_len = max_seq_len
|
| self.type_size = type_size
|
| self.pad_id = pad_id
|
| self.dim = dim
|
| self.num_heads = num_heads
|
| self.num_layers = num_layers
|
| self.post_norm = post_norm
|
| self.eps = eps
|
|
|
|
|
| self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
| self.type_embedding = nn.Embedding(type_size, dim)
|
| self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
|
|
| self.blocks = nn.ModuleList([
|
| AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
| for _ in range(num_layers)
|
| ])
|
|
|
|
|
| self.norm = nn.LayerNorm(dim, eps=eps)
|
|
|
| def forward(self, ids):
|
| """
|
| ids: [B, L] of torch.LongTensor.
|
| """
|
| b, s = ids.shape
|
| mask = ids.ne(self.pad_id).long()
|
|
|
|
|
| x = self.token_embedding(ids) + \
|
| self.type_embedding(torch.zeros_like(ids)) + \
|
| self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
| if self.post_norm:
|
| x = self.norm(x)
|
| x = self.dropout(x)
|
|
|
|
|
| mask = torch.where(
|
| mask.view(b, 1, 1, s).gt(0), 0.0,
|
| torch.finfo(x.dtype).min)
|
| for block in self.blocks:
|
| x = block(x, mask)
|
|
|
|
|
| if not self.post_norm:
|
| x = self.norm(x)
|
| return x
|
|
|
|
|
| def xlm_roberta_large(pretrained=False,
|
| return_tokenizer=False,
|
| device='cpu',
|
| **kwargs):
|
| """
|
| XLMRobertaLarge adapted from Huggingface.
|
| """
|
|
|
| cfg = dict(
|
| vocab_size=250002,
|
| max_seq_len=514,
|
| type_size=1,
|
| pad_id=1,
|
| dim=1024,
|
| num_heads=16,
|
| num_layers=24,
|
| post_norm=True,
|
| dropout=0.1,
|
| eps=1e-5)
|
| cfg.update(**kwargs)
|
|
|
|
|
| with torch.device(device):
|
| model = XLMRoberta(**cfg)
|
| return model
|
|
|