Add support for transformers 4.44 through 5.0+

#11
Files changed (3) hide show
  1. README.md +2 -2
  2. config.json +1 -1
  3. llama_bidirectional_model.py +234 -50
README.md CHANGED
@@ -67,10 +67,10 @@ We trained the model on public datasets described in the Dataset and Training se
67
 
68
  ### **Installation**
69
 
70
- The model requires transformers version 4.47.1.
71
 
72
  ```bash
73
- pip install transformers==4.47.1
74
  ```
75
 
76
  ### **Usage**
 
67
 
68
  ### **Installation**
69
 
70
+ The model requires transformers version 4.44 or above.
71
 
72
  ```bash
73
+ pip install transformers>=4.44
74
  ```
75
 
76
  ### **Usage**
config.json CHANGED
@@ -40,7 +40,7 @@
40
  "rope_type": "llama3"
41
  },
42
  "rope_theta": 500000.0,
43
- "temperature": 0.2,
44
  "tie_word_embeddings": true,
45
  "torch_dtype": "bfloat16",
46
  "transformers_version": "4.44.2",
 
40
  "rope_type": "llama3"
41
  },
42
  "rope_theta": 500000.0,
43
+ "temperature": 1.0,
44
  "tie_word_embeddings": true,
45
  "torch_dtype": "bfloat16",
46
  "transformers_version": "4.44.2",
llama_bidirectional_model.py CHANGED
@@ -1,18 +1,43 @@
1
- from typing import List, Optional, Tuple, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
- import torch.nn.functional as F
5
- from torch import Tensor, nn
6
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
- from transformers.cache_utils import Cache, HybridCache
8
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
9
- from transformers.modeling_outputs import (
10
- BaseModelOutputWithPast,
11
- SequenceClassifierOutputWithPast,
12
- )
13
  from transformers.models.llama.configuration_llama import LlamaConfig
14
  from transformers.models.llama.modeling_llama import (
15
- LlamaForSequenceClassification,
16
  LlamaModel,
17
  LlamaPreTrainedModel,
18
  )
@@ -20,8 +45,200 @@ from transformers.utils import logging
20
 
21
  logger = logging.get_logger(__name__)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
26
 
27
  if pool_type == "avg":
@@ -46,48 +263,13 @@ def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) ->
46
  return emb
47
 
48
 
49
- class LlamaBidirectionalConfig(LlamaConfig):
50
- model_type = "llama_bidirec"
51
-
52
- def __init__(
53
- self, pooling="avg", temperature=1.0, **kwargs,
54
- ):
55
- self.pooling = pooling
56
- self.temperature = temperature
57
- super().__init__(**kwargs,)
58
-
59
-
60
- class LlamaBidirectionalModel(LlamaModel):
61
- config_class = LlamaBidirectionalConfig
62
-
63
- def __init__(self, config: LlamaConfig):
64
- super().__init__(config)
65
- for layer in self.layers:
66
- layer.self_attn.is_causal = False
67
- self.config._attn_implementation = "eager"
68
-
69
- def _update_causal_mask(
70
- self,
71
- attention_mask: torch.Tensor,
72
- input_tensor: torch.Tensor,
73
- cache_position: torch.Tensor,
74
- past_key_values: Cache,
75
- output_attentions: bool,
76
- ):
77
- # Generates bi-directional attention.
78
- causal_mask = _prepare_4d_attention_mask(attention_mask, input_tensor.dtype)
79
- return causal_mask
80
-
81
-
82
- class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification):
83
  config_class = LlamaBidirectionalConfig
84
 
85
  def __init__(self, config):
86
  super().__init__(config)
87
- # Releasing the parameters of LlamaModel
88
- # created by parent LlamaForSequenceClassification
89
- del self.model
90
-
91
  self.model = LlamaBidirectionalModel(config)
92
 
93
  # Initialize weights and apply final processing
@@ -105,6 +287,7 @@ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification
105
  output_attentions: Optional[bool] = None,
106
  output_hidden_states: Optional[bool] = None,
107
  return_dict: Optional[bool] = None,
 
108
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
109
  r"""
110
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -126,6 +309,7 @@ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification
126
  output_attentions=output_attentions,
127
  output_hidden_states=output_hidden_states,
128
  return_dict=return_dict,
 
129
  )
130
  hidden_states = transformer_outputs[0]
131
 
@@ -140,7 +324,7 @@ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification
140
 
141
  loss = None
142
  if labels is not None:
143
- labels = labels.to(logits.device)
144
  if self.config.problem_type is None:
145
  if self.num_labels == 1:
146
  self.config.problem_type = "regression"
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0.
3
+ """
4
+ Bidirectional Llama model for cross-encoder reranking.
5
+
6
+ Modifies LlamaModel to use bidirectional (non-causal) attention so each token
7
+ attends to all others — required for cross-encoder scoring of query-document pairs.
8
+
9
+ Provides three classes:
10
+ - LlamaBidirectionalConfig: Adds pooling and temperature to LlamaConfig.
11
+ - LlamaBidirectionalModel: LlamaModel with causal masking replaced by
12
+ bidirectional masking. Overrides forward() to support transformers >=4.44.
13
+ - LlamaBidirectionalForSequenceClassification: Pools hidden states and
14
+ projects to a relevance score via a linear head.
15
+
16
+ Transformers version compatibility (>=4.44 including 5.0+):
17
+ The forward() implementation handles these API changes at import time via
18
+ inspect.signature() on LlamaDecoderLayer and DynamicCache:
19
+
20
+ < 4.53: _update_causal_mask exists on LlamaModel (not used here).
21
+ 4.53+: Masking moved to masking_utils; requires full forward() override.
22
+ < 4.54: Decoder layer returns a tuple.
23
+ 4.54+: Decoder layer returns a tensor.
24
+ < 4.56: Cache kwarg is ``past_key_value`` (singular).
25
+ 4.56+: Cache kwarg is ``past_key_values`` (plural); DynamicCache accepts config.
26
+ 5.0+: Native ``create_bidirectional_mask`` in masking_utils.
27
+ """
28
+
29
+ import inspect
30
+ from typing import Optional, Union, Tuple, List
31
 
32
  import torch
33
+ import torch.nn as nn
 
34
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
35
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
36
+ from transformers.cache_utils import Cache, DynamicCache
37
+ from transformers.modeling_outputs import BaseModelOutputWithPast
 
 
 
38
  from transformers.models.llama.configuration_llama import LlamaConfig
39
  from transformers.models.llama.modeling_llama import (
40
+ LlamaDecoderLayer,
41
  LlamaModel,
42
  LlamaPreTrainedModel,
43
  )
 
45
 
46
  logger = logging.get_logger(__name__)
47
 
48
+ # Check if native create_bidirectional_mask exists (transformers >= 5.0)
49
+ try:
50
+ from transformers.masking_utils import create_bidirectional_mask
51
+
52
+ _HAS_NATIVE_BIDIRECTIONAL_MASK = True
53
+ except ImportError:
54
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
55
+
56
+ _HAS_NATIVE_BIDIRECTIONAL_MASK = False
57
+
58
+ # Detect API differences via introspection
59
+ _decoder_forward_params = inspect.signature(LlamaDecoderLayer.forward).parameters
60
+ _dynamic_cache_init_params = inspect.signature(DynamicCache.__init__).parameters
61
+
62
+ # past_key_value (singular) in < 4.56, past_key_values (plural) in >= 4.56
63
+ _USE_PLURAL_CACHE_PARAM = "past_key_values" in _decoder_forward_params
64
+ # DynamicCache accepts config parameter in >= 4.56
65
+ _DYNAMIC_CACHE_ACCEPTS_CONFIG = "config" in _dynamic_cache_init_params
66
+
67
+
68
+ class LlamaBidirectionalConfig(LlamaConfig):
69
+ """Configuration for LlamaBidirectionalModel with pooling and temperature settings."""
70
+
71
+ model_type = "llama_bidirec"
72
+
73
+ def __init__(
74
+ self, pooling: str = "avg", temperature: float = 1.0, **kwargs
75
+ ) -> None:
76
+ """
77
+ Initialize bidirectional Llama configuration.
78
+
79
+ Args:
80
+ pooling: Pooling strategy for embeddings ("avg", "cls", "last", etc.)
81
+ temperature: Temperature scaling for embeddings
82
+ **kwargs: Additional arguments passed to LlamaConfig
83
+ """
84
+ self.pooling = pooling
85
+ self.temperature = temperature
86
+ super().__init__(**kwargs)
87
+
88
+
89
+ class LlamaBidirectionalModel(LlamaModel):
90
+ """
91
+ LlamaModel modified to use bidirectional (non-causal) attention.
92
+
93
+ In standard Llama, each token can only attend to previous tokens (causal attention).
94
+ This model removes that restriction, allowing each token to attend to all tokens
95
+ in the sequence, which is useful for embedding tasks.
96
+
97
+ The key modifications are:
98
+ 1. Setting is_causal=False on all attention layers
99
+ 2. Using a bidirectional attention mask instead of causal mask
100
+ """
101
+
102
+ config_class = LlamaBidirectionalConfig
103
+
104
+ def __init__(self, config: LlamaConfig) -> None:
105
+ super().__init__(config)
106
+ for layer in self.layers:
107
+ layer.self_attn.is_causal = False
108
+
109
+ def _create_bidirectional_mask(
110
+ self,
111
+ input_embeds: torch.Tensor,
112
+ attention_mask: torch.Tensor | None,
113
+ ) -> torch.Tensor | None:
114
+ """
115
+ Create bidirectional attention mask.
116
+
117
+ Args:
118
+ input_embeds: Input embeddings tensor of shape (batch_size, seq_len, hidden_size)
119
+ attention_mask: Optional 2D attention mask of shape (batch_size, seq_len)
120
+ where 1 indicates tokens to attend to and 0 indicates masked tokens
121
+
122
+ Returns:
123
+ 4D attention mask suitable for the attention implementation, or None
124
+ if no masking is needed
125
+ """
126
+ if attention_mask is None:
127
+ return None
128
+
129
+ if _HAS_NATIVE_BIDIRECTIONAL_MASK:
130
+ return create_bidirectional_mask(
131
+ config=self.config,
132
+ input_embeds=input_embeds,
133
+ attention_mask=attention_mask,
134
+ )
135
+
136
+ # Fallback for transformers < 5.0 without create_bidirectional_mask
137
+
138
+ # Flash attention handles 2D masks internally; only pass mask if there
139
+ # are actually masked tokens (zeros), otherwise return None for efficiency
140
+ if getattr(self.config, "_attn_implementation", None) == "flash_attention_2":
141
+ has_masked_tokens = (attention_mask == 0).any()
142
+ return attention_mask if has_masked_tokens else None
143
 
144
+ return _prepare_4d_attention_mask(attention_mask, input_embeds.dtype)
145
+
146
+ def forward(
147
+ self,
148
+ input_ids: torch.LongTensor | None = None,
149
+ attention_mask: torch.Tensor | None = None,
150
+ position_ids: torch.LongTensor | None = None,
151
+ past_key_values: Cache | None = None,
152
+ inputs_embeds: torch.FloatTensor | None = None,
153
+ cache_position: torch.LongTensor | None = None,
154
+ use_cache: bool | None = None,
155
+ **kwargs,
156
+ ) -> BaseModelOutputWithPast:
157
+ """
158
+ Forward pass with bidirectional attention.
159
+
160
+ Args:
161
+ input_ids: Input token IDs of shape (batch_size, seq_len)
162
+ attention_mask: Attention mask of shape (batch_size, seq_len)
163
+ position_ids: Position IDs for rotary embeddings
164
+ past_key_values: Cached key/value states for incremental decoding
165
+ inputs_embeds: Pre-computed input embeddings (alternative to input_ids)
166
+ cache_position: Position indices for cache updates
167
+ use_cache: Whether to return cached key/value states
168
+ **kwargs: Additional arguments passed to decoder layers
169
+
170
+ Returns:
171
+ BaseModelOutputWithPast containing last_hidden_state and past_key_values
172
+ """
173
+ if (input_ids is None) ^ (inputs_embeds is not None):
174
+ raise ValueError(
175
+ "You must specify exactly one of input_ids or inputs_embeds"
176
+ )
177
+
178
+ if inputs_embeds is None:
179
+ inputs_embeds = self.embed_tokens(input_ids)
180
+
181
+ # Initialize cache if needed
182
+ if use_cache and past_key_values is None:
183
+ if _DYNAMIC_CACHE_ACCEPTS_CONFIG:
184
+ past_key_values = DynamicCache(config=self.config)
185
+ else:
186
+ past_key_values = DynamicCache()
187
+
188
+ if cache_position is None:
189
+ past_seen_tokens = (
190
+ past_key_values.get_seq_length() if past_key_values is not None else 0
191
+ )
192
+ cache_position = torch.arange(
193
+ past_seen_tokens,
194
+ past_seen_tokens + inputs_embeds.shape[1],
195
+ device=inputs_embeds.device,
196
+ )
197
+
198
+ if position_ids is None:
199
+ position_ids = cache_position.unsqueeze(0)
200
+
201
+ bidirectional_mask = self._create_bidirectional_mask(
202
+ inputs_embeds, attention_mask
203
+ )
204
+
205
+ hidden_states = inputs_embeds
206
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
207
+
208
+ # Build decoder layer kwargs with correct cache parameter name
209
+ # (past_key_value in < 4.56, past_key_values in >= 4.56)
210
+ layer_kwargs = {
211
+ "attention_mask": bidirectional_mask,
212
+ "position_ids": position_ids,
213
+ "use_cache": use_cache,
214
+ "cache_position": cache_position,
215
+ "position_embeddings": position_embeddings,
216
+ }
217
+ if _USE_PLURAL_CACHE_PARAM:
218
+ layer_kwargs["past_key_values"] = past_key_values
219
+ else:
220
+ layer_kwargs["past_key_value"] = past_key_values
221
+
222
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
223
+ layer_outputs = decoder_layer(hidden_states, **layer_kwargs)
224
+
225
+ # Decoder returns tuple in < 4.54, tensor in >= 4.54
226
+ if isinstance(layer_outputs, tuple):
227
+ hidden_states = layer_outputs[0]
228
+ else:
229
+ hidden_states = layer_outputs
230
+
231
+ hidden_states = self.norm(hidden_states)
232
+
233
+ return BaseModelOutputWithPast(
234
+ last_hidden_state=hidden_states,
235
+ past_key_values=past_key_values,
236
+ )
237
+
238
+
239
+ def pool(
240
+ last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str
241
+ ) -> torch.Tensor:
242
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
243
 
244
  if pool_type == "avg":
 
263
  return emb
264
 
265
 
266
+ class LlamaBidirectionalForSequenceClassification(LlamaPreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  config_class = LlamaBidirectionalConfig
268
 
269
  def __init__(self, config):
270
  super().__init__(config)
271
+ self.num_labels = config.num_labels
272
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
 
 
273
  self.model = LlamaBidirectionalModel(config)
274
 
275
  # Initialize weights and apply final processing
 
287
  output_attentions: Optional[bool] = None,
288
  output_hidden_states: Optional[bool] = None,
289
  return_dict: Optional[bool] = None,
290
+ **kwargs,
291
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
292
  r"""
293
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
309
  output_attentions=output_attentions,
310
  output_hidden_states=output_hidden_states,
311
  return_dict=return_dict,
312
+ **kwargs,
313
  )
314
  hidden_states = transformer_outputs[0]
315
 
 
324
 
325
  loss = None
326
  if labels is not None:
327
+ labels = labels.to(pooled_logits.device)
328
  if self.config.problem_type is None:
329
  if self.num_labels == 1:
330
  self.config.problem_type = "regression"