lkhphuc commited on
Commit
1f8827e
·
1 Parent(s): def4178

Add fused weights

Browse files
Files changed (3) hide show
  1. attention.py +33 -1
  2. model.safetensors +2 -2
  3. modeling_falcon_perception.py +15 -61
attention.py CHANGED
@@ -1,9 +1,12 @@
1
  import torch
2
  from torch import Tensor as T
3
  from torch.nn.attention.flex_attention import (
 
4
  _mask_mod_signature,
 
5
  create_block_mask,
6
  flex_attention,
 
7
  )
8
 
9
  # ---------------------------------------------------------------------------
@@ -106,9 +109,38 @@ _compiled_create_block_mask = torch.compile(
106
 
107
 
108
  @torch.inference_mode()
109
- def create_attention_mask(*args, **kwargs):
110
  """
111
  NOTE: We compile this for performance/memory reasons in large masks. To reduce
112
  recompiles due to grad_mode flips, we always run mask creation under inference_mode.
113
  """
114
  return _compiled_create_block_mask(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import Tensor as T
3
  from torch.nn.attention.flex_attention import (
4
+ BlockMask,
5
  _mask_mod_signature,
6
+ and_masks,
7
  create_block_mask,
8
  flex_attention,
9
+ or_masks,
10
  )
11
 
12
  # ---------------------------------------------------------------------------
 
109
 
110
 
111
  @torch.inference_mode()
112
+ def create_attention_mask(*args, **kwargs) -> BlockMask:
113
  """
114
  NOTE: We compile this for performance/memory reasons in large masks. To reduce
115
  recompiles due to grad_mode flips, we always run mask creation under inference_mode.
116
  """
117
  return _compiled_create_block_mask(*args, **kwargs)
118
+
119
+
120
+ def create_batch_attention_mask(
121
+ input_batch: T,
122
+ *,
123
+ pad_token_id: int,
124
+ eos_token_id: int,
125
+ soi_token_id: int,
126
+ eoi_token_id: int,
127
+ max_len: int | None = None,
128
+ ) -> BlockMask:
129
+ """Build the combined FlexAttention mask for the batch engine.
130
+
131
+ Composes causal + document + non-left-pad + image-prefix masks.
132
+ """
133
+ B, S = input_batch.size()
134
+ block_causal_mask_mod = and_masks(
135
+ get_causal_mask_mod(),
136
+ get_document_mask_mod(input_batch, eos_token_id),
137
+ get_non_left_pad_mask_mod(input_batch, pad_token_id),
138
+ )
139
+ image_prefix_mask_mod = get_image_prefix_mask_mod(
140
+ batch=input_batch,
141
+ soi_id=soi_token_id,
142
+ eoi_id=eoi_token_id,
143
+ )
144
+ mask_mod = or_masks(image_prefix_mask_mod, block_causal_mask_mod)
145
+ max_len = max_len or S
146
+ return create_attention_mask(mask_mod, B, None, max_len, max_len)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d842c991349d997852c99ebf0dc6d368fc70b73c658d117a3088132a6cbb68ca
3
- size 2529523048
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c680d5b1a834a2df61baae9dd694b95856872609d0a769702faf5ad658297641
3
+ size 2529514176
modeling_falcon_perception.py CHANGED
@@ -1,5 +1,4 @@
1
  import math
2
- import time
3
  from pathlib import Path
4
 
5
  import einops as E
@@ -15,8 +14,6 @@ from torch import nn
15
  from torch.nn.attention.flex_attention import (
16
  AuxRequest,
17
  BlockMask,
18
- and_masks,
19
- or_masks,
20
  )
21
  from transformers import AutoTokenizer, PreTrainedModel
22
 
@@ -25,10 +22,7 @@ from .attention import (
25
  compiled_flex_attn_decode,
26
  compiled_flex_attn_prefill,
27
  create_attention_mask,
28
- get_causal_mask_mod,
29
- get_document_mask_mod,
30
- get_image_prefix_mask_mod,
31
- get_non_left_pad_mask_mod,
32
  offset_mask_mod,
33
  )
34
  from .configuration_falcon_perception import FalconPerceptionConfig
@@ -99,19 +93,12 @@ class Attention(nn.Module):
99
  self.q_dim = config.n_heads * self.head_dim
100
  self.kv_dim = self.n_kv_heads * self.head_dim
101
 
102
- self.wq = nn.Linear(config.dim, self.q_dim, bias=False)
103
- self.wk = nn.Linear(config.dim, self.kv_dim, bias=False)
104
- self.wv = nn.Linear(config.dim, self.kv_dim, bias=False)
105
  self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
106
  self.sinks = nn.Parameter(torch.empty((config.n_heads,)))
107
 
108
- def _fuse_weights(self):
109
- wqkv_weight = torch.cat([self.wq.weight.data, self.wk.weight.data, self.wv.weight.data], dim=0)
110
- self.register_buffer("_wqkv_weight", wqkv_weight)
111
- del self.wq, self.wk, self.wv
112
-
113
  def _pre_attention_qkv(self, x) -> tuple[T, T, T]:
114
- qkv = F.linear(F.rms_norm(x, (x.size(-1),)), self._wqkv_weight)
115
  xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
116
  xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim)
117
  xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim)
@@ -195,27 +182,13 @@ def squared_relu_gate(packed: T, hidden_dim: int) -> T:
195
  class FeedForward(nn.Module):
196
  def __init__(self, dim: int, hidden_dim: int):
197
  super().__init__()
198
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
199
  self.w2 = nn.Linear(hidden_dim, dim, bias=False)
200
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
201
  self.hidden_dim = hidden_dim
202
 
203
- def _fuse_weights(self):
204
- if hasattr(self, "_w13_weight"):
205
- return
206
- w1_weight_fused = self.w1.weight.data * math.sqrt(2.0)
207
- w13_weight = torch.empty(
208
- (2 * self.hidden_dim, self.w1.weight.shape[1]),
209
- device=w1_weight_fused.device, dtype=w1_weight_fused.dtype,
210
- )
211
- w13_weight[0::2] = w1_weight_fused
212
- w13_weight[1::2] = self.w3.weight.data
213
- self.register_buffer("_w13_weight", w13_weight)
214
- del self.w1, self.w3
215
-
216
  def forward(self, x: torch.Tensor) -> torch.Tensor:
217
  x = F.rms_norm(x, (x.size(-1),))
218
- w13_out = F.linear(x, self._w13_weight)
219
  return self.w2(squared_relu_gate(w13_out, self.hidden_dim))
220
 
221
 
@@ -357,31 +330,17 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
357
 
358
  # -- Weight management ---------------------------------------------------
359
 
360
- def _fuse_weights(self):
 
361
  if self._weights_fused:
362
  return
363
-
364
  device = self.tok_embeddings.weight.device
365
  c = self.config
366
-
367
- # Recompute freqs_cis on the actual device — non-persistent buffers
368
- # get replaced with empty tensors by transformers' meta-device loading.
369
  rope_dim = c.head_dim // 2
370
  freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device)
371
  self.register_buffer("freqs_cis", freqs_cis, persistent=False)
372
-
373
- # Ensure freqs_cis_golden is on the right device (loaded from safetensors)
374
  if self.freqs_cis_golden.device != device:
375
  self.freqs_cis_golden = self.freqs_cis_golden.to(device)
376
-
377
- for layer in self.layers.values():
378
- layer.attention._fuse_weights()
379
- layer.feed_forward._fuse_weights()
380
- self.coord_decoder.w1.weight.mul_(math.sqrt(2))
381
- self.size_decoder.w1.weight.mul_(math.sqrt(2))
382
- if self.config.do_segmentation:
383
- for layer in self.proj_segm.layers:
384
- layer.weight.mul_(math.sqrt(2))
385
  self._weights_fused = True
386
 
387
  def compile_model(self):
@@ -418,19 +377,14 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
418
  # -- Attention mask ------------------------------------------------------
419
 
420
  def get_attention_mask(self, input_batch: T, max_len: int | None = None):
421
- B, S = input_batch.size()
422
- c = self.config
423
- block_causal_mask_mod = and_masks(
424
- get_causal_mask_mod(),
425
- get_document_mask_mod(input_batch, c.eos_id),
426
- get_non_left_pad_mask_mod(input_batch, self._pad_token_id),
427
- )
428
- image_prefix_mask_mod = get_image_prefix_mask_mod(
429
- batch=input_batch, soi_id=c.image_cls_token_id, eoi_id=c.img_end_id,
430
  )
431
- mask_mod = or_masks(image_prefix_mask_mod, block_causal_mask_mod)
432
- max_len = max_len or S
433
- return create_attention_mask(mask_mod, B, None, max_len, max_len)
434
 
435
  def get_upsampler_attn_mask(self, H, W, h, w, device):
436
  return create_attention_mask(
@@ -699,7 +653,7 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
699
  "mask_rle": {"counts": str, "size": [H, W]},
700
  }
701
  """
702
- self._fuse_weights()
703
  if compile:
704
  self.compile_model()
705
 
 
1
  import math
 
2
  from pathlib import Path
3
 
4
  import einops as E
 
14
  from torch.nn.attention.flex_attention import (
15
  AuxRequest,
16
  BlockMask,
 
 
17
  )
18
  from transformers import AutoTokenizer, PreTrainedModel
19
 
 
22
  compiled_flex_attn_decode,
23
  compiled_flex_attn_prefill,
24
  create_attention_mask,
25
+ create_batch_attention_mask,
 
 
 
26
  offset_mask_mod,
27
  )
28
  from .configuration_falcon_perception import FalconPerceptionConfig
 
93
  self.q_dim = config.n_heads * self.head_dim
94
  self.kv_dim = self.n_kv_heads * self.head_dim
95
 
96
+ self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False)
 
 
97
  self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
98
  self.sinks = nn.Parameter(torch.empty((config.n_heads,)))
99
 
 
 
 
 
 
100
  def _pre_attention_qkv(self, x) -> tuple[T, T, T]:
101
+ qkv = self.wqkv(F.rms_norm(x, (x.size(-1),)))
102
  xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
103
  xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim)
104
  xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim)
 
182
  class FeedForward(nn.Module):
183
  def __init__(self, dim: int, hidden_dim: int):
184
  super().__init__()
185
+ self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
186
  self.w2 = nn.Linear(hidden_dim, dim, bias=False)
 
187
  self.hidden_dim = hidden_dim
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def forward(self, x: torch.Tensor) -> torch.Tensor:
190
  x = F.rms_norm(x, (x.size(-1),))
191
+ w13_out = self.w13(x)
192
  return self.w2(squared_relu_gate(w13_out, self.hidden_dim))
193
 
194
 
 
330
 
331
  # -- Weight management ---------------------------------------------------
332
 
333
+ def _ensure_device_buffers(self):
334
+ """Recompute non-persistent buffers that HF meta-device loading may discard."""
335
  if self._weights_fused:
336
  return
 
337
  device = self.tok_embeddings.weight.device
338
  c = self.config
 
 
 
339
  rope_dim = c.head_dim // 2
340
  freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device)
341
  self.register_buffer("freqs_cis", freqs_cis, persistent=False)
 
 
342
  if self.freqs_cis_golden.device != device:
343
  self.freqs_cis_golden = self.freqs_cis_golden.to(device)
 
 
 
 
 
 
 
 
 
344
  self._weights_fused = True
345
 
346
  def compile_model(self):
 
377
  # -- Attention mask ------------------------------------------------------
378
 
379
  def get_attention_mask(self, input_batch: T, max_len: int | None = None):
380
+ return create_batch_attention_mask(
381
+ input_batch,
382
+ pad_token_id=self._pad_token_id,
383
+ eos_token_id=self.config.eos_id,
384
+ soi_token_id=self.config.image_cls_token_id,
385
+ eoi_token_id=self.config.img_end_id,
386
+ max_len=max_len,
 
 
387
  )
 
 
 
388
 
389
  def get_upsampler_attn_mask(self, H, W, h, w, device):
390
  return create_attention_mask(
 
653
  "mask_rle": {"counts": str, "size": [H, W]},
654
  }
655
  """
656
+ self._ensure_device_buffers()
657
  if compile:
658
  self.compile_model()
659