VisionLanguageGroup commited on
Commit
02e04fb
·
1 Parent(s): f10f497
_utils/attn_utils.py DELETED
@@ -1,592 +0,0 @@
1
- import abc
2
-
3
- import cv2
4
- import numpy as np
5
- import torch
6
- from IPython.display import display
7
- from PIL import Image
8
- from typing import Union, Tuple, List
9
- from einops import rearrange, repeat
10
- import math
11
- from torch import nn, einsum
12
- from inspect import isfunction
13
- from diffusers.utils import logging
14
- try:
15
- from diffusers.models.unet_2d_condition import UNet2DConditionOutput
16
- except:
17
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
18
-
19
- try:
20
- from diffusers.models.cross_attention import CrossAttention
21
- except:
22
- from diffusers.models.attention_processor import Attention as CrossAttention
23
-
24
- MAX_NUM_WORDS = 77
25
- LOW_RESOURCE = False
26
-
27
- class CountingCrossAttnProcessor1:
28
-
29
- def __init__(self, attnstore, place_in_unet):
30
- super().__init__()
31
- self.attnstore = attnstore
32
- self.place_in_unet = place_in_unet
33
-
34
- def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
35
- batch_size, sequence_length, dim = hidden_states.shape
36
- h = attn_layer.heads
37
- q = attn_layer.to_q(hidden_states)
38
- is_cross = encoder_hidden_states is not None
39
- context = encoder_hidden_states if is_cross else hidden_states
40
- k = attn_layer.to_k(context)
41
- v = attn_layer.to_v(context)
42
- # q = attn_layer.reshape_heads_to_batch_dim(q)
43
- # k = attn_layer.reshape_heads_to_batch_dim(k)
44
- # v = attn_layer.reshape_heads_to_batch_dim(v)
45
- # q = attn_layer.head_to_batch_dim(q)
46
- # k = attn_layer.head_to_batch_dim(k)
47
- # v = attn_layer.head_to_batch_dim(v)
48
- q = self.head_to_batch_dim(q, h)
49
- k = self.head_to_batch_dim(k, h)
50
- v = self.head_to_batch_dim(v, h)
51
-
52
- sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
53
-
54
- if attention_mask is not None:
55
- attention_mask = attention_mask.reshape(batch_size, -1)
56
- max_neg_value = -torch.finfo(sim.dtype).max
57
- attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
58
- sim.masked_fill_(~attention_mask, max_neg_value)
59
-
60
- # attention, what we cannot get enough of
61
- attn_ = sim.softmax(dim=-1).clone()
62
- # softmax = nn.Softmax(dim=-1)
63
- # attn_ = softmax(sim)
64
- self.attnstore(attn_, is_cross, self.place_in_unet)
65
- out = torch.einsum("b i j, b j d -> b i d", attn_, v)
66
- # out = attn_layer.batch_to_head_dim(out)
67
- out = self.batch_to_head_dim(out, h)
68
-
69
- if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
70
- to_out = attn_layer.to_out[0]
71
- else:
72
- to_out = attn_layer.to_out
73
-
74
- out = to_out(out)
75
- return out
76
-
77
- def batch_to_head_dim(self, tensor, head_size):
78
- # head_size = self.heads
79
- batch_size, seq_len, dim = tensor.shape
80
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
81
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
82
- return tensor
83
-
84
- def head_to_batch_dim(self, tensor, head_size, out_dim=3):
85
- # head_size = self.heads
86
- batch_size, seq_len, dim = tensor.shape
87
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
88
- tensor = tensor.permute(0, 2, 1, 3)
89
-
90
- if out_dim == 3:
91
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
92
-
93
- return tensor
94
-
95
-
96
- def register_attention_control(model, controller):
97
-
98
- attn_procs = {}
99
- cross_att_count = 0
100
- for name in model.unet.attn_processors.keys():
101
- cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
102
- if name.startswith("mid_block"):
103
- hidden_size = model.unet.config.block_out_channels[-1]
104
- place_in_unet = "mid"
105
- elif name.startswith("up_blocks"):
106
- block_id = int(name[len("up_blocks.")])
107
- hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
108
- place_in_unet = "up"
109
- elif name.startswith("down_blocks"):
110
- block_id = int(name[len("down_blocks.")])
111
- hidden_size = model.unet.config.block_out_channels[block_id]
112
- place_in_unet = "down"
113
- else:
114
- continue
115
-
116
- cross_att_count += 1
117
- # attn_procs[name] = AttendExciteCrossAttnProcessor(
118
- # attnstore=controller, place_in_unet=place_in_unet
119
- # )
120
- attn_procs[name] = CountingCrossAttnProcessor1(
121
- attnstore=controller, place_in_unet=place_in_unet
122
- )
123
-
124
- model.unet.set_attn_processor(attn_procs)
125
- controller.num_att_layers = cross_att_count
126
-
127
- def register_hier_output(model):
128
- self = model.unet
129
- from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
130
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
131
- def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
132
- attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
133
- mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
134
-
135
- out_list = []
136
-
137
-
138
- default_overall_up_factor = 2**self.num_upsamplers
139
-
140
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
141
- forward_upsample_size = False
142
- upsample_size = None
143
-
144
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
145
- logger.info("Forward upsample size to force interpolation output size.")
146
- forward_upsample_size = True
147
-
148
- if attention_mask is not None:
149
- # assume that mask is expressed as:
150
- # (1 = keep, 0 = discard)
151
- # convert mask into a bias that can be added to attention scores:
152
- # (keep = +0, discard = -10000.0)
153
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
154
- attention_mask = attention_mask.unsqueeze(1)
155
-
156
- if encoder_attention_mask is not None:
157
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
158
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
159
-
160
- if self.config.center_input_sample:
161
- sample = 2 * sample - 1.0
162
-
163
- timesteps = timestep
164
- if not torch.is_tensor(timesteps):
165
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
166
- # This would be a good case for the `match` statement (Python 3.10+)
167
- is_mps = sample.device.type == "mps"
168
- if isinstance(timestep, float):
169
- dtype = torch.float32 if is_mps else torch.float64
170
- else:
171
- dtype = torch.int32 if is_mps else torch.int64
172
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
173
- elif len(timesteps.shape) == 0:
174
- timesteps = timesteps[None].to(sample.device)
175
-
176
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
177
- timesteps = timesteps.expand(sample.shape[0])
178
-
179
- t_emb = self.time_proj(timesteps)
180
-
181
- t_emb = t_emb.to(dtype=sample.dtype)
182
-
183
- emb = self.time_embedding(t_emb, timestep_cond)
184
- aug_emb = None
185
-
186
- if self.class_embedding is not None:
187
- if class_labels is None:
188
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
189
-
190
- if self.config.class_embed_type == "timestep":
191
- class_labels = self.time_proj(class_labels)
192
-
193
- # `Timesteps` does not contain any weights and will always return f32 tensors
194
- # there might be better ways to encapsulate this.
195
- class_labels = class_labels.to(dtype=sample.dtype)
196
-
197
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
198
-
199
- if self.config.class_embeddings_concat:
200
- emb = torch.cat([emb, class_emb], dim=-1)
201
- else:
202
- emb = emb + class_emb
203
-
204
- if self.config.addition_embed_type == "text":
205
- aug_emb = self.add_embedding(encoder_hidden_states)
206
- elif self.config.addition_embed_type == "text_image":
207
- # Kandinsky 2.1 - style
208
- if "image_embeds" not in added_cond_kwargs:
209
- raise ValueError(
210
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
211
- )
212
-
213
- image_embs = added_cond_kwargs.get("image_embeds")
214
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
215
- aug_emb = self.add_embedding(text_embs, image_embs)
216
- elif self.config.addition_embed_type == "text_time":
217
- # SDXL - style
218
- if "text_embeds" not in added_cond_kwargs:
219
- raise ValueError(
220
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
221
- )
222
- text_embeds = added_cond_kwargs.get("text_embeds")
223
- if "time_ids" not in added_cond_kwargs:
224
- raise ValueError(
225
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
226
- )
227
- time_ids = added_cond_kwargs.get("time_ids")
228
- time_embeds = self.add_time_proj(time_ids.flatten())
229
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
230
-
231
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
232
- add_embeds = add_embeds.to(emb.dtype)
233
- aug_emb = self.add_embedding(add_embeds)
234
- elif self.config.addition_embed_type == "image":
235
- # Kandinsky 2.2 - style
236
- if "image_embeds" not in added_cond_kwargs:
237
- raise ValueError(
238
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
239
- )
240
- image_embs = added_cond_kwargs.get("image_embeds")
241
- aug_emb = self.add_embedding(image_embs)
242
- elif self.config.addition_embed_type == "image_hint":
243
- # Kandinsky 2.2 - style
244
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
245
- raise ValueError(
246
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
247
- )
248
- image_embs = added_cond_kwargs.get("image_embeds")
249
- hint = added_cond_kwargs.get("hint")
250
- aug_emb, hint = self.add_embedding(image_embs, hint)
251
- sample = torch.cat([sample, hint], dim=1)
252
-
253
- emb = emb + aug_emb if aug_emb is not None else emb
254
-
255
- if self.time_embed_act is not None:
256
- emb = self.time_embed_act(emb)
257
-
258
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
259
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
260
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
261
- # Kadinsky 2.1 - style
262
- if "image_embeds" not in added_cond_kwargs:
263
- raise ValueError(
264
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
265
- )
266
-
267
- image_embeds = added_cond_kwargs.get("image_embeds")
268
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
269
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
270
- # Kandinsky 2.2 - style
271
- if "image_embeds" not in added_cond_kwargs:
272
- raise ValueError(
273
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
274
- )
275
- image_embeds = added_cond_kwargs.get("image_embeds")
276
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
277
- # 2. pre-process
278
- sample = self.conv_in(sample) # 1, 320, 64, 64
279
-
280
- # 2.5 GLIGEN position net
281
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
282
- cross_attention_kwargs = cross_attention_kwargs.copy()
283
- gligen_args = cross_attention_kwargs.pop("gligen")
284
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
285
-
286
- # 3. down
287
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
288
-
289
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
290
- is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
291
-
292
- down_block_res_samples = (sample,)
293
-
294
- for downsample_block in self.down_blocks:
295
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
296
- # For t2i-adapter CrossAttnDownBlock2D
297
- additional_residuals = {}
298
- if is_adapter and len(down_block_additional_residuals) > 0:
299
- additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
300
-
301
- sample, res_samples = downsample_block(
302
- hidden_states=sample,
303
- temb=emb,
304
- encoder_hidden_states=encoder_hidden_states,
305
- attention_mask=attention_mask,
306
- cross_attention_kwargs=cross_attention_kwargs,
307
- encoder_attention_mask=encoder_attention_mask,
308
- **additional_residuals,
309
- )
310
- else:
311
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
312
-
313
- if is_adapter and len(down_block_additional_residuals) > 0:
314
- sample += down_block_additional_residuals.pop(0)
315
-
316
- down_block_res_samples += res_samples
317
-
318
- if is_controlnet:
319
- new_down_block_res_samples = ()
320
-
321
- for down_block_res_sample, down_block_additional_residual in zip(
322
- down_block_res_samples, down_block_additional_residuals
323
- ):
324
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
325
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
326
-
327
- down_block_res_samples = new_down_block_res_samples
328
-
329
- # 4. mid
330
- if self.mid_block is not None:
331
- sample = self.mid_block(
332
- sample,
333
- emb,
334
- encoder_hidden_states=encoder_hidden_states,
335
- attention_mask=attention_mask,
336
- cross_attention_kwargs=cross_attention_kwargs,
337
- encoder_attention_mask=encoder_attention_mask,
338
- )
339
- # To support T2I-Adapter-XL
340
- if (
341
- is_adapter
342
- and len(down_block_additional_residuals) > 0
343
- and sample.shape == down_block_additional_residuals[0].shape
344
- ):
345
- sample += down_block_additional_residuals.pop(0)
346
-
347
- if is_controlnet:
348
- sample = sample + mid_block_additional_residual
349
-
350
- # 5. up
351
- for i, upsample_block in enumerate(self.up_blocks):
352
- is_final_block = i == len(self.up_blocks) - 1
353
-
354
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
355
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
356
-
357
- # if we have not reached the final block and need to forward the
358
- # upsample size, we do it here
359
- if not is_final_block and forward_upsample_size:
360
- upsample_size = down_block_res_samples[-1].shape[2:]
361
-
362
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
363
- sample = upsample_block(
364
- hidden_states=sample,
365
- temb=emb,
366
- res_hidden_states_tuple=res_samples,
367
- encoder_hidden_states=encoder_hidden_states,
368
- cross_attention_kwargs=cross_attention_kwargs,
369
- upsample_size=upsample_size,
370
- attention_mask=attention_mask,
371
- encoder_attention_mask=encoder_attention_mask,
372
- )
373
- else:
374
- sample = upsample_block(
375
- hidden_states=sample,
376
- temb=emb,
377
- res_hidden_states_tuple=res_samples,
378
- upsample_size=upsample_size,
379
- scale=lora_scale,
380
- )
381
-
382
- # if i in [1, 4, 7]:
383
- out_list.append(sample)
384
-
385
- # 6. post-process
386
- if self.conv_norm_out:
387
- sample = self.conv_norm_out(sample)
388
- sample = self.conv_act(sample)
389
- sample = self.conv_out(sample)
390
-
391
- if not return_dict:
392
- return (sample,)
393
-
394
- return UNet2DConditionOutput(sample=sample), out_list
395
-
396
- self.forward = forward
397
-
398
-
399
- class AttentionControl(abc.ABC):
400
-
401
- def step_callback(self, x_t):
402
- return x_t
403
-
404
- def between_steps(self):
405
- return
406
-
407
- @property
408
- def num_uncond_att_layers(self):
409
- return 0
410
-
411
- @abc.abstractmethod
412
- def forward(self, attn, is_cross: bool, place_in_unet: str):
413
- raise NotImplementedError
414
-
415
- def __call__(self, attn, is_cross: bool, place_in_unet: str):
416
- if self.cur_att_layer >= self.num_uncond_att_layers:
417
- # self.forward(attn, is_cross, place_in_unet)
418
- if LOW_RESOURCE:
419
- attn = self.forward(attn, is_cross, place_in_unet)
420
- else:
421
- h = attn.shape[0]
422
- attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
423
- self.cur_att_layer += 1
424
- if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
425
- self.cur_att_layer = 0
426
- self.cur_step += 1
427
- self.between_steps()
428
- return attn
429
-
430
- def reset(self):
431
- self.cur_step = 0
432
- self.cur_att_layer = 0
433
-
434
- def __init__(self):
435
- self.cur_step = 0
436
- self.num_att_layers = -1
437
- self.cur_att_layer = 0
438
-
439
-
440
- class EmptyControl(AttentionControl):
441
-
442
- def forward(self, attn, is_cross: bool, place_in_unet: str):
443
- return attn
444
-
445
-
446
- class AttentionStore(AttentionControl):
447
-
448
- @staticmethod
449
- def get_empty_store():
450
- return {"down_cross": [], "mid_cross": [], "up_cross": [],
451
- "down_self": [], "mid_self": [], "up_self": []}
452
-
453
- def forward(self, attn, is_cross: bool, place_in_unet: str):
454
- key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
455
- if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
456
- self.step_store[key].append(attn)
457
- return attn
458
-
459
- def between_steps(self):
460
- self.attention_store = self.step_store
461
- if self.save_global_store:
462
- with torch.no_grad():
463
- if len(self.global_store) == 0:
464
- self.global_store = self.step_store
465
- else:
466
- for key in self.global_store:
467
- for i in range(len(self.global_store[key])):
468
- self.global_store[key][i] += self.step_store[key][i].detach()
469
- self.step_store = self.get_empty_store()
470
- self.step_store = self.get_empty_store()
471
-
472
- def get_average_attention(self):
473
- average_attention = self.attention_store
474
- return average_attention
475
-
476
- def get_average_global_attention(self):
477
- average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
478
- self.attention_store}
479
- return average_attention
480
-
481
- def reset(self):
482
- super(AttentionStore, self).reset()
483
- self.step_store = self.get_empty_store()
484
- self.attention_store = {}
485
- self.global_store = {}
486
-
487
- def __init__(self, max_size=32, save_global_store=False):
488
- '''
489
- Initialize an empty AttentionStore
490
- :param step_index: used to visualize only a specific step in the diffusion process
491
- '''
492
- super(AttentionStore, self).__init__()
493
- self.save_global_store = save_global_store
494
- self.max_size = max_size
495
- self.step_store = self.get_empty_store()
496
- self.attention_store = {}
497
- self.global_store = {}
498
- self.curr_step_index = 0
499
-
500
- def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
501
- out = []
502
- attention_maps = attention_store.get_average_attention()
503
- num_pixels = res ** 2
504
- for location in from_where:
505
- for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
506
- if item.shape[1] == num_pixels:
507
- cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
508
- out.append(cross_maps)
509
- out = torch.cat(out, dim=0)
510
- out = out.sum(0) / out.shape[0]
511
- return out
512
-
513
-
514
- def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
515
- tokens = tokenizer.encode(prompts[select])
516
- decoder = tokenizer.decode
517
- attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
518
- images = []
519
- for i in range(len(tokens)):
520
- image = attention_maps[:, :, i]
521
- image = 255 * image / image.max()
522
- image = image.unsqueeze(-1).expand(*image.shape, 3)
523
- image = image.numpy().astype(np.uint8)
524
- image = np.array(Image.fromarray(image).resize((256, 256)))
525
- image = text_under_image(image, decoder(int(tokens[i])))
526
- images.append(image)
527
- view_images(np.stack(images, axis=0))
528
-
529
-
530
- def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
531
- max_com=10, select: int = 0):
532
- attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
533
- u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
534
- images = []
535
- for i in range(max_com):
536
- image = vh[i].reshape(res, res)
537
- image = image - image.min()
538
- image = 255 * image / image.max()
539
- image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
540
- image = Image.fromarray(image).resize((256, 256))
541
- image = np.array(image)
542
- images.append(image)
543
- view_images(np.concatenate(images, axis=1))
544
-
545
- def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
546
- h, w, c = image.shape
547
- offset = int(h * .2)
548
- img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
549
- font = cv2.FONT_HERSHEY_SIMPLEX
550
- # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
551
- img[:h] = image
552
- textsize = cv2.getTextSize(text, font, 1, 2)[0]
553
- text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
554
- cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
555
- return img
556
-
557
-
558
- def view_images(images, num_rows=1, offset_ratio=0.02):
559
- if type(images) is list:
560
- num_empty = len(images) % num_rows
561
- elif images.ndim == 4:
562
- num_empty = images.shape[0] % num_rows
563
- else:
564
- images = [images]
565
- num_empty = 0
566
-
567
- empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
568
- images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
569
- num_items = len(images)
570
-
571
- h, w, c = images[0].shape
572
- offset = int(h * offset_ratio)
573
- num_cols = num_items // num_rows
574
- image_ = np.ones((h * num_rows + offset * (num_rows - 1),
575
- w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
576
- for i in range(num_rows):
577
- for j in range(num_cols):
578
- image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
579
- i * num_cols + j]
580
-
581
- pil_img = Image.fromarray(image_)
582
- display(pil_img)
583
-
584
- def self_cross_attn(self_attn, cross_attn):
585
- res = self_attn.shape[0]
586
- assert res == cross_attn.shape[0]
587
- # cross attn [res, res] -> [res*res]
588
- cross_attn_ = cross_attn.reshape([res*res])
589
- # self_attn [res, res, res*res]
590
- self_cross_attn = cross_attn_ * self_attn
591
- self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
592
- return self_cross_attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
_utils/attn_utils_new.py CHANGED
@@ -19,7 +19,6 @@ try:
19
  from diffusers.models.cross_attention import CrossAttention
20
  except:
21
  from diffusers.models.attention_processor import Attention as CrossAttention
22
- from typing import Any, Dict, List, Optional, Tuple, Union
23
  MAX_NUM_WORDS = 77
24
  LOW_RESOURCE = False
25
 
@@ -512,91 +511,7 @@ def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from
512
  out = out.sum(0) / out.shape[0]
513
  return out
514
 
515
- def aggregate_attention1(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
516
- out = []
517
- attention_maps = attention_store.get_average_attention()
518
- num_pixels = res ** 2
519
- for location in from_where:
520
- for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
521
- if item.shape[1] == num_pixels:
522
- cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
523
- out.append(cross_maps)
524
- # out = torch.cat(out, dim=0)
525
- # out = out.sum(0) / out.shape[0]
526
- out = out[1]
527
- out = out.sum(0) / out.shape[0]
528
- return out
529
-
530
-
531
- def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
532
- tokens = tokenizer.encode(prompts[select])
533
- decoder = tokenizer.decode
534
- attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
535
- images = []
536
- for i in range(len(tokens)):
537
- image = attention_maps[:, :, i]
538
- image = 255 * image / image.max()
539
- image = image.unsqueeze(-1).expand(*image.shape, 3)
540
- image = image.numpy().astype(np.uint8)
541
- image = np.array(Image.fromarray(image).resize((256, 256)))
542
- image = text_under_image(image, decoder(int(tokens[i])))
543
- images.append(image)
544
- view_images(np.stack(images, axis=0))
545
-
546
 
547
- def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
548
- max_com=10, select: int = 0):
549
- attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
550
- u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
551
- images = []
552
- for i in range(max_com):
553
- image = vh[i].reshape(res, res)
554
- image = image - image.min()
555
- image = 255 * image / image.max()
556
- image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
557
- image = Image.fromarray(image).resize((256, 256))
558
- image = np.array(image)
559
- images.append(image)
560
- view_images(np.concatenate(images, axis=1))
561
-
562
- def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
563
- h, w, c = image.shape
564
- offset = int(h * .2)
565
- img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
566
- font = cv2.FONT_HERSHEY_SIMPLEX
567
- # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
568
- img[:h] = image
569
- textsize = cv2.getTextSize(text, font, 1, 2)[0]
570
- text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
571
- cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
572
- return img
573
-
574
-
575
- def view_images(images, num_rows=1, offset_ratio=0.02):
576
- if type(images) is list:
577
- num_empty = len(images) % num_rows
578
- elif images.ndim == 4:
579
- num_empty = images.shape[0] % num_rows
580
- else:
581
- images = [images]
582
- num_empty = 0
583
-
584
- empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
585
- images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
586
- num_items = len(images)
587
-
588
- h, w, c = images[0].shape
589
- offset = int(h * offset_ratio)
590
- num_cols = num_items // num_rows
591
- image_ = np.ones((h * num_rows + offset * (num_rows - 1),
592
- w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
593
- for i in range(num_rows):
594
- for j in range(num_cols):
595
- image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
596
- i * num_cols + j]
597
-
598
- pil_img = Image.fromarray(image_)
599
- display(pil_img)
600
 
601
  def self_cross_attn(self_attn, cross_attn):
602
  cross_attn = cross_attn.squeeze()
 
19
  from diffusers.models.cross_attention import CrossAttention
20
  except:
21
  from diffusers.models.attention_processor import Attention as CrossAttention
 
22
  MAX_NUM_WORDS = 77
23
  LOW_RESOURCE = False
24
 
 
511
  out = out.sum(0) / out.shape[0]
512
  return out
513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
  def self_cross_attn(self_attn, cross_attn):
517
  cross_attn = cross_attn.squeeze()
_utils/load_models.py CHANGED
@@ -6,11 +6,7 @@ import torch.nn as nn
6
  def load_stable_diffusion_model(config: RunConfig):
7
  device = torch.device('cpu')
8
 
9
- if config.sd_2_1:
10
- stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
11
- else:
12
- stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
13
- # stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device)
14
  stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device)
15
  return stable
16
 
 
6
  def load_stable_diffusion_model(config: RunConfig):
7
  device = torch.device('cpu')
8
 
9
+ stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
 
 
 
 
10
  stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device)
11
  return stable
12
 
_utils/seg_eval.py DELETED
@@ -1,61 +0,0 @@
1
- import torch
2
-
3
-
4
- def iou_torch(inst1, inst2):
5
- inter = torch.logical_and(inst1, inst2).sum().float()
6
- union = torch.logical_or(inst1, inst2).sum().float()
7
- if union == 0:
8
- return torch.tensor(float('nan'))
9
- return inter / union
10
-
11
- def get_instances_torch(mask):
12
- # 返回所有非背景的 instance mask(布尔型)
13
- ids = torch.unique(mask)
14
- return [(mask == i) for i in ids if i != 0]
15
-
16
- def compute_instance_miou(pred_mask, gt_mask):
17
- # pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型
18
- pred_instances = get_instances_torch(pred_mask)
19
- gt_instances = get_instances_torch(gt_mask)
20
-
21
- ious = []
22
- for gt in gt_instances:
23
- best_iou = torch.tensor(0.0).to(pred_mask.device)
24
- for pred in pred_instances:
25
- i = iou_torch(pred, gt)
26
- if i > best_iou:
27
- best_iou = i
28
- ious.append(best_iou)
29
-
30
- # 处理空情况
31
- if len(ious) == 0:
32
- return torch.tensor(float('nan'))
33
- return torch.nanmean(torch.stack(ious))
34
-
35
- from torch import Tensor
36
-
37
-
38
- def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
39
- # Average of Dice coefficient for all batches, or for a single mask
40
- assert input.size() == target.size()
41
- assert input.dim() == 3 or not reduce_batch_first
42
-
43
- sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
44
-
45
- inter = 2 * (input * target).sum(dim=sum_dim)
46
- sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
47
- sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
48
-
49
- dice = (inter + epsilon) / (sets_sum + epsilon)
50
- return dice.mean()
51
-
52
-
53
- def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
54
- # Average of Dice coefficient for all classes
55
- return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
56
-
57
-
58
- def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
59
- # Dice loss (objective to minimize) between 0 and 1
60
- fn = multiclass_dice_coeff if multiclass else dice_coeff
61
- return 1 - fn(input, target, reduce_batch_first=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py CHANGED
@@ -7,8 +7,6 @@ from typing import Dict, List
7
  class RunConfig:
8
  # Guiding text prompt
9
  prompt: str = "<task-prompt>"
10
- # Whether to use Stable Diffusion v2.1
11
- sd_2_1: bool = False
12
  # Which token indices to alter with attend-and-excite
13
  token_indices: List[int] = field(default_factory=lambda: [2,5])
14
  # Which random seeds to use when generating
 
7
  class RunConfig:
8
  # Guiding text prompt
9
  prompt: str = "<task-prompt>"
 
 
10
  # Which token indices to alter with attend-and-excite
11
  token_indices: List[int] = field(default_factory=lambda: [2,5])
12
  # Which random seeds to use when generating
counting.py CHANGED
@@ -12,19 +12,16 @@ from PIL import Image
12
  import numpy as np
13
  from config import RunConfig
14
  from _utils import attn_utils_new as attn_utils
15
- from _utils.attn_utils import AttentionStore
16
  from _utils.misc_helper import *
17
  import torch.nn.functional as F
18
  import matplotlib.pyplot as plt
19
  import cv2
20
  import warnings
21
- from pytorch_lightning.callbacks import ModelCheckpoint
22
  warnings.filterwarnings("ignore", category=UserWarning)
23
  import pytorch_lightning as pl
24
  from _utils.load_models import load_stable_diffusion_model
25
  from models.model import Counting_with_SD_features_loca as Counting
26
- from pytorch_lightning.loggers import WandbLogger
27
- from models.enc_model.loca_args import get_argparser as loca_get_argparser
28
  from models.enc_model.loca import build_model as build_loca_model
29
  import time
30
  import torchvision.transforms as T
@@ -44,12 +41,7 @@ class CountingModule(pl.LightningModule):
44
  def initialize_model(self):
45
 
46
  # load loca model
47
- loca_args = loca_get_argparser().parse_args()
48
- self.loca_model = build_loca_model(loca_args)
49
- # weights = torch.load("ckpt/loca_few_shot.pt")["model"]
50
- # weights = {k.replace("module","") : v for k, v in weights.items()}
51
- # self.loca_model.load_state_dict(weights, strict=False)
52
- # del weights
53
 
54
  self.counting_adapter = Counting(scale_factor=SCALE)
55
  # if os.path.isfile(self.args.adapter_weight):
 
12
  import numpy as np
13
  from config import RunConfig
14
  from _utils import attn_utils_new as attn_utils
15
+ from _utils.attn_utils_new import AttentionStore
16
  from _utils.misc_helper import *
17
  import torch.nn.functional as F
18
  import matplotlib.pyplot as plt
19
  import cv2
20
  import warnings
 
21
  warnings.filterwarnings("ignore", category=UserWarning)
22
  import pytorch_lightning as pl
23
  from _utils.load_models import load_stable_diffusion_model
24
  from models.model import Counting_with_SD_features_loca as Counting
 
 
25
  from models.enc_model.loca import build_model as build_loca_model
26
  import time
27
  import torchvision.transforms as T
 
41
  def initialize_model(self):
42
 
43
  # load loca model
44
+ self.loca_model = build_loca_model()
 
 
 
 
 
45
 
46
  self.counting_adapter = Counting(scale_factor=SCALE)
47
  # if os.path.isfile(self.args.adapter_weight):
models/enc_model/loca.py CHANGED
@@ -78,12 +78,6 @@ class LOCA(nn.Module):
78
  nn.LayerNorm((64, 64))
79
  )
80
 
81
- # self.fuse1 = nn.Sequential(
82
- # nn.Conv2d(322, 256, kernel_size=1, stride=1),
83
- # nn.LeakyReLU(),
84
- # nn.LayerNorm((64, 64))
85
- # )
86
-
87
  def forward_before_reg(self, x, bboxes):
88
  num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects
89
  # backbone
@@ -105,7 +99,6 @@ class LOCA(nn.Module):
105
 
106
  all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256]
107
 
108
- outputs = list()
109
  response_maps_list = []
110
  for i in range(all_prototypes.size(0)):
111
  prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
@@ -122,18 +115,10 @@ class LOCA(nn.Module):
122
  bs, num_objects, self.emb_dim, h, w
123
  ).max(dim=1)[0]
124
 
125
- # # send through regression heads
126
- # if i == all_prototypes.size(0) - 1:
127
- # predicted_dmaps = self.regression_head(response_maps)
128
- # else:
129
- # predicted_dmaps = self.aux_heads[i](response_maps)
130
- # outputs.append(predicted_dmaps)
131
  response_maps_list.append(response_maps)
132
 
133
  out = {
134
- # "pred": outputs[-1],
135
  "feature_bf_regression": response_maps_list[-1],
136
- # "aux_pred": outputs[:-1],
137
  "aux_feature_bf_regression": response_maps_list[:-1]
138
  }
139
 
@@ -162,71 +147,61 @@ class LOCA(nn.Module):
162
 
163
  return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
164
 
165
- def forward_reg1(self, response_maps, self_attn):
166
- # attn_stack = self.attn_norm(attn_stack)
167
- # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
168
- # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
169
- # unet_feature = unet_feature * attn_stack_mean
170
- # if unet_feature.shape[1] == 322:
171
- # unet_feature = self.fuse1(unet_feature)
172
- # else:
173
- # unet_feature = self.fuse(unet_feature)
174
-
175
-
176
-
177
- response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
178
-
179
- outputs = []
180
- for i in range(len(response_maps)):
181
- response_map = response_maps[i] + self_attn
182
- if i == len(response_maps) - 1:
183
- predicted_dmaps = self.regression_head(response_map)
184
- else:
185
- predicted_dmaps = self.aux_heads[i](response_map)
186
- outputs.append(predicted_dmaps)
187
 
188
- return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
189
 
190
- def forward_reg_without_unet(self, response_maps, attn_stack):
191
- # attn_stack = self.attn_norm(attn_stack)
192
- attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
193
-
194
- response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
195
-
196
- outputs = []
197
- for i in range(len(response_maps)):
198
- response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i]
199
- if i == len(response_maps) - 1:
200
- predicted_dmaps = self.regression_head(response_map)
201
- else:
202
- predicted_dmaps = self.aux_heads[i](response_map)
203
- outputs.append(predicted_dmaps)
204
 
205
- return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
206
-
207
-
208
- def build_model(args):
209
 
210
- assert args.backbone in ['resnet18', 'resnet50', 'resnet101']
211
- assert args.reduction in [4, 8, 16]
212
 
 
 
 
 
 
213
  return LOCA(
214
- image_size=args.image_size,
215
- num_encoder_layers=args.num_enc_layers,
216
- num_ope_iterative_steps=args.num_ope_iterative_steps,
217
- num_objects=args.num_objects,
218
- zero_shot=args.zero_shot,
219
- emb_dim=args.emb_dim,
220
- num_heads=args.num_heads,
221
- kernel_dim=args.kernel_dim,
222
- backbone_name=args.backbone,
223
- swav_backbone=args.swav_backbone,
224
- train_backbone=args.backbone_lr > 0,
225
- reduction=args.reduction,
226
- dropout=args.dropout,
227
  layer_norm_eps=1e-5,
228
  mlp_factor=8,
229
- norm_first=args.pre_norm,
230
  activation=nn.GELU,
231
  norm=True,
232
  )
 
78
  nn.LayerNorm((64, 64))
79
  )
80
 
 
 
 
 
 
 
81
  def forward_before_reg(self, x, bboxes):
82
  num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects
83
  # backbone
 
99
 
100
  all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256]
101
 
 
102
  response_maps_list = []
103
  for i in range(all_prototypes.size(0)):
104
  prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
 
115
  bs, num_objects, self.emb_dim, h, w
116
  ).max(dim=1)[0]
117
 
 
 
 
 
 
 
118
  response_maps_list.append(response_maps)
119
 
120
  out = {
 
121
  "feature_bf_regression": response_maps_list[-1],
 
122
  "aux_feature_bf_regression": response_maps_list[:-1]
123
  }
124
 
 
147
 
148
  return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
149
 
150
+ # def forward_reg1(self, response_maps, self_attn):
151
+
152
+ # response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
153
+
154
+ # outputs = []
155
+ # for i in range(len(response_maps)):
156
+ # response_map = response_maps[i] + self_attn
157
+ # if i == len(response_maps) - 1:
158
+ # predicted_dmaps = self.regression_head(response_map)
159
+ # else:
160
+ # predicted_dmaps = self.aux_heads[i](response_map)
161
+ # outputs.append(predicted_dmaps)
 
 
 
 
 
 
 
 
 
 
162
 
163
+ # return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
164
 
165
+ # def forward_reg_without_unet(self, response_maps, attn_stack):
166
+ # # attn_stack = self.attn_norm(attn_stack)
167
+ # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
168
+
169
+ # response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
170
+
171
+ # outputs = []
172
+ # for i in range(len(response_maps)):
173
+ # response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i]
174
+ # if i == len(response_maps) - 1:
175
+ # predicted_dmaps = self.regression_head(response_map)
176
+ # else:
177
+ # predicted_dmaps = self.aux_heads[i](response_map)
178
+ # outputs.append(predicted_dmaps)
179
 
180
+ # return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
 
 
 
181
 
 
 
182
 
183
+ def build_model():
184
+ """
185
+ Build LOCA with a fixed configuration based on defaults in `loca_args.py`.
186
+ The `args` parameter is accepted for backward compatibility but ignored.
187
+ """
188
  return LOCA(
189
+ image_size=512,
190
+ num_encoder_layers=3,
191
+ num_ope_iterative_steps=3,
192
+ num_objects=3,
193
+ zero_shot=False,
194
+ emb_dim=256,
195
+ num_heads=8,
196
+ kernel_dim=3,
197
+ backbone_name="resnet50",
198
+ swav_backbone=True,
199
+ train_backbone=False, # backbone_lr default is 0 in loca_args.py
200
+ reduction=8,
201
+ dropout=0.1,
202
  layer_norm_eps=1e-5,
203
  mlp_factor=8,
204
+ norm_first=True,
205
  activation=nn.GELU,
206
  norm=True,
207
  )
models/enc_model/loca_args.py DELETED
@@ -1,44 +0,0 @@
1
- import argparse
2
-
3
-
4
- def get_argparser():
5
-
6
- parser = argparse.ArgumentParser("LOCA parser", add_help=False)
7
-
8
- parser.add_argument('--model_name', default='loca_few_shot', type=str)
9
- parser.add_argument(
10
- '--data_path',
11
- default='./data/FSC147_384_V2',
12
- type=str
13
- )
14
- parser.add_argument(
15
- '--model_path',
16
- default='ckpt',
17
- type=str
18
- )
19
- parser.add_argument('--backbone', default='resnet50', type=str)
20
- parser.add_argument('--swav_backbone', action='store_true', default=True)
21
- parser.add_argument('--reduction', default=8, type=int)
22
- parser.add_argument('--image_size', default=512, type=int)
23
- parser.add_argument('--num_enc_layers', default=3, type=int)
24
- parser.add_argument('--num_ope_iterative_steps', default=3, type=int)
25
- parser.add_argument('--emb_dim', default=256, type=int)
26
- parser.add_argument('--num_heads', default=8, type=int)
27
- parser.add_argument('--kernel_dim', default=3, type=int)
28
- parser.add_argument('--num_objects', default=3, type=int)
29
- parser.add_argument('--epochs', default=200, type=int)
30
- parser.add_argument('--resume_training', action='store_true')
31
- parser.add_argument('--lr', default=1e-4, type=float)
32
- parser.add_argument('--backbone_lr', default=0, type=float)
33
- parser.add_argument('--lr_drop', default=200, type=int)
34
- parser.add_argument('--weight_decay', default=1e-4, type=float)
35
- parser.add_argument('--batch_size', default=1, type=int)
36
- parser.add_argument('--dropout', default=0.1, type=float)
37
- parser.add_argument('--num_workers', default=8, type=int)
38
- parser.add_argument('--max_grad_norm', default=0.1, type=float)
39
- parser.add_argument('--aux_weight', default=0.3, type=float)
40
- parser.add_argument('--tiling_p', default=0.5, type=float)
41
- parser.add_argument('--zero_shot', action='store_true')
42
- parser.add_argument('--pre_norm', action='store_true', default=True)
43
-
44
- return parser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/enc_model/regression_head.py CHANGED
@@ -55,38 +55,3 @@ class DensityMapRegressor(nn.Module):
55
  nn.init.constant_(module.bias, 0)
56
 
57
 
58
- class DensityMapRegressor_(nn.Module):
59
-
60
- def __init__(self, in_channels, reduction):
61
-
62
- super(DensityMapRegressor, self).__init__()
63
-
64
- if reduction == 8:
65
- self.regressor = nn.Sequential(
66
- UpsamplingLayer(in_channels, 128),
67
- UpsamplingLayer(128, 64),
68
- UpsamplingLayer(64, 32),
69
- nn.Conv2d(32, 1, kernel_size=1),
70
- nn.LeakyReLU()
71
- )
72
- elif reduction == 16:
73
- self.regressor = nn.Sequential(
74
- UpsamplingLayer(in_channels, 128),
75
- UpsamplingLayer(128, 64),
76
- UpsamplingLayer(64, 32),
77
- UpsamplingLayer(32, 16),
78
- nn.Conv2d(16, 1, kernel_size=1),
79
- nn.LeakyReLU()
80
- )
81
-
82
- self.reset_parameters()
83
-
84
- def forward(self, x):
85
- return self.regressor(x)
86
-
87
- def reset_parameters(self):
88
- for module in self.modules():
89
- if isinstance(module, nn.Conv2d):
90
- nn.init.normal_(module.weight, std=0.01)
91
- if module.bias is not None:
92
- nn.init.constant_(module.bias, 0)
 
55
  nn.init.constant_(module.bias, 0)
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tra_post_model/trackastra/model/model.py CHANGED
@@ -139,138 +139,6 @@ class DecoderLayer(nn.Module):
139
  return x
140
 
141
 
142
- # class BidirectionalRelativePositionalAttention(RelativePositionalAttention):
143
- # def forward(
144
- # self,
145
- # query1: torch.Tensor,
146
- # query2: torch.Tensor,
147
- # coords: torch.Tensor,
148
- # padding_mask: torch.Tensor = None,
149
- # ):
150
- # B, N, D = query1.size()
151
- # q1 = self.q_pro(query1) # (B, N, D)
152
- # q2 = self.q_pro(query2) # (B, N, D)
153
- # v1 = self.v_pro(query1) # (B, N, D)
154
- # v2 = self.v_pro(query2) # (B, N, D)
155
-
156
- # # (B, nh, N, hs)
157
- # q1 = q1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
158
- # v1 = v1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
159
- # q2 = q2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
160
- # v2 = v2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
161
-
162
- # attn_mask = torch.zeros(
163
- # (B, self.n_head, N, N), device=query1.device, dtype=q1.dtype
164
- # )
165
-
166
- # # add negative value but not too large to keep mixed precision loss from becoming nan
167
- # attn_ignore_val = -1e3
168
-
169
- # # spatial cutoff
170
- # yx = coords[..., 1:]
171
- # spatial_dist = torch.cdist(yx, yx)
172
- # spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
173
- # attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
174
-
175
- # # dont add positional bias to self-attention if coords is None
176
- # if coords is not None:
177
- # if self._mode == "bias":
178
- # attn_mask = attn_mask + self.pos_bias(coords)
179
- # elif self._mode == "rope":
180
- # q1, q2 = self.rot_pos_enc(q1, q2, coords)
181
- # else:
182
- # pass
183
-
184
- # dist = torch.cdist(coords, coords, p=2)
185
- # attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
186
-
187
- # # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens)
188
- # if padding_mask is not None:
189
- # ignore_mask = torch.logical_or(
190
- # padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
191
- # ).unsqueeze(1)
192
- # attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
193
-
194
- # self.attn_mask = attn_mask.clone()
195
-
196
- # y1 = nn.functional.scaled_dot_product_attention(
197
- # q1,
198
- # q2,
199
- # v1,
200
- # attn_mask=attn_mask,
201
- # dropout_p=self.dropout if self.training else 0,
202
- # )
203
- # y2 = nn.functional.scaled_dot_product_attention(
204
- # q2,
205
- # q1,
206
- # v2,
207
- # attn_mask=attn_mask,
208
- # dropout_p=self.dropout if self.training else 0,
209
- # )
210
-
211
- # y1 = y1.transpose(1, 2).contiguous().view(B, N, D)
212
- # y1 = self.proj(y1)
213
- # y2 = y2.transpose(1, 2).contiguous().view(B, N, D)
214
- # y2 = self.proj(y2)
215
- # return y1, y2
216
-
217
-
218
- # class BidirectionalCrossAttention(nn.Module):
219
- # def __init__(
220
- # self,
221
- # coord_dim: int = 2,
222
- # d_model=256,
223
- # num_heads=4,
224
- # dropout=0.1,
225
- # window: int = 16,
226
- # cutoff_spatial: int = 256,
227
- # positional_bias: Literal["bias", "rope", "none"] = "bias",
228
- # positional_bias_n_spatial: int = 32,
229
- # ):
230
- # super().__init__()
231
- # self.positional_bias = positional_bias
232
- # self.attn = BidirectionalRelativePositionalAttention(
233
- # coord_dim,
234
- # d_model,
235
- # num_heads,
236
- # cutoff_spatial=cutoff_spatial,
237
- # n_spatial=positional_bias_n_spatial,
238
- # cutoff_temporal=window,
239
- # n_temporal=window,
240
- # dropout=dropout,
241
- # mode=positional_bias,
242
- # )
243
-
244
- # self.mlp = FeedForward(d_model)
245
- # self.norm1 = nn.LayerNorm(d_model)
246
- # self.norm2 = nn.LayerNorm(d_model)
247
-
248
- # def forward(
249
- # self,
250
- # x: torch.Tensor,
251
- # y: torch.Tensor,
252
- # coords: torch.Tensor,
253
- # padding_mask: torch.Tensor = None,
254
- # ):
255
- # x = self.norm1(x)
256
- # y = self.norm1(y)
257
-
258
- # # cross attention
259
- # # setting coords to None disables positional bias
260
- # x2, y2 = self.attn(
261
- # x,
262
- # y,
263
- # coords=coords if self.positional_bias else None,
264
- # padding_mask=padding_mask,
265
- # )
266
- # # print(torch.norm(x2).item()/torch.norm(x).item())
267
- # x = x + x2
268
- # x = x + self.mlp(self.norm2(x))
269
- # y = y + y2
270
- # y = y + self.mlp(self.norm2(y))
271
-
272
- # return x, y
273
-
274
 
275
  class TrackingTransformer(torch.nn.Module):
276
  def __init__(
 
139
  return x
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  class TrackingTransformer(torch.nn.Module):
144
  def __init__(
segmentation.py CHANGED
@@ -8,7 +8,7 @@ from PIL import Image
8
  import numpy as np
9
  from config import RunConfig
10
  from _utils import attn_utils_new as attn_utils
11
- from _utils.attn_utils import AttentionStore
12
  from _utils.misc_helper import *
13
  import torch.nn.functional as F
14
  import logging
@@ -20,10 +20,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
20
  import pytorch_lightning as pl
21
  from _utils.load_models import load_stable_diffusion_model
22
  from models.model import Counting_with_SD_features_dino_vit_c3 as Counting
23
- from models.enc_model.loca_args import get_argparser as loca_get_argparser
24
  from models.enc_model.loca import build_model as build_loca_model
25
  import time
26
- from _utils.seg_eval import *
27
  from models.seg_post_model import metrics
28
  from datetime import datetime
29
  import json
@@ -49,8 +47,7 @@ class SegmentationModule(pl.LightningModule):
49
  def initialize_model(self):
50
 
51
  # load loca model
52
- loca_args = loca_get_argparser().parse_args()
53
- self.loca_model = build_loca_model(loca_args)
54
  self.loca_model.eval()
55
 
56
  self.counting_adapter = Counting(scale_factor=SCALE)
 
8
  import numpy as np
9
  from config import RunConfig
10
  from _utils import attn_utils_new as attn_utils
11
+ from _utils.attn_utils_new import AttentionStore
12
  from _utils.misc_helper import *
13
  import torch.nn.functional as F
14
  import logging
 
20
  import pytorch_lightning as pl
21
  from _utils.load_models import load_stable_diffusion_model
22
  from models.model import Counting_with_SD_features_dino_vit_c3 as Counting
 
23
  from models.enc_model.loca import build_model as build_loca_model
24
  import time
 
25
  from models.seg_post_model import metrics
26
  from datetime import datetime
27
  import json
 
47
  def initialize_model(self):
48
 
49
  # load loca model
50
+ self.loca_model = build_loca_model()
 
51
  self.loca_model.eval()
52
 
53
  self.counting_adapter = Counting(scale_factor=SCALE)
tracking_one.py CHANGED
@@ -13,13 +13,9 @@ import tifffile
13
  import skimage.io as io
14
  from config import RunConfig
15
  from _utils import attn_utils_new as attn_utils
16
- from _utils.attn_utils import AttentionStore
17
  from _utils.misc_helper import *
18
- from torch.autograd import Variable
19
- import itertools
20
- from accelerate import Accelerator
21
  import torch.nn.functional as F
22
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
23
  from tqdm import tqdm
24
  import torch.nn as nn
25
  import matplotlib.pyplot as plt
@@ -29,19 +25,14 @@ warnings.filterwarnings("ignore", category=UserWarning)
29
  import pytorch_lightning as pl
30
  from _utils.load_models import load_stable_diffusion_model
31
  from models.model import Counting_with_SD_features_track as Counting
32
- from models.enc_model.loca_args import get_argparser as loca_get_argparser
33
  from models.enc_model.loca import build_model as build_loca_model
34
  import time
35
- from _utils.seg_eval import *
36
- from models.tra_post_model.trackastra.model import Trackastra
37
  from models.tra_post_model.trackastra.model import TrackingTransformer
38
  from models.tra_post_model.trackastra.utils import (
39
- blockwise_causal_norm,
40
- blockwise_sum,
41
  normalize,
42
  )
43
- from models.tra_post_model.trackastra.data import build_windows_sd, get_features, load_tiff_timeseries
44
- from models.tra_post_model.trackastra.tracking import TrackGraph, build_graph, track_greedy, graph_to_ctc
45
  from _utils.track_args import parse_train_args as get_track_args
46
  import torchvision.transforms as T
47
  from pathlib import Path
@@ -90,8 +81,7 @@ class TrackingModule(pl.LightningModule):
90
  def initialize_model(self):
91
 
92
  # load loca model
93
- loca_args = loca_get_argparser().parse_args()
94
- self.loca_model = build_loca_model(loca_args)
95
  # weights = torch.load("ckpt/loca_few_shot.pt")["model"]
96
  # weights = {k.replace("module","") : v for k, v in weights.items()}
97
  # self.loca_model.load_state_dict(weights, strict=False)
@@ -985,7 +975,6 @@ class TrackingModule(pl.LightningModule):
985
 
986
  self.eval()
987
  imgs, imgs_raw, images_stable, tra_imgs, imgs_01, height, width = load_track_images(file_dir)
988
- # tra_imgs = torch.from_numpy(imgs_).float().to(self.device)
989
  imgs_stable = torch.from_numpy(images_stable).float().to(self.device)
990
  imgs_enc = torch.from_numpy(imgs).float().to(self.device)
991
 
@@ -1032,37 +1021,31 @@ class TrackingModule(pl.LightningModule):
1032
  )
1033
  track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
1034
 
1035
- # ctc_tracks, masks_tracked = graph_to_ctc(
1036
- # track_graph,
1037
- # masks,
1038
- # outdir=f"tracked/{dataname}",
1039
- # )
1040
-
1041
  return track_graph, masks
1042
 
1043
 
1044
 
1045
- def inference(data_path, box=None):
1046
- if box is not None:
1047
- use_box = True
1048
- else:
1049
- use_box = False
1050
 
1051
- model = TrackingModule(use_box=use_box)
1052
- load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_tra.pth"), strict=True)
1053
 
1054
- model.move_to_device(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
1055
 
1056
 
1057
- track_graph, masks = model.track(file_dir=data_path, dataname="inference_sequence")
1058
 
1059
- if not os.path.exists(f"tracked_ours_seg_pred3/"):
1060
- os.makedirs(f"tracked_ours_seg_pred3/")
1061
- ctc_tracks, masks_tracked = graph_to_ctc(
1062
- track_graph,
1063
- masks,
1064
- outdir=f"tracked_ours_seg_pred3/",
1065
- )
1066
 
1067
- if __name__ == "__main__":
1068
- inference(data_path="example_imgs/2D+Time/Fluo-N2DL-HeLa/train/Fluo-N2DL-HeLa/02")
 
13
  import skimage.io as io
14
  from config import RunConfig
15
  from _utils import attn_utils_new as attn_utils
16
+ from _utils.attn_utils_new import AttentionStore
17
  from _utils.misc_helper import *
 
 
 
18
  import torch.nn.functional as F
 
19
  from tqdm import tqdm
20
  import torch.nn as nn
21
  import matplotlib.pyplot as plt
 
25
  import pytorch_lightning as pl
26
  from _utils.load_models import load_stable_diffusion_model
27
  from models.model import Counting_with_SD_features_track as Counting
 
28
  from models.enc_model.loca import build_model as build_loca_model
29
  import time
 
 
30
  from models.tra_post_model.trackastra.model import TrackingTransformer
31
  from models.tra_post_model.trackastra.utils import (
 
 
32
  normalize,
33
  )
34
+ from models.tra_post_model.trackastra.data import build_windows_sd, get_features
35
+ from models.tra_post_model.trackastra.tracking import TrackGraph, build_graph, track_greedy
36
  from _utils.track_args import parse_train_args as get_track_args
37
  import torchvision.transforms as T
38
  from pathlib import Path
 
81
  def initialize_model(self):
82
 
83
  # load loca model
84
+ self.loca_model = build_loca_model()
 
85
  # weights = torch.load("ckpt/loca_few_shot.pt")["model"]
86
  # weights = {k.replace("module","") : v for k, v in weights.items()}
87
  # self.loca_model.load_state_dict(weights, strict=False)
 
975
 
976
  self.eval()
977
  imgs, imgs_raw, images_stable, tra_imgs, imgs_01, height, width = load_track_images(file_dir)
 
978
  imgs_stable = torch.from_numpy(images_stable).float().to(self.device)
979
  imgs_enc = torch.from_numpy(imgs).float().to(self.device)
980
 
 
1021
  )
1022
  track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
1023
 
 
 
 
 
 
 
1024
  return track_graph, masks
1025
 
1026
 
1027
 
1028
+ # def inference(data_path, box=None):
1029
+ # if box is not None:
1030
+ # use_box = True
1031
+ # else:
1032
+ # use_box = False
1033
 
1034
+ # model = TrackingModule(use_box=use_box)
1035
+ # load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_tra.pth"), strict=True)
1036
 
1037
+ # model.move_to_device(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
1038
 
1039
 
1040
+ # track_graph, masks = model.track(file_dir=data_path, dataname="inference_sequence")
1041
 
1042
+ # if not os.path.exists(f"tracked_ours_seg_pred3/"):
1043
+ # os.makedirs(f"tracked_ours_seg_pred3/")
1044
+ # ctc_tracks, masks_tracked = graph_to_ctc(
1045
+ # track_graph,
1046
+ # masks,
1047
+ # outdir=f"tracked_ours_seg_pred3/",
1048
+ # )
1049
 
1050
+ # if __name__ == "__main__":
1051
+ # inference(data_path="example_imgs/2D+Time/Fluo-N2DL-HeLa/train/Fluo-N2DL-HeLa/02")