atharvasc27112001 commited on
Commit
638223b
·
verified ·
1 Parent(s): 3d51a28

Create focus_area_model.py

Browse files
Files changed (1) hide show
  1. focus_area_model.py +19 -0
focus_area_model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ class LabelEmbCls(nn.Module):
6
+ def __init__(self, base: AutoModel, lbl_emb: torch.Tensor):
7
+ super().__init__()
8
+ self.bert = base
9
+ self.lbl_E = nn.Parameter(lbl_emb, requires_grad=False)
10
+ self.tau = nn.Parameter(torch.tensor(1.0))
11
+ def forward(self, input_ids, attention_mask, token_type_ids=None):
12
+ # get the [CLS] token embedding
13
+ cls = self.bert(
14
+ input_ids=input_ids,
15
+ attention_mask=attention_mask,
16
+ token_type_ids=token_type_ids
17
+ ).last_hidden_state[:, 0] # shape [batch, 768]
18
+ # compute dot-product / tau
19
+ return torch.matmul(cls, self.lbl_E.T) / self.tau