Update amplify to extract last embedding after last layer norm.
Browse files- amplify.py +4 -1
amplify.py
CHANGED
|
@@ -276,7 +276,10 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 276 |
attentions.append(attn)
|
| 277 |
|
| 278 |
# Classification head with layer norm
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
# Return logits or the output of the last hidden layer
|
| 282 |
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|
|
|
|
| 276 |
attentions.append(attn)
|
| 277 |
|
| 278 |
# Classification head with layer norm
|
| 279 |
+
x_normalized = self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x
|
| 280 |
+
if output_hidden_states:
|
| 281 |
+
hidden_states[-1]=x_normalized
|
| 282 |
+
logits = self.decoder(x_normalized)
|
| 283 |
|
| 284 |
# Return logits or the output of the last hidden layer
|
| 285 |
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|