nvidia-oliver-holworthy commited on
Commit
5f8124a
·
unverified ·
1 Parent(s): 35dc76a

Update the llama_bidirectional_model.py docstring for improved clarity

Browse files

Signed-off-by: Oliver Holworthy <nvidia-oliver-holworthy@users.noreply.huggingface.co>

Files changed (1) hide show
  1. llama_bidirectional_model.py +23 -21
llama_bidirectional_model.py CHANGED
@@ -1,27 +1,29 @@
1
  # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
  # SPDX-License-Identifier: Apache-2.0.
3
  """
4
- Bidirectional Llama model for embedding tasks.
5
-
6
- This module provides a modified LlamaModel that uses bidirectional (non-causal)
7
- attention, suitable for generating embeddings where each token should attend
8
- to all other tokens in the sequence.
9
-
10
- Supports transformers version 4.44 and above with a unified forward() implementation.
11
-
12
- Version compatibility notes:
13
- - transformers 4.47: Setting _attn_implementation in __init__ had no effect due to
14
- attention initialization order
15
- - transformers 4.48+: Attention refactor (transformers#35235) activated the
16
- _attn_implementation setting, which defaulted to "eager" instead of "sdpa"
17
- - transformers < 4.53: LlamaModel has _update_causal_mask method that can be overridden
18
- - transformers 4.53+: _update_causal_mask removed; masking moved to masking_utils module,
19
- necessitating a full forward() override for custom attention masks
20
- - transformers < 4.54: Decoder layer returns tuple, uses past_key_value (singular)
21
- - transformers 4.54-4.55: Decoder layer returns tensor, uses past_key_value (singular)
22
- - transformers 4.56+: Decoder layer returns tensor, uses past_key_values (plural),
23
- DynamicCache accepts config parameter
24
- - transformers 5.0+: Has native create_bidirectional_mask in masking_utils
 
 
25
  """
26
 
27
  import inspect
 
1
  # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
  # SPDX-License-Identifier: Apache-2.0.
3
  """
4
+ Bidirectional Llama model for cross-encoder reranking.
5
+
6
+ Modifies LlamaModel to use bidirectional (non-causal) attention so each token
7
+ attends to all others required for cross-encoder scoring of query-document pairs.
8
+
9
+ Provides three classes:
10
+ - LlamaBidirectionalConfig: Adds pooling and temperature to LlamaConfig.
11
+ - LlamaBidirectionalModel: LlamaModel with causal masking replaced by
12
+ bidirectional masking. Overrides forward() to support transformers >=4.44.
13
+ - LlamaBidirectionalForSequenceClassification: Pools hidden states and
14
+ projects to a relevance score via a linear head.
15
+
16
+ Transformers version compatibility (>=4.44 including 5.0+):
17
+ The forward() implementation handles these API changes at import time via
18
+ inspect.signature() on LlamaDecoderLayer and DynamicCache:
19
+
20
+ < 4.53: _update_causal_mask exists on LlamaModel (not used here).
21
+ 4.53+: Masking moved to masking_utils; requires full forward() override.
22
+ < 4.54: Decoder layer returns a tuple.
23
+ 4.54+: Decoder layer returns a tensor.
24
+ < 4.56: Cache kwarg is ``past_key_value`` (singular).
25
+ 4.56+: Cache kwarg is ``past_key_values`` (plural); DynamicCache accepts config.
26
+ 5.0+: Native ``create_bidirectional_mask`` in masking_utils.
27
  """
28
 
29
  import inspect