VisionLanguageGroup commited on
Commit
f10f497
ยท
1 Parent(s): 102cd7d
app.py CHANGED
@@ -314,7 +314,7 @@ def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
314
 
315
  try:
316
  mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
317
- print("๐Ÿ“ mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
318
  except Exception as e:
319
  print(f"โŒ Inference failed: {str(e)}")
320
  return None, None, {}
@@ -344,8 +344,6 @@ def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
344
  inst_mask = mask_np.astype(np.int32)
345
  unique_ids = np.unique(inst_mask)
346
  num_instances = len(unique_ids[unique_ids != 0])
347
- print(f"โœ… Instance IDs found: {unique_ids}, Total instances: {num_instances}")
348
-
349
  if num_instances == 0:
350
  print("โš ๏ธ No instance found, returning dummy red image")
351
  return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None, {}
 
314
 
315
  try:
316
  mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
317
+ print("๐Ÿ“ mask shape:", mask.shape, "dtype:", mask.dtype)
318
  except Exception as e:
319
  print(f"โŒ Inference failed: {str(e)}")
320
  return None, None, {}
 
344
  inst_mask = mask_np.astype(np.int32)
345
  unique_ids = np.unique(inst_mask)
346
  num_instances = len(unique_ids[unique_ids != 0])
 
 
347
  if num_instances == 0:
348
  print("โš ๏ธ No instance found, returning dummy red image")
349
  return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None, {}
models/model.py CHANGED
@@ -5,35 +5,10 @@ import os
5
  import clip
6
  import sys
7
  import numpy as np
8
- from models.seg_post_model.cellpose.models import CellposeModel
9
 
10
  from torchvision.ops import roi_align
11
- def crop_roi_feat(feat, boxes):
12
- """
13
- feat: 1 x c x h x w
14
- boxes: m x 4, 4: [y_tl, x_tl, y_br, x_br]
15
- """
16
- _, _, h, w = feat.shape
17
- out_stride = 512 / h
18
- boxes_scaled = boxes / out_stride
19
- boxes_scaled[:, :2] = torch.floor(boxes_scaled[:, :2]) # y_tl, x_tl: floor
20
- boxes_scaled[:, 2:] = torch.ceil(boxes_scaled[:, 2:]) # y_br, x_br: ceil
21
- boxes_scaled[:, :2] = torch.clamp_min(boxes_scaled[:, :2], 0)
22
- boxes_scaled[:, 2] = torch.clamp_max(boxes_scaled[:, 2], h)
23
- boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], w)
24
- feat_boxes = []
25
- for idx_box in range(0, boxes.shape[0]):
26
- y_tl, x_tl, y_br, x_br = boxes_scaled[idx_box]
27
- y_tl, x_tl, y_br, x_br = int(y_tl), int(x_tl), int(y_br), int(x_br)
28
- feat_box = feat[:, :, y_tl : (y_br + 1), x_tl : (x_br + 1)]
29
- feat_boxes.append(feat_box)
30
- return feat_boxes
31
 
32
- class Counting_with_SD_features(nn.Module):
33
- def __init__(self, scale_factor):
34
- super(Counting_with_SD_features, self).__init__()
35
- self.adapter = adapter_roi()
36
- # self.regressor = regressor_with_SD_features()
37
 
38
  class Counting_with_SD_features_loca(nn.Module):
39
  def __init__(self, scale_factor):
@@ -55,60 +30,6 @@ class Counting_with_SD_features_track(nn.Module):
55
  self.regressor = regressor_with_SD_features_tra()
56
 
57
 
58
- class adapter_roi(nn.Module):
59
- def __init__(self, pool_size=[3, 3]):
60
- super(adapter_roi, self).__init__()
61
- self.pool_size = pool_size
62
- self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
63
- # self.relu = nn.ReLU()
64
- # self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
65
- self.pool = nn.MaxPool2d(2)
66
- self.fc = nn.Linear(256 * 3 * 3, 768)
67
- # **new
68
- self.fc1 = nn.Sequential(
69
- nn.ReLU(),
70
- nn.Linear(768, 768 // 4, bias=False),
71
- nn.ReLU()
72
- )
73
- self.fc2 = nn.Sequential(
74
- nn.Linear(768 // 4, 768, bias=False),
75
- # nn.ReLU()
76
- )
77
- self.initialize_weights()
78
-
79
- def forward(self, x, boxes):
80
- num_of_boxes = boxes.shape[1]
81
- rois = []
82
- bs, _, h, w = x.shape
83
- boxes = torch.cat([
84
- torch.arange(
85
- bs, requires_grad=False
86
- ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
87
- boxes.flatten(0, 1),
88
- ], dim=1)
89
- rois = roi_align(
90
- x,
91
- boxes=boxes, output_size=3,
92
- spatial_scale=1.0 / 8, aligned=True
93
- )
94
- rois = torch.mean(rois, dim=0, keepdim=True)
95
- x = self.conv1(rois)
96
- x = x.view(x.size(0), -1)
97
- x = self.fc(x)
98
-
99
- x = self.fc1(x)
100
- x = self.fc2(x)
101
- return x
102
-
103
-
104
- def initialize_weights(self):
105
- for m in self.modules():
106
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
107
- nn.init.xavier_normal_(m.weight)
108
- if m.bias is not None:
109
- nn.init.constant_(m.bias, 0)
110
-
111
-
112
  class adapter_roi_loca(nn.Module):
113
  def __init__(self, pool_size=[3, 3]):
114
  super(adapter_roi_loca, self).__init__()
@@ -188,69 +109,6 @@ class adapter_roi_loca(nn.Module):
188
 
189
 
190
 
191
-
192
- class regressor1(nn.Module):
193
- def __init__(self):
194
- super(regressor1, self).__init__()
195
- self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
196
- self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
197
- self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
198
- self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
199
- self.leaky_relu = nn.LeakyReLU()
200
- self.relu = nn.ReLU()
201
- self.initialize_weights()
202
-
203
-
204
-
205
- def forward(self, x):
206
- x_ = self.conv1(x)
207
- x_ = self.leaky_relu(x_)
208
- x_ = self.upsampler(x_)
209
- x_ = self.conv2(x_)
210
- x_ = self.leaky_relu(x_)
211
- x_ = self.upsampler(x_)
212
- x_ = self.conv3(x_)
213
- x_ = self.relu(x_)
214
- out = x_
215
- return out
216
-
217
- def initialize_weights(self):
218
- for m in self.modules():
219
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
220
- nn.init.xavier_normal_(m.weight)
221
- if m.bias is not None:
222
- nn.init.constant_(m.bias, 0)
223
-
224
-
225
- class regressor1(nn.Module):
226
- def __init__(self):
227
- super(regressor1, self).__init__()
228
- self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
229
- self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
230
- self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
231
- self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
232
- self.leaky_relu = nn.LeakyReLU()
233
- self.relu = nn.ReLU()
234
-
235
- def forward(self, x):
236
- x_ = self.conv1(x)
237
- x_ = self.leaky_relu(x_)
238
- x_ = self.upsampler(x_)
239
- x_ = self.conv2(x_)
240
- x_ = self.leaky_relu(x_)
241
- x_ = self.upsampler(x_)
242
- x_ = self.conv3(x_)
243
- x_ = self.relu(x_)
244
- out = x_
245
- return out
246
- def initialize_weights(self):
247
- for m in self.modules():
248
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
249
- nn.init.xavier_normal_(m.weight)
250
- if m.bias is not None:
251
- nn.init.constant_(m.bias, 0)
252
-
253
-
254
  class regressor_with_SD_features(nn.Module):
255
  def __init__(self):
256
  super(regressor_with_SD_features, self).__init__()
@@ -301,57 +159,6 @@ class regressor_with_SD_features(nn.Module):
301
  if m.bias is not None:
302
  nn.init.constant_(m.bias, 0)
303
 
304
- class regressor_with_SD_features_seg(nn.Module):
305
- def __init__(self):
306
- super(regressor_with_SD_features_seg, self).__init__()
307
- self.layer1 = nn.Sequential(
308
- nn.Conv2d(324, 256, kernel_size=1, stride=1),
309
- nn.LeakyReLU(),
310
- nn.LayerNorm((64, 64))
311
- )
312
- self.layer2 = nn.Sequential(
313
- nn.Conv2d(256, 128, kernel_size=3, padding=1),
314
- nn.LeakyReLU(),
315
- nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
316
- )
317
- self.layer3 = nn.Sequential(
318
- nn.Conv2d(128, 64, kernel_size=3, padding=1),
319
- nn.ReLU(),
320
- nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
321
- )
322
- self.layer4 = nn.Sequential(
323
- nn.Conv2d(64, 32, kernel_size=3, padding=1),
324
- nn.LeakyReLU(),
325
- nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
326
- )
327
- self.conv = nn.Sequential(
328
- nn.Conv2d(32, 2, kernel_size=1),
329
- # nn.ReLU()
330
- )
331
- self.norm = nn.LayerNorm(normalized_shape=(64, 64))
332
- self.initialize_weights()
333
-
334
- def forward(self, attn_stack, feature_list):
335
- attn_stack = self.norm(attn_stack)
336
- unet_feature = feature_list[-1]
337
- attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
338
- unet_feature = unet_feature * attn_stack_mean
339
- unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
340
- x = self.layer1(unet_feature)
341
- x = self.layer2(x)
342
- x = self.layer3(x)
343
- x = self.layer4(x)
344
- out = self.conv(x)
345
- return out
346
-
347
- def initialize_weights(self):
348
- for m in self.modules():
349
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
350
- nn.init.xavier_normal_(m.weight)
351
- if m.bias is not None:
352
- nn.init.constant_(m.bias, 0)
353
-
354
-
355
  from models.enc_model.unet_parts import *
356
 
357
 
@@ -363,7 +170,7 @@ class regressor_with_SD_features_seg_vit_c3(nn.Module):
363
  self.bilinear = bilinear
364
  self.norm = nn.LayerNorm(normalized_shape=(64, 64))
365
  self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
366
- self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
367
  self.vit = self.vit_model.net
368
 
369
  def forward(self, img, attn_stack, feature_list):
@@ -380,7 +187,7 @@ class regressor_with_SD_features_seg_vit_c3(nn.Module):
380
 
381
 
382
 
383
- out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
384
  if out.dtype == np.uint16:
385
  out = out.astype(np.int16)
386
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
@@ -403,12 +210,11 @@ class regressor_with_SD_features_tra(nn.Module):
403
 
404
  # segmentation
405
  self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
406
- self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
407
  self.vit = self.vit_model.net
408
 
409
  self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
410
  self.mlp = nn.Linear(64 * 64, 320)
411
- # self.vit = self.vit_model.net.float()
412
 
413
  def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
414
  attn_stack = attn_stack[:, [1,3], ...]
@@ -422,7 +228,7 @@ class regressor_with_SD_features_tra(nn.Module):
422
  x = self.inc_0(x)
423
  feat = x
424
 
425
- out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
426
  if out.dtype == np.uint16:
427
  out = out.astype(np.int16)
428
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
@@ -450,204 +256,3 @@ class regressor_with_SD_features_tra(nn.Module):
450
  nn.init.xavier_normal_(m.weight)
451
  if m.bias is not None:
452
  nn.init.constant_(m.bias, 0)
453
-
454
-
455
-
456
- class regressor_with_SD_features_inst_seg_unet(nn.Module):
457
- def __init__(self, n_channels=8, n_classes=3, bilinear=False):
458
- super(regressor_with_SD_features_inst_seg_unet, self).__init__()
459
- self.n_channels = n_channels
460
- self.n_classes = n_classes
461
- self.bilinear = bilinear
462
- self.norm = nn.LayerNorm(normalized_shape=(64, 64))
463
- self.inc_0 = (DoubleConv(n_channels, 3))
464
- self.inc = (DoubleConv(3, 64))
465
- self.down1 = (Down(64, 128))
466
- self.down2 = (Down(128, 256))
467
- self.down3 = (Down(256, 512))
468
- factor = 2 if bilinear else 1
469
- self.down4 = (Down(512, 1024 // factor))
470
- self.up1 = (Up(1024, 512 // factor, bilinear))
471
- self.up2 = (Up(512, 256 // factor, bilinear))
472
- self.up3 = (Up(256, 128 // factor, bilinear))
473
- self.up4 = (Up(128, 64, bilinear))
474
- self.outc = (OutConv(64, n_classes))
475
-
476
- def forward(self, img, attn_stack, feature_list):
477
- attn_stack = self.norm(attn_stack)
478
- unet_feature = feature_list[-1]
479
- unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
480
- attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
481
- unet_feature_mean = unet_feature_mean * attn_stack_mean
482
- x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
483
- if x.shape[-1] != 512:
484
- x = F.interpolate(x, size=(512, 512), mode="bilinear")
485
- x = torch.cat([img, x], dim=1) # [1, 8, 512, 512]
486
- x = self.inc_0(x)
487
- x1 = self.inc(x)
488
- x2 = self.down1(x1)
489
- x3 = self.down2(x2)
490
- x4 = self.down3(x3)
491
- x5 = self.down4(x4)
492
- x = self.up1(x5, x4)
493
- x = self.up2(x, x3)
494
- x = self.up3(x, x2)
495
- x = self.up4(x, x1)
496
- out = self.outc(x)
497
- return out
498
-
499
- def initialize_weights(self):
500
- for m in self.modules():
501
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
502
- nn.init.xavier_normal_(m.weight)
503
- if m.bias is not None:
504
- nn.init.constant_(m.bias, 0)
505
-
506
-
507
- class regressor_with_SD_features_self(nn.Module):
508
- def __init__(self):
509
- super(regressor_with_SD_features_self, self).__init__()
510
- self.layer = nn.Sequential(
511
- nn.Conv2d(4096, 1024, kernel_size=1, stride=1),
512
- nn.LeakyReLU(),
513
- nn.LayerNorm((64, 64)),
514
- nn.Conv2d(1024, 256, kernel_size=1, stride=1),
515
- nn.LeakyReLU(),
516
- nn.LayerNorm((64, 64)),
517
- )
518
- self.layer2 = nn.Sequential(
519
- nn.Conv2d(256, 128, kernel_size=3, padding=1),
520
- nn.LeakyReLU(),
521
- nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
522
- )
523
- self.layer3 = nn.Sequential(
524
- nn.Conv2d(128, 64, kernel_size=3, padding=1),
525
- nn.ReLU(),
526
- nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
527
- )
528
- self.layer4 = nn.Sequential(
529
- nn.Conv2d(64, 32, kernel_size=3, padding=1),
530
- nn.LeakyReLU(),
531
- nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
532
- )
533
- self.conv = nn.Sequential(
534
- nn.Conv2d(32, 1, kernel_size=1),
535
- nn.ReLU()
536
- )
537
- self.norm = nn.LayerNorm(normalized_shape=(64, 64))
538
- self.initialize_weights()
539
-
540
- def forward(self, self_attn):
541
- self_attn = self_attn.permute(2, 0, 1)
542
- self_attn = self.layer(self_attn)
543
- return self_attn
544
- # attn_stack = self.norm(attn_stack)
545
- # unet_feature = feature_list[-1]
546
- # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
547
- # unet_feature = unet_feature * attn_stack_mean
548
- # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
549
- # x = self.layer(unet_feature)
550
- # x = self.layer2(x)
551
- # x = self.layer3(x)
552
- # x = self.layer4(x)
553
- # out = self.conv(x)
554
- # return out / 100
555
-
556
- def initialize_weights(self):
557
- for m in self.modules():
558
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
559
- nn.init.xavier_normal_(m.weight)
560
- if m.bias is not None:
561
- nn.init.constant_(m.bias, 0)
562
-
563
-
564
- class regressor_with_SD_features_latent(nn.Module):
565
- def __init__(self):
566
- super(regressor_with_SD_features_latent, self).__init__()
567
- self.layer = nn.Sequential(
568
- nn.Conv2d(4, 256, kernel_size=1, stride=1),
569
- nn.LeakyReLU(),
570
- nn.LayerNorm((64, 64))
571
- )
572
- self.layer2 = nn.Sequential(
573
- nn.Conv2d(256, 128, kernel_size=3, padding=1),
574
- nn.LeakyReLU(),
575
- nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
576
- )
577
- self.layer3 = nn.Sequential(
578
- nn.Conv2d(128, 64, kernel_size=3, padding=1),
579
- nn.ReLU(),
580
- nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
581
- )
582
- self.layer4 = nn.Sequential(
583
- nn.Conv2d(64, 32, kernel_size=3, padding=1),
584
- nn.LeakyReLU(),
585
- nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
586
- )
587
- self.conv = nn.Sequential(
588
- nn.Conv2d(32, 1, kernel_size=1),
589
- nn.ReLU()
590
- )
591
- self.norm = nn.LayerNorm(normalized_shape=(64, 64))
592
- self.initialize_weights()
593
-
594
- def forward(self, self_attn):
595
- # self_attn = self_attn.permute(2, 0, 1)
596
- self_attn = self.layer(self_attn)
597
- return self_attn
598
- # attn_stack = self.norm(attn_stack)
599
- # unet_feature = feature_list[-1]
600
- # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
601
- # unet_feature = unet_feature * attn_stack_mean
602
- # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
603
- # x = self.layer(unet_feature)
604
- # x = self.layer2(x)
605
- # x = self.layer3(x)
606
- # x = self.layer4(x)
607
- # out = self.conv(x)
608
- # return out / 100
609
-
610
- def initialize_weights(self):
611
- for m in self.modules():
612
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
613
- nn.init.xavier_normal_(m.weight)
614
- if m.bias is not None:
615
- nn.init.constant_(m.bias, 0)
616
-
617
-
618
-
619
-
620
-
621
- class regressor_with_deconv(nn.Module):
622
- def __init__(self):
623
- super(regressor_with_deconv, self).__init__()
624
- self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
625
- self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
626
- self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
627
- self.deconv1 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
628
- self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
629
- self.leaky_relu = nn.LeakyReLU()
630
- self.relu = nn.ReLU()
631
- self.initialize_weights()
632
-
633
- def forward(self, x):
634
- x_ = self.conv1(x)
635
- x_ = self.leaky_relu(x_)
636
- x_ = self.deconv1(x_)
637
- x_ = self.conv2(x_)
638
- x_ = self.leaky_relu(x_)
639
- x_ = self.deconv2(x_)
640
- x_ = self.conv3(x_)
641
- x_ = self.relu(x_)
642
- out = x_
643
- return out
644
-
645
- def initialize_weights(self):
646
- for m in self.modules():
647
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
648
- nn.init.xavier_normal_(m.weight)
649
- if m.bias is not None:
650
- nn.init.constant_(m.bias, 0)
651
-
652
-
653
-
 
5
  import clip
6
  import sys
7
  import numpy as np
8
+ from models.seg_post_model.models import SegModel
9
 
10
  from torchvision.ops import roi_align
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
12
 
13
  class Counting_with_SD_features_loca(nn.Module):
14
  def __init__(self, scale_factor):
 
30
  self.regressor = regressor_with_SD_features_tra()
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class adapter_roi_loca(nn.Module):
34
  def __init__(self, pool_size=[3, 3]):
35
  super(adapter_roi_loca, self).__init__()
 
109
 
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  class regressor_with_SD_features(nn.Module):
113
  def __init__(self):
114
  super(regressor_with_SD_features, self).__init__()
 
159
  if m.bias is not None:
160
  nn.init.constant_(m.bias, 0)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  from models.enc_model.unet_parts import *
163
 
164
 
 
170
  self.bilinear = bilinear
171
  self.norm = nn.LayerNorm(normalized_shape=(64, 64))
172
  self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
173
+ self.vit_model = SegModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
174
  self.vit = self.vit_model.net
175
 
176
  def forward(self, img, attn_stack, feature_list):
 
187
 
188
 
189
 
190
+ out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())
191
  if out.dtype == np.uint16:
192
  out = out.astype(np.int16)
193
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
 
210
 
211
  # segmentation
212
  self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
213
+ self.vit_model = SegModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
214
  self.vit = self.vit_model.net
215
 
216
  self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
217
  self.mlp = nn.Linear(64 * 64, 320)
 
218
 
219
  def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
220
  attn_stack = attn_stack[:, [1,3], ...]
 
228
  x = self.inc_0(x)
229
  feat = x
230
 
231
+ out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())
232
  if out.dtype == np.uint16:
233
  out = out.astype(np.int16)
234
  out = torch.from_numpy(out).unsqueeze(0).to(x.device)
 
256
  nn.init.xavier_normal_(m.weight)
257
  if m.bias is not None:
258
  nn.init.constant_(m.bias, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/seg_post_model/cellpose/__init__.py DELETED
@@ -1 +0,0 @@
1
- # from .version import version, version_str
 
 
models/seg_post_model/cellpose/io.py DELETED
@@ -1,816 +0,0 @@
1
- """
2
- Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- import os, warnings, glob, shutil
5
- from natsort import natsorted
6
- import numpy as np
7
- import cv2
8
- import tifffile
9
- import logging, pathlib, sys
10
- from tqdm import tqdm
11
- from pathlib import Path
12
- import re
13
- from .version import version_str
14
- from roifile import ImagejRoi, roiwrite
15
-
16
- try:
17
- from qtpy import QtGui, QtCore, Qt, QtWidgets
18
- from qtpy.QtWidgets import QMessageBox
19
- GUI = True
20
- except:
21
- GUI = False
22
-
23
- try:
24
- import matplotlib.pyplot as plt
25
- MATPLOTLIB = True
26
- except:
27
- MATPLOTLIB = False
28
-
29
- try:
30
- import nd2
31
- ND2 = True
32
- except:
33
- ND2 = False
34
-
35
- try:
36
- import nrrd
37
- NRRD = True
38
- except:
39
- NRRD = False
40
-
41
- try:
42
- from google.cloud import storage
43
- SERVER_UPLOAD = True
44
- except:
45
- SERVER_UPLOAD = False
46
-
47
- io_logger = logging.getLogger(__name__)
48
-
49
- def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None):
50
- cp_dir = pathlib.Path.home().joinpath(cp_path)
51
- cp_dir.mkdir(exist_ok=True)
52
- log_file = cp_dir.joinpath(logfile_name)
53
- try:
54
- log_file.unlink()
55
- except:
56
- print('creating new log file')
57
- handlers = [logging.FileHandler(log_file),]
58
- if stdout_file_replacement is not None:
59
- handlers.append(logging.FileHandler(stdout_file_replacement))
60
- else:
61
- handlers.append(logging.StreamHandler(sys.stdout))
62
- logging.basicConfig(
63
- level=logging.INFO,
64
- format="%(asctime)s [%(levelname)s] %(message)s",
65
- handlers=handlers,
66
- force=True
67
- )
68
- logger = logging.getLogger(__name__)
69
- logger.info(f"WRITING LOG OUTPUT TO {log_file}")
70
- logger.info(version_str)
71
-
72
- return logger, log_file
73
-
74
-
75
- from . import utils, plot, transforms
76
-
77
- # helper function to check for a path; if it doesn't exist, make it
78
- def check_dir(path):
79
- if not os.path.isdir(path):
80
- os.mkdir(path)
81
-
82
-
83
- def outlines_to_text(base, outlines):
84
- with open(base + "_cp_outlines.txt", "w") as f:
85
- for o in outlines:
86
- xy = list(o.flatten())
87
- xy_str = ",".join(map(str, xy))
88
- f.write(xy_str)
89
- f.write("\n")
90
-
91
-
92
- def load_dax(filename):
93
- ### modified from ZhuangLab github:
94
- ### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
95
-
96
- inf_filename = os.path.splitext(filename)[0] + ".inf"
97
- if not os.path.exists(inf_filename):
98
- io_logger.critical(
99
- f"ERROR: no inf file found for dax file {filename}, cannot load dax without it"
100
- )
101
- return None
102
-
103
- ### get metadata
104
- image_height, image_width = None, None
105
- # extract the movie information from the associated inf file
106
- size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)")
107
- length_re = re.compile(r"number of frames = ([\d]+)")
108
- endian_re = re.compile(r" (big|little) endian")
109
-
110
- with open(inf_filename, "r") as inf_file:
111
- lines = inf_file.read().split("\n")
112
- for line in lines:
113
- m = size_re.match(line)
114
- if m:
115
- image_height = int(m.group(2))
116
- image_width = int(m.group(1))
117
- m = length_re.match(line)
118
- if m:
119
- number_frames = int(m.group(1))
120
- m = endian_re.search(line)
121
- if m:
122
- if m.group(1) == "big":
123
- bigendian = 1
124
- else:
125
- bigendian = 0
126
- # set defaults, warn the user that they couldn"t be determined from the inf file.
127
- if not image_height:
128
- io_logger.warning("could not determine dax image size, assuming 256x256")
129
- image_height = 256
130
- image_width = 256
131
-
132
- ### load image
133
- img = np.memmap(filename, dtype="uint16",
134
- shape=(number_frames, image_height, image_width))
135
- if bigendian:
136
- img = img.byteswap()
137
- img = np.array(img)
138
-
139
- return img
140
-
141
-
142
- def imread(filename):
143
- """
144
- Read in an image file with tif or image file type supported by cv2.
145
-
146
- Args:
147
- filename (str): The path to the image file.
148
-
149
- Returns:
150
- numpy.ndarray: The image data as a NumPy array.
151
-
152
- Raises:
153
- None
154
-
155
- Raises an error if the image file format is not supported.
156
-
157
- Examples:
158
- >>> img = imread("image.tif")
159
- """
160
- # ensure that extension check is not case sensitive
161
- ext = os.path.splitext(filename)[-1].lower()
162
- if ext == ".tif" or ext == ".tiff" or ext == ".flex":
163
- with tifffile.TiffFile(filename) as tif:
164
- ltif = len(tif.pages)
165
- try:
166
- full_shape = tif.shaped_metadata[0]["shape"]
167
- except:
168
- try:
169
- page = tif.series[0][0]
170
- full_shape = tif.series[0].shape
171
- except:
172
- ltif = 0
173
- if ltif < 10:
174
- img = tif.asarray()
175
- else:
176
- page = tif.series[0][0]
177
- shape, dtype = page.shape, page.dtype
178
- ltif = int(np.prod(full_shape) / np.prod(shape))
179
- io_logger.info(f"reading tiff with {ltif} planes")
180
- img = np.zeros((ltif, *shape), dtype=dtype)
181
- for i, page in enumerate(tqdm(tif.series[0])):
182
- img[i] = page.asarray()
183
- img = img.reshape(full_shape)
184
- return img
185
- elif ext == ".dax":
186
- img = load_dax(filename)
187
- return img
188
- elif ext == ".nd2":
189
- if not ND2:
190
- io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
191
- return None
192
- elif ext == ".nrrd":
193
- if not NRRD:
194
- io_logger.critical(
195
- "ERROR: need to 'pip install pynrrd' to load in .nrrd file")
196
- return None
197
- else:
198
- img, metadata = nrrd.read(filename)
199
- if img.ndim == 3:
200
- img = img.transpose(2, 0, 1)
201
- return img
202
- elif ext != ".npy":
203
- try:
204
- img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH)
205
- if img.ndim > 2:
206
- img = img[..., [2, 1, 0]]
207
- return img
208
- except Exception as e:
209
- io_logger.critical("ERROR: could not read file, %s" % e)
210
- return None
211
- else:
212
- try:
213
- dat = np.load(filename, allow_pickle=True).item()
214
- masks = dat["masks"]
215
- return masks
216
- except Exception as e:
217
- io_logger.critical("ERROR: could not read masks from file, %s" % e)
218
- return None
219
-
220
-
221
- def imread_2D(img_file):
222
- """
223
- Read in a 2D image file and convert it to a 3-channel image. Attempts to do this for multi-channel and grayscale images.
224
- If the image has more than 3 channels, only the first 3 channels are kept.
225
-
226
- Args:
227
- img_file (str): The path to the image file.
228
-
229
- Returns:
230
- img_out (numpy.ndarray): The 3-channel image data as a NumPy array.
231
- """
232
- img = imread(img_file)
233
- return transforms.convert_image(img, do_3D=False)
234
-
235
-
236
- def imread_3D(img_file):
237
- """
238
- Read in a 3D image file and convert it to have a channel axis last automatically. Attempts to do this for multi-channel and grayscale images.
239
-
240
- If multichannel image, the channel axis is assumed to be the smallest dimension, and the z axis is the next smallest dimension.
241
- Use `cellpose.io.imread()` to load the full image without selecting the z and channel axes.
242
-
243
- Args:
244
- img_file (str): The path to the image file.
245
-
246
- Returns:
247
- img_out (numpy.ndarray): The image data as a NumPy array.
248
- """
249
- img = imread(img_file)
250
-
251
- dimension_lengths = list(img.shape)
252
-
253
- # grayscale images:
254
- if img.ndim == 3:
255
- channel_axis = None
256
- # guess at z axis:
257
- z_axis = np.argmin(dimension_lengths)
258
-
259
- elif img.ndim == 4:
260
- # guess at channel axis:
261
- channel_axis = np.argmin(dimension_lengths)
262
-
263
- # guess at z axis:
264
- # set channel axis to max so argmin works:
265
- dimension_lengths[channel_axis] = max(dimension_lengths)
266
- z_axis = np.argmin(dimension_lengths)
267
-
268
- else:
269
- raise ValueError(f'image shape error, 3D image must 3 or 4 dimensional. Number of dimensions: {img.ndim}')
270
-
271
- try:
272
- return transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=True)
273
- except Exception as e:
274
- io_logger.critical("ERROR: could not read file, %s" % e)
275
- io_logger.critical("ERROR: Guessed z_axis: %s, channel_axis: %s" % (z_axis, channel_axis))
276
- return None
277
-
278
- def remove_model(filename, delete=False):
279
- """ remove model from .cellpose custom model list """
280
- filename = os.path.split(filename)[-1]
281
- from . import models
282
- model_strings = models.get_user_models()
283
- if len(model_strings) > 0:
284
- with open(models.MODEL_LIST_PATH, "w") as textfile:
285
- for fname in model_strings:
286
- textfile.write(fname + "\n")
287
- else:
288
- # write empty file
289
- textfile = open(models.MODEL_LIST_PATH, "w")
290
- textfile.close()
291
- print(f"{filename} removed from custom model list")
292
- if delete:
293
- os.remove(os.fspath(models.MODEL_DIR.joinpath(fname)))
294
- print("model deleted")
295
-
296
-
297
- def add_model(filename):
298
- """ add model to .cellpose models folder to use with GUI or CLI """
299
- from . import models
300
- fname = os.path.split(filename)[-1]
301
- try:
302
- shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
303
- except shutil.SameFileError:
304
- pass
305
- print(f"{filename} copied to models folder {os.fspath(models.MODEL_DIR)}")
306
- if fname not in models.get_user_models():
307
- with open(models.MODEL_LIST_PATH, "a") as textfile:
308
- textfile.write(fname + "\n")
309
-
310
-
311
- def imsave(filename, arr):
312
- """
313
- Saves an image array to a file.
314
-
315
- Args:
316
- filename (str): The name of the file to save the image to.
317
- arr (numpy.ndarray): The image array to be saved.
318
-
319
- Returns:
320
- None
321
- """
322
- ext = os.path.splitext(filename)[-1].lower()
323
- if ext == ".tif" or ext == ".tiff":
324
- tifffile.imwrite(filename, data=arr, compression="zlib")
325
- else:
326
- if len(arr.shape) > 2:
327
- arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
328
- cv2.imwrite(filename, arr)
329
-
330
-
331
- def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
332
- """
333
- Finds all images in a folder and its subfolders (if specified) with the given file extensions.
334
-
335
- Args:
336
- folder (str): The path to the folder to search for images.
337
- mask_filter (str): The filter for mask files.
338
- imf (str, optional): The additional filter for image files. Defaults to None.
339
- look_one_level_down (bool, optional): Whether to search for images in subfolders. Defaults to False.
340
-
341
- Returns:
342
- list: A list of image file paths.
343
-
344
- Raises:
345
- ValueError: If no files are found in the specified folder.
346
- ValueError: If no images are found in the specified folder with the supported file extensions.
347
- ValueError: If no images are found in the specified folder without the mask or flow file endings.
348
- """
349
- mask_filters = ["_cp_output", "_flows", "_flows_0", "_flows_1",
350
- "_flows_2", "_cellprob", "_masks", mask_filter]
351
- image_names = []
352
- if imf is None:
353
- imf = ""
354
-
355
- folders = []
356
- if look_one_level_down:
357
- folders = natsorted(glob.glob(os.path.join(folder, "*/")))
358
- folders.append(folder)
359
- exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".flex", ".dax", ".nd2", ".nrrd"]
360
- l0 = 0
361
- al = 0
362
- for folder in folders:
363
- all_files = glob.glob(folder + "/*")
364
- al += len(all_files)
365
- for ext in exts:
366
- image_names.extend(glob.glob(folder + f"/*{imf}{ext}"))
367
- image_names.extend(glob.glob(folder + f"/*{imf}{ext.upper()}"))
368
- l0 += len(image_names)
369
-
370
- # return error if no files found
371
- if al == 0:
372
- raise ValueError("ERROR: no files in --dir folder ")
373
- elif l0 == 0:
374
- raise ValueError(
375
- "ERROR: no images in --dir folder with extensions .png, .jpg, .jpeg, .tif, .tiff, .flex"
376
- )
377
-
378
- image_names = natsorted(image_names)
379
- imn = []
380
- for im in image_names:
381
- imfile = os.path.splitext(im)[0]
382
- igood = all([(len(imfile) > len(mask_filter) and
383
- imfile[-len(mask_filter):] != mask_filter) or
384
- len(imfile) <= len(mask_filter) for mask_filter in mask_filters])
385
- if len(imf) > 0:
386
- igood &= imfile[-len(imf):] == imf
387
- if igood:
388
- imn.append(im)
389
-
390
- image_names = imn
391
-
392
- # remove duplicates
393
- image_names = [*set(image_names)]
394
- image_names = natsorted(image_names)
395
-
396
- if len(image_names) == 0:
397
- raise ValueError(
398
- "ERROR: no images in --dir folder without _masks or _flows or _cellprob ending")
399
-
400
- return image_names
401
-
402
- def get_label_files(image_names, mask_filter, imf=None):
403
- """
404
- Get the label files corresponding to the given image names and mask filter.
405
-
406
- Args:
407
- image_names (list): List of image names.
408
- mask_filter (str): Mask filter to be applied.
409
- imf (str, optional): Image file extension. Defaults to None.
410
-
411
- Returns:
412
- tuple: A tuple containing the label file names and flow file names (if present).
413
- """
414
- nimg = len(image_names)
415
- label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)]
416
-
417
- if imf is not None and len(imf) > 0:
418
- label_names = [label_names0[n][:-len(imf)] for n in range(nimg)]
419
- else:
420
- label_names = label_names0
421
-
422
- # check for flows
423
- if os.path.exists(label_names0[0] + "_flows.tif"):
424
- flow_names = [label_names0[n] + "_flows.tif" for n in range(nimg)]
425
- else:
426
- flow_names = [label_names[n] + "_flows.tif" for n in range(nimg)]
427
- if not all([os.path.exists(flow) for flow in flow_names]):
428
- io_logger.info(
429
- "not all flows are present, running flow generation for all images")
430
- flow_names = None
431
-
432
- # check for masks
433
- if mask_filter == "_seg.npy":
434
- label_names = [label_names[n] + mask_filter for n in range(nimg)]
435
- return label_names, None
436
-
437
- if os.path.exists(label_names[0] + mask_filter + ".tif"):
438
- label_names = [label_names[n] + mask_filter + ".tif" for n in range(nimg)]
439
- elif os.path.exists(label_names[0] + mask_filter + ".tiff"):
440
- label_names = [label_names[n] + mask_filter + ".tiff" for n in range(nimg)]
441
- elif os.path.exists(label_names[0] + mask_filter + ".png"):
442
- label_names = [label_names[n] + mask_filter + ".png" for n in range(nimg)]
443
- # TODO, allow _seg.npy
444
- #elif os.path.exists(label_names[0] + "_seg.npy"):
445
- # io_logger.info("labels found as _seg.npy files, converting to tif")
446
- else:
447
- if not flow_names:
448
- raise ValueError("labels not provided with correct --mask_filter")
449
- else:
450
- label_names = None
451
- if not all([os.path.exists(label) for label in label_names]):
452
- if not flow_names:
453
- raise ValueError(
454
- "labels not provided for all images in train and/or test set")
455
- else:
456
- label_names = None
457
-
458
- return label_names, flow_names
459
-
460
-
461
- def load_images_labels(tdir, mask_filter="_masks", image_filter=None,
462
- look_one_level_down=False):
463
- """
464
- Loads images and corresponding labels from a directory.
465
-
466
- Args:
467
- tdir (str): The directory path.
468
- mask_filter (str, optional): The filter for mask files. Defaults to "_masks".
469
- image_filter (str, optional): The filter for image files. Defaults to None.
470
- look_one_level_down (bool, optional): Whether to look for files one level down. Defaults to False.
471
-
472
- Returns:
473
- tuple: A tuple containing a list of images, a list of labels, and a list of image names.
474
- """
475
- image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down)
476
- nimg = len(image_names)
477
-
478
- # training data
479
- label_names, flow_names = get_label_files(image_names, mask_filter,
480
- imf=image_filter)
481
-
482
- images = []
483
- labels = []
484
- k = 0
485
- for n in range(nimg):
486
- if (os.path.isfile(label_names[n]) or
487
- (flow_names is not None and os.path.isfile(flow_names[0]))):
488
- image = imread(image_names[n])
489
- if label_names is not None:
490
- label = imread(label_names[n])
491
- if flow_names is not None:
492
- flow = imread(flow_names[n])
493
- if flow.shape[0] < 4:
494
- label = np.concatenate((label[np.newaxis, :, :], flow), axis=0)
495
- else:
496
- label = flow
497
- images.append(image)
498
- labels.append(label)
499
- k += 1
500
- io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels")
501
- return images, labels, image_names
502
-
503
- def load_train_test_data(train_dir, test_dir=None, image_filter=None,
504
- mask_filter="_masks", look_one_level_down=False):
505
- """
506
- Loads training and testing data for a Cellpose model.
507
-
508
- Args:
509
- train_dir (str): The directory path containing the training data.
510
- test_dir (str, optional): The directory path containing the testing data. Defaults to None.
511
- image_filter (str, optional): The filter for selecting image files. Defaults to None.
512
- mask_filter (str, optional): The filter for selecting mask files. Defaults to "_masks".
513
- look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False.
514
-
515
- Returns:
516
- images, labels, image_names, test_images, test_labels, test_image_names
517
-
518
- """
519
- images, labels, image_names = load_images_labels(train_dir, mask_filter,
520
- image_filter, look_one_level_down)
521
- # testing data
522
- test_images, test_labels, test_image_names = None, None, None
523
- if test_dir is not None:
524
- test_images, test_labels, test_image_names = load_images_labels(
525
- test_dir, mask_filter, image_filter, look_one_level_down)
526
-
527
- return images, labels, image_names, test_images, test_labels, test_image_names
528
-
529
-
530
- def masks_flows_to_seg(images, masks, flows, file_names,
531
- channels=None,
532
- imgs_restore=None, restore_type=None, ratio=1.):
533
- """Save output of model eval to be loaded in GUI.
534
-
535
- Can be list output (run on multiple images) or single output (run on single image).
536
-
537
- Saved to file_names[k]+"_seg.npy".
538
-
539
- Args:
540
- images (list): Images input into cellpose.
541
- masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
542
- flows (list): Flows output from Cellpose.eval.
543
- file_names (list, str): Names of files of images.
544
- diams (float array): Diameters used to run Cellpose. Defaults to 30. TODO: remove this
545
- channels (list, int, optional): Channels used to run Cellpose. Defaults to None.
546
-
547
- Returns:
548
- None
549
- """
550
-
551
- if channels is None:
552
- channels = [0, 0]
553
-
554
- if isinstance(masks, list):
555
- if imgs_restore is None:
556
- imgs_restore = [None] * len(masks)
557
- if isinstance(file_names, str):
558
- file_names = [file_names] * len(masks)
559
- for k, [image, mask, flow,
560
- # diam,
561
- file_name, img_restore
562
- ] in enumerate(zip(images, masks, flows,
563
- # diams,
564
- file_names,
565
- imgs_restore)):
566
- channels_img = channels
567
- if channels_img is not None and len(channels) > 2:
568
- channels_img = channels[k]
569
- masks_flows_to_seg(image, mask, flow, file_name,
570
- # diams=diam,
571
- channels=channels_img, imgs_restore=img_restore,
572
- restore_type=restore_type, ratio=ratio)
573
- return
574
-
575
- if len(channels) == 1:
576
- channels = channels[0]
577
-
578
- flowi = []
579
- if flows[0].ndim == 3:
580
- Ly, Lx = masks.shape[-2:]
581
- flowi.append(
582
- cv2.resize(flows[0], (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis,
583
- ...])
584
- else:
585
- flowi.append(flows[0])
586
-
587
- if flows[0].ndim == 3:
588
- cellprob = (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(
589
- np.uint8)
590
- cellprob = cv2.resize(cellprob, (Lx, Ly), interpolation=cv2.INTER_NEAREST)
591
- flowi.append(cellprob[np.newaxis, ...])
592
- flowi.append(np.zeros(flows[0].shape, dtype=np.uint8))
593
- flowi[-1] = flowi[-1][np.newaxis, ...]
594
- else:
595
- flowi.append(
596
- (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(np.uint8))
597
- flowi.append((flows[1][0] / 10 * 127 + 127).astype(np.uint8))
598
- if len(flows) > 2:
599
- if len(flows) > 3:
600
- flowi.append(flows[3])
601
- else:
602
- flowi.append([])
603
- flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0))
604
- outlines = masks * utils.masks_to_outlines(masks)
605
- base = os.path.splitext(file_names)[0]
606
-
607
- dat = {
608
- "outlines":
609
- outlines.astype(np.uint16) if outlines.max() < 2**16 -
610
- 1 else outlines.astype(np.uint32),
611
- "masks":
612
- masks.astype(np.uint16) if outlines.max() < 2**16 -
613
- 1 else masks.astype(np.uint32),
614
- "chan_choose":
615
- channels,
616
- "ismanual":
617
- np.zeros(masks.max(), bool),
618
- "filename":
619
- file_names,
620
- "flows":
621
- flowi,
622
- "diameter":
623
- np.nan
624
- }
625
- if restore_type is not None and imgs_restore is not None:
626
- dat["restore"] = restore_type
627
- dat["ratio"] = ratio
628
- dat["img_restore"] = imgs_restore
629
-
630
- np.save(base + "_seg.npy", dat)
631
-
632
- def save_to_png(images, masks, flows, file_names):
633
- """ deprecated (runs io.save_masks with png=True)
634
-
635
- does not work for 3D images
636
-
637
- """
638
- save_masks(images, masks, flows, file_names, png=True)
639
-
640
-
641
- def save_rois(masks, file_name, multiprocessing=None):
642
- """ save masks to .roi files in .zip archive for ImageJ/Fiji
643
-
644
- Args:
645
- masks (np.ndarray): masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels
646
- file_name (str): name to save the .zip file to
647
-
648
- Returns:
649
- None
650
- """
651
- outlines = utils.outlines_list(masks, multiprocessing=multiprocessing)
652
- nonempty_outlines = [outline for outline in outlines if len(outline)!=0]
653
- if len(outlines)!=len(nonempty_outlines):
654
- print(f"empty outlines found, saving {len(nonempty_outlines)} ImageJ ROIs to .zip archive.")
655
- rois = [ImagejRoi.frompoints(outline) for outline in nonempty_outlines]
656
- file_name = os.path.splitext(file_name)[0] + '_rois.zip'
657
-
658
-
659
- # Delete file if it exists; the roifile lib appends to existing zip files.
660
- # If the user removed a mask it will still be in the zip file
661
- if os.path.exists(file_name):
662
- os.remove(file_name)
663
-
664
- roiwrite(file_name, rois)
665
-
666
-
667
- def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0],
668
- suffix="_cp_masks", save_flows=False, save_outlines=False, dir_above=False,
669
- in_folders=False, savedir=None, save_txt=False, save_mpl=False):
670
- """ Save masks + nicely plotted segmentation image to png and/or tiff.
671
-
672
- Can save masks, flows to different directories, if in_folders is True.
673
-
674
- If png, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.png".
675
-
676
- If tif, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.tif".
677
-
678
- If png and matplotlib installed, full segmentation figure is saved to file_names[k]+"_cp.png".
679
-
680
- Only tif option works for 3D data, and only tif option works for empty masks.
681
-
682
- Args:
683
- images (list): Images input into cellpose.
684
- masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
685
- flows (list): Flows output from Cellpose.eval.
686
- file_names (list, str): Names of files of images.
687
- png (bool, optional): Save masks to PNG. Defaults to True.
688
- tif (bool, optional): Save masks to TIF. Defaults to False.
689
- channels (list, int, optional): Channels used to run Cellpose. Defaults to [0,0].
690
- suffix (str, optional): Add name to saved masks. Defaults to "_cp_masks".
691
- save_flows (bool, optional): Save flows output from Cellpose.eval. Defaults to False.
692
- save_outlines (bool, optional): Save outlines of masks. Defaults to False.
693
- dir_above (bool, optional): Save masks/flows in directory above. Defaults to False.
694
- in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False.
695
- savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None.
696
- save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False.
697
- save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D.
698
- This takes a long time for large images. Defaults to False.
699
-
700
- Returns:
701
- None
702
- """
703
-
704
- if isinstance(masks, list):
705
- for image, mask, flow, file_name in zip(images, masks, flows, file_names):
706
- save_masks(image, mask, flow, file_name, png=png, tif=tif, suffix=suffix,
707
- dir_above=dir_above, save_flows=save_flows,
708
- save_outlines=save_outlines, savedir=savedir, save_txt=save_txt,
709
- in_folders=in_folders, save_mpl=save_mpl)
710
- return
711
-
712
- if masks.ndim > 2 and not tif:
713
- raise ValueError("cannot save 3D outputs as PNG, use tif option instead")
714
-
715
- if masks.max() == 0:
716
- io_logger.warning("no masks found, will not save PNG or outlines")
717
- if not tif:
718
- return
719
- else:
720
- png = False
721
- save_outlines = False
722
- save_flows = False
723
- save_txt = False
724
-
725
- if savedir is None:
726
- if dir_above:
727
- savedir = Path(file_names).parent.parent.absolute(
728
- ) #go up a level to save in its own folder
729
- else:
730
- savedir = Path(file_names).parent.absolute()
731
-
732
- check_dir(savedir)
733
-
734
- basename = os.path.splitext(os.path.basename(file_names))[0]
735
- if in_folders:
736
- maskdir = os.path.join(savedir, "masks")
737
- outlinedir = os.path.join(savedir, "outlines")
738
- txtdir = os.path.join(savedir, "txt_outlines")
739
- flowdir = os.path.join(savedir, "flows")
740
- else:
741
- maskdir = savedir
742
- outlinedir = savedir
743
- txtdir = savedir
744
- flowdir = savedir
745
-
746
- check_dir(maskdir)
747
-
748
- exts = []
749
- if masks.ndim > 2:
750
- png = False
751
- tif = True
752
- if png:
753
- if masks.max() < 2**16:
754
- masks = masks.astype(np.uint16)
755
- exts.append(".png")
756
- else:
757
- png = False
758
- tif = True
759
- io_logger.warning(
760
- "found more than 65535 masks in each image, cannot save PNG, saving as TIF"
761
- )
762
- if tif:
763
- exts.append(".tif")
764
-
765
- # save masks
766
- with warnings.catch_warnings():
767
- warnings.simplefilter("ignore")
768
- for ext in exts:
769
- imsave(os.path.join(maskdir, basename + suffix + ext), masks)
770
-
771
- if save_mpl and png and MATPLOTLIB and not min(images.shape) > 3:
772
- # Make and save original/segmentation/flows image
773
-
774
- img = images.copy()
775
- if img.ndim < 3:
776
- img = img[:, :, np.newaxis]
777
- elif img.shape[0] < 8:
778
- np.transpose(img, (1, 2, 0))
779
-
780
- fig = plt.figure(figsize=(12, 3))
781
- plot.show_segmentation(fig, img, masks, flows[0])
782
- fig.savefig(os.path.join(savedir, basename + "_cp_output" + suffix + ".png"),
783
- dpi=300)
784
- plt.close(fig)
785
-
786
- # ImageJ txt outline files
787
- if masks.ndim < 3 and save_txt:
788
- check_dir(txtdir)
789
- outlines = utils.outlines_list(masks)
790
- outlines_to_text(os.path.join(txtdir, basename), outlines)
791
-
792
- # RGB outline images
793
- if masks.ndim < 3 and save_outlines:
794
- check_dir(outlinedir)
795
- outlines = utils.masks_to_outlines(masks)
796
- outX, outY = np.nonzero(outlines)
797
- img0 = transforms.normalize99(images)
798
- if img0.shape[0] < 4:
799
- img0 = np.transpose(img0, (1, 2, 0))
800
- if img0.shape[-1] < 3 or img0.ndim < 3:
801
- img0 = plot.image_to_rgb(img0, channels=channels)
802
- else:
803
- if img0.max() <= 50.0:
804
- img0 = np.uint8(np.clip(img0 * 255, 0, 1))
805
- imgout = img0.copy()
806
- imgout[outX, outY] = np.array([255, 0, 0]) #pure red
807
- imsave(os.path.join(outlinedir, basename + "_outlines" + suffix + ".png"),
808
- imgout)
809
-
810
- # save RGB flow picture
811
- if masks.ndim < 3 and save_flows:
812
- check_dir(flowdir)
813
- imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"),
814
- (flows[0] * (2**16 - 1)).astype(np.uint16))
815
- #save full flow data
816
- imsave(os.path.join(flowdir, basename + '_dP' + suffix + '.tif'), flows[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/plot.py DELETED
@@ -1,281 +0,0 @@
1
- """
2
- Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- import os
5
- import numpy as np
6
- import cv2
7
- from scipy.ndimage import gaussian_filter
8
- from . import utils, io, transforms
9
-
10
- try:
11
- import matplotlib
12
- MATPLOTLIB_ENABLED = True
13
- except:
14
- MATPLOTLIB_ENABLED = False
15
-
16
- try:
17
- from skimage import color
18
- from skimage.segmentation import find_boundaries
19
- SKIMAGE_ENABLED = True
20
- except:
21
- SKIMAGE_ENABLED = False
22
-
23
-
24
- # modified to use sinebow color
25
- def dx_to_circ(dP):
26
- """Converts the optic flow representation to a circular color representation.
27
-
28
- Args:
29
- dP (ndarray): Flow field components [dy, dx].
30
-
31
- Returns:
32
- ndarray: The circular color representation of the optic flow.
33
-
34
- """
35
- mag = 255 * np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2, axis=0))), 0, 1.)
36
- angles = np.arctan2(dP[1], dP[0]) + np.pi
37
- a = 2
38
- mag /= a
39
- rgb = np.zeros((*dP.shape[1:], 3), "uint8")
40
- rgb[..., 0] = np.clip(mag * (np.cos(angles) + 1), 0, 255).astype("uint8")
41
- rgb[..., 1] = np.clip(mag * (np.cos(angles + 2 * np.pi / 3) + 1), 0, 255).astype("uint8")
42
- rgb[..., 2] = np.clip(mag * (np.cos(angles + 4 * np.pi / 3) + 1), 0, 255).astype("uint8")
43
-
44
- return rgb
45
-
46
-
47
- def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None):
48
- """Plot segmentation results (like on website).
49
-
50
- Can save each panel of figure with file_name option. Use channels option if
51
- img input is not an RGB image with 3 channels.
52
-
53
- Args:
54
- fig (matplotlib.pyplot.figure): Figure in which to make plot.
55
- img (ndarray): 2D or 3D array. Image input into cellpose.
56
- maski (int, ndarray): For image k, masks[k] output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
57
- flowi (int, ndarray): For image k, flows[k][0] output from Cellpose.eval (RGB of flows).
58
- channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0].
59
- file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None.
60
- seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False.
61
- """
62
- if not MATPLOTLIB_ENABLED:
63
- raise ImportError(
64
- "matplotlib not installed, install with 'pip install matplotlib'")
65
- ax = fig.add_subplot(1, 4, 1)
66
- img0 = img.copy()
67
-
68
- if img0.shape[0] < 4:
69
- img0 = np.transpose(img0, (1, 2, 0))
70
- if img0.shape[-1] < 3 or img0.ndim < 3:
71
- img0 = image_to_rgb(img0, channels=channels)
72
- else:
73
- if img0.max() <= 50.0:
74
- img0 = np.uint8(np.clip(img0, 0, 1) * 255)
75
- ax.imshow(img0)
76
- ax.set_title("original image")
77
- ax.axis("off")
78
-
79
- outlines = utils.masks_to_outlines(maski)
80
-
81
- overlay = mask_overlay(img0, maski)
82
-
83
- ax = fig.add_subplot(1, 4, 2)
84
- outX, outY = np.nonzero(outlines)
85
- imgout = img0.copy()
86
- imgout[outX, outY] = np.array([255, 0, 0]) # pure red
87
-
88
- ax.imshow(imgout)
89
- ax.set_title("predicted outlines")
90
- ax.axis("off")
91
-
92
- ax = fig.add_subplot(1, 4, 3)
93
- ax.imshow(overlay)
94
- ax.set_title("predicted masks")
95
- ax.axis("off")
96
-
97
- ax = fig.add_subplot(1, 4, 4)
98
- ax.imshow(flowi)
99
- ax.set_title("predicted cell pose")
100
- ax.axis("off")
101
-
102
- if file_name is not None:
103
- save_path = os.path.splitext(file_name)[0]
104
- io.imsave(save_path + "_overlay.jpg", overlay)
105
- io.imsave(save_path + "_outlines.jpg", imgout)
106
- io.imsave(save_path + "_flows.jpg", flowi)
107
-
108
-
109
- def mask_rgb(masks, colors=None):
110
- """Masks in random RGB colors.
111
-
112
- Args:
113
- masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
114
- colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
115
-
116
- Returns:
117
- RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
118
- """
119
- if colors is not None:
120
- if colors.max() > 1:
121
- colors = np.float32(colors)
122
- colors /= 255
123
- colors = utils.rgb_to_hsv(colors)
124
-
125
- HSV = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32)
126
- HSV[:, :, 2] = 1.0
127
- for n in range(int(masks.max())):
128
- ipix = (masks == n + 1).nonzero()
129
- if colors is None:
130
- HSV[ipix[0], ipix[1], 0] = np.random.rand()
131
- else:
132
- HSV[ipix[0], ipix[1], 0] = colors[n, 0]
133
- HSV[ipix[0], ipix[1], 1] = np.random.rand() * 0.5 + 0.5
134
- HSV[ipix[0], ipix[1], 2] = np.random.rand() * 0.5 + 0.5
135
- RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
136
- return RGB
137
-
138
-
139
- def mask_overlay(img, masks, colors=None):
140
- """Overlay masks on image (set image to grayscale).
141
-
142
- Args:
143
- img (int or float, 2D or 3D array): Image of size [Ly x Lx (x nchan)].
144
- masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
145
- colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
146
-
147
- Returns:
148
- RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
149
- """
150
- if colors is not None:
151
- if colors.max() > 1:
152
- colors = np.float32(colors)
153
- colors /= 255
154
- colors = utils.rgb_to_hsv(colors)
155
- if img.ndim > 2:
156
- img = img.astype(np.float32).mean(axis=-1)
157
- else:
158
- img = img.astype(np.float32)
159
-
160
- HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
161
- HSV[:, :, 2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1)
162
- hues = np.linspace(0, 1, masks.max() + 1)[np.random.permutation(masks.max())]
163
- for n in range(int(masks.max())):
164
- ipix = (masks == n + 1).nonzero()
165
- if colors is None:
166
- HSV[ipix[0], ipix[1], 0] = hues[n]
167
- else:
168
- HSV[ipix[0], ipix[1], 0] = colors[n, 0]
169
- HSV[ipix[0], ipix[1], 1] = 1.0
170
- RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
171
- return RGB
172
-
173
-
174
- def image_to_rgb(img0, channels=[0, 0]):
175
- """Converts image from 2 x Ly x Lx or Ly x Lx x 2 to RGB Ly x Lx x 3.
176
-
177
- Args:
178
- img0 (ndarray): Input image of shape 2 x Ly x Lx or Ly x Lx x 2.
179
-
180
- Returns:
181
- ndarray: RGB image of shape Ly x Lx x 3.
182
-
183
- """
184
- img = img0.copy()
185
- img = img.astype(np.float32)
186
- if img.ndim < 3:
187
- img = img[:, :, np.newaxis]
188
- if img.shape[0] < 5:
189
- img = np.transpose(img, (1, 2, 0))
190
- if channels[0] == 0:
191
- img = img.mean(axis=-1)[:, :, np.newaxis]
192
- for i in range(img.shape[-1]):
193
- if np.ptp(img[:, :, i]) > 0:
194
- img[:, :, i] = np.clip(transforms.normalize99(img[:, :, i]), 0, 1)
195
- img[:, :, i] = np.clip(img[:, :, i], 0, 1)
196
- img *= 255
197
- img = np.uint8(img)
198
- RGB = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
199
- if img.shape[-1] == 1:
200
- RGB = np.tile(img, (1, 1, 3))
201
- else:
202
- RGB[:, :, channels[0] - 1] = img[:, :, 0]
203
- if channels[1] > 0:
204
- RGB[:, :, channels[1] - 1] = img[:, :, 1]
205
- return RGB
206
-
207
-
208
- def interesting_patch(mask, bsize=130):
209
- """
210
- Get patch of size bsize x bsize with most masks.
211
-
212
- Args:
213
- mask (ndarray): Input mask.
214
- bsize (int): Size of the patch.
215
-
216
- Returns:
217
- tuple: Patch coordinates (y, x).
218
-
219
- """
220
- Ly, Lx = mask.shape
221
- m = np.float32(mask > 0)
222
- m = gaussian_filter(m, bsize / 2)
223
- y, x = np.unravel_index(np.argmax(m), m.shape)
224
- ycent = max(bsize // 2, min(y, Ly - bsize // 2))
225
- xcent = max(bsize // 2, min(x, Lx - bsize // 2))
226
- patch = [
227
- np.arange(ycent - bsize // 2, ycent + bsize // 2, 1, int),
228
- np.arange(xcent - bsize // 2, xcent + bsize // 2, 1, int)
229
- ]
230
- return patch
231
-
232
-
233
- def disk(med, r, Ly, Lx):
234
- """Returns the pixels of a disk with a given radius and center.
235
-
236
- Args:
237
- med (tuple): The center coordinates of the disk.
238
- r (float): The radius of the disk.
239
- Ly (int): The height of the image.
240
- Lx (int): The width of the image.
241
-
242
- Returns:
243
- tuple: A tuple containing the y and x coordinates of the pixels within the disk.
244
-
245
- """
246
- yy, xx = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
247
- indexing="ij")
248
- inds = ((yy - med[0])**2 + (xx - med[1])**2)**0.5 <= r
249
- y = yy[inds].flatten()
250
- x = xx[inds].flatten()
251
- return y, x
252
-
253
-
254
- def outline_view(img0, maski, color=[1, 0, 0], mode="inner"):
255
- """
256
- Generates a red outline overlay onto the image.
257
-
258
- Args:
259
- img0 (numpy.ndarray): The input image.
260
- maski (numpy.ndarray): The mask representing the region of interest.
261
- color (list, optional): The color of the outline overlay. Defaults to [1, 0, 0] (red).
262
- mode (str, optional): The mode for generating the outline. Defaults to "inner".
263
-
264
- Returns:
265
- numpy.ndarray: The image with the red outline overlay.
266
-
267
- """
268
- if img0.ndim == 2:
269
- img0 = np.stack([img0] * 3, axis=-1)
270
- elif img0.ndim != 3:
271
- raise ValueError("img0 not right size (must have ndim 2 or 3)")
272
-
273
- if SKIMAGE_ENABLED:
274
- outlines = find_boundaries(maski, mode=mode)
275
- else:
276
- outlines = utils.masks_to_outlines(maski, mode=mode)
277
- outY, outX = np.nonzero(outlines)
278
- imgout = img0.copy()
279
- imgout[outY, outX] = np.array(color)
280
-
281
- return imgout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/utils.py DELETED
@@ -1,667 +0,0 @@
1
- """
2
- Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- import logging
5
- import os, tempfile, shutil, io
6
- from tqdm import tqdm, trange
7
- from urllib.request import urlopen
8
- import cv2
9
- from scipy.ndimage import find_objects, gaussian_filter, generate_binary_structure, label
10
- from scipy.spatial import ConvexHull
11
- import numpy as np
12
- import colorsys
13
- import fastremap
14
- import fill_voids
15
- from multiprocessing import Pool, cpu_count
16
- # try:
17
- # from cellpose import metrics
18
- # except:
19
- # import metrics as metrics
20
- from models.seg_post_model.cellpose import metrics
21
-
22
- try:
23
- from skimage.morphology import remove_small_holes
24
- SKIMAGE_ENABLED = True
25
- except:
26
- SKIMAGE_ENABLED = False
27
-
28
-
29
- class TqdmToLogger(io.StringIO):
30
- """
31
- Output stream for TQDM which will output to logger module instead of
32
- the StdOut.
33
- """
34
- logger = None
35
- level = None
36
- buf = ""
37
-
38
- def __init__(self, logger, level=None):
39
- super(TqdmToLogger, self).__init__()
40
- self.logger = logger
41
- self.level = level or logging.INFO
42
-
43
- def write(self, buf):
44
- self.buf = buf.strip("\r\n\t ")
45
-
46
- def flush(self):
47
- self.logger.log(self.level, self.buf)
48
-
49
-
50
- def rgb_to_hsv(arr):
51
- rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv)
52
- r, g, b = np.rollaxis(arr, axis=-1)
53
- h, s, v = rgb_to_hsv_channels(r, g, b)
54
- hsv = np.stack((h, s, v), axis=-1)
55
- return hsv
56
-
57
-
58
- def hsv_to_rgb(arr):
59
- hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb)
60
- h, s, v = np.rollaxis(arr, axis=-1)
61
- r, g, b = hsv_to_rgb_channels(h, s, v)
62
- rgb = np.stack((r, g, b), axis=-1)
63
- return rgb
64
-
65
-
66
- def download_url_to_file(url, dst, progress=True):
67
- r"""Download object at the given URL to a local path.
68
- Thanks to torch, slightly modified
69
- Args:
70
- url (string): URL of the object to download
71
- dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
72
- progress (bool, optional): whether or not to display a progress bar to stderr
73
- Default: True
74
- """
75
- file_size = None
76
- import ssl
77
- ssl._create_default_https_context = ssl._create_unverified_context
78
- u = urlopen(url)
79
- meta = u.info()
80
- if hasattr(meta, "getheaders"):
81
- content_length = meta.getheaders("Content-Length")
82
- else:
83
- content_length = meta.get_all("Content-Length")
84
- if content_length is not None and len(content_length) > 0:
85
- file_size = int(content_length[0])
86
- # We deliberately save it in a temp file and move it after
87
- dst = os.path.expanduser(dst)
88
- dst_dir = os.path.dirname(dst)
89
- f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
90
- try:
91
- with tqdm(total=file_size, disable=not progress, unit="B", unit_scale=True,
92
- unit_divisor=1024) as pbar:
93
- while True:
94
- buffer = u.read(8192)
95
- if len(buffer) == 0:
96
- break
97
- f.write(buffer)
98
- pbar.update(len(buffer))
99
- f.close()
100
- shutil.move(f.name, dst)
101
- finally:
102
- f.close()
103
- if os.path.exists(f.name):
104
- os.remove(f.name)
105
-
106
-
107
- def distance_to_boundary(masks):
108
- """Get the distance to the boundary of mask pixels.
109
-
110
- Args:
111
- masks (int, 2D or 3D array): The masks array. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
112
-
113
- Returns:
114
- dist_to_bound (2D or 3D array): The distance to the boundary. Size [Ly x Lx] or [Lz x Ly x Lx].
115
-
116
- Raises:
117
- ValueError: If the masks array is not 2D or 3D.
118
-
119
- """
120
- if masks.ndim > 3 or masks.ndim < 2:
121
- raise ValueError("distance_to_boundary takes 2D or 3D array, not %dD array" %
122
- masks.ndim)
123
- dist_to_bound = np.zeros(masks.shape, np.float64)
124
-
125
- if masks.ndim == 3:
126
- for i in range(masks.shape[0]):
127
- dist_to_bound[i] = distance_to_boundary(masks[i])
128
- return dist_to_bound
129
- else:
130
- slices = find_objects(masks)
131
- for i, si in enumerate(slices):
132
- if si is not None:
133
- sr, sc = si
134
- mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
135
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
136
- cv2.CHAIN_APPROX_NONE)
137
- pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
138
- ypix, xpix = np.nonzero(mask)
139
- min_dist = ((ypix[:, np.newaxis] - pvr)**2 +
140
- (xpix[:, np.newaxis] - pvc)**2).min(axis=1)
141
- dist_to_bound[ypix + sr.start, xpix + sc.start] = min_dist
142
- return dist_to_bound
143
-
144
-
145
- def masks_to_edges(masks, threshold=1.0):
146
- """Get edges of masks as a 0-1 array.
147
-
148
- Args:
149
- masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
150
- threshold (float, optional): Threshold value for distance to boundary. Defaults to 1.0.
151
-
152
- Returns:
153
- edges (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are edge pixels.
154
- """
155
- dist_to_bound = distance_to_boundary(masks)
156
- edges = (dist_to_bound < threshold) * (masks > 0)
157
- return edges
158
-
159
-
160
- def remove_edge_masks(masks, change_index=True):
161
- """Removes masks with pixels on the edge of the image.
162
-
163
- Args:
164
- masks (int, 2D or 3D array): The masks to be processed. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
165
- change_index (bool, optional): If True, after removing masks, changes the indexing so that there are no missing label numbers. Defaults to True.
166
-
167
- Returns:
168
- outlines (2D or 3D array): The processed masks. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
169
- """
170
- slices = find_objects(masks.astype(int))
171
- for i, si in enumerate(slices):
172
- remove = False
173
- if si is not None:
174
- for d, sid in enumerate(si):
175
- if sid.start == 0 or sid.stop == masks.shape[d]:
176
- remove = True
177
- break
178
- if remove:
179
- masks[si][masks[si] == i + 1] = 0
180
- shape = masks.shape
181
- if change_index:
182
- _, masks = np.unique(masks, return_inverse=True)
183
- masks = np.reshape(masks, shape).astype(np.int32)
184
-
185
- return masks
186
-
187
-
188
- def masks_to_outlines(masks):
189
- """Get outlines of masks as a 0-1 array.
190
-
191
- Args:
192
- masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
193
-
194
- Returns:
195
- outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines.
196
- """
197
- if masks.ndim > 3 or masks.ndim < 2:
198
- raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
199
- masks.ndim)
200
- outlines = np.zeros(masks.shape, bool)
201
-
202
- if masks.ndim == 3:
203
- for i in range(masks.shape[0]):
204
- outlines[i] = masks_to_outlines(masks[i])
205
- return outlines
206
- else:
207
- slices = find_objects(masks.astype(int))
208
- for i, si in enumerate(slices):
209
- if si is not None:
210
- sr, sc = si
211
- mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
212
- contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
213
- cv2.CHAIN_APPROX_NONE)
214
- pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
215
- vr, vc = pvr + sr.start, pvc + sc.start
216
- outlines[vr, vc] = 1
217
- return outlines
218
-
219
-
220
- def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None):
221
- """Get outlines of masks as a list to loop over for plotting.
222
-
223
- Args:
224
- masks (ndarray): Array of masks.
225
- multiprocessing_threshold (int, optional): Threshold for enabling multiprocessing. Defaults to 1000.
226
- multiprocessing (bool, optional): Flag to enable multiprocessing. Defaults to None.
227
-
228
- Returns:
229
- list: List of outlines.
230
-
231
- Raises:
232
- None
233
-
234
- Notes:
235
- - This function is a wrapper for outlines_list_single and outlines_list_multi.
236
- - Multiprocessing is disabled for Windows.
237
- """
238
- # default to use multiprocessing if not few_masks, but allow user to override
239
- if multiprocessing is None:
240
- few_masks = np.max(masks) < multiprocessing_threshold
241
- multiprocessing = not few_masks
242
-
243
- # disable multiprocessing for Windows
244
- if os.name == "nt":
245
- if multiprocessing:
246
- logging.getLogger(__name__).warning(
247
- "Multiprocessing is disabled for Windows")
248
- multiprocessing = False
249
-
250
- if multiprocessing:
251
- return outlines_list_multi(masks)
252
- else:
253
- return outlines_list_single(masks)
254
-
255
-
256
- def outlines_list_single(masks):
257
- """Get outlines of masks as a list to loop over for plotting.
258
-
259
- Args:
260
- masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
261
-
262
- Returns:
263
- list: List of outlines as pixel coordinates.
264
-
265
- """
266
- outpix = []
267
- for n in np.unique(masks)[1:]:
268
- mn = masks == n
269
- if mn.sum() > 0:
270
- contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
271
- method=cv2.CHAIN_APPROX_NONE)
272
- contours = contours[-2]
273
- cmax = np.argmax([c.shape[0] for c in contours])
274
- pix = contours[cmax].astype(int).squeeze()
275
- if len(pix) > 4:
276
- outpix.append(pix)
277
- else:
278
- outpix.append(np.zeros((0, 2)))
279
- return outpix
280
-
281
-
282
- def outlines_list_multi(masks, num_processes=None):
283
- """
284
- Get outlines of masks as a list to loop over for plotting.
285
-
286
- Args:
287
- masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
288
-
289
- Returns:
290
- list: List of outlines as pixel coordinates.
291
- """
292
- if num_processes is None:
293
- num_processes = cpu_count()
294
-
295
- unique_masks = np.unique(masks)[1:]
296
- with Pool(processes=num_processes) as pool:
297
- outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks])
298
- return outpix
299
-
300
-
301
- def get_outline_multi(args):
302
- """Get the outline of a specific mask in a multi-mask image.
303
-
304
- Args:
305
- args (tuple): A tuple containing the masks and the mask number.
306
-
307
- Returns:
308
- numpy.ndarray: The outline of the specified mask as an array of coordinates.
309
-
310
- """
311
- masks, n = args
312
- mn = masks == n
313
- if mn.sum() > 0:
314
- contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
315
- method=cv2.CHAIN_APPROX_NONE)
316
- contours = contours[-2]
317
- cmax = np.argmax([c.shape[0] for c in contours])
318
- pix = contours[cmax].astype(int).squeeze()
319
- return pix if len(pix) > 4 else np.zeros((0, 2))
320
- return np.zeros((0, 2))
321
-
322
-
323
- def dilate_masks(masks, n_iter=5):
324
- """Dilate masks by n_iter pixels.
325
-
326
- Args:
327
- masks (ndarray): Array of masks.
328
- n_iter (int, optional): Number of pixels to dilate the masks. Defaults to 5.
329
-
330
- Returns:
331
- ndarray: Dilated masks.
332
- """
333
- dilated_masks = masks.copy()
334
- for n in range(n_iter):
335
- # define the structuring element to use for dilation
336
- kernel = np.ones((3, 3), "uint8")
337
- # find the distance to each mask (distances are zero within masks)
338
- dist_transform = cv2.distanceTransform((dilated_masks == 0).astype("uint8"),
339
- cv2.DIST_L2, 5)
340
- # dilate each mask and assign to it the pixels along the border of the mask
341
- # (does not allow dilation into other masks since dist_transform is zero there)
342
- for i in range(1, np.max(masks) + 1):
343
- mask = (dilated_masks == i).astype("uint8")
344
- dilated_mask = cv2.dilate(mask, kernel, iterations=1)
345
- dilated_mask = np.logical_and(dist_transform < 2, dilated_mask)
346
- dilated_masks[dilated_mask > 0] = i
347
- return dilated_masks
348
-
349
-
350
- def get_perimeter(points):
351
- """
352
- Calculate the perimeter of a set of points.
353
-
354
- Parameters:
355
- points (ndarray): An array of points with shape (npoints, ndim).
356
-
357
- Returns:
358
- float: The perimeter of the points.
359
-
360
- """
361
- if points.shape[0] > 4:
362
- points = np.append(points, points[:1], axis=0)
363
- return ((np.diff(points, axis=0)**2).sum(axis=1)**0.5).sum()
364
- else:
365
- return 0
366
-
367
-
368
- def get_mask_compactness(masks):
369
- """
370
- Calculate the compactness of masks.
371
-
372
- Parameters:
373
- masks (ndarray): Binary masks representing objects.
374
-
375
- Returns:
376
- ndarray: Array of compactness values for each mask.
377
- """
378
- perimeters = get_mask_perimeters(masks)
379
- npoints = np.unique(masks, return_counts=True)[1][1:]
380
- areas = npoints
381
- compactness = 4 * np.pi * areas / perimeters**2
382
- compactness[perimeters == 0] = 0
383
- compactness[compactness > 1.0] = 1.0
384
- return compactness
385
-
386
-
387
- def get_mask_perimeters(masks):
388
- """
389
- Calculate the perimeters of the given masks.
390
-
391
- Parameters:
392
- masks (numpy.ndarray): Binary masks representing objects.
393
-
394
- Returns:
395
- numpy.ndarray: Array containing the perimeters of each mask.
396
- """
397
- perimeters = np.zeros(masks.max())
398
- for n in range(masks.max()):
399
- mn = masks == (n + 1)
400
- if mn.sum() > 0:
401
- contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
402
- method=cv2.CHAIN_APPROX_NONE)[-2]
403
- perimeters[n] = np.array(
404
- [get_perimeter(c.astype(int).squeeze()) for c in contours]).sum()
405
-
406
- return perimeters
407
-
408
-
409
- def circleMask(d0):
410
- """
411
- Creates an array with indices which are the radius of that x,y point.
412
-
413
- Args:
414
- d0 (tuple): Patch of (-d0, d0+1) over which radius is computed.
415
-
416
- Returns:
417
- tuple: A tuple containing:
418
- - rs (ndarray): Array of radii with shape (2*d0[0]+1, 2*d0[1]+1).
419
- - dx (ndarray): Indices of the patch along the x-axis.
420
- - dy (ndarray): Indices of the patch along the y-axis.
421
- """
422
- dx = np.tile(np.arange(-d0[1], d0[1] + 1), (2 * d0[0] + 1, 1))
423
- dy = np.tile(np.arange(-d0[0], d0[0] + 1), (2 * d0[1] + 1, 1))
424
- dy = dy.transpose()
425
-
426
- rs = (dy**2 + dx**2)**0.5
427
- return rs, dx, dy
428
-
429
-
430
- def get_mask_stats(masks_true):
431
- """
432
- Calculate various statistics for the given binary masks.
433
-
434
- Parameters:
435
- masks_true (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
436
-
437
- Returns:
438
- convexity (ndarray): Convexity values for each mask.
439
- solidity (ndarray): Solidity values for each mask.
440
- compactness (ndarray): Compactness values for each mask.
441
- """
442
- mask_perimeters = get_mask_perimeters(masks_true)
443
-
444
- # disk for compactness
445
- rs, dy, dx = circleMask(np.array([100, 100]))
446
- rsort = np.sort(rs.flatten())
447
-
448
- # area for solidity
449
- npoints = np.unique(masks_true, return_counts=True)[1][1:]
450
- areas = npoints - mask_perimeters / 2 - 1
451
-
452
- compactness = np.zeros(masks_true.max())
453
- convexity = np.zeros(masks_true.max())
454
- solidity = np.zeros(masks_true.max())
455
- convex_perimeters = np.zeros(masks_true.max())
456
- convex_areas = np.zeros(masks_true.max())
457
- for ic in range(masks_true.max()):
458
- points = np.array(np.nonzero(masks_true == (ic + 1))).T
459
- if len(points) > 15 and mask_perimeters[ic] > 0:
460
- med = np.median(points, axis=0)
461
- # compute compactness of ROI
462
- r2 = ((points - med)**2).sum(axis=1)**0.5
463
- compactness[ic] = (rsort[:r2.size].mean() + 1e-10) / r2.mean()
464
- try:
465
- hull = ConvexHull(points)
466
- convex_perimeters[ic] = hull.area
467
- convex_areas[ic] = hull.volume
468
- except:
469
- convex_perimeters[ic] = 0
470
-
471
- convexity[mask_perimeters > 0.0] = (convex_perimeters[mask_perimeters > 0.0] /
472
- mask_perimeters[mask_perimeters > 0.0])
473
- solidity[convex_areas > 0.0] = (areas[convex_areas > 0.0] /
474
- convex_areas[convex_areas > 0.0])
475
- convexity = np.clip(convexity, 0.0, 1.0)
476
- solidity = np.clip(solidity, 0.0, 1.0)
477
- compactness = np.clip(compactness, 0.0, 1.0)
478
- return convexity, solidity, compactness
479
-
480
-
481
- def get_masks_unet(output, cell_threshold=0, boundary_threshold=0):
482
- """Create masks using cell probability and cell boundary.
483
-
484
- Args:
485
- output (ndarray): The output array containing cell probability and cell boundary.
486
- cell_threshold (float, optional): The threshold value for cell probability. Defaults to 0.
487
- boundary_threshold (float, optional): The threshold value for cell boundary. Defaults to 0.
488
-
489
- Returns:
490
- ndarray: The masks representing the segmented cells.
491
-
492
- """
493
- cells = (output[..., 1] - output[..., 0]) > cell_threshold
494
- selem = generate_binary_structure(cells.ndim, connectivity=1)
495
- labels, nlabels = label(cells, selem)
496
-
497
- if output.shape[-1] > 2:
498
- slices = find_objects(labels)
499
- dists = 10000 * np.ones(labels.shape, np.float32)
500
- mins = np.zeros(labels.shape, np.int32)
501
- borders = np.logical_and(~(labels > 0), output[..., 2] > boundary_threshold)
502
- pad = 10
503
- for i, slc in enumerate(slices):
504
- if slc is not None:
505
- slc_pad = tuple([
506
- slice(max(0, sli.start - pad), min(labels.shape[j], sli.stop + pad))
507
- for j, sli in enumerate(slc)
508
- ])
509
- msk = (labels[slc_pad] == (i + 1)).astype(np.float32)
510
- msk = 1 - gaussian_filter(msk, 5)
511
- dists[slc_pad] = np.minimum(dists[slc_pad], msk)
512
- mins[slc_pad][dists[slc_pad] == msk] = (i + 1)
513
- labels[labels == 0] = borders[labels == 0] * mins[labels == 0]
514
-
515
- masks = labels
516
- shape0 = masks.shape
517
- _, masks = np.unique(masks, return_inverse=True)
518
- masks = np.reshape(masks, shape0)
519
- return masks
520
-
521
-
522
- def stitch3D(masks, stitch_threshold=0.25):
523
- """
524
- Stitch 2D masks into a 3D volume using a stitch_threshold on IOU.
525
-
526
- Args:
527
- masks (list or ndarray): List of 2D masks.
528
- stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25.
529
-
530
- Returns:
531
- list: List of stitched 3D masks.
532
- """
533
- mmax = masks[0].max()
534
- empty = 0
535
- for i in trange(len(masks) - 1):
536
- iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:]
537
- if not iou.size and empty == 0:
538
- masks[i + 1] = masks[i + 1]
539
- mmax = masks[i + 1].max()
540
- elif not iou.size and not empty == 0:
541
- icount = masks[i + 1].max()
542
- istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype)
543
- mmax += icount
544
- istitch = np.append(np.array(0), istitch)
545
- masks[i + 1] = istitch[masks[i + 1]]
546
- else:
547
- iou[iou < stitch_threshold] = 0.0
548
- iou[iou < iou.max(axis=0)] = 0.0
549
- istitch = iou.argmax(axis=1) + 1
550
- ino = np.nonzero(iou.max(axis=1) == 0.0)[0]
551
- istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype)
552
- mmax += len(ino)
553
- istitch = np.append(np.array(0), istitch)
554
- masks[i + 1] = istitch[masks[i + 1]]
555
- empty = 1
556
-
557
- return masks
558
-
559
-
560
- def diameters(masks):
561
- """
562
- Calculate the diameters of the objects in the given masks.
563
-
564
- Parameters:
565
- masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
566
-
567
- Returns:
568
- tuple: A tuple containing the median diameter and an array of diameters for each object.
569
-
570
- Examples:
571
- >>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]])
572
- >>> diameters(masks)
573
- (1.0, array([1.41421356, 1.0, 1.0]))
574
- """
575
- uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True)
576
- counts = counts[1:]
577
- md = np.median(counts**0.5)
578
- if np.isnan(md):
579
- md = 0
580
- md /= (np.pi**0.5) / 2
581
- return md, counts**0.5
582
-
583
-
584
- def radius_distribution(masks, bins):
585
- """
586
- Calculate the radius distribution of masks.
587
-
588
- Args:
589
- masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
590
- bins (int): Number of bins for the histogram.
591
-
592
- Returns:
593
- A tuple containing a normalized histogram of radii, median radius, array of radii.
594
-
595
- """
596
- unique, counts = np.unique(masks, return_counts=True)
597
- counts = counts[unique != 0]
598
- nb, _ = np.histogram((counts**0.5) * 0.5, bins)
599
- nb = nb.astype(np.float32)
600
- if nb.sum() > 0:
601
- nb = nb / nb.sum()
602
- md = np.median(counts**0.5) * 0.5
603
- if np.isnan(md):
604
- md = 0
605
- md /= (np.pi**0.5) / 2
606
- return nb, md, (counts**0.5) / 2
607
-
608
-
609
- def size_distribution(masks):
610
- """
611
- Calculates the size distribution of masks.
612
-
613
- Args:
614
- masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
615
-
616
- Returns:
617
- float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes.
618
- """
619
- counts = np.unique(masks, return_counts=True)[1][1:]
620
- return np.percentile(counts, 25) / np.percentile(counts, 75)
621
-
622
-
623
- def fill_holes_and_remove_small_masks(masks, min_size=15):
624
- """ Fills holes in masks (2D/3D) and discards masks smaller than min_size.
625
-
626
- This function fills holes in each mask using fill_voids.fill.
627
- It also removes masks that are smaller than the specified min_size.
628
-
629
- Parameters:
630
- masks (ndarray): Int, 2D or 3D array of labelled masks.
631
- 0 represents no mask, while positive integers represent mask labels.
632
- The size can be [Ly x Lx] or [Lz x Ly x Lx].
633
- min_size (int, optional): Minimum number of pixels per mask.
634
- Masks smaller than min_size will be removed.
635
- Set to -1 to turn off this functionality. Default is 15.
636
-
637
- Returns:
638
- ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
639
- 0 represents no mask, while positive integers represent mask labels.
640
- The size is [Ly x Lx] or [Lz x Ly x Lx].
641
- """
642
-
643
- if masks.ndim > 3 or masks.ndim < 2:
644
- raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
645
- masks.ndim)
646
-
647
- # Filter small masks
648
- if min_size > 0:
649
- counts = fastremap.unique(masks, return_counts=True)[1][1:]
650
- masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
651
- fastremap.renumber(masks, in_place=True)
652
-
653
- slices = find_objects(masks)
654
- j = 0
655
- for i, slc in enumerate(slices):
656
- if slc is not None:
657
- msk = masks[slc] == (i + 1)
658
- msk = fill_voids.fill(msk)
659
- masks[slc][msk] = (j + 1)
660
- j += 1
661
-
662
- if min_size > 0:
663
- counts = fastremap.unique(masks, return_counts=True)[1][1:]
664
- masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
665
- fastremap.renumber(masks, in_place=True)
666
-
667
- return masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/cellpose/version.py DELETED
@@ -1,18 +0,0 @@
1
- """
2
- Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
- """
4
- from importlib.metadata import PackageNotFoundError, version
5
- import sys
6
- from platform import python_version
7
- import torch
8
-
9
- try:
10
- version = version("cellpose")
11
- except PackageNotFoundError:
12
- version = "unknown"
13
-
14
- version_str = f"""
15
- cellpose version: \t{version}
16
- platform: \t{sys.platform}
17
- python version: \t{python_version()}
18
- torch version: \t{torch.__version__}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/{cellpose/core.py โ†’ core.py} RENAMED
@@ -109,39 +109,6 @@ def assign_device(use_torch=True, gpu=False, device=0):
109
  return device, gpu
110
 
111
 
112
- def _to_device(x, device, dtype=torch.float32):
113
- """
114
- Converts the input tensor or numpy array to the specified device.
115
-
116
- Args:
117
- x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
118
- device (torch.device): The target device.
119
-
120
- Returns:
121
- torch.Tensor: The converted tensor on the specified device.
122
- """
123
- if not isinstance(x, torch.Tensor):
124
- X = torch.from_numpy(x).to(device, dtype=dtype)
125
- return X
126
- else:
127
- return x
128
-
129
-
130
- def _from_device(X):
131
- """
132
- Converts a PyTorch tensor from the device to a NumPy array on the CPU.
133
-
134
- Args:
135
- X (torch.Tensor): The input PyTorch tensor.
136
-
137
- Returns:
138
- numpy.ndarray: The converted NumPy array.
139
- """
140
- # The cast is so numpy conversion always works
141
- x = X.detach().cpu().to(torch.float32).numpy()
142
- return x
143
-
144
-
145
  def _forward(net, x, feat=None):
146
  """Converts images to torch tensors, runs the network model, and returns numpy arrays.
147
 
@@ -152,15 +119,19 @@ def _forward(net, x, feat=None):
152
  Returns:
153
  Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
154
  """
155
- X = _to_device(x, device=net.device, dtype=net.dtype)
 
 
 
156
  if feat is not None:
157
- feat = _to_device(feat, device=net.device, dtype=net.dtype)
 
158
  net.eval()
159
  with torch.no_grad():
160
  y, style = net(X, feat=feat)[:2]
161
  del X
162
- y = _from_device(y)
163
- style = _from_device(style)
164
  return y, style
165
 
166
 
@@ -269,54 +240,3 @@ def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1,
269
  yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
270
  yf = yf.transpose(0,2,3,1)
271
  return yf, np.array(styles)
272
-
273
-
274
- def run_3D(net, imgs, batch_size=8, augment=False,
275
- tile_overlap=0.1, bsize=224, net_ortho=None,
276
- progress=None):
277
- """
278
- Run network on image z-stack.
279
-
280
- (faster if augment is False)
281
-
282
- Args:
283
- imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
284
- batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
285
- rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
286
- anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
287
- augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
288
- tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
289
- bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
290
- net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
291
- progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
292
-
293
- Returns:
294
- Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
295
- y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
296
- style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
297
- """
298
- sstr = ["YX", "ZY", "ZX"]
299
- pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
300
- ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
301
- cp = [(1, 2), (0, 2), (0, 1)]
302
- cpy = [(0, 1), (0, 1), (0, 1)]
303
- shape = imgs.shape[:-1]
304
- yf = np.zeros((*shape, 4), "float32")
305
- for p in range(3):
306
- xsl = imgs.transpose(pm[p])
307
- # per image
308
- core_logger.info("running %s: %d planes of size (%d, %d)" %
309
- (sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
310
- y, style = run_net(net,
311
- xsl, batch_size=batch_size, augment=augment,
312
- bsize=bsize, tile_overlap=tile_overlap,
313
- rsz=None)
314
- yf[..., -1] += y[..., -1].transpose(ipm[p])
315
- for j in range(2):
316
- yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
317
- y = None; del y
318
-
319
- if progress is not None:
320
- progress.setValue(25 + 15 * p)
321
-
322
- return yf, style
 
109
  return device, gpu
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def _forward(net, x, feat=None):
113
  """Converts images to torch tensors, runs the network model, and returns numpy arrays.
114
 
 
119
  Returns:
120
  Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
121
  """
122
+ if not isinstance(x, torch.Tensor):
123
+ X = torch.from_numpy(x).to(net.device, dtype=net.dtype)
124
+ else:
125
+ X = x
126
  if feat is not None:
127
+ if not isinstance(feat, torch.Tensor):
128
+ feat = torch.from_numpy(feat).to(net.device, dtype=net.dtype)
129
  net.eval()
130
  with torch.no_grad():
131
  y, style = net(X, feat=feat)[:2]
132
  del X
133
+ y = y.detach().cpu().to(torch.float32).numpy()
134
+ style = style.detach().cpu().to(torch.float32).numpy()
135
  return y, style
136
 
137
 
 
240
  yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
241
  yf = yf.transpose(0,2,3,1)
242
  return yf, np.array(styles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/seg_post_model/{cellpose/dynamics.py โ†’ dynamics.py} RENAMED
@@ -151,126 +151,6 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
151
  mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
152
  return mu0
153
 
154
- def masks_to_flows_gpu_3d(masks, device=None, niter=None):
155
- """Convert masks to flows using diffusion from center pixel.
156
-
157
- Args:
158
- masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
159
- device (torch.device, optional): The device to run the computation on. Defaults to None.
160
- niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
161
-
162
- Returns:
163
- np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
164
-
165
- """
166
- if device is None:
167
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
168
-
169
- Lz0, Ly0, Lx0 = masks.shape
170
- Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
171
-
172
- masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
173
- masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
174
-
175
- # get mask pixel neighbors
176
- z, y, x = torch.nonzero(masks_padded).T
177
- neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
178
- neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
179
- neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)
180
-
181
- neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)
182
-
183
- # get mask centers
184
- slices = find_objects(masks)
185
-
186
- centers = np.zeros((masks.max(), 3), "int")
187
- for i, si in enumerate(slices):
188
- if si is not None:
189
- sz, sy, sx = si
190
- zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1))
191
- zi = zi.astype(np.int32) + 1 # add padding
192
- yi = yi.astype(np.int32) + 1 # add padding
193
- xi = xi.astype(np.int32) + 1 # add padding
194
- zmed = np.mean(zi)
195
- ymed = np.mean(yi)
196
- xmed = np.mean(xi)
197
- imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2)
198
- zmed = zi[imin]
199
- ymed = yi[imin]
200
- xmed = xi[imin]
201
- centers[i, 0] = zmed + sz.start
202
- centers[i, 1] = ymed + sy.start
203
- centers[i, 2] = xmed + sx.start
204
-
205
- # get neighbor validator (not all neighbors are in same mask)
206
- neighbor_masks = masks_padded[tuple(neighbors)]
207
- isneighbor = neighbor_masks == neighbor_masks[0]
208
- ext = np.array(
209
- [[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
210
- for sz, sy, sx in slices])
211
- n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter
212
-
213
- # run diffusion
214
- shape = masks_padded.shape
215
- mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter,
216
- device=device)
217
- # normalize
218
- mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
219
-
220
- # put into original image
221
- mu0 = np.zeros((3, Lz0, Ly0, Lx0))
222
- mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
223
- return mu0
224
-
225
- def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None,
226
- return_flows=True):
227
- """Converts labels (list of masks or flows) to flows for training model.
228
-
229
- Args:
230
- labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx],
231
- it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D)
232
- is used to create flows and cell probabilities.
233
- files (list of str, optional): The files to save the flows to. If provided, flows are saved to
234
- files to be reused. Defaults to None.
235
- device (str, optional): The device to use for computation. Defaults to None.
236
- redo_flows (bool, optional): Whether to recompute the flows. Defaults to False.
237
- niter (int, optional): The number of iterations for computing flows. Defaults to None.
238
-
239
- Returns:
240
- list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k],
241
- flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow,
242
- and flows[k][4] is heat distribution.
243
- """
244
- nimg = len(labels)
245
- if labels[0].ndim < 3:
246
- labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]
247
-
248
- flows = []
249
- # flows need to be recomputed
250
- if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
251
- dynamics_logger.info("computing flows for labels")
252
-
253
- # compute flows; labels are fixed here to be unique, so they need to be passed back
254
- # make sure labels are unique!
255
- labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
256
- iterator = trange if nimg > 1 else range
257
- for n in iterator(nimg):
258
- labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
259
- vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)
260
-
261
- # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
262
- flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
263
- axis=0).astype(np.float32)
264
- if files is not None:
265
- file_name = os.path.splitext(files[n])[0]
266
- tifffile.imwrite(file_name + "_flows.tif", flow)
267
- if return_flows:
268
- flows.append(flow)
269
- else:
270
- dynamics_logger.info("flows precomputed")
271
- if return_flows:
272
- flows = [labels[n].astype(np.float32) for n in range(nimg)]
273
- return flows
274
 
275
 
276
  def flow_error(maski, dP_net, device=None):
@@ -372,29 +252,6 @@ def steps_interp(dP, inds, niter, device=torch.device("cpu")):
372
  pt = pt.unsqueeze(0) if pt.ndim==1 else pt
373
  return pt.T
374
 
375
- def follow_flows(dP, inds, niter=200, device=torch.device("cpu")):
376
- """ Run dynamics to recover masks in 2D or 3D.
377
-
378
- Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability
379
- are used (as defined by inds).
380
-
381
- Args:
382
- dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
383
- mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes.
384
- niter (int, optional): Number of iterations of dynamics to run. Default is 200.
385
- interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True.
386
- device (torch.device, optional): Device to use for computation. Default is None.
387
-
388
- Returns:
389
- A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
390
- inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
391
- """
392
- shape = np.array(dP.shape[1:]).astype(np.int32)
393
- ndim = len(inds)
394
-
395
- p = steps_interp(dP, inds, niter, device=device)
396
-
397
- return p
398
 
399
 
400
  def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
@@ -551,7 +408,6 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
551
  seed_masks[:,5,5,5] = 1
552
 
553
  for iter in range(5):
554
- # extend
555
  seed_masks = max_pool_nd(seed_masks, kernel_size=3)
556
  seed_masks *= h_slc > 2
557
  del h_slc
@@ -580,7 +436,6 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
580
  fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
581
  M0 = M0.reshape(tuple(shape0))
582
 
583
- #print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
584
  return M0
585
 
586
 
@@ -652,7 +507,7 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
652
  mask = np.zeros(shape, "uint16")
653
  return mask
654
 
655
- p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5.,
656
  inds=inds, niter=niter,
657
  device=device)
658
  if not torch.is_tensor(p_final):
 
151
  mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
152
  return mu0
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
 
156
  def flow_error(maski, dP_net, device=None):
 
252
  pt = pt.unsqueeze(0) if pt.ndim==1 else pt
253
  return pt.T
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
 
257
  def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
 
408
  seed_masks[:,5,5,5] = 1
409
 
410
  for iter in range(5):
 
411
  seed_masks = max_pool_nd(seed_masks, kernel_size=3)
412
  seed_masks *= h_slc > 2
413
  del h_slc
 
436
  fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
437
  M0 = M0.reshape(tuple(shape0))
438
 
 
439
  return M0
440
 
441
 
 
507
  mask = np.zeros(shape, "uint16")
508
  return mask
509
 
510
+ p_final = steps_interp(dP * (cellprob > cellprob_threshold) / 5.,
511
  inds=inds, niter=niter,
512
  device=device)
513
  if not torch.is_tensor(p_final):
models/seg_post_model/io.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import cv2
7
+ import tifffile
8
+ import logging
9
+ from tqdm import tqdm
10
+ import re
11
+
12
+ try:
13
+ import nd2
14
+ ND2 = True
15
+ except:
16
+ ND2 = False
17
+
18
+ try:
19
+ import nrrd
20
+ NRRD = True
21
+ except:
22
+ NRRD = False
23
+
24
+
25
+ io_logger = logging.getLogger(__name__)
26
+
27
+ def load_dax(filename):
28
+ ### modified from ZhuangLab github:
29
+ ### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
30
+
31
+ inf_filename = os.path.splitext(filename)[0] + ".inf"
32
+ if not os.path.exists(inf_filename):
33
+ io_logger.critical(
34
+ f"ERROR: no inf file found for dax file {filename}, cannot load dax without it"
35
+ )
36
+ return None
37
+
38
+ ### get metadata
39
+ image_height, image_width = None, None
40
+ # extract the movie information from the associated inf file
41
+ size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)")
42
+ length_re = re.compile(r"number of frames = ([\d]+)")
43
+ endian_re = re.compile(r" (big|little) endian")
44
+
45
+ with open(inf_filename, "r") as inf_file:
46
+ lines = inf_file.read().split("\n")
47
+ for line in lines:
48
+ m = size_re.match(line)
49
+ if m:
50
+ image_height = int(m.group(2))
51
+ image_width = int(m.group(1))
52
+ m = length_re.match(line)
53
+ if m:
54
+ number_frames = int(m.group(1))
55
+ m = endian_re.search(line)
56
+ if m:
57
+ if m.group(1) == "big":
58
+ bigendian = 1
59
+ else:
60
+ bigendian = 0
61
+ # set defaults, warn the user that they couldn"t be determined from the inf file.
62
+ if not image_height:
63
+ io_logger.warning("could not determine dax image size, assuming 256x256")
64
+ image_height = 256
65
+ image_width = 256
66
+
67
+ ### load image
68
+ img = np.memmap(filename, dtype="uint16",
69
+ shape=(number_frames, image_height, image_width))
70
+ if bigendian:
71
+ img = img.byteswap()
72
+ img = np.array(img)
73
+
74
+ return img
75
+
76
+
77
+ def imread(filename):
78
+ """
79
+ Read in an image file with tif or image file type supported by cv2.
80
+
81
+ Args:
82
+ filename (str): The path to the image file.
83
+
84
+ Returns:
85
+ numpy.ndarray: The image data as a NumPy array.
86
+
87
+ Raises:
88
+ None
89
+
90
+ Raises an error if the image file format is not supported.
91
+
92
+ Examples:
93
+ >>> img = imread("image.tif")
94
+ """
95
+ # ensure that extension check is not case sensitive
96
+ ext = os.path.splitext(filename)[-1].lower()
97
+ if ext == ".tif" or ext == ".tiff" or ext == ".flex":
98
+ with tifffile.TiffFile(filename) as tif:
99
+ ltif = len(tif.pages)
100
+ try:
101
+ full_shape = tif.shaped_metadata[0]["shape"]
102
+ except:
103
+ try:
104
+ page = tif.series[0][0]
105
+ full_shape = tif.series[0].shape
106
+ except:
107
+ ltif = 0
108
+ if ltif < 10:
109
+ img = tif.asarray()
110
+ else:
111
+ page = tif.series[0][0]
112
+ shape, dtype = page.shape, page.dtype
113
+ ltif = int(np.prod(full_shape) / np.prod(shape))
114
+ io_logger.info(f"reading tiff with {ltif} planes")
115
+ img = np.zeros((ltif, *shape), dtype=dtype)
116
+ for i, page in enumerate(tqdm(tif.series[0])):
117
+ img[i] = page.asarray()
118
+ img = img.reshape(full_shape)
119
+ return img
120
+ elif ext == ".dax":
121
+ img = load_dax(filename)
122
+ return img
123
+ elif ext == ".nd2":
124
+ if not ND2:
125
+ io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
126
+ return None
127
+ elif ext == ".nrrd":
128
+ if not NRRD:
129
+ io_logger.critical(
130
+ "ERROR: need to 'pip install pynrrd' to load in .nrrd file")
131
+ return None
132
+ else:
133
+ img, metadata = nrrd.read(filename)
134
+ if img.ndim == 3:
135
+ img = img.transpose(2, 0, 1)
136
+ return img
137
+ elif ext != ".npy":
138
+ try:
139
+ img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH)
140
+ if img.ndim > 2:
141
+ img = img[..., [2, 1, 0]]
142
+ return img
143
+ except Exception as e:
144
+ io_logger.critical("ERROR: could not read file, %s" % e)
145
+ return None
146
+ else:
147
+ try:
148
+ dat = np.load(filename, allow_pickle=True).item()
149
+ masks = dat["masks"]
150
+ return masks
151
+ except Exception as e:
152
+ io_logger.critical("ERROR: could not read masks from file, %s" % e)
153
+ return None
154
+
155
+
156
+ def imsave(filename, arr):
157
+ """
158
+ Saves an image array to a file.
159
+
160
+ Args:
161
+ filename (str): The name of the file to save the image to.
162
+ arr (numpy.ndarray): The image array to be saved.
163
+
164
+ Returns:
165
+ None
166
+ """
167
+ ext = os.path.splitext(filename)[-1].lower()
168
+ if ext == ".tif" or ext == ".tiff":
169
+ tifffile.imwrite(filename, data=arr, compression="zlib")
170
+ else:
171
+ if len(arr.shape) > 2:
172
+ arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
173
+ cv2.imwrite(filename, arr)
174
+
models/seg_post_model/{cellpose/metrics.py โ†’ metrics.py} RENAMED
@@ -2,83 +2,10 @@
2
  Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
  """
4
  import numpy as np
5
- from . import utils
6
  from scipy.optimize import linear_sum_assignment
7
- from scipy.ndimage import convolve
8
  from scipy.sparse import csr_matrix
9
 
10
 
11
- def mask_ious(masks_true, masks_pred):
12
- """Return best-matched masks."""
13
- iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
14
- n_min = min(iou.shape[0], iou.shape[1])
15
- costs = -(iou >= 0.5).astype(float) - iou / (2 * n_min)
16
- true_ind, pred_ind = linear_sum_assignment(costs)
17
- iout = np.zeros(masks_true.max())
18
- iout[true_ind] = iou[true_ind, pred_ind]
19
- preds = np.zeros(masks_true.max(), "int")
20
- preds[true_ind] = pred_ind + 1
21
- return iout, preds
22
-
23
-
24
- def boundary_scores(masks_true, masks_pred, scales):
25
- """
26
- Calculate boundary precision, recall, and F-score.
27
-
28
- Args:
29
- masks_true (list): List of true masks.
30
- masks_pred (list): List of predicted masks.
31
- scales (list): List of scales.
32
-
33
- Returns:
34
- tuple: A tuple containing precision, recall, and F-score arrays.
35
- """
36
- diams = [utils.diameters(lbl)[0] for lbl in masks_true]
37
- precision = np.zeros((len(scales), len(masks_true)))
38
- recall = np.zeros((len(scales), len(masks_true)))
39
- fscore = np.zeros((len(scales), len(masks_true)))
40
- for j, scale in enumerate(scales):
41
- for n in range(len(masks_true)):
42
- diam = max(1, scale * diams[n])
43
- rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))])
44
- filt = (rs <= diam).astype(np.float32)
45
- otrue = utils.masks_to_outlines(masks_true[n])
46
- otrue = convolve(otrue, filt)
47
- opred = utils.masks_to_outlines(masks_pred[n])
48
- opred = convolve(opred, filt)
49
- tp = np.logical_and(otrue == 1, opred == 1).sum()
50
- fp = np.logical_and(otrue == 0, opred == 1).sum()
51
- fn = np.logical_and(otrue == 1, opred == 0).sum()
52
- precision[j, n] = tp / (tp + fp)
53
- recall[j, n] = tp / (tp + fn)
54
- fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
55
- return precision, recall, fscore
56
-
57
-
58
- def aggregated_jaccard_index(masks_true, masks_pred):
59
- """
60
- AJI = intersection of all matched masks / union of all masks
61
-
62
- Args:
63
- masks_true (list of np.ndarrays (int) or np.ndarray (int)):
64
- where 0=NO masks; 1,2... are mask labels
65
- masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
66
- np.ndarray (int) where 0=NO masks; 1,2... are mask labels
67
-
68
- Returns:
69
- aji (float): aggregated jaccard index for each set of masks
70
- """
71
- aji = np.zeros(len(masks_true))
72
- for n in range(len(masks_true)):
73
- iout, preds = mask_ious(masks_true[n], masks_pred[n])
74
- inds = np.arange(0, masks_true[n].max(), 1, int)
75
- overlap = _label_overlap(masks_true[n], masks_pred[n])
76
- union = np.logical_or(masks_true[n] > 0, masks_pred[n] > 0).sum()
77
- overlap = overlap[inds[preds > 0] + 1, preds[preds > 0].astype(int)]
78
- aji[n] = overlap.sum() / union
79
- return aji
80
-
81
-
82
  def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
83
  """
84
  Average precision estimation: AP = TP / (TP + FP + FN)
 
2
  Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
  """
4
  import numpy as np
 
5
  from scipy.optimize import linear_sum_assignment
 
6
  from scipy.sparse import csr_matrix
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
10
  """
11
  Average precision estimation: AP = TP / (TP + FP + FN)
models/seg_post_model/{cellpose/models.py โ†’ models.py} RENAMED
@@ -3,31 +3,29 @@ Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer,
3
  """
4
 
5
  import os, time
6
- from pathlib import Path
7
  import numpy as np
8
  from tqdm import trange
9
  import torch
10
  from scipy.ndimage import gaussian_filter
11
- import gc
12
  import cv2
13
 
14
  import logging
15
 
16
  models_logger = logging.getLogger(__name__)
17
 
18
- from . import transforms, dynamics, utils, plot
19
  from .vit_sam import Transformer
20
- from .core import assign_device, run_net, run_3D
21
 
22
- _CPSAM_MODEL_URL = "https://huggingface.co/mouseland/cellpose-sam/resolve/main/cpsam"
23
- _MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
24
- # _MODEL_DIR_DEFAULT = Path.home().joinpath(".cellpose", "models")
25
- _MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models")
26
- MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
27
 
28
- MODEL_NAMES = ["cpsam"]
29
 
30
- MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
31
 
32
  normalize_default = {
33
  "lowhigh": None,
@@ -42,30 +40,17 @@ normalize_default = {
42
  }
43
 
44
 
45
- # def model_path(model_type, model_index=0):
46
- # return cache_CPSAM_model_path()
 
 
 
 
 
 
47
 
48
 
49
- # def cache_CPSAM_model_path():
50
- # MODEL_DIR.mkdir(parents=True, exist_ok=True)
51
- # cached_file = os.fspath(MODEL_DIR.joinpath('cpsam'))
52
- # if not os.path.exists(cached_file):
53
- # models_logger.info('Downloading: "{}" to {}\n'.format(_CPSAM_MODEL_URL, cached_file))
54
- # utils.download_url_to_file(_CPSAM_MODEL_URL, cached_file, progress=True)
55
- # return cached_file
56
-
57
-
58
- def get_user_models():
59
- model_strings = []
60
- if os.path.exists(MODEL_LIST_PATH):
61
- with open(MODEL_LIST_PATH, "r") as textfile:
62
- lines = [line.rstrip() for line in textfile]
63
- if len(lines) > 0:
64
- model_strings.extend(lines)
65
- return model_strings
66
-
67
-
68
- class CellposeModel():
69
  """
70
  Class representing a Cellpose model.
71
 
@@ -102,17 +87,6 @@ class CellposeModel():
102
  device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
103
  use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
104
  """
105
- # if diam_mean is not None:
106
- # models_logger.warning(
107
- # "diam_mean argument are not used in v4.0.1+. Ignoring this argument..."
108
- # )
109
- # if model_type is not None:
110
- # models_logger.warning(
111
- # "model_type argument is not used in v4.0.1+. Ignoring this argument..."
112
- # )
113
- # if nchan is not None:
114
- # models_logger.warning("nchan argument is deprecated in v4.0.1+. Ignoring this argument")
115
-
116
  ### assign model device
117
  self.device = assign_device(gpu=gpu)[0] if device is None else device
118
  if torch.cuda.is_available():
@@ -127,36 +101,10 @@ class CellposeModel():
127
  # raise ValueError("Must specify a pretrained model, training from scratch is not implemented")
128
  pretrained_model = ""
129
 
130
- ### create neural network
131
- if pretrained_model and not os.path.exists(pretrained_model):
132
- # check if pretrained model is in the models directory
133
- model_strings = get_user_models()
134
- all_models = MODEL_NAMES.copy()
135
- all_models.extend(model_strings)
136
- if pretrained_model in all_models:
137
- pretrained_model = os.path.join(MODEL_DIR, pretrained_model)
138
- else:
139
- pretrained_model = os.path.join(MODEL_DIR, "cpsam")
140
- models_logger.warning(
141
- f"pretrained model {pretrained_model} not found, using default model"
142
- )
143
-
144
  self.pretrained_model = pretrained_model
145
  dtype = torch.bfloat16 if use_bfloat16 else torch.float32
146
  self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device)
147
 
148
- if os.path.exists(self.pretrained_model):
149
- models_logger.info(f">>>> loading model {self.pretrained_model}")
150
- self.net.load_model(self.pretrained_model, device=self.device)
151
- # else:
152
- # try:
153
- # if os.path.split(self.pretrained_model)[-1] != 'cpsam':
154
- # raise FileNotFoundError('model file not recognized')
155
- # cache_CPSAM_model_path()
156
- # self.net.load_model(self.pretrained_model, device=self.device)
157
- # except:
158
- # print("ViT not initialized")
159
-
160
 
161
  def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None,
162
  z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
@@ -166,12 +114,6 @@ class CellposeModel():
166
  augment=False, tile_overlap=0.1, bsize=256,
167
  compute_masks=True, progress=None):
168
 
169
-
170
- # if rescale is not None:
171
- # models_logger.warning("rescaling deprecated in v4.0.1+")
172
- # if channels is not None:
173
- # models_logger.warning("channels deprecated in v4.0.1+. If data contain more than 3 channels, only the first 3 channels will be used")
174
-
175
  if isinstance(x, list) or x.squeeze().ndim == 5:
176
  self.timing = []
177
  masks, styles, flows = [], [], []
@@ -230,7 +172,7 @@ class CellposeModel():
230
  Ly_0 = x.shape[1]
231
  Lx_0 = x.shape[2]
232
  Lz_0 = None
233
- if do_3D or stitch_threshold > 0:
234
  Lz_0 = x.shape[0]
235
  if diameter is not None:
236
  image_scaling = 30. / diameter
@@ -273,7 +215,7 @@ class CellposeModel():
273
  # transpose feat to have channels last
274
  feat = np.moveaxis(feat, 1, -1)
275
 
276
- # ajust the anisotropy when diameter is specified and images are resized:
277
  if isinstance(anisotropy, (float, int)) and image_scaling:
278
  anisotropy = image_scaling * anisotropy
279
 
@@ -287,12 +229,6 @@ class CellposeModel():
287
  do_3D=do_3D,
288
  anisotropy=anisotropy)
289
 
290
- if do_3D:
291
- if flow3D_smooth > 0:
292
- models_logger.info(f"smoothing flows with sigma={flow3D_smooth}")
293
- dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth))
294
- torch.cuda.empty_cache()
295
- gc.collect()
296
 
297
  if resample:
298
  # upsample flows before computing them:
@@ -310,28 +246,16 @@ class CellposeModel():
310
  else:
311
  masks = np.zeros(0) #pass back zeros if not compute_masks
312
 
313
- masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze()
314
 
315
  # undo resizing:
316
  if image_scaling is not None or anisotropy is not None:
317
 
318
- dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) # works for 2 or 3D:
319
- cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
320
-
321
- if do_3D:
322
- if compute_masks:
323
- # Rescale xy then xz:
324
- masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
325
- masks = masks.transpose(1, 0, 2)
326
- masks = transforms.resize_image(masks, Ly=Lz_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
327
- masks = masks.transpose(1, 0, 2)
328
 
329
- else:
330
- # 2D or 3D stitching case:
331
- if compute_masks:
332
- masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
333
 
334
- return masks, [plot.dx_to_circ(dP), dP, cellprob], styles
335
 
336
 
337
  def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
@@ -428,29 +352,15 @@ class CellposeModel():
428
  nimg = shape[0]
429
 
430
 
431
- if do_3D:
432
- Lz, Ly, Lx = shape[:-1]
433
- if anisotropy is not None and anisotropy != 1.0:
434
- models_logger.info(f"resizing 3D image with anisotropy={anisotropy}")
435
- x = transforms.resize_image(x.transpose(1,0,2,3),
436
- Ly=int(Lz*anisotropy),
437
- Lx=int(Lx)).transpose(1,0,2,3)
438
- yf, styles = run_3D(self.net, x,
439
- batch_size=batch_size, augment=augment,
440
- tile_overlap=tile_overlap,
441
- bsize=bsize
442
- )
443
- cellprob = yf[..., -1]
444
- dP = yf[..., :-1].transpose((3, 0, 1, 2))
445
- else:
446
- yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment,
447
- batch_size=batch_size,
448
- tile_overlap=tile_overlap,
449
- )
450
- cellprob = yf[..., -1]
451
- dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
452
- if yf.shape[-1] > 3:
453
- styles = yf[..., :-3]
454
 
455
  styles = styles.squeeze()
456
 
@@ -471,48 +381,48 @@ class CellposeModel():
471
  changed_device_from = "mps"
472
  Lz, Ly, Lx = shape[:3]
473
  tic = time.time()
474
- if do_3D:
475
- masks = dynamics.resize_and_compute_masks(
476
- dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
477
- flow_threshold=flow_threshold, do_3D=do_3D,
478
- min_size=min_size, max_size_fraction=max_size_fraction,
479
- resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum()
480
- else None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  device=self.device)
482
- else:
483
- nimg = shape[0]
484
- Ly0, Lx0 = cellprob[0].shape
485
- resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx]
486
- tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
487
- iterator = trange(nimg, file=tqdm_out,
488
- mininterval=30) if nimg > 1 else range(nimg)
489
- for i in iterator:
490
- # turn off min_size for 3D stitching
491
- min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1
492
- outputs = dynamics.resize_and_compute_masks(
493
- dP[:, i], cellprob[i],
494
- niter=niter, cellprob_threshold=cellprob_threshold,
495
- flow_threshold=flow_threshold, resize=resize,
496
- min_size=min_size0, max_size_fraction=max_size_fraction,
497
- device=self.device)
498
- if i==0 and nimg > 1:
499
- masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype)
500
- if nimg > 1:
501
- masks[i] = outputs
502
- else:
503
- masks = outputs
504
-
505
- if stitch_threshold > 0 and nimg > 1:
506
- models_logger.info(
507
- f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
508
- )
509
- masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
510
- masks = utils.fill_holes_and_remove_small_masks(
511
- masks, min_size=min_size)
512
- elif nimg > 1:
513
- models_logger.warning(
514
- "3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
515
- )
516
 
517
  flow_time = time.time() - tic
518
  if shape[0] > 1:
 
3
  """
4
 
5
  import os, time
6
+ # from pathlib import Path
7
  import numpy as np
8
  from tqdm import trange
9
  import torch
10
  from scipy.ndimage import gaussian_filter
11
+ # import gc
12
  import cv2
13
 
14
  import logging
15
 
16
  models_logger = logging.getLogger(__name__)
17
 
18
+ from . import transforms, dynamics, utils
19
  from .vit_sam import Transformer
20
+ from .core import assign_device, run_net
21
 
22
+ # _MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
23
+ # _MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models")
24
+ # MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
 
 
25
 
26
+ # MODEL_NAMES = ["cpsam"]
27
 
28
+ # MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
29
 
30
  normalize_default = {
31
  "lowhigh": None,
 
40
  }
41
 
42
 
43
+ # def get_user_models():
44
+ # model_strings = []
45
+ # if os.path.exists(MODEL_LIST_PATH):
46
+ # with open(MODEL_LIST_PATH, "r") as textfile:
47
+ # lines = [line.rstrip() for line in textfile]
48
+ # if len(lines) > 0:
49
+ # model_strings.extend(lines)
50
+ # return model_strings
51
 
52
 
53
+ class SegModel():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  """
55
  Class representing a Cellpose model.
56
 
 
87
  device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
88
  use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
89
  """
 
 
 
 
 
 
 
 
 
 
 
90
  ### assign model device
91
  self.device = assign_device(gpu=gpu)[0] if device is None else device
92
  if torch.cuda.is_available():
 
101
  # raise ValueError("Must specify a pretrained model, training from scratch is not implemented")
102
  pretrained_model = ""
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  self.pretrained_model = pretrained_model
105
  dtype = torch.bfloat16 if use_bfloat16 else torch.float32
106
  self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None,
110
  z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
 
114
  augment=False, tile_overlap=0.1, bsize=256,
115
  compute_masks=True, progress=None):
116
 
 
 
 
 
 
 
117
  if isinstance(x, list) or x.squeeze().ndim == 5:
118
  self.timing = []
119
  masks, styles, flows = [], [], []
 
172
  Ly_0 = x.shape[1]
173
  Lx_0 = x.shape[2]
174
  Lz_0 = None
175
+ if stitch_threshold > 0:
176
  Lz_0 = x.shape[0]
177
  if diameter is not None:
178
  image_scaling = 30. / diameter
 
215
  # transpose feat to have channels last
216
  feat = np.moveaxis(feat, 1, -1)
217
 
218
+ # adjust the anisotropy when diameter is specified and images are resized:
219
  if isinstance(anisotropy, (float, int)) and image_scaling:
220
  anisotropy = image_scaling * anisotropy
221
 
 
229
  do_3D=do_3D,
230
  anisotropy=anisotropy)
231
 
 
 
 
 
 
 
232
 
233
  if resample:
234
  # upsample flows before computing them:
 
246
  else:
247
  masks = np.zeros(0) #pass back zeros if not compute_masks
248
 
249
+ masks = masks.squeeze()
250
 
251
  # undo resizing:
252
  if image_scaling is not None or anisotropy is not None:
253
 
254
+ if compute_masks:
255
+ masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
 
 
256
 
257
+ return masks
 
 
 
258
 
 
259
 
260
 
261
  def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
 
352
  nimg = shape[0]
353
 
354
 
355
+
356
+ yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment,
357
+ batch_size=batch_size,
358
+ tile_overlap=tile_overlap,
359
+ )
360
+ cellprob = yf[..., -1]
361
+ dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
362
+ if yf.shape[-1] > 3:
363
+ styles = yf[..., :-3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  styles = styles.squeeze()
366
 
 
381
  changed_device_from = "mps"
382
  Lz, Ly, Lx = shape[:3]
383
  tic = time.time()
384
+ # if do_3D:
385
+ # masks = dynamics.resize_and_compute_masks(
386
+ # dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
387
+ # flow_threshold=flow_threshold, do_3D=do_3D,
388
+ # min_size=min_size, max_size_fraction=max_size_fraction,
389
+ # resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum()
390
+ # else None,
391
+ # device=self.device)
392
+ # else:
393
+ nimg = shape[0]
394
+ Ly0, Lx0 = cellprob[0].shape
395
+ resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx]
396
+ tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
397
+ iterator = trange(nimg, file=tqdm_out,
398
+ mininterval=30) if nimg > 1 else range(nimg)
399
+ for i in iterator:
400
+ # turn off min_size for 3D stitching
401
+ min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1
402
+ outputs = dynamics.resize_and_compute_masks(
403
+ dP[:, i], cellprob[i],
404
+ niter=niter, cellprob_threshold=cellprob_threshold,
405
+ flow_threshold=flow_threshold, resize=resize,
406
+ min_size=min_size0, max_size_fraction=max_size_fraction,
407
  device=self.device)
408
+ if i==0 and nimg > 1:
409
+ masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype)
410
+ if nimg > 1:
411
+ masks[i] = outputs
412
+ else:
413
+ masks = outputs
414
+
415
+ if stitch_threshold > 0 and nimg > 1:
416
+ models_logger.info(
417
+ f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
418
+ )
419
+ masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
420
+ masks = utils.fill_holes_and_remove_small_masks(
421
+ masks, min_size=min_size)
422
+ elif nimg > 1:
423
+ models_logger.warning(
424
+ "3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
425
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
  flow_time = time.time() - tic
428
  if shape[0] > 1:
models/seg_post_model/{cellpose/transforms.py โ†’ transforms.py} RENAMED
@@ -398,145 +398,6 @@ def move_axis(img, m_axis=-1, first=True):
398
  return img
399
 
400
 
401
- def move_min_dim(img, force=False):
402
- """Move the minimum dimension last as channels if it is less than 10 or force is True.
403
-
404
- Args:
405
- img (ndarray): The input image.
406
- force (bool, optional): If True, the minimum dimension will always be moved.
407
- Defaults to False.
408
-
409
- Returns:
410
- ndarray: The image with the minimum dimension moved to the last axis as channels.
411
- """
412
- if len(img.shape) > 2:
413
- min_dim = min(img.shape)
414
- if min_dim < 10 or force:
415
- if img.shape[-1] == min_dim:
416
- channel_axis = -1
417
- else:
418
- channel_axis = (img.shape).index(min_dim)
419
- img = move_axis(img, m_axis=channel_axis, first=False)
420
- return img
421
-
422
-
423
- def update_axis(m_axis, to_squeeze, ndim):
424
- """
425
- Squeeze the axis value based on the given parameters.
426
-
427
- Args:
428
- m_axis (int): The current axis value.
429
- to_squeeze (numpy.ndarray): An array of indices to squeeze.
430
- ndim (int): The number of dimensions.
431
-
432
- Returns:
433
- m_axis (int or None): The updated axis value.
434
- """
435
- if m_axis == -1:
436
- m_axis = ndim - 1
437
- if (to_squeeze == m_axis).sum() == 1:
438
- m_axis = None
439
- else:
440
- inds = np.ones(ndim, bool)
441
- inds[to_squeeze] = False
442
- m_axis = np.nonzero(np.arange(0, ndim)[inds] == m_axis)[0]
443
- if len(m_axis) > 0:
444
- m_axis = m_axis[0]
445
- else:
446
- m_axis = None
447
- return m_axis
448
-
449
-
450
- def _convert_image_3d(x, channel_axis=None, z_axis=None):
451
- """
452
- Convert a 3D or 4D image array to have dimensions ordered as (Z, X, Y, C).
453
-
454
- Arrays of ndim=3 are assumed to be grayscale and must be specified with z_axis.
455
- Arrays of ndim=4 must have both `channel_axis` and `z_axis` specified.
456
-
457
- Args:
458
- x (numpy.ndarray): Input image array. Must be either 3D (assumed to be grayscale 3D) or 4D.
459
- channel_axis (int): The axis index corresponding to the channel dimension in the input array. \
460
- Must be specified for 4D images.
461
- z_axis (int): The axis index corresponding to the depth (Z) dimension in the input array. \
462
- Must be specified for both 3D and 4D images.
463
-
464
- Returns:
465
- numpy.ndarray: A 4D image array with dimensions ordered as (Z, X, Y, C), where C is the channel
466
- dimension. If the input has fewer than 3 channels, the output will be padded with zeros to \
467
- have 3 channels. If the input has more than 3 channels, only the first 3 channels will be retained.
468
-
469
- Raises:
470
- ValueError: If `z_axis` is not specified for 3D images. If either `channel_axis` or `z_axis` \
471
- is not specified for 4D images. If the input image does not have 3 or 4 dimensions.
472
-
473
- Notes:
474
- - For 3D images (ndim=3), the function assumes the input is grayscale and adds a singleton channel dimension.
475
- - The function reorders the dimensions of the input array to ensure the output has the desired (Z, X, Y, C) order.
476
- - If the number of channels is not equal to 3, the function either truncates or pads the \
477
- channels to ensure the output has exactly 3 channels.
478
- """
479
-
480
- if x.ndim < 3:
481
- raise ValueError(f"Input image must have at least 3 dimensions, input shape: {x.shape}, ndim={x.ndim}")
482
-
483
- if z_axis is not None and z_axis < 0:
484
- z_axis += x.ndim
485
-
486
- # if image is ndim==3, assume it is greyscale 3D and use provided z_axis
487
- if x.ndim == 3 and z_axis is not None:
488
- # add in channel axis
489
- x = x[..., np.newaxis]
490
- channel_axis = 3
491
- elif x.ndim == 3 and z_axis is None:
492
- raise ValueError("z_axis must be specified when segmenting 3D images of ndim=3")
493
-
494
-
495
- if channel_axis is None or z_axis is None:
496
- raise ValueError("For 4D images, both `channel_axis` and `z_axis` must be explicitly specified. Please provide values for both parameters.")
497
- if channel_axis is not None and channel_axis < 0:
498
- channel_axis += x.ndim
499
- if channel_axis is None or channel_axis >= x.ndim:
500
- raise IndexError(f"channel_axis {channel_axis} is out of bounds for input array with {x.ndim} dimensions")
501
- assert x.ndim == 4, f"input image must have ndim == 4, ndim={x.ndim}"
502
-
503
- x_dim_shapes = list(x.shape)
504
- num_z_layers = x_dim_shapes[z_axis]
505
- num_channels = x_dim_shapes[channel_axis]
506
- x_xy_axes = [i for i in range(x.ndim)]
507
-
508
- # need to remove the z and channels from the shapes:
509
- # delete the one with the bigger index first
510
- if z_axis > channel_axis:
511
- del x_dim_shapes[z_axis]
512
- del x_dim_shapes[channel_axis]
513
-
514
- del x_xy_axes[z_axis]
515
- del x_xy_axes[channel_axis]
516
-
517
- else:
518
- del x_dim_shapes[channel_axis]
519
- del x_dim_shapes[z_axis]
520
-
521
- del x_xy_axes[channel_axis]
522
- del x_xy_axes[z_axis]
523
-
524
- x = x.transpose((z_axis, x_xy_axes[0], x_xy_axes[1], channel_axis))
525
-
526
- # Handle cases with not 3 channels:
527
- if num_channels != 3:
528
- x_chans_to_copy = min(3, num_channels)
529
-
530
- if num_channels > 3:
531
- transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels")
532
- x = x[..., :x_chans_to_copy]
533
- else:
534
- # less than 3 channels: pad up to
535
- pad_width = [(0, 0), (0, 0), (0, 0), (0, 3 - x_chans_to_copy)]
536
- x = np.pad(x, pad_width, mode='constant', constant_values=0)
537
-
538
- return x
539
-
540
 
541
  def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
542
  """Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
@@ -571,13 +432,9 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
571
  if z_axis is not None and not do_3D:
572
  raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
573
 
574
- # make sure that channel_axis and z_axis are specified if 3D
575
  if ndim == 4 and not do_3D:
576
  raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
577
 
578
- # make sure that channel_axis and z_axis are specified if 3D
579
- if do_3D:
580
- return _convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis)
581
 
582
  ######################## 2D reshaping ########################
583
  # if user specifies channel axis, return early
@@ -956,10 +813,10 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal
956
  nchan = X[0].shape[0]
957
  else:
958
  nchan = 1
959
- if do_3D and X[0].ndim > 3:
960
- shape = (zcrop, xy[0], xy[1])
961
- else:
962
- shape = (xy[0], xy[1])
963
  imgi = np.zeros((nimg, nchan, *shape), "float32")
964
 
965
  lbl = []
@@ -1008,15 +865,6 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal
1008
  if labels.ndim < 3:
1009
  labels = labels[np.newaxis, :, :]
1010
 
1011
- if do_3D:
1012
- Lz = X[n].shape[-3]
1013
- flip_z = np.random.rand() > .5
1014
- lz = int(np.round(zcrop / scale[n]))
1015
- iz = np.random.randint(0, Lz - lz)
1016
- img = img[:,iz:iz + lz,:,:]
1017
- if Y is not None:
1018
- labels = labels[:,iz:iz + lz,:,:]
1019
-
1020
  if do_flip:
1021
  if flip:
1022
  img = img[..., ::-1]
@@ -1024,47 +872,15 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal
1024
  labels = labels[..., ::-1]
1025
  if nt > 1 and not unet:
1026
  labels[-1] = -labels[-1]
1027
- if do_3D and flip_z:
1028
- img = img[:, ::-1]
1029
- if Y is not None:
1030
- labels = labels[:,::-1]
1031
- if nt > 1 and not unet:
1032
- labels[-3] = -labels[-3]
1033
 
1034
  for k in range(nchan):
1035
- if do_3D:
1036
- img0 = np.zeros((lz, xy[0], xy[1]), "float32")
1037
- for z in range(lz):
1038
- I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
1039
- flags=cv2.INTER_LINEAR)
1040
- img0[z] = I
1041
- if scale[n] != 1.0:
1042
- for y in range(imgi.shape[-2]):
1043
- imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
1044
- interpolation=cv2.INTER_LINEAR)
1045
- else:
1046
- imgi[n, k] = img0
1047
- else:
1048
- I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
1049
- imgi[n, k] = I
1050
 
1051
  if Y is not None:
1052
  for k in range(nt):
1053
  flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
1054
- if do_3D:
1055
- lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
1056
- for z in range(lz):
1057
- I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
1058
- flags=flag)
1059
- lbl0[z] = I
1060
- if scale[n] != 1.0:
1061
- for y in range(lbl.shape[-2]):
1062
- lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
1063
- interpolation=flag)
1064
- else:
1065
- lbl[n, k] = lbl0
1066
- else:
1067
- lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
1068
 
1069
  if nt > 1 and not unet:
1070
  v1 = lbl[n, -1].copy()
@@ -1106,10 +922,7 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
1106
  nchan = X[0].shape[0]
1107
  else:
1108
  nchan = 1
1109
- if do_3D and X[0].ndim > 3:
1110
- shape = (zcrop, xy[0], xy[1])
1111
- else:
1112
- shape = (xy[0], xy[1])
1113
  imgi = np.zeros((nimg, nchan, *shape), "float32")
1114
 
1115
  lbl = []
@@ -1169,17 +982,6 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
1169
  if feats.ndim < 3:
1170
  feats = feats[np.newaxis, :, :]
1171
 
1172
- if do_3D:
1173
- Lz = X[n].shape[-3]
1174
- flip_z = np.random.rand() > .5
1175
- lz = int(np.round(zcrop / scale[n]))
1176
- iz = np.random.randint(0, Lz - lz)
1177
- img = img[:,iz:iz + lz,:,:]
1178
- if Y is not None:
1179
- labels = labels[:,iz:iz + lz,:,:]
1180
- if feat is not None:
1181
- feats = feats[:,iz:iz + lz,:,:]
1182
-
1183
  if do_flip:
1184
  if flip:
1185
  img = img[..., ::-1]
@@ -1189,49 +991,16 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
1189
  labels[-1] = -labels[-1]
1190
  if feat is not None:
1191
  feats = feats[..., ::-1]
1192
- if do_3D and flip_z:
1193
- img = img[:, ::-1]
1194
- if Y is not None:
1195
- labels = labels[:,::-1]
1196
- if nt > 1 and not unet:
1197
- labels[-3] = -labels[-3]
1198
- if feat is not None:
1199
- feats = feats[:, ::-1]
1200
 
1201
  for k in range(nchan):
1202
- if do_3D:
1203
- img0 = np.zeros((lz, xy[0], xy[1]), "float32")
1204
- for z in range(lz):
1205
- I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
1206
- flags=cv2.INTER_LINEAR)
1207
- img0[z] = I
1208
- if scale[n] != 1.0:
1209
- for y in range(imgi.shape[-2]):
1210
- imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
1211
- interpolation=cv2.INTER_LINEAR)
1212
- else:
1213
- imgi[n, k] = img0
1214
- else:
1215
- I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
1216
- imgi[n, k] = I
1217
 
1218
  if Y is not None:
1219
  for k in range(nt):
1220
  flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
1221
- if do_3D:
1222
- lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
1223
- for z in range(lz):
1224
- I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
1225
- flags=flag)
1226
- lbl0[z] = I
1227
- if scale[n] != 1.0:
1228
- for y in range(lbl.shape[-2]):
1229
- lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
1230
- interpolation=flag)
1231
- else:
1232
- lbl[n, k] = lbl0
1233
- else:
1234
- lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
1235
 
1236
  if nt > 1 and not unet:
1237
  v1 = lbl[n, -1].copy()
@@ -1241,20 +1010,7 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
1241
 
1242
  if feat is not None:
1243
  for k in range(nf):
1244
- if do_3D:
1245
- feat0 = np.zeros((lz, xy[0], xy[1]), "float32")
1246
- for z in range(lz):
1247
- I = cv2.warpAffine(feats[k, z], M, (xy[1], xy[0]),
1248
- flags=cv2.INTER_LINEAR)
1249
- feat0[z] = I
1250
- if scale[n] != 1.0:
1251
- for y in range(feat_out.shape[-2]):
1252
- feat_out[n, k, :, y] = cv2.resize(feat0[:, y], (xy[1], zcrop),
1253
- interpolation=cv2.INTER_LINEAR)
1254
- else:
1255
- feat_out[n, k] = feat0
1256
- else:
1257
- feat_out[n, k] = cv2.warpAffine(feats[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
1258
 
1259
 
1260
 
 
398
  return img
399
 
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
  def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
403
  """Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
 
432
  if z_axis is not None and not do_3D:
433
  raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
434
 
 
435
  if ndim == 4 and not do_3D:
436
  raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
437
 
 
 
 
438
 
439
  ######################## 2D reshaping ########################
440
  # if user specifies channel axis, return early
 
813
  nchan = X[0].shape[0]
814
  else:
815
  nchan = 1
816
+ # if do_3D and X[0].ndim > 3:
817
+ # shape = (zcrop, xy[0], xy[1])
818
+ # else:
819
+ shape = (xy[0], xy[1])
820
  imgi = np.zeros((nimg, nchan, *shape), "float32")
821
 
822
  lbl = []
 
865
  if labels.ndim < 3:
866
  labels = labels[np.newaxis, :, :]
867
 
 
 
 
 
 
 
 
 
 
868
  if do_flip:
869
  if flip:
870
  img = img[..., ::-1]
 
872
  labels = labels[..., ::-1]
873
  if nt > 1 and not unet:
874
  labels[-1] = -labels[-1]
 
 
 
 
 
 
875
 
876
  for k in range(nchan):
877
+ I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
878
+ imgi[n, k] = I
 
 
 
 
 
 
 
 
 
 
 
 
 
879
 
880
  if Y is not None:
881
  for k in range(nt):
882
  flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
883
+ lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
  if nt > 1 and not unet:
886
  v1 = lbl[n, -1].copy()
 
922
  nchan = X[0].shape[0]
923
  else:
924
  nchan = 1
925
+ shape = (xy[0], xy[1])
 
 
 
926
  imgi = np.zeros((nimg, nchan, *shape), "float32")
927
 
928
  lbl = []
 
982
  if feats.ndim < 3:
983
  feats = feats[np.newaxis, :, :]
984
 
 
 
 
 
 
 
 
 
 
 
 
985
  if do_flip:
986
  if flip:
987
  img = img[..., ::-1]
 
991
  labels[-1] = -labels[-1]
992
  if feat is not None:
993
  feats = feats[..., ::-1]
994
+
 
 
 
 
 
 
 
995
 
996
  for k in range(nchan):
997
+ I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
998
+ imgi[n, k] = I
 
 
 
 
 
 
 
 
 
 
 
 
 
999
 
1000
  if Y is not None:
1001
  for k in range(nt):
1002
  flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
1003
+ lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
 
1005
  if nt > 1 and not unet:
1006
  v1 = lbl[n, -1].copy()
 
1010
 
1011
  if feat is not None:
1012
  for k in range(nf):
1013
+ feat_out[n, k] = cv2.warpAffine(feats[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
 
 
 
 
 
 
 
 
 
 
 
 
 
1014
 
1015
 
1016
 
models/seg_post_model/utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import logging
5
+ import io
6
+ from tqdm import tqdm, trange
7
+ import cv2
8
+ from scipy.ndimage import find_objects
9
+ import numpy as np
10
+ import fastremap
11
+ import fill_voids
12
+ from models.seg_post_model import metrics
13
+
14
+
15
+ class TqdmToLogger(io.StringIO):
16
+ """
17
+ Output stream for TQDM which will output to logger module instead of
18
+ the StdOut.
19
+ """
20
+ logger = None
21
+ level = None
22
+ buf = ""
23
+
24
+ def __init__(self, logger, level=None):
25
+ super(TqdmToLogger, self).__init__()
26
+ self.logger = logger
27
+ self.level = level or logging.INFO
28
+
29
+ def write(self, buf):
30
+ self.buf = buf.strip("\r\n\t ")
31
+
32
+ def flush(self):
33
+ self.logger.log(self.level, self.buf)
34
+
35
+
36
+
37
+ # def masks_to_outlines(masks):
38
+ # """Get outlines of masks as a 0-1 array.
39
+
40
+ # Args:
41
+ # masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
42
+
43
+ # Returns:
44
+ # outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines.
45
+ # """
46
+ # if masks.ndim > 3 or masks.ndim < 2:
47
+ # raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
48
+ # masks.ndim)
49
+ # outlines = np.zeros(masks.shape, bool)
50
+
51
+ # if masks.ndim == 3:
52
+ # for i in range(masks.shape[0]):
53
+ # outlines[i] = masks_to_outlines(masks[i])
54
+ # return outlines
55
+ # else:
56
+ # slices = find_objects(masks.astype(int))
57
+ # for i, si in enumerate(slices):
58
+ # if si is not None:
59
+ # sr, sc = si
60
+ # mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
61
+ # contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
62
+ # cv2.CHAIN_APPROX_NONE)
63
+ # pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
64
+ # vr, vc = pvr + sr.start, pvc + sc.start
65
+ # outlines[vr, vc] = 1
66
+ # return outlines
67
+
68
+
69
+ def stitch3D(masks, stitch_threshold=0.25):
70
+ """
71
+ Stitch 2D masks into a 3D volume using a stitch_threshold on IOU.
72
+
73
+ Args:
74
+ masks (list or ndarray): List of 2D masks.
75
+ stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25.
76
+
77
+ Returns:
78
+ list: List of stitched 3D masks.
79
+ """
80
+ mmax = masks[0].max()
81
+ empty = 0
82
+ for i in trange(len(masks) - 1):
83
+ iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:]
84
+ if not iou.size and empty == 0:
85
+ masks[i + 1] = masks[i + 1]
86
+ mmax = masks[i + 1].max()
87
+ elif not iou.size and not empty == 0:
88
+ icount = masks[i + 1].max()
89
+ istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype)
90
+ mmax += icount
91
+ istitch = np.append(np.array(0), istitch)
92
+ masks[i + 1] = istitch[masks[i + 1]]
93
+ else:
94
+ iou[iou < stitch_threshold] = 0.0
95
+ iou[iou < iou.max(axis=0)] = 0.0
96
+ istitch = iou.argmax(axis=1) + 1
97
+ ino = np.nonzero(iou.max(axis=1) == 0.0)[0]
98
+ istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype)
99
+ mmax += len(ino)
100
+ istitch = np.append(np.array(0), istitch)
101
+ masks[i + 1] = istitch[masks[i + 1]]
102
+ empty = 1
103
+
104
+ return masks
105
+
106
+
107
+ # def diameters(masks):
108
+ # """
109
+ # Calculate the diameters of the objects in the given masks.
110
+
111
+ # Parameters:
112
+ # masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
113
+
114
+ # Returns:
115
+ # tuple: A tuple containing the median diameter and an array of diameters for each object.
116
+
117
+ # Examples:
118
+ # >>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]])
119
+ # >>> diameters(masks)
120
+ # (1.0, array([1.41421356, 1.0, 1.0]))
121
+ # """
122
+ # uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True)
123
+ # counts = counts[1:]
124
+ # md = np.median(counts**0.5)
125
+ # if np.isnan(md):
126
+ # md = 0
127
+ # md /= (np.pi**0.5) / 2
128
+ # return md, counts**0.5
129
+
130
+
131
+ # def radius_distribution(masks, bins):
132
+ # """
133
+ # Calculate the radius distribution of masks.
134
+
135
+ # Args:
136
+ # masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
137
+ # bins (int): Number of bins for the histogram.
138
+
139
+ # Returns:
140
+ # A tuple containing a normalized histogram of radii, median radius, array of radii.
141
+
142
+ # """
143
+ # unique, counts = np.unique(masks, return_counts=True)
144
+ # counts = counts[unique != 0]
145
+ # nb, _ = np.histogram((counts**0.5) * 0.5, bins)
146
+ # nb = nb.astype(np.float32)
147
+ # if nb.sum() > 0:
148
+ # nb = nb / nb.sum()
149
+ # md = np.median(counts**0.5) * 0.5
150
+ # if np.isnan(md):
151
+ # md = 0
152
+ # md /= (np.pi**0.5) / 2
153
+ # return nb, md, (counts**0.5) / 2
154
+
155
+
156
+ # def size_distribution(masks):
157
+ # """
158
+ # Calculates the size distribution of masks.
159
+
160
+ # Args:
161
+ # masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
162
+
163
+ # Returns:
164
+ # float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes.
165
+ # """
166
+ # counts = np.unique(masks, return_counts=True)[1][1:]
167
+ # return np.percentile(counts, 25) / np.percentile(counts, 75)
168
+
169
+
170
+ def fill_holes_and_remove_small_masks(masks, min_size=15):
171
+ """ Fills holes in masks (2D/3D) and discards masks smaller than min_size.
172
+
173
+ This function fills holes in each mask using fill_voids.fill.
174
+ It also removes masks that are smaller than the specified min_size.
175
+
176
+ Parameters:
177
+ masks (ndarray): Int, 2D or 3D array of labelled masks.
178
+ 0 represents no mask, while positive integers represent mask labels.
179
+ The size can be [Ly x Lx] or [Lz x Ly x Lx].
180
+ min_size (int, optional): Minimum number of pixels per mask.
181
+ Masks smaller than min_size will be removed.
182
+ Set to -1 to turn off this functionality. Default is 15.
183
+
184
+ Returns:
185
+ ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
186
+ 0 represents no mask, while positive integers represent mask labels.
187
+ The size is [Ly x Lx] or [Lz x Ly x Lx].
188
+ """
189
+
190
+ if masks.ndim > 3 or masks.ndim < 2:
191
+ raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
192
+ masks.ndim)
193
+
194
+ # Filter small masks
195
+ if min_size > 0:
196
+ counts = fastremap.unique(masks, return_counts=True)[1][1:]
197
+ masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
198
+ fastremap.renumber(masks, in_place=True)
199
+
200
+ slices = find_objects(masks)
201
+ j = 0
202
+ for i, slc in enumerate(slices):
203
+ if slc is not None:
204
+ msk = masks[slc] == (i + 1)
205
+ msk = fill_voids.fill(msk)
206
+ masks[slc][msk] = (j + 1)
207
+ j += 1
208
+
209
+ if min_size > 0:
210
+ counts = fastremap.unique(masks, return_counts=True)[1][1:]
211
+ masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
212
+ fastremap.renumber(masks, in_place=True)
213
+
214
+ return masks
models/seg_post_model/{cellpose/vit_sam.py โ†’ vit_sam.py} RENAMED
@@ -130,66 +130,3 @@ class Transformer(nn.Module):
130
  """
131
  torch.save(self.state_dict(), filename)
132
 
133
-
134
-
135
- class CPnetBioImageIO(Transformer):
136
- """
137
- A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
138
-
139
- This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
140
- allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
141
- """
142
-
143
- def forward(self, x):
144
- """
145
- Perform a forward pass of the CPnet model and return unpacked tensors.
146
-
147
- Args:
148
- x (torch.Tensor): Input tensor.
149
-
150
- Returns:
151
- tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
152
- """
153
- output_tensor, style_tensor, downsampled_tensors = super().forward(x)
154
- return output_tensor, style_tensor, *downsampled_tensors
155
-
156
-
157
- def load_model(self, filename, device=None):
158
- """
159
- Load the model from a file.
160
-
161
- Args:
162
- filename (str): The path to the file where the model is saved.
163
- device (torch.device, optional): The device to load the model on. Defaults to None.
164
- """
165
- if (device is not None) and (device.type != "cpu"):
166
- state_dict = torch.load(filename, map_location=device, weights_only=True)
167
- else:
168
- self.__init__(self.nout)
169
- state_dict = torch.load(filename, map_location=torch.device("cpu"),
170
- weights_only=True)
171
-
172
- self.load_state_dict(state_dict)
173
-
174
- def load_state_dict(self, state_dict):
175
- """
176
- Load the state dictionary into the model.
177
-
178
- This method overrides the default `load_state_dict` to handle Cellpose's custom
179
- loading mechanism and ensures compatibility with BioImage.IO Core.
180
-
181
- Args:
182
- state_dict (Mapping[str, Any]): A state dictionary to load into the model
183
- """
184
- if state_dict["output.2.weight"].shape[0] != self.nout:
185
- for name in self.state_dict():
186
- if "output" not in name:
187
- self.state_dict()[name].copy_(state_dict[name])
188
- else:
189
- super().load_state_dict(
190
- {name: param for name, param in state_dict.items()},
191
- strict=False)
192
-
193
-
194
-
195
-
 
130
  """
131
  torch.save(self.state_dict(), filename)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
segmentation.py CHANGED
@@ -1,9 +1,6 @@
1
  import os
2
- import pprint
3
  from typing import Any, List, Optional
4
- import argparse
5
  from huggingface_hub import hf_hub_download
6
- import pyrallis
7
  from pytorch_lightning.utilities.types import STEP_OUTPUT
8
  import torch
9
  import os
@@ -27,7 +24,7 @@ from models.enc_model.loca_args import get_argparser as loca_get_argparser
27
  from models.enc_model.loca import build_model as build_loca_model
28
  import time
29
  from _utils.seg_eval import *
30
- from models.seg_post_model.cellpose import metrics
31
  from datetime import datetime
32
  import json
33
  import logging
@@ -35,7 +32,7 @@ from PIL import Image
35
  import torchvision.transforms as T
36
  import cv2
37
  from skimage import io, measure
38
- logging.getLogger('models.seg_post_model.cellpose.models').setLevel(logging.ERROR)
39
 
40
  SCALE = 1
41
 
@@ -76,8 +73,6 @@ class SegmentationModule(pl.LightningModule):
76
  " `placeholder_token` that is not already in the tokenizer."
77
  )
78
  try:
79
- # print("loading pretrained task embedding from {}".format("pretrained/task_embed.pth"))
80
- # task_embed_from_pretrain = torch.load("pretrained/task_embed.pth")
81
  task_embed_from_pretrain = hf_hub_download(
82
  repo_id="phoebe777777/111",
83
  filename="task_embed.pth",
@@ -190,11 +185,6 @@ class SegmentationModule(pl.LightningModule):
190
  exemplar_attention_maps2 = []
191
  exemplar_attention_maps3 = []
192
 
193
- cross_self_task_attn_maps = []
194
- cross_self_exe_attn_maps1 = []
195
- cross_self_exe_attn_maps2 = []
196
- cross_self_exe_attn_maps3 = []
197
-
198
  # only use 64x64 self-attention
199
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
200
  prompts=[self.config.prompt for i in range(bsz)], # ่ฟ™้‡Œ่ฆๆ”นไนˆ
@@ -323,7 +313,7 @@ def inference(data_path, box=None, save_path="./example_imgs", visualize=False):
323
  ax[0].axis("off")
324
  ax[1].imshow(img_show)
325
  for inst_id in np.unique(mask_show):
326
- if inst_id == 0: # 0 ้€šๅธธๆ˜ฏ่ƒŒๆ™ฏ
327
  continue
328
  # ็”ŸๆˆไบŒๅ€ผ mask
329
  binary_mask = (mask_show == inst_id).astype(np.uint8)
@@ -352,12 +342,6 @@ def main():
352
  from matplotlib import cm
353
 
354
  def overlay_instances(img, mask, alpha=0.5, cmap_name="tab20"):
355
- """
356
- img: ๅŽŸๅ›พ (H, W, 3)๏ผŒ่Œƒๅ›ด [0,255] ๆˆ– [0,1]
357
- mask: ๅฎžไพ‹ๅˆ†ๅ‰ฒ็ป“ๆžœ (H, W)๏ผŒ่ƒŒๆ™ฏ=0๏ผŒๅฎžไพ‹=1,2,...
358
- alpha: ้€ๆ˜Žๅบฆ
359
- cmap_name: ้ขœ่‰ฒๆ˜ ๅฐ„่กจ
360
- """
361
  img = img.astype(np.float32)
362
  if len(img.shape) == 2:
363
  img = np.stack([img]*3, axis=-1)
@@ -369,7 +353,7 @@ def overlay_instances(img, mask, alpha=0.5, cmap_name="tab20"):
369
  cmap = cm.get_cmap(cmap_name, np.max(mask)+1)
370
 
371
  for inst_id in np.unique(mask):
372
- if inst_id == 0: # ่ƒŒๆ™ฏ่ทณ่ฟ‡
373
  continue
374
  color = np.array(cmap(inst_id)[:3]) # RGB
375
  overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color
 
1
  import os
 
2
  from typing import Any, List, Optional
 
3
  from huggingface_hub import hf_hub_download
 
4
  from pytorch_lightning.utilities.types import STEP_OUTPUT
5
  import torch
6
  import os
 
24
  from models.enc_model.loca import build_model as build_loca_model
25
  import time
26
  from _utils.seg_eval import *
27
+ from models.seg_post_model import metrics
28
  from datetime import datetime
29
  import json
30
  import logging
 
32
  import torchvision.transforms as T
33
  import cv2
34
  from skimage import io, measure
35
+ logging.getLogger('models.seg_post_model.models').setLevel(logging.ERROR)
36
 
37
  SCALE = 1
38
 
 
73
  " `placeholder_token` that is not already in the tokenizer."
74
  )
75
  try:
 
 
76
  task_embed_from_pretrain = hf_hub_download(
77
  repo_id="phoebe777777/111",
78
  filename="task_embed.pth",
 
185
  exemplar_attention_maps2 = []
186
  exemplar_attention_maps3 = []
187
 
 
 
 
 
 
188
  # only use 64x64 self-attention
189
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
190
  prompts=[self.config.prompt for i in range(bsz)], # ่ฟ™้‡Œ่ฆๆ”นไนˆ
 
313
  ax[0].axis("off")
314
  ax[1].imshow(img_show)
315
  for inst_id in np.unique(mask_show):
316
+ if inst_id == 0: # 0 background
317
  continue
318
  # ็”ŸๆˆไบŒๅ€ผ mask
319
  binary_mask = (mask_show == inst_id).astype(np.uint8)
 
342
  from matplotlib import cm
343
 
344
  def overlay_instances(img, mask, alpha=0.5, cmap_name="tab20"):
 
 
 
 
 
 
345
  img = img.astype(np.float32)
346
  if len(img.shape) == 2:
347
  img = np.stack([img]*3, axis=-1)
 
353
  cmap = cm.get_cmap(cmap_name, np.max(mask)+1)
354
 
355
  for inst_id in np.unique(mask):
356
+ if inst_id == 0:
357
  continue
358
  color = np.array(cmap(inst_id)[:3]) # RGB
359
  overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color