| |
|
| | import math |
| | from functools import partial |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange, reduce, repeat |
| | from torch.utils.checkpoint import checkpoint |
| | from transformers import AutoModel, PreTrainedModel |
| |
|
| | from .config import LUARConfig |
| |
|
| | |
| | |
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| | def summarize_qkv_chunk( |
| | q, k, v, |
| | mask |
| | ): |
| | """Dot-Product Attention for a chunk of queries, keys, and values. |
| | """ |
| | weight = torch.einsum('b h i d, b h j d -> b h i j', q, k) |
| |
|
| | if exists(mask): |
| | |
| | weight += mask |
| |
|
| | weight_max = weight.amax(dim = -1, keepdim = True).detach() |
| | weight = weight - weight_max |
| |
|
| | exp_weight = weight.exp() |
| | weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v) |
| |
|
| | return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') |
| |
|
| | checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk) |
| |
|
| | def memory_efficient_attention( |
| | q, k, v, |
| | mask = None, |
| | q_bucket_size = 512, |
| | k_bucket_size = 1024, |
| | eps = 1e-8 |
| | ): |
| | scale = q.shape[-1] ** -0.5 |
| | q = q * scale |
| |
|
| | |
| | needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad |
| | summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk |
| |
|
| | |
| | q_chunks = q.split(q_bucket_size, dim = -2) |
| | k_chunks = k.split(k_bucket_size, dim = -2) |
| | v_chunks = v.split(k_bucket_size, dim = -2) |
| | mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks)) |
| |
|
| | |
| | out = [] |
| | for q_index, q_chunk in enumerate(q_chunks): |
| | exp_weights = [] |
| | weighted_values = [] |
| | weight_maxes = [] |
| | |
| | for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)): |
| |
|
| | exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( |
| | q_chunk, |
| | k_chunk, |
| | v_chunk, |
| | mask_chunk, |
| | ) |
| |
|
| | exp_weights.append(exp_weight_chunk) |
| | weighted_values.append(weighted_value_chunk) |
| | weight_maxes.append(weight_max_chunk) |
| |
|
| | exp_weights = torch.stack(exp_weights, dim = -1) |
| | weighted_values = torch.stack(weighted_values, dim = -1) |
| | weight_maxes = torch.stack(weight_maxes, dim = -1) |
| |
|
| | global_max = weight_maxes.amax(dim = -1, keepdim = True) |
| | renorm_factor = (weight_maxes - global_max).exp().detach() |
| |
|
| | exp_weights = exp_weights * renorm_factor |
| | weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c') |
| |
|
| | all_values = weighted_values.sum(dim = -1) |
| | all_weights = exp_weights.sum(dim = -1) |
| |
|
| | normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps) |
| | out.append(normalized_values) |
| |
|
| | return torch.cat(out, dim=-2) |
| |
|
| | class SelfAttention(nn.Module): |
| | """Implements Dot-Product Self-Attention as used in "Attention is all You Need". |
| | """ |
| | def __init__( |
| | self, |
| | memory_efficient_attention=False, |
| | q_bucket_size=512, |
| | k_bucket_size=1024, |
| | ): |
| | super(SelfAttention, self).__init__() |
| | self.use_memory_efficient_attention = memory_efficient_attention |
| | self.q_bucket_size = q_bucket_size |
| | self.k_bucket_size = k_bucket_size |
| |
|
| | def forward(self, k, q, v): |
| |
|
| | if self.use_memory_efficient_attention: |
| | q, k, v = map( |
| | lambda t: rearrange(t, 'b n (h d) -> b h n d', h = 12), |
| | (q, k, v) |
| | ) |
| |
|
| | out = memory_efficient_attention( |
| | q, k, v, |
| | q_bucket_size=self.q_bucket_size, |
| | k_bucket_size=self.k_bucket_size |
| | ) |
| | out = rearrange(out, 'b h n d -> b n (h d)') |
| | return out |
| | else: |
| | d_k = q.size(-1) |
| | scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k) |
| | p_attn = F.softmax(scores, dim=-1) |
| | return torch.matmul(p_attn, v) |
| |
|
| | class LUAR(PreTrainedModel): |
| | """Defines the LUAR model. |
| | """ |
| | config_class = LUARConfig |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.create_transformer(revision=config.upstream_transformer_revision) |
| | self.attn_fn = SelfAttention( |
| | config.use_memory_efficient_attention, |
| | config.q_bucket_size, |
| | config.k_bucket_size, |
| | ) |
| | self.linear = nn.Linear(self.hidden_size, config.embedding_size) |
| |
|
| | def create_transformer(self, revision: Optional[str] = None): |
| | """Creates the Transformer backbone. |
| | """ |
| | kwargs = {"revision": revision} if revision else {} |
| | self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1", **kwargs) |
| | self.hidden_size = self.transformer.config.hidden_size |
| | self.num_attention_heads = self.transformer.config.num_attention_heads |
| | self.dim_head = self.hidden_size // self.num_attention_heads |
| | |
| | def mean_pooling(self, token_embeddings, attention_mask): |
| | """Mean Pooling as described in the SBERT paper. |
| | """ |
| | input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type()) |
| | sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum') |
| | sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9) |
| | return sum_embeddings / sum_mask |
| | |
| | def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0): |
| | """Computes the Author Embedding. |
| | """ |
| | B, E, _ = attention_mask.shape |
| |
|
| | input_ids = rearrange(input_ids, 'b e l -> (b e) l') |
| | attention_mask = rearrange(attention_mask, 'b e l -> (b e) l') |
| |
|
| | if document_batch_size > 0: |
| | outputs = {"last_hidden_state": [], "attentions": []} |
| | for i in range(0, len(input_ids), document_batch_size): |
| | out = self.transformer( |
| | input_ids=input_ids[i:i+document_batch_size], |
| | attention_mask=attention_mask[i:i+document_batch_size], |
| | return_dict=True, |
| | output_hidden_states=False, |
| | output_attentions=output_attentions, |
| | ) |
| | outputs["last_hidden_state"].append(out["last_hidden_state"]) |
| | if output_attentions: |
| | outputs["attentions"].append(out["attentions"]) |
| | outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0) |
| | if output_attentions: |
| | outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))]) |
| | else: |
| | outputs = self.transformer( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | return_dict=True, |
| | output_hidden_states=False, |
| | output_attentions=output_attentions, |
| | ) |
| | |
| | |
| | comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask) |
| | comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E) |
| |
|
| | |
| | episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings) |
| | episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max') |
| | |
| | episode_embeddings = self.linear(episode_embeddings) |
| | |
| | if output_attentions: |
| | return episode_embeddings, outputs["attentions"] |
| |
|
| | return episode_embeddings |
| | |
| | def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0): |
| | """Calculates a fixed-length feature vector for a batch of episode samples. |
| | """ |
| | output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size) |
| |
|
| | return output |
| |
|