Update the llama_bidirectional_model.py docstring for improved clarity
Browse filesSigned-off-by: Oliver Holworthy <nvidia-oliver-holworthy@users.noreply.huggingface.co>
- llama_bidirectional_model.py +23 -21
llama_bidirectional_model.py
CHANGED
|
@@ -1,27 +1,29 @@
|
|
| 1 |
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
# SPDX-License-Identifier: Apache-2.0.
|
| 3 |
"""
|
| 4 |
-
Bidirectional Llama model for
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
import inspect
|
|
|
|
| 1 |
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
# SPDX-License-Identifier: Apache-2.0.
|
| 3 |
"""
|
| 4 |
+
Bidirectional Llama model for cross-encoder reranking.
|
| 5 |
+
|
| 6 |
+
Modifies LlamaModel to use bidirectional (non-causal) attention so each token
|
| 7 |
+
attends to all others — required for cross-encoder scoring of query-document pairs.
|
| 8 |
+
|
| 9 |
+
Provides three classes:
|
| 10 |
+
- LlamaBidirectionalConfig: Adds pooling and temperature to LlamaConfig.
|
| 11 |
+
- LlamaBidirectionalModel: LlamaModel with causal masking replaced by
|
| 12 |
+
bidirectional masking. Overrides forward() to support transformers >=4.44.
|
| 13 |
+
- LlamaBidirectionalForSequenceClassification: Pools hidden states and
|
| 14 |
+
projects to a relevance score via a linear head.
|
| 15 |
+
|
| 16 |
+
Transformers version compatibility (>=4.44 including 5.0+):
|
| 17 |
+
The forward() implementation handles these API changes at import time via
|
| 18 |
+
inspect.signature() on LlamaDecoderLayer and DynamicCache:
|
| 19 |
+
|
| 20 |
+
< 4.53: _update_causal_mask exists on LlamaModel (not used here).
|
| 21 |
+
4.53+: Masking moved to masking_utils; requires full forward() override.
|
| 22 |
+
< 4.54: Decoder layer returns a tuple.
|
| 23 |
+
4.54+: Decoder layer returns a tensor.
|
| 24 |
+
< 4.56: Cache kwarg is ``past_key_value`` (singular).
|
| 25 |
+
4.56+: Cache kwarg is ``past_key_values`` (plural); DynamicCache accepts config.
|
| 26 |
+
5.0+: Native ``create_bidirectional_mask`` in masking_utils.
|
| 27 |
"""
|
| 28 |
|
| 29 |
import inspect
|