davidhd commited on
Commit
3c9d879
·
verified ·
1 Parent(s): 3c0b151

Update amplify to extract last embedding after last layer norm.

Browse files
Files changed (1) hide show
  1. amplify.py +4 -1
amplify.py CHANGED
@@ -276,7 +276,10 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
276
  attentions.append(attn)
277
 
278
  # Classification head with layer norm
279
- logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
 
 
 
280
 
281
  # Return logits or the output of the last hidden layer
282
  return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
 
276
  attentions.append(attn)
277
 
278
  # Classification head with layer norm
279
+ x_normalized = self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x
280
+ if output_hidden_states:
281
+ hidden_states[-1]=x_normalized
282
+ logits = self.decoder(x_normalized)
283
 
284
  # Return logits or the output of the last hidden layer
285
  return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)