Spaces:
Running on Zero
Running on Zero
Commit ยท
f10f497
1
Parent(s): 102cd7d
clean up
Browse files- app.py +1 -3
- models/model.py +5 -400
- models/seg_post_model/__init__.py +1 -0
- models/seg_post_model/cellpose/__init__.py +0 -1
- models/seg_post_model/cellpose/io.py +0 -816
- models/seg_post_model/cellpose/plot.py +0 -281
- models/seg_post_model/cellpose/utils.py +0 -667
- models/seg_post_model/cellpose/version.py +0 -18
- models/seg_post_model/{cellpose/core.py โ core.py} +8 -88
- models/seg_post_model/{cellpose/dynamics.py โ dynamics.py} +1 -146
- models/seg_post_model/io.py +174 -0
- models/seg_post_model/{cellpose/metrics.py โ metrics.py} +0 -73
- models/seg_post_model/{cellpose/models.py โ models.py} +74 -164
- models/seg_post_model/{cellpose/transforms.py โ transforms.py} +13 -257
- models/seg_post_model/utils.py +214 -0
- models/seg_post_model/{cellpose/vit_sam.py โ vit_sam.py} +0 -63
- segmentation.py +4 -20
app.py
CHANGED
|
@@ -314,7 +314,7 @@ def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
|
|
| 314 |
|
| 315 |
try:
|
| 316 |
mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
|
| 317 |
-
print("๐ mask shape:", mask.shape, "dtype:", mask.dtype
|
| 318 |
except Exception as e:
|
| 319 |
print(f"โ Inference failed: {str(e)}")
|
| 320 |
return None, None, {}
|
|
@@ -344,8 +344,6 @@ def segment_with_choice(use_box_choice, annot_value, overlay_alpha):
|
|
| 344 |
inst_mask = mask_np.astype(np.int32)
|
| 345 |
unique_ids = np.unique(inst_mask)
|
| 346 |
num_instances = len(unique_ids[unique_ids != 0])
|
| 347 |
-
print(f"โ
Instance IDs found: {unique_ids}, Total instances: {num_instances}")
|
| 348 |
-
|
| 349 |
if num_instances == 0:
|
| 350 |
print("โ ๏ธ No instance found, returning dummy red image")
|
| 351 |
return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None, {}
|
|
|
|
| 314 |
|
| 315 |
try:
|
| 316 |
mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
|
| 317 |
+
print("๐ mask shape:", mask.shape, "dtype:", mask.dtype)
|
| 318 |
except Exception as e:
|
| 319 |
print(f"โ Inference failed: {str(e)}")
|
| 320 |
return None, None, {}
|
|
|
|
| 344 |
inst_mask = mask_np.astype(np.int32)
|
| 345 |
unique_ids = np.unique(inst_mask)
|
| 346 |
num_instances = len(unique_ids[unique_ids != 0])
|
|
|
|
|
|
|
| 347 |
if num_instances == 0:
|
| 348 |
print("โ ๏ธ No instance found, returning dummy red image")
|
| 349 |
return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None, {}
|
models/model.py
CHANGED
|
@@ -5,35 +5,10 @@ import os
|
|
| 5 |
import clip
|
| 6 |
import sys
|
| 7 |
import numpy as np
|
| 8 |
-
from models.seg_post_model.
|
| 9 |
|
| 10 |
from torchvision.ops import roi_align
|
| 11 |
-
def crop_roi_feat(feat, boxes):
|
| 12 |
-
"""
|
| 13 |
-
feat: 1 x c x h x w
|
| 14 |
-
boxes: m x 4, 4: [y_tl, x_tl, y_br, x_br]
|
| 15 |
-
"""
|
| 16 |
-
_, _, h, w = feat.shape
|
| 17 |
-
out_stride = 512 / h
|
| 18 |
-
boxes_scaled = boxes / out_stride
|
| 19 |
-
boxes_scaled[:, :2] = torch.floor(boxes_scaled[:, :2]) # y_tl, x_tl: floor
|
| 20 |
-
boxes_scaled[:, 2:] = torch.ceil(boxes_scaled[:, 2:]) # y_br, x_br: ceil
|
| 21 |
-
boxes_scaled[:, :2] = torch.clamp_min(boxes_scaled[:, :2], 0)
|
| 22 |
-
boxes_scaled[:, 2] = torch.clamp_max(boxes_scaled[:, 2], h)
|
| 23 |
-
boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], w)
|
| 24 |
-
feat_boxes = []
|
| 25 |
-
for idx_box in range(0, boxes.shape[0]):
|
| 26 |
-
y_tl, x_tl, y_br, x_br = boxes_scaled[idx_box]
|
| 27 |
-
y_tl, x_tl, y_br, x_br = int(y_tl), int(x_tl), int(y_br), int(x_br)
|
| 28 |
-
feat_box = feat[:, :, y_tl : (y_br + 1), x_tl : (x_br + 1)]
|
| 29 |
-
feat_boxes.append(feat_box)
|
| 30 |
-
return feat_boxes
|
| 31 |
|
| 32 |
-
class Counting_with_SD_features(nn.Module):
|
| 33 |
-
def __init__(self, scale_factor):
|
| 34 |
-
super(Counting_with_SD_features, self).__init__()
|
| 35 |
-
self.adapter = adapter_roi()
|
| 36 |
-
# self.regressor = regressor_with_SD_features()
|
| 37 |
|
| 38 |
class Counting_with_SD_features_loca(nn.Module):
|
| 39 |
def __init__(self, scale_factor):
|
|
@@ -55,60 +30,6 @@ class Counting_with_SD_features_track(nn.Module):
|
|
| 55 |
self.regressor = regressor_with_SD_features_tra()
|
| 56 |
|
| 57 |
|
| 58 |
-
class adapter_roi(nn.Module):
|
| 59 |
-
def __init__(self, pool_size=[3, 3]):
|
| 60 |
-
super(adapter_roi, self).__init__()
|
| 61 |
-
self.pool_size = pool_size
|
| 62 |
-
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 63 |
-
# self.relu = nn.ReLU()
|
| 64 |
-
# self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 65 |
-
self.pool = nn.MaxPool2d(2)
|
| 66 |
-
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 67 |
-
# **new
|
| 68 |
-
self.fc1 = nn.Sequential(
|
| 69 |
-
nn.ReLU(),
|
| 70 |
-
nn.Linear(768, 768 // 4, bias=False),
|
| 71 |
-
nn.ReLU()
|
| 72 |
-
)
|
| 73 |
-
self.fc2 = nn.Sequential(
|
| 74 |
-
nn.Linear(768 // 4, 768, bias=False),
|
| 75 |
-
# nn.ReLU()
|
| 76 |
-
)
|
| 77 |
-
self.initialize_weights()
|
| 78 |
-
|
| 79 |
-
def forward(self, x, boxes):
|
| 80 |
-
num_of_boxes = boxes.shape[1]
|
| 81 |
-
rois = []
|
| 82 |
-
bs, _, h, w = x.shape
|
| 83 |
-
boxes = torch.cat([
|
| 84 |
-
torch.arange(
|
| 85 |
-
bs, requires_grad=False
|
| 86 |
-
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 87 |
-
boxes.flatten(0, 1),
|
| 88 |
-
], dim=1)
|
| 89 |
-
rois = roi_align(
|
| 90 |
-
x,
|
| 91 |
-
boxes=boxes, output_size=3,
|
| 92 |
-
spatial_scale=1.0 / 8, aligned=True
|
| 93 |
-
)
|
| 94 |
-
rois = torch.mean(rois, dim=0, keepdim=True)
|
| 95 |
-
x = self.conv1(rois)
|
| 96 |
-
x = x.view(x.size(0), -1)
|
| 97 |
-
x = self.fc(x)
|
| 98 |
-
|
| 99 |
-
x = self.fc1(x)
|
| 100 |
-
x = self.fc2(x)
|
| 101 |
-
return x
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def initialize_weights(self):
|
| 105 |
-
for m in self.modules():
|
| 106 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 107 |
-
nn.init.xavier_normal_(m.weight)
|
| 108 |
-
if m.bias is not None:
|
| 109 |
-
nn.init.constant_(m.bias, 0)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
class adapter_roi_loca(nn.Module):
|
| 113 |
def __init__(self, pool_size=[3, 3]):
|
| 114 |
super(adapter_roi_loca, self).__init__()
|
|
@@ -188,69 +109,6 @@ class adapter_roi_loca(nn.Module):
|
|
| 188 |
|
| 189 |
|
| 190 |
|
| 191 |
-
|
| 192 |
-
class regressor1(nn.Module):
|
| 193 |
-
def __init__(self):
|
| 194 |
-
super(regressor1, self).__init__()
|
| 195 |
-
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 196 |
-
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 197 |
-
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 198 |
-
self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 199 |
-
self.leaky_relu = nn.LeakyReLU()
|
| 200 |
-
self.relu = nn.ReLU()
|
| 201 |
-
self.initialize_weights()
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def forward(self, x):
|
| 206 |
-
x_ = self.conv1(x)
|
| 207 |
-
x_ = self.leaky_relu(x_)
|
| 208 |
-
x_ = self.upsampler(x_)
|
| 209 |
-
x_ = self.conv2(x_)
|
| 210 |
-
x_ = self.leaky_relu(x_)
|
| 211 |
-
x_ = self.upsampler(x_)
|
| 212 |
-
x_ = self.conv3(x_)
|
| 213 |
-
x_ = self.relu(x_)
|
| 214 |
-
out = x_
|
| 215 |
-
return out
|
| 216 |
-
|
| 217 |
-
def initialize_weights(self):
|
| 218 |
-
for m in self.modules():
|
| 219 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 220 |
-
nn.init.xavier_normal_(m.weight)
|
| 221 |
-
if m.bias is not None:
|
| 222 |
-
nn.init.constant_(m.bias, 0)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
class regressor1(nn.Module):
|
| 226 |
-
def __init__(self):
|
| 227 |
-
super(regressor1, self).__init__()
|
| 228 |
-
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 229 |
-
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 230 |
-
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 231 |
-
self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 232 |
-
self.leaky_relu = nn.LeakyReLU()
|
| 233 |
-
self.relu = nn.ReLU()
|
| 234 |
-
|
| 235 |
-
def forward(self, x):
|
| 236 |
-
x_ = self.conv1(x)
|
| 237 |
-
x_ = self.leaky_relu(x_)
|
| 238 |
-
x_ = self.upsampler(x_)
|
| 239 |
-
x_ = self.conv2(x_)
|
| 240 |
-
x_ = self.leaky_relu(x_)
|
| 241 |
-
x_ = self.upsampler(x_)
|
| 242 |
-
x_ = self.conv3(x_)
|
| 243 |
-
x_ = self.relu(x_)
|
| 244 |
-
out = x_
|
| 245 |
-
return out
|
| 246 |
-
def initialize_weights(self):
|
| 247 |
-
for m in self.modules():
|
| 248 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 249 |
-
nn.init.xavier_normal_(m.weight)
|
| 250 |
-
if m.bias is not None:
|
| 251 |
-
nn.init.constant_(m.bias, 0)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
class regressor_with_SD_features(nn.Module):
|
| 255 |
def __init__(self):
|
| 256 |
super(regressor_with_SD_features, self).__init__()
|
|
@@ -301,57 +159,6 @@ class regressor_with_SD_features(nn.Module):
|
|
| 301 |
if m.bias is not None:
|
| 302 |
nn.init.constant_(m.bias, 0)
|
| 303 |
|
| 304 |
-
class regressor_with_SD_features_seg(nn.Module):
|
| 305 |
-
def __init__(self):
|
| 306 |
-
super(regressor_with_SD_features_seg, self).__init__()
|
| 307 |
-
self.layer1 = nn.Sequential(
|
| 308 |
-
nn.Conv2d(324, 256, kernel_size=1, stride=1),
|
| 309 |
-
nn.LeakyReLU(),
|
| 310 |
-
nn.LayerNorm((64, 64))
|
| 311 |
-
)
|
| 312 |
-
self.layer2 = nn.Sequential(
|
| 313 |
-
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 314 |
-
nn.LeakyReLU(),
|
| 315 |
-
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 316 |
-
)
|
| 317 |
-
self.layer3 = nn.Sequential(
|
| 318 |
-
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 319 |
-
nn.ReLU(),
|
| 320 |
-
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 321 |
-
)
|
| 322 |
-
self.layer4 = nn.Sequential(
|
| 323 |
-
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 324 |
-
nn.LeakyReLU(),
|
| 325 |
-
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 326 |
-
)
|
| 327 |
-
self.conv = nn.Sequential(
|
| 328 |
-
nn.Conv2d(32, 2, kernel_size=1),
|
| 329 |
-
# nn.ReLU()
|
| 330 |
-
)
|
| 331 |
-
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 332 |
-
self.initialize_weights()
|
| 333 |
-
|
| 334 |
-
def forward(self, attn_stack, feature_list):
|
| 335 |
-
attn_stack = self.norm(attn_stack)
|
| 336 |
-
unet_feature = feature_list[-1]
|
| 337 |
-
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 338 |
-
unet_feature = unet_feature * attn_stack_mean
|
| 339 |
-
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 340 |
-
x = self.layer1(unet_feature)
|
| 341 |
-
x = self.layer2(x)
|
| 342 |
-
x = self.layer3(x)
|
| 343 |
-
x = self.layer4(x)
|
| 344 |
-
out = self.conv(x)
|
| 345 |
-
return out
|
| 346 |
-
|
| 347 |
-
def initialize_weights(self):
|
| 348 |
-
for m in self.modules():
|
| 349 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 350 |
-
nn.init.xavier_normal_(m.weight)
|
| 351 |
-
if m.bias is not None:
|
| 352 |
-
nn.init.constant_(m.bias, 0)
|
| 353 |
-
|
| 354 |
-
|
| 355 |
from models.enc_model.unet_parts import *
|
| 356 |
|
| 357 |
|
|
@@ -363,7 +170,7 @@ class regressor_with_SD_features_seg_vit_c3(nn.Module):
|
|
| 363 |
self.bilinear = bilinear
|
| 364 |
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 365 |
self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
|
| 366 |
-
self.vit_model =
|
| 367 |
self.vit = self.vit_model.net
|
| 368 |
|
| 369 |
def forward(self, img, attn_stack, feature_list):
|
|
@@ -380,7 +187,7 @@ class regressor_with_SD_features_seg_vit_c3(nn.Module):
|
|
| 380 |
|
| 381 |
|
| 382 |
|
| 383 |
-
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())
|
| 384 |
if out.dtype == np.uint16:
|
| 385 |
out = out.astype(np.int16)
|
| 386 |
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
|
@@ -403,12 +210,11 @@ class regressor_with_SD_features_tra(nn.Module):
|
|
| 403 |
|
| 404 |
# segmentation
|
| 405 |
self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
|
| 406 |
-
self.vit_model =
|
| 407 |
self.vit = self.vit_model.net
|
| 408 |
|
| 409 |
self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
|
| 410 |
self.mlp = nn.Linear(64 * 64, 320)
|
| 411 |
-
# self.vit = self.vit_model.net.float()
|
| 412 |
|
| 413 |
def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
|
| 414 |
attn_stack = attn_stack[:, [1,3], ...]
|
|
@@ -422,7 +228,7 @@ class regressor_with_SD_features_tra(nn.Module):
|
|
| 422 |
x = self.inc_0(x)
|
| 423 |
feat = x
|
| 424 |
|
| 425 |
-
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())
|
| 426 |
if out.dtype == np.uint16:
|
| 427 |
out = out.astype(np.int16)
|
| 428 |
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
|
@@ -450,204 +256,3 @@ class regressor_with_SD_features_tra(nn.Module):
|
|
| 450 |
nn.init.xavier_normal_(m.weight)
|
| 451 |
if m.bias is not None:
|
| 452 |
nn.init.constant_(m.bias, 0)
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
class regressor_with_SD_features_inst_seg_unet(nn.Module):
|
| 457 |
-
def __init__(self, n_channels=8, n_classes=3, bilinear=False):
|
| 458 |
-
super(regressor_with_SD_features_inst_seg_unet, self).__init__()
|
| 459 |
-
self.n_channels = n_channels
|
| 460 |
-
self.n_classes = n_classes
|
| 461 |
-
self.bilinear = bilinear
|
| 462 |
-
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 463 |
-
self.inc_0 = (DoubleConv(n_channels, 3))
|
| 464 |
-
self.inc = (DoubleConv(3, 64))
|
| 465 |
-
self.down1 = (Down(64, 128))
|
| 466 |
-
self.down2 = (Down(128, 256))
|
| 467 |
-
self.down3 = (Down(256, 512))
|
| 468 |
-
factor = 2 if bilinear else 1
|
| 469 |
-
self.down4 = (Down(512, 1024 // factor))
|
| 470 |
-
self.up1 = (Up(1024, 512 // factor, bilinear))
|
| 471 |
-
self.up2 = (Up(512, 256 // factor, bilinear))
|
| 472 |
-
self.up3 = (Up(256, 128 // factor, bilinear))
|
| 473 |
-
self.up4 = (Up(128, 64, bilinear))
|
| 474 |
-
self.outc = (OutConv(64, n_classes))
|
| 475 |
-
|
| 476 |
-
def forward(self, img, attn_stack, feature_list):
|
| 477 |
-
attn_stack = self.norm(attn_stack)
|
| 478 |
-
unet_feature = feature_list[-1]
|
| 479 |
-
unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
|
| 480 |
-
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 481 |
-
unet_feature_mean = unet_feature_mean * attn_stack_mean
|
| 482 |
-
x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 483 |
-
if x.shape[-1] != 512:
|
| 484 |
-
x = F.interpolate(x, size=(512, 512), mode="bilinear")
|
| 485 |
-
x = torch.cat([img, x], dim=1) # [1, 8, 512, 512]
|
| 486 |
-
x = self.inc_0(x)
|
| 487 |
-
x1 = self.inc(x)
|
| 488 |
-
x2 = self.down1(x1)
|
| 489 |
-
x3 = self.down2(x2)
|
| 490 |
-
x4 = self.down3(x3)
|
| 491 |
-
x5 = self.down4(x4)
|
| 492 |
-
x = self.up1(x5, x4)
|
| 493 |
-
x = self.up2(x, x3)
|
| 494 |
-
x = self.up3(x, x2)
|
| 495 |
-
x = self.up4(x, x1)
|
| 496 |
-
out = self.outc(x)
|
| 497 |
-
return out
|
| 498 |
-
|
| 499 |
-
def initialize_weights(self):
|
| 500 |
-
for m in self.modules():
|
| 501 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 502 |
-
nn.init.xavier_normal_(m.weight)
|
| 503 |
-
if m.bias is not None:
|
| 504 |
-
nn.init.constant_(m.bias, 0)
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
class regressor_with_SD_features_self(nn.Module):
|
| 508 |
-
def __init__(self):
|
| 509 |
-
super(regressor_with_SD_features_self, self).__init__()
|
| 510 |
-
self.layer = nn.Sequential(
|
| 511 |
-
nn.Conv2d(4096, 1024, kernel_size=1, stride=1),
|
| 512 |
-
nn.LeakyReLU(),
|
| 513 |
-
nn.LayerNorm((64, 64)),
|
| 514 |
-
nn.Conv2d(1024, 256, kernel_size=1, stride=1),
|
| 515 |
-
nn.LeakyReLU(),
|
| 516 |
-
nn.LayerNorm((64, 64)),
|
| 517 |
-
)
|
| 518 |
-
self.layer2 = nn.Sequential(
|
| 519 |
-
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 520 |
-
nn.LeakyReLU(),
|
| 521 |
-
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 522 |
-
)
|
| 523 |
-
self.layer3 = nn.Sequential(
|
| 524 |
-
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 525 |
-
nn.ReLU(),
|
| 526 |
-
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 527 |
-
)
|
| 528 |
-
self.layer4 = nn.Sequential(
|
| 529 |
-
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 530 |
-
nn.LeakyReLU(),
|
| 531 |
-
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 532 |
-
)
|
| 533 |
-
self.conv = nn.Sequential(
|
| 534 |
-
nn.Conv2d(32, 1, kernel_size=1),
|
| 535 |
-
nn.ReLU()
|
| 536 |
-
)
|
| 537 |
-
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 538 |
-
self.initialize_weights()
|
| 539 |
-
|
| 540 |
-
def forward(self, self_attn):
|
| 541 |
-
self_attn = self_attn.permute(2, 0, 1)
|
| 542 |
-
self_attn = self.layer(self_attn)
|
| 543 |
-
return self_attn
|
| 544 |
-
# attn_stack = self.norm(attn_stack)
|
| 545 |
-
# unet_feature = feature_list[-1]
|
| 546 |
-
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 547 |
-
# unet_feature = unet_feature * attn_stack_mean
|
| 548 |
-
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 549 |
-
# x = self.layer(unet_feature)
|
| 550 |
-
# x = self.layer2(x)
|
| 551 |
-
# x = self.layer3(x)
|
| 552 |
-
# x = self.layer4(x)
|
| 553 |
-
# out = self.conv(x)
|
| 554 |
-
# return out / 100
|
| 555 |
-
|
| 556 |
-
def initialize_weights(self):
|
| 557 |
-
for m in self.modules():
|
| 558 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 559 |
-
nn.init.xavier_normal_(m.weight)
|
| 560 |
-
if m.bias is not None:
|
| 561 |
-
nn.init.constant_(m.bias, 0)
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
class regressor_with_SD_features_latent(nn.Module):
|
| 565 |
-
def __init__(self):
|
| 566 |
-
super(regressor_with_SD_features_latent, self).__init__()
|
| 567 |
-
self.layer = nn.Sequential(
|
| 568 |
-
nn.Conv2d(4, 256, kernel_size=1, stride=1),
|
| 569 |
-
nn.LeakyReLU(),
|
| 570 |
-
nn.LayerNorm((64, 64))
|
| 571 |
-
)
|
| 572 |
-
self.layer2 = nn.Sequential(
|
| 573 |
-
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 574 |
-
nn.LeakyReLU(),
|
| 575 |
-
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 576 |
-
)
|
| 577 |
-
self.layer3 = nn.Sequential(
|
| 578 |
-
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 579 |
-
nn.ReLU(),
|
| 580 |
-
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 581 |
-
)
|
| 582 |
-
self.layer4 = nn.Sequential(
|
| 583 |
-
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 584 |
-
nn.LeakyReLU(),
|
| 585 |
-
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 586 |
-
)
|
| 587 |
-
self.conv = nn.Sequential(
|
| 588 |
-
nn.Conv2d(32, 1, kernel_size=1),
|
| 589 |
-
nn.ReLU()
|
| 590 |
-
)
|
| 591 |
-
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 592 |
-
self.initialize_weights()
|
| 593 |
-
|
| 594 |
-
def forward(self, self_attn):
|
| 595 |
-
# self_attn = self_attn.permute(2, 0, 1)
|
| 596 |
-
self_attn = self.layer(self_attn)
|
| 597 |
-
return self_attn
|
| 598 |
-
# attn_stack = self.norm(attn_stack)
|
| 599 |
-
# unet_feature = feature_list[-1]
|
| 600 |
-
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 601 |
-
# unet_feature = unet_feature * attn_stack_mean
|
| 602 |
-
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 603 |
-
# x = self.layer(unet_feature)
|
| 604 |
-
# x = self.layer2(x)
|
| 605 |
-
# x = self.layer3(x)
|
| 606 |
-
# x = self.layer4(x)
|
| 607 |
-
# out = self.conv(x)
|
| 608 |
-
# return out / 100
|
| 609 |
-
|
| 610 |
-
def initialize_weights(self):
|
| 611 |
-
for m in self.modules():
|
| 612 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 613 |
-
nn.init.xavier_normal_(m.weight)
|
| 614 |
-
if m.bias is not None:
|
| 615 |
-
nn.init.constant_(m.bias, 0)
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
class regressor_with_deconv(nn.Module):
|
| 622 |
-
def __init__(self):
|
| 623 |
-
super(regressor_with_deconv, self).__init__()
|
| 624 |
-
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 625 |
-
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 626 |
-
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 627 |
-
self.deconv1 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
|
| 628 |
-
self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
|
| 629 |
-
self.leaky_relu = nn.LeakyReLU()
|
| 630 |
-
self.relu = nn.ReLU()
|
| 631 |
-
self.initialize_weights()
|
| 632 |
-
|
| 633 |
-
def forward(self, x):
|
| 634 |
-
x_ = self.conv1(x)
|
| 635 |
-
x_ = self.leaky_relu(x_)
|
| 636 |
-
x_ = self.deconv1(x_)
|
| 637 |
-
x_ = self.conv2(x_)
|
| 638 |
-
x_ = self.leaky_relu(x_)
|
| 639 |
-
x_ = self.deconv2(x_)
|
| 640 |
-
x_ = self.conv3(x_)
|
| 641 |
-
x_ = self.relu(x_)
|
| 642 |
-
out = x_
|
| 643 |
-
return out
|
| 644 |
-
|
| 645 |
-
def initialize_weights(self):
|
| 646 |
-
for m in self.modules():
|
| 647 |
-
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
|
| 648 |
-
nn.init.xavier_normal_(m.weight)
|
| 649 |
-
if m.bias is not None:
|
| 650 |
-
nn.init.constant_(m.bias, 0)
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
|
|
|
| 5 |
import clip
|
| 6 |
import sys
|
| 7 |
import numpy as np
|
| 8 |
+
from models.seg_post_model.models import SegModel
|
| 9 |
|
| 10 |
from torchvision.ops import roi_align
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class Counting_with_SD_features_loca(nn.Module):
|
| 14 |
def __init__(self, scale_factor):
|
|
|
|
| 30 |
self.regressor = regressor_with_SD_features_tra()
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class adapter_roi_loca(nn.Module):
|
| 34 |
def __init__(self, pool_size=[3, 3]):
|
| 35 |
super(adapter_roi_loca, self).__init__()
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
class regressor_with_SD_features(nn.Module):
|
| 113 |
def __init__(self):
|
| 114 |
super(regressor_with_SD_features, self).__init__()
|
|
|
|
| 159 |
if m.bias is not None:
|
| 160 |
nn.init.constant_(m.bias, 0)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
from models.enc_model.unet_parts import *
|
| 163 |
|
| 164 |
|
|
|
|
| 170 |
self.bilinear = bilinear
|
| 171 |
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 172 |
self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
|
| 173 |
+
self.vit_model = SegModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
|
| 174 |
self.vit = self.vit_model.net
|
| 175 |
|
| 176 |
def forward(self, img, attn_stack, feature_list):
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
|
| 190 |
+
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())
|
| 191 |
if out.dtype == np.uint16:
|
| 192 |
out = out.astype(np.int16)
|
| 193 |
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
|
|
|
| 210 |
|
| 211 |
# segmentation
|
| 212 |
self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
|
| 213 |
+
self.vit_model = SegModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
|
| 214 |
self.vit = self.vit_model.net
|
| 215 |
|
| 216 |
self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
|
| 217 |
self.mlp = nn.Linear(64 * 64, 320)
|
|
|
|
| 218 |
|
| 219 |
def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
|
| 220 |
attn_stack = attn_stack[:, [1,3], ...]
|
|
|
|
| 228 |
x = self.inc_0(x)
|
| 229 |
feat = x
|
| 230 |
|
| 231 |
+
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())
|
| 232 |
if out.dtype == np.uint16:
|
| 233 |
out = out.astype(np.int16)
|
| 234 |
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
|
|
|
| 256 |
nn.init.xavier_normal_(m.weight)
|
| 257 |
if m.bias is not None:
|
| 258 |
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
models/seg_post_model/cellpose/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# from .version import version, version_str
|
|
|
|
|
|
models/seg_post_model/cellpose/io.py
DELETED
|
@@ -1,816 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
import os, warnings, glob, shutil
|
| 5 |
-
from natsort import natsorted
|
| 6 |
-
import numpy as np
|
| 7 |
-
import cv2
|
| 8 |
-
import tifffile
|
| 9 |
-
import logging, pathlib, sys
|
| 10 |
-
from tqdm import tqdm
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import re
|
| 13 |
-
from .version import version_str
|
| 14 |
-
from roifile import ImagejRoi, roiwrite
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
from qtpy import QtGui, QtCore, Qt, QtWidgets
|
| 18 |
-
from qtpy.QtWidgets import QMessageBox
|
| 19 |
-
GUI = True
|
| 20 |
-
except:
|
| 21 |
-
GUI = False
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
import matplotlib.pyplot as plt
|
| 25 |
-
MATPLOTLIB = True
|
| 26 |
-
except:
|
| 27 |
-
MATPLOTLIB = False
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
import nd2
|
| 31 |
-
ND2 = True
|
| 32 |
-
except:
|
| 33 |
-
ND2 = False
|
| 34 |
-
|
| 35 |
-
try:
|
| 36 |
-
import nrrd
|
| 37 |
-
NRRD = True
|
| 38 |
-
except:
|
| 39 |
-
NRRD = False
|
| 40 |
-
|
| 41 |
-
try:
|
| 42 |
-
from google.cloud import storage
|
| 43 |
-
SERVER_UPLOAD = True
|
| 44 |
-
except:
|
| 45 |
-
SERVER_UPLOAD = False
|
| 46 |
-
|
| 47 |
-
io_logger = logging.getLogger(__name__)
|
| 48 |
-
|
| 49 |
-
def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None):
|
| 50 |
-
cp_dir = pathlib.Path.home().joinpath(cp_path)
|
| 51 |
-
cp_dir.mkdir(exist_ok=True)
|
| 52 |
-
log_file = cp_dir.joinpath(logfile_name)
|
| 53 |
-
try:
|
| 54 |
-
log_file.unlink()
|
| 55 |
-
except:
|
| 56 |
-
print('creating new log file')
|
| 57 |
-
handlers = [logging.FileHandler(log_file),]
|
| 58 |
-
if stdout_file_replacement is not None:
|
| 59 |
-
handlers.append(logging.FileHandler(stdout_file_replacement))
|
| 60 |
-
else:
|
| 61 |
-
handlers.append(logging.StreamHandler(sys.stdout))
|
| 62 |
-
logging.basicConfig(
|
| 63 |
-
level=logging.INFO,
|
| 64 |
-
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 65 |
-
handlers=handlers,
|
| 66 |
-
force=True
|
| 67 |
-
)
|
| 68 |
-
logger = logging.getLogger(__name__)
|
| 69 |
-
logger.info(f"WRITING LOG OUTPUT TO {log_file}")
|
| 70 |
-
logger.info(version_str)
|
| 71 |
-
|
| 72 |
-
return logger, log_file
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
from . import utils, plot, transforms
|
| 76 |
-
|
| 77 |
-
# helper function to check for a path; if it doesn't exist, make it
|
| 78 |
-
def check_dir(path):
|
| 79 |
-
if not os.path.isdir(path):
|
| 80 |
-
os.mkdir(path)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def outlines_to_text(base, outlines):
|
| 84 |
-
with open(base + "_cp_outlines.txt", "w") as f:
|
| 85 |
-
for o in outlines:
|
| 86 |
-
xy = list(o.flatten())
|
| 87 |
-
xy_str = ",".join(map(str, xy))
|
| 88 |
-
f.write(xy_str)
|
| 89 |
-
f.write("\n")
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def load_dax(filename):
|
| 93 |
-
### modified from ZhuangLab github:
|
| 94 |
-
### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
|
| 95 |
-
|
| 96 |
-
inf_filename = os.path.splitext(filename)[0] + ".inf"
|
| 97 |
-
if not os.path.exists(inf_filename):
|
| 98 |
-
io_logger.critical(
|
| 99 |
-
f"ERROR: no inf file found for dax file {filename}, cannot load dax without it"
|
| 100 |
-
)
|
| 101 |
-
return None
|
| 102 |
-
|
| 103 |
-
### get metadata
|
| 104 |
-
image_height, image_width = None, None
|
| 105 |
-
# extract the movie information from the associated inf file
|
| 106 |
-
size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)")
|
| 107 |
-
length_re = re.compile(r"number of frames = ([\d]+)")
|
| 108 |
-
endian_re = re.compile(r" (big|little) endian")
|
| 109 |
-
|
| 110 |
-
with open(inf_filename, "r") as inf_file:
|
| 111 |
-
lines = inf_file.read().split("\n")
|
| 112 |
-
for line in lines:
|
| 113 |
-
m = size_re.match(line)
|
| 114 |
-
if m:
|
| 115 |
-
image_height = int(m.group(2))
|
| 116 |
-
image_width = int(m.group(1))
|
| 117 |
-
m = length_re.match(line)
|
| 118 |
-
if m:
|
| 119 |
-
number_frames = int(m.group(1))
|
| 120 |
-
m = endian_re.search(line)
|
| 121 |
-
if m:
|
| 122 |
-
if m.group(1) == "big":
|
| 123 |
-
bigendian = 1
|
| 124 |
-
else:
|
| 125 |
-
bigendian = 0
|
| 126 |
-
# set defaults, warn the user that they couldn"t be determined from the inf file.
|
| 127 |
-
if not image_height:
|
| 128 |
-
io_logger.warning("could not determine dax image size, assuming 256x256")
|
| 129 |
-
image_height = 256
|
| 130 |
-
image_width = 256
|
| 131 |
-
|
| 132 |
-
### load image
|
| 133 |
-
img = np.memmap(filename, dtype="uint16",
|
| 134 |
-
shape=(number_frames, image_height, image_width))
|
| 135 |
-
if bigendian:
|
| 136 |
-
img = img.byteswap()
|
| 137 |
-
img = np.array(img)
|
| 138 |
-
|
| 139 |
-
return img
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def imread(filename):
|
| 143 |
-
"""
|
| 144 |
-
Read in an image file with tif or image file type supported by cv2.
|
| 145 |
-
|
| 146 |
-
Args:
|
| 147 |
-
filename (str): The path to the image file.
|
| 148 |
-
|
| 149 |
-
Returns:
|
| 150 |
-
numpy.ndarray: The image data as a NumPy array.
|
| 151 |
-
|
| 152 |
-
Raises:
|
| 153 |
-
None
|
| 154 |
-
|
| 155 |
-
Raises an error if the image file format is not supported.
|
| 156 |
-
|
| 157 |
-
Examples:
|
| 158 |
-
>>> img = imread("image.tif")
|
| 159 |
-
"""
|
| 160 |
-
# ensure that extension check is not case sensitive
|
| 161 |
-
ext = os.path.splitext(filename)[-1].lower()
|
| 162 |
-
if ext == ".tif" or ext == ".tiff" or ext == ".flex":
|
| 163 |
-
with tifffile.TiffFile(filename) as tif:
|
| 164 |
-
ltif = len(tif.pages)
|
| 165 |
-
try:
|
| 166 |
-
full_shape = tif.shaped_metadata[0]["shape"]
|
| 167 |
-
except:
|
| 168 |
-
try:
|
| 169 |
-
page = tif.series[0][0]
|
| 170 |
-
full_shape = tif.series[0].shape
|
| 171 |
-
except:
|
| 172 |
-
ltif = 0
|
| 173 |
-
if ltif < 10:
|
| 174 |
-
img = tif.asarray()
|
| 175 |
-
else:
|
| 176 |
-
page = tif.series[0][0]
|
| 177 |
-
shape, dtype = page.shape, page.dtype
|
| 178 |
-
ltif = int(np.prod(full_shape) / np.prod(shape))
|
| 179 |
-
io_logger.info(f"reading tiff with {ltif} planes")
|
| 180 |
-
img = np.zeros((ltif, *shape), dtype=dtype)
|
| 181 |
-
for i, page in enumerate(tqdm(tif.series[0])):
|
| 182 |
-
img[i] = page.asarray()
|
| 183 |
-
img = img.reshape(full_shape)
|
| 184 |
-
return img
|
| 185 |
-
elif ext == ".dax":
|
| 186 |
-
img = load_dax(filename)
|
| 187 |
-
return img
|
| 188 |
-
elif ext == ".nd2":
|
| 189 |
-
if not ND2:
|
| 190 |
-
io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
|
| 191 |
-
return None
|
| 192 |
-
elif ext == ".nrrd":
|
| 193 |
-
if not NRRD:
|
| 194 |
-
io_logger.critical(
|
| 195 |
-
"ERROR: need to 'pip install pynrrd' to load in .nrrd file")
|
| 196 |
-
return None
|
| 197 |
-
else:
|
| 198 |
-
img, metadata = nrrd.read(filename)
|
| 199 |
-
if img.ndim == 3:
|
| 200 |
-
img = img.transpose(2, 0, 1)
|
| 201 |
-
return img
|
| 202 |
-
elif ext != ".npy":
|
| 203 |
-
try:
|
| 204 |
-
img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH)
|
| 205 |
-
if img.ndim > 2:
|
| 206 |
-
img = img[..., [2, 1, 0]]
|
| 207 |
-
return img
|
| 208 |
-
except Exception as e:
|
| 209 |
-
io_logger.critical("ERROR: could not read file, %s" % e)
|
| 210 |
-
return None
|
| 211 |
-
else:
|
| 212 |
-
try:
|
| 213 |
-
dat = np.load(filename, allow_pickle=True).item()
|
| 214 |
-
masks = dat["masks"]
|
| 215 |
-
return masks
|
| 216 |
-
except Exception as e:
|
| 217 |
-
io_logger.critical("ERROR: could not read masks from file, %s" % e)
|
| 218 |
-
return None
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
def imread_2D(img_file):
|
| 222 |
-
"""
|
| 223 |
-
Read in a 2D image file and convert it to a 3-channel image. Attempts to do this for multi-channel and grayscale images.
|
| 224 |
-
If the image has more than 3 channels, only the first 3 channels are kept.
|
| 225 |
-
|
| 226 |
-
Args:
|
| 227 |
-
img_file (str): The path to the image file.
|
| 228 |
-
|
| 229 |
-
Returns:
|
| 230 |
-
img_out (numpy.ndarray): The 3-channel image data as a NumPy array.
|
| 231 |
-
"""
|
| 232 |
-
img = imread(img_file)
|
| 233 |
-
return transforms.convert_image(img, do_3D=False)
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def imread_3D(img_file):
|
| 237 |
-
"""
|
| 238 |
-
Read in a 3D image file and convert it to have a channel axis last automatically. Attempts to do this for multi-channel and grayscale images.
|
| 239 |
-
|
| 240 |
-
If multichannel image, the channel axis is assumed to be the smallest dimension, and the z axis is the next smallest dimension.
|
| 241 |
-
Use `cellpose.io.imread()` to load the full image without selecting the z and channel axes.
|
| 242 |
-
|
| 243 |
-
Args:
|
| 244 |
-
img_file (str): The path to the image file.
|
| 245 |
-
|
| 246 |
-
Returns:
|
| 247 |
-
img_out (numpy.ndarray): The image data as a NumPy array.
|
| 248 |
-
"""
|
| 249 |
-
img = imread(img_file)
|
| 250 |
-
|
| 251 |
-
dimension_lengths = list(img.shape)
|
| 252 |
-
|
| 253 |
-
# grayscale images:
|
| 254 |
-
if img.ndim == 3:
|
| 255 |
-
channel_axis = None
|
| 256 |
-
# guess at z axis:
|
| 257 |
-
z_axis = np.argmin(dimension_lengths)
|
| 258 |
-
|
| 259 |
-
elif img.ndim == 4:
|
| 260 |
-
# guess at channel axis:
|
| 261 |
-
channel_axis = np.argmin(dimension_lengths)
|
| 262 |
-
|
| 263 |
-
# guess at z axis:
|
| 264 |
-
# set channel axis to max so argmin works:
|
| 265 |
-
dimension_lengths[channel_axis] = max(dimension_lengths)
|
| 266 |
-
z_axis = np.argmin(dimension_lengths)
|
| 267 |
-
|
| 268 |
-
else:
|
| 269 |
-
raise ValueError(f'image shape error, 3D image must 3 or 4 dimensional. Number of dimensions: {img.ndim}')
|
| 270 |
-
|
| 271 |
-
try:
|
| 272 |
-
return transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=True)
|
| 273 |
-
except Exception as e:
|
| 274 |
-
io_logger.critical("ERROR: could not read file, %s" % e)
|
| 275 |
-
io_logger.critical("ERROR: Guessed z_axis: %s, channel_axis: %s" % (z_axis, channel_axis))
|
| 276 |
-
return None
|
| 277 |
-
|
| 278 |
-
def remove_model(filename, delete=False):
|
| 279 |
-
""" remove model from .cellpose custom model list """
|
| 280 |
-
filename = os.path.split(filename)[-1]
|
| 281 |
-
from . import models
|
| 282 |
-
model_strings = models.get_user_models()
|
| 283 |
-
if len(model_strings) > 0:
|
| 284 |
-
with open(models.MODEL_LIST_PATH, "w") as textfile:
|
| 285 |
-
for fname in model_strings:
|
| 286 |
-
textfile.write(fname + "\n")
|
| 287 |
-
else:
|
| 288 |
-
# write empty file
|
| 289 |
-
textfile = open(models.MODEL_LIST_PATH, "w")
|
| 290 |
-
textfile.close()
|
| 291 |
-
print(f"{filename} removed from custom model list")
|
| 292 |
-
if delete:
|
| 293 |
-
os.remove(os.fspath(models.MODEL_DIR.joinpath(fname)))
|
| 294 |
-
print("model deleted")
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
def add_model(filename):
|
| 298 |
-
""" add model to .cellpose models folder to use with GUI or CLI """
|
| 299 |
-
from . import models
|
| 300 |
-
fname = os.path.split(filename)[-1]
|
| 301 |
-
try:
|
| 302 |
-
shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
|
| 303 |
-
except shutil.SameFileError:
|
| 304 |
-
pass
|
| 305 |
-
print(f"{filename} copied to models folder {os.fspath(models.MODEL_DIR)}")
|
| 306 |
-
if fname not in models.get_user_models():
|
| 307 |
-
with open(models.MODEL_LIST_PATH, "a") as textfile:
|
| 308 |
-
textfile.write(fname + "\n")
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
def imsave(filename, arr):
|
| 312 |
-
"""
|
| 313 |
-
Saves an image array to a file.
|
| 314 |
-
|
| 315 |
-
Args:
|
| 316 |
-
filename (str): The name of the file to save the image to.
|
| 317 |
-
arr (numpy.ndarray): The image array to be saved.
|
| 318 |
-
|
| 319 |
-
Returns:
|
| 320 |
-
None
|
| 321 |
-
"""
|
| 322 |
-
ext = os.path.splitext(filename)[-1].lower()
|
| 323 |
-
if ext == ".tif" or ext == ".tiff":
|
| 324 |
-
tifffile.imwrite(filename, data=arr, compression="zlib")
|
| 325 |
-
else:
|
| 326 |
-
if len(arr.shape) > 2:
|
| 327 |
-
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
|
| 328 |
-
cv2.imwrite(filename, arr)
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
|
| 332 |
-
"""
|
| 333 |
-
Finds all images in a folder and its subfolders (if specified) with the given file extensions.
|
| 334 |
-
|
| 335 |
-
Args:
|
| 336 |
-
folder (str): The path to the folder to search for images.
|
| 337 |
-
mask_filter (str): The filter for mask files.
|
| 338 |
-
imf (str, optional): The additional filter for image files. Defaults to None.
|
| 339 |
-
look_one_level_down (bool, optional): Whether to search for images in subfolders. Defaults to False.
|
| 340 |
-
|
| 341 |
-
Returns:
|
| 342 |
-
list: A list of image file paths.
|
| 343 |
-
|
| 344 |
-
Raises:
|
| 345 |
-
ValueError: If no files are found in the specified folder.
|
| 346 |
-
ValueError: If no images are found in the specified folder with the supported file extensions.
|
| 347 |
-
ValueError: If no images are found in the specified folder without the mask or flow file endings.
|
| 348 |
-
"""
|
| 349 |
-
mask_filters = ["_cp_output", "_flows", "_flows_0", "_flows_1",
|
| 350 |
-
"_flows_2", "_cellprob", "_masks", mask_filter]
|
| 351 |
-
image_names = []
|
| 352 |
-
if imf is None:
|
| 353 |
-
imf = ""
|
| 354 |
-
|
| 355 |
-
folders = []
|
| 356 |
-
if look_one_level_down:
|
| 357 |
-
folders = natsorted(glob.glob(os.path.join(folder, "*/")))
|
| 358 |
-
folders.append(folder)
|
| 359 |
-
exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".flex", ".dax", ".nd2", ".nrrd"]
|
| 360 |
-
l0 = 0
|
| 361 |
-
al = 0
|
| 362 |
-
for folder in folders:
|
| 363 |
-
all_files = glob.glob(folder + "/*")
|
| 364 |
-
al += len(all_files)
|
| 365 |
-
for ext in exts:
|
| 366 |
-
image_names.extend(glob.glob(folder + f"/*{imf}{ext}"))
|
| 367 |
-
image_names.extend(glob.glob(folder + f"/*{imf}{ext.upper()}"))
|
| 368 |
-
l0 += len(image_names)
|
| 369 |
-
|
| 370 |
-
# return error if no files found
|
| 371 |
-
if al == 0:
|
| 372 |
-
raise ValueError("ERROR: no files in --dir folder ")
|
| 373 |
-
elif l0 == 0:
|
| 374 |
-
raise ValueError(
|
| 375 |
-
"ERROR: no images in --dir folder with extensions .png, .jpg, .jpeg, .tif, .tiff, .flex"
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
image_names = natsorted(image_names)
|
| 379 |
-
imn = []
|
| 380 |
-
for im in image_names:
|
| 381 |
-
imfile = os.path.splitext(im)[0]
|
| 382 |
-
igood = all([(len(imfile) > len(mask_filter) and
|
| 383 |
-
imfile[-len(mask_filter):] != mask_filter) or
|
| 384 |
-
len(imfile) <= len(mask_filter) for mask_filter in mask_filters])
|
| 385 |
-
if len(imf) > 0:
|
| 386 |
-
igood &= imfile[-len(imf):] == imf
|
| 387 |
-
if igood:
|
| 388 |
-
imn.append(im)
|
| 389 |
-
|
| 390 |
-
image_names = imn
|
| 391 |
-
|
| 392 |
-
# remove duplicates
|
| 393 |
-
image_names = [*set(image_names)]
|
| 394 |
-
image_names = natsorted(image_names)
|
| 395 |
-
|
| 396 |
-
if len(image_names) == 0:
|
| 397 |
-
raise ValueError(
|
| 398 |
-
"ERROR: no images in --dir folder without _masks or _flows or _cellprob ending")
|
| 399 |
-
|
| 400 |
-
return image_names
|
| 401 |
-
|
| 402 |
-
def get_label_files(image_names, mask_filter, imf=None):
|
| 403 |
-
"""
|
| 404 |
-
Get the label files corresponding to the given image names and mask filter.
|
| 405 |
-
|
| 406 |
-
Args:
|
| 407 |
-
image_names (list): List of image names.
|
| 408 |
-
mask_filter (str): Mask filter to be applied.
|
| 409 |
-
imf (str, optional): Image file extension. Defaults to None.
|
| 410 |
-
|
| 411 |
-
Returns:
|
| 412 |
-
tuple: A tuple containing the label file names and flow file names (if present).
|
| 413 |
-
"""
|
| 414 |
-
nimg = len(image_names)
|
| 415 |
-
label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)]
|
| 416 |
-
|
| 417 |
-
if imf is not None and len(imf) > 0:
|
| 418 |
-
label_names = [label_names0[n][:-len(imf)] for n in range(nimg)]
|
| 419 |
-
else:
|
| 420 |
-
label_names = label_names0
|
| 421 |
-
|
| 422 |
-
# check for flows
|
| 423 |
-
if os.path.exists(label_names0[0] + "_flows.tif"):
|
| 424 |
-
flow_names = [label_names0[n] + "_flows.tif" for n in range(nimg)]
|
| 425 |
-
else:
|
| 426 |
-
flow_names = [label_names[n] + "_flows.tif" for n in range(nimg)]
|
| 427 |
-
if not all([os.path.exists(flow) for flow in flow_names]):
|
| 428 |
-
io_logger.info(
|
| 429 |
-
"not all flows are present, running flow generation for all images")
|
| 430 |
-
flow_names = None
|
| 431 |
-
|
| 432 |
-
# check for masks
|
| 433 |
-
if mask_filter == "_seg.npy":
|
| 434 |
-
label_names = [label_names[n] + mask_filter for n in range(nimg)]
|
| 435 |
-
return label_names, None
|
| 436 |
-
|
| 437 |
-
if os.path.exists(label_names[0] + mask_filter + ".tif"):
|
| 438 |
-
label_names = [label_names[n] + mask_filter + ".tif" for n in range(nimg)]
|
| 439 |
-
elif os.path.exists(label_names[0] + mask_filter + ".tiff"):
|
| 440 |
-
label_names = [label_names[n] + mask_filter + ".tiff" for n in range(nimg)]
|
| 441 |
-
elif os.path.exists(label_names[0] + mask_filter + ".png"):
|
| 442 |
-
label_names = [label_names[n] + mask_filter + ".png" for n in range(nimg)]
|
| 443 |
-
# TODO, allow _seg.npy
|
| 444 |
-
#elif os.path.exists(label_names[0] + "_seg.npy"):
|
| 445 |
-
# io_logger.info("labels found as _seg.npy files, converting to tif")
|
| 446 |
-
else:
|
| 447 |
-
if not flow_names:
|
| 448 |
-
raise ValueError("labels not provided with correct --mask_filter")
|
| 449 |
-
else:
|
| 450 |
-
label_names = None
|
| 451 |
-
if not all([os.path.exists(label) for label in label_names]):
|
| 452 |
-
if not flow_names:
|
| 453 |
-
raise ValueError(
|
| 454 |
-
"labels not provided for all images in train and/or test set")
|
| 455 |
-
else:
|
| 456 |
-
label_names = None
|
| 457 |
-
|
| 458 |
-
return label_names, flow_names
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
def load_images_labels(tdir, mask_filter="_masks", image_filter=None,
|
| 462 |
-
look_one_level_down=False):
|
| 463 |
-
"""
|
| 464 |
-
Loads images and corresponding labels from a directory.
|
| 465 |
-
|
| 466 |
-
Args:
|
| 467 |
-
tdir (str): The directory path.
|
| 468 |
-
mask_filter (str, optional): The filter for mask files. Defaults to "_masks".
|
| 469 |
-
image_filter (str, optional): The filter for image files. Defaults to None.
|
| 470 |
-
look_one_level_down (bool, optional): Whether to look for files one level down. Defaults to False.
|
| 471 |
-
|
| 472 |
-
Returns:
|
| 473 |
-
tuple: A tuple containing a list of images, a list of labels, and a list of image names.
|
| 474 |
-
"""
|
| 475 |
-
image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down)
|
| 476 |
-
nimg = len(image_names)
|
| 477 |
-
|
| 478 |
-
# training data
|
| 479 |
-
label_names, flow_names = get_label_files(image_names, mask_filter,
|
| 480 |
-
imf=image_filter)
|
| 481 |
-
|
| 482 |
-
images = []
|
| 483 |
-
labels = []
|
| 484 |
-
k = 0
|
| 485 |
-
for n in range(nimg):
|
| 486 |
-
if (os.path.isfile(label_names[n]) or
|
| 487 |
-
(flow_names is not None and os.path.isfile(flow_names[0]))):
|
| 488 |
-
image = imread(image_names[n])
|
| 489 |
-
if label_names is not None:
|
| 490 |
-
label = imread(label_names[n])
|
| 491 |
-
if flow_names is not None:
|
| 492 |
-
flow = imread(flow_names[n])
|
| 493 |
-
if flow.shape[0] < 4:
|
| 494 |
-
label = np.concatenate((label[np.newaxis, :, :], flow), axis=0)
|
| 495 |
-
else:
|
| 496 |
-
label = flow
|
| 497 |
-
images.append(image)
|
| 498 |
-
labels.append(label)
|
| 499 |
-
k += 1
|
| 500 |
-
io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels")
|
| 501 |
-
return images, labels, image_names
|
| 502 |
-
|
| 503 |
-
def load_train_test_data(train_dir, test_dir=None, image_filter=None,
|
| 504 |
-
mask_filter="_masks", look_one_level_down=False):
|
| 505 |
-
"""
|
| 506 |
-
Loads training and testing data for a Cellpose model.
|
| 507 |
-
|
| 508 |
-
Args:
|
| 509 |
-
train_dir (str): The directory path containing the training data.
|
| 510 |
-
test_dir (str, optional): The directory path containing the testing data. Defaults to None.
|
| 511 |
-
image_filter (str, optional): The filter for selecting image files. Defaults to None.
|
| 512 |
-
mask_filter (str, optional): The filter for selecting mask files. Defaults to "_masks".
|
| 513 |
-
look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False.
|
| 514 |
-
|
| 515 |
-
Returns:
|
| 516 |
-
images, labels, image_names, test_images, test_labels, test_image_names
|
| 517 |
-
|
| 518 |
-
"""
|
| 519 |
-
images, labels, image_names = load_images_labels(train_dir, mask_filter,
|
| 520 |
-
image_filter, look_one_level_down)
|
| 521 |
-
# testing data
|
| 522 |
-
test_images, test_labels, test_image_names = None, None, None
|
| 523 |
-
if test_dir is not None:
|
| 524 |
-
test_images, test_labels, test_image_names = load_images_labels(
|
| 525 |
-
test_dir, mask_filter, image_filter, look_one_level_down)
|
| 526 |
-
|
| 527 |
-
return images, labels, image_names, test_images, test_labels, test_image_names
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
def masks_flows_to_seg(images, masks, flows, file_names,
|
| 531 |
-
channels=None,
|
| 532 |
-
imgs_restore=None, restore_type=None, ratio=1.):
|
| 533 |
-
"""Save output of model eval to be loaded in GUI.
|
| 534 |
-
|
| 535 |
-
Can be list output (run on multiple images) or single output (run on single image).
|
| 536 |
-
|
| 537 |
-
Saved to file_names[k]+"_seg.npy".
|
| 538 |
-
|
| 539 |
-
Args:
|
| 540 |
-
images (list): Images input into cellpose.
|
| 541 |
-
masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
|
| 542 |
-
flows (list): Flows output from Cellpose.eval.
|
| 543 |
-
file_names (list, str): Names of files of images.
|
| 544 |
-
diams (float array): Diameters used to run Cellpose. Defaults to 30. TODO: remove this
|
| 545 |
-
channels (list, int, optional): Channels used to run Cellpose. Defaults to None.
|
| 546 |
-
|
| 547 |
-
Returns:
|
| 548 |
-
None
|
| 549 |
-
"""
|
| 550 |
-
|
| 551 |
-
if channels is None:
|
| 552 |
-
channels = [0, 0]
|
| 553 |
-
|
| 554 |
-
if isinstance(masks, list):
|
| 555 |
-
if imgs_restore is None:
|
| 556 |
-
imgs_restore = [None] * len(masks)
|
| 557 |
-
if isinstance(file_names, str):
|
| 558 |
-
file_names = [file_names] * len(masks)
|
| 559 |
-
for k, [image, mask, flow,
|
| 560 |
-
# diam,
|
| 561 |
-
file_name, img_restore
|
| 562 |
-
] in enumerate(zip(images, masks, flows,
|
| 563 |
-
# diams,
|
| 564 |
-
file_names,
|
| 565 |
-
imgs_restore)):
|
| 566 |
-
channels_img = channels
|
| 567 |
-
if channels_img is not None and len(channels) > 2:
|
| 568 |
-
channels_img = channels[k]
|
| 569 |
-
masks_flows_to_seg(image, mask, flow, file_name,
|
| 570 |
-
# diams=diam,
|
| 571 |
-
channels=channels_img, imgs_restore=img_restore,
|
| 572 |
-
restore_type=restore_type, ratio=ratio)
|
| 573 |
-
return
|
| 574 |
-
|
| 575 |
-
if len(channels) == 1:
|
| 576 |
-
channels = channels[0]
|
| 577 |
-
|
| 578 |
-
flowi = []
|
| 579 |
-
if flows[0].ndim == 3:
|
| 580 |
-
Ly, Lx = masks.shape[-2:]
|
| 581 |
-
flowi.append(
|
| 582 |
-
cv2.resize(flows[0], (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis,
|
| 583 |
-
...])
|
| 584 |
-
else:
|
| 585 |
-
flowi.append(flows[0])
|
| 586 |
-
|
| 587 |
-
if flows[0].ndim == 3:
|
| 588 |
-
cellprob = (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(
|
| 589 |
-
np.uint8)
|
| 590 |
-
cellprob = cv2.resize(cellprob, (Lx, Ly), interpolation=cv2.INTER_NEAREST)
|
| 591 |
-
flowi.append(cellprob[np.newaxis, ...])
|
| 592 |
-
flowi.append(np.zeros(flows[0].shape, dtype=np.uint8))
|
| 593 |
-
flowi[-1] = flowi[-1][np.newaxis, ...]
|
| 594 |
-
else:
|
| 595 |
-
flowi.append(
|
| 596 |
-
(np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(np.uint8))
|
| 597 |
-
flowi.append((flows[1][0] / 10 * 127 + 127).astype(np.uint8))
|
| 598 |
-
if len(flows) > 2:
|
| 599 |
-
if len(flows) > 3:
|
| 600 |
-
flowi.append(flows[3])
|
| 601 |
-
else:
|
| 602 |
-
flowi.append([])
|
| 603 |
-
flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0))
|
| 604 |
-
outlines = masks * utils.masks_to_outlines(masks)
|
| 605 |
-
base = os.path.splitext(file_names)[0]
|
| 606 |
-
|
| 607 |
-
dat = {
|
| 608 |
-
"outlines":
|
| 609 |
-
outlines.astype(np.uint16) if outlines.max() < 2**16 -
|
| 610 |
-
1 else outlines.astype(np.uint32),
|
| 611 |
-
"masks":
|
| 612 |
-
masks.astype(np.uint16) if outlines.max() < 2**16 -
|
| 613 |
-
1 else masks.astype(np.uint32),
|
| 614 |
-
"chan_choose":
|
| 615 |
-
channels,
|
| 616 |
-
"ismanual":
|
| 617 |
-
np.zeros(masks.max(), bool),
|
| 618 |
-
"filename":
|
| 619 |
-
file_names,
|
| 620 |
-
"flows":
|
| 621 |
-
flowi,
|
| 622 |
-
"diameter":
|
| 623 |
-
np.nan
|
| 624 |
-
}
|
| 625 |
-
if restore_type is not None and imgs_restore is not None:
|
| 626 |
-
dat["restore"] = restore_type
|
| 627 |
-
dat["ratio"] = ratio
|
| 628 |
-
dat["img_restore"] = imgs_restore
|
| 629 |
-
|
| 630 |
-
np.save(base + "_seg.npy", dat)
|
| 631 |
-
|
| 632 |
-
def save_to_png(images, masks, flows, file_names):
|
| 633 |
-
""" deprecated (runs io.save_masks with png=True)
|
| 634 |
-
|
| 635 |
-
does not work for 3D images
|
| 636 |
-
|
| 637 |
-
"""
|
| 638 |
-
save_masks(images, masks, flows, file_names, png=True)
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
def save_rois(masks, file_name, multiprocessing=None):
|
| 642 |
-
""" save masks to .roi files in .zip archive for ImageJ/Fiji
|
| 643 |
-
|
| 644 |
-
Args:
|
| 645 |
-
masks (np.ndarray): masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels
|
| 646 |
-
file_name (str): name to save the .zip file to
|
| 647 |
-
|
| 648 |
-
Returns:
|
| 649 |
-
None
|
| 650 |
-
"""
|
| 651 |
-
outlines = utils.outlines_list(masks, multiprocessing=multiprocessing)
|
| 652 |
-
nonempty_outlines = [outline for outline in outlines if len(outline)!=0]
|
| 653 |
-
if len(outlines)!=len(nonempty_outlines):
|
| 654 |
-
print(f"empty outlines found, saving {len(nonempty_outlines)} ImageJ ROIs to .zip archive.")
|
| 655 |
-
rois = [ImagejRoi.frompoints(outline) for outline in nonempty_outlines]
|
| 656 |
-
file_name = os.path.splitext(file_name)[0] + '_rois.zip'
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
# Delete file if it exists; the roifile lib appends to existing zip files.
|
| 660 |
-
# If the user removed a mask it will still be in the zip file
|
| 661 |
-
if os.path.exists(file_name):
|
| 662 |
-
os.remove(file_name)
|
| 663 |
-
|
| 664 |
-
roiwrite(file_name, rois)
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0],
|
| 668 |
-
suffix="_cp_masks", save_flows=False, save_outlines=False, dir_above=False,
|
| 669 |
-
in_folders=False, savedir=None, save_txt=False, save_mpl=False):
|
| 670 |
-
""" Save masks + nicely plotted segmentation image to png and/or tiff.
|
| 671 |
-
|
| 672 |
-
Can save masks, flows to different directories, if in_folders is True.
|
| 673 |
-
|
| 674 |
-
If png, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.png".
|
| 675 |
-
|
| 676 |
-
If tif, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.tif".
|
| 677 |
-
|
| 678 |
-
If png and matplotlib installed, full segmentation figure is saved to file_names[k]+"_cp.png".
|
| 679 |
-
|
| 680 |
-
Only tif option works for 3D data, and only tif option works for empty masks.
|
| 681 |
-
|
| 682 |
-
Args:
|
| 683 |
-
images (list): Images input into cellpose.
|
| 684 |
-
masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
|
| 685 |
-
flows (list): Flows output from Cellpose.eval.
|
| 686 |
-
file_names (list, str): Names of files of images.
|
| 687 |
-
png (bool, optional): Save masks to PNG. Defaults to True.
|
| 688 |
-
tif (bool, optional): Save masks to TIF. Defaults to False.
|
| 689 |
-
channels (list, int, optional): Channels used to run Cellpose. Defaults to [0,0].
|
| 690 |
-
suffix (str, optional): Add name to saved masks. Defaults to "_cp_masks".
|
| 691 |
-
save_flows (bool, optional): Save flows output from Cellpose.eval. Defaults to False.
|
| 692 |
-
save_outlines (bool, optional): Save outlines of masks. Defaults to False.
|
| 693 |
-
dir_above (bool, optional): Save masks/flows in directory above. Defaults to False.
|
| 694 |
-
in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False.
|
| 695 |
-
savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None.
|
| 696 |
-
save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False.
|
| 697 |
-
save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D.
|
| 698 |
-
This takes a long time for large images. Defaults to False.
|
| 699 |
-
|
| 700 |
-
Returns:
|
| 701 |
-
None
|
| 702 |
-
"""
|
| 703 |
-
|
| 704 |
-
if isinstance(masks, list):
|
| 705 |
-
for image, mask, flow, file_name in zip(images, masks, flows, file_names):
|
| 706 |
-
save_masks(image, mask, flow, file_name, png=png, tif=tif, suffix=suffix,
|
| 707 |
-
dir_above=dir_above, save_flows=save_flows,
|
| 708 |
-
save_outlines=save_outlines, savedir=savedir, save_txt=save_txt,
|
| 709 |
-
in_folders=in_folders, save_mpl=save_mpl)
|
| 710 |
-
return
|
| 711 |
-
|
| 712 |
-
if masks.ndim > 2 and not tif:
|
| 713 |
-
raise ValueError("cannot save 3D outputs as PNG, use tif option instead")
|
| 714 |
-
|
| 715 |
-
if masks.max() == 0:
|
| 716 |
-
io_logger.warning("no masks found, will not save PNG or outlines")
|
| 717 |
-
if not tif:
|
| 718 |
-
return
|
| 719 |
-
else:
|
| 720 |
-
png = False
|
| 721 |
-
save_outlines = False
|
| 722 |
-
save_flows = False
|
| 723 |
-
save_txt = False
|
| 724 |
-
|
| 725 |
-
if savedir is None:
|
| 726 |
-
if dir_above:
|
| 727 |
-
savedir = Path(file_names).parent.parent.absolute(
|
| 728 |
-
) #go up a level to save in its own folder
|
| 729 |
-
else:
|
| 730 |
-
savedir = Path(file_names).parent.absolute()
|
| 731 |
-
|
| 732 |
-
check_dir(savedir)
|
| 733 |
-
|
| 734 |
-
basename = os.path.splitext(os.path.basename(file_names))[0]
|
| 735 |
-
if in_folders:
|
| 736 |
-
maskdir = os.path.join(savedir, "masks")
|
| 737 |
-
outlinedir = os.path.join(savedir, "outlines")
|
| 738 |
-
txtdir = os.path.join(savedir, "txt_outlines")
|
| 739 |
-
flowdir = os.path.join(savedir, "flows")
|
| 740 |
-
else:
|
| 741 |
-
maskdir = savedir
|
| 742 |
-
outlinedir = savedir
|
| 743 |
-
txtdir = savedir
|
| 744 |
-
flowdir = savedir
|
| 745 |
-
|
| 746 |
-
check_dir(maskdir)
|
| 747 |
-
|
| 748 |
-
exts = []
|
| 749 |
-
if masks.ndim > 2:
|
| 750 |
-
png = False
|
| 751 |
-
tif = True
|
| 752 |
-
if png:
|
| 753 |
-
if masks.max() < 2**16:
|
| 754 |
-
masks = masks.astype(np.uint16)
|
| 755 |
-
exts.append(".png")
|
| 756 |
-
else:
|
| 757 |
-
png = False
|
| 758 |
-
tif = True
|
| 759 |
-
io_logger.warning(
|
| 760 |
-
"found more than 65535 masks in each image, cannot save PNG, saving as TIF"
|
| 761 |
-
)
|
| 762 |
-
if tif:
|
| 763 |
-
exts.append(".tif")
|
| 764 |
-
|
| 765 |
-
# save masks
|
| 766 |
-
with warnings.catch_warnings():
|
| 767 |
-
warnings.simplefilter("ignore")
|
| 768 |
-
for ext in exts:
|
| 769 |
-
imsave(os.path.join(maskdir, basename + suffix + ext), masks)
|
| 770 |
-
|
| 771 |
-
if save_mpl and png and MATPLOTLIB and not min(images.shape) > 3:
|
| 772 |
-
# Make and save original/segmentation/flows image
|
| 773 |
-
|
| 774 |
-
img = images.copy()
|
| 775 |
-
if img.ndim < 3:
|
| 776 |
-
img = img[:, :, np.newaxis]
|
| 777 |
-
elif img.shape[0] < 8:
|
| 778 |
-
np.transpose(img, (1, 2, 0))
|
| 779 |
-
|
| 780 |
-
fig = plt.figure(figsize=(12, 3))
|
| 781 |
-
plot.show_segmentation(fig, img, masks, flows[0])
|
| 782 |
-
fig.savefig(os.path.join(savedir, basename + "_cp_output" + suffix + ".png"),
|
| 783 |
-
dpi=300)
|
| 784 |
-
plt.close(fig)
|
| 785 |
-
|
| 786 |
-
# ImageJ txt outline files
|
| 787 |
-
if masks.ndim < 3 and save_txt:
|
| 788 |
-
check_dir(txtdir)
|
| 789 |
-
outlines = utils.outlines_list(masks)
|
| 790 |
-
outlines_to_text(os.path.join(txtdir, basename), outlines)
|
| 791 |
-
|
| 792 |
-
# RGB outline images
|
| 793 |
-
if masks.ndim < 3 and save_outlines:
|
| 794 |
-
check_dir(outlinedir)
|
| 795 |
-
outlines = utils.masks_to_outlines(masks)
|
| 796 |
-
outX, outY = np.nonzero(outlines)
|
| 797 |
-
img0 = transforms.normalize99(images)
|
| 798 |
-
if img0.shape[0] < 4:
|
| 799 |
-
img0 = np.transpose(img0, (1, 2, 0))
|
| 800 |
-
if img0.shape[-1] < 3 or img0.ndim < 3:
|
| 801 |
-
img0 = plot.image_to_rgb(img0, channels=channels)
|
| 802 |
-
else:
|
| 803 |
-
if img0.max() <= 50.0:
|
| 804 |
-
img0 = np.uint8(np.clip(img0 * 255, 0, 1))
|
| 805 |
-
imgout = img0.copy()
|
| 806 |
-
imgout[outX, outY] = np.array([255, 0, 0]) #pure red
|
| 807 |
-
imsave(os.path.join(outlinedir, basename + "_outlines" + suffix + ".png"),
|
| 808 |
-
imgout)
|
| 809 |
-
|
| 810 |
-
# save RGB flow picture
|
| 811 |
-
if masks.ndim < 3 and save_flows:
|
| 812 |
-
check_dir(flowdir)
|
| 813 |
-
imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"),
|
| 814 |
-
(flows[0] * (2**16 - 1)).astype(np.uint16))
|
| 815 |
-
#save full flow data
|
| 816 |
-
imsave(os.path.join(flowdir, basename + '_dP' + suffix + '.tif'), flows[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/plot.py
DELETED
|
@@ -1,281 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
import os
|
| 5 |
-
import numpy as np
|
| 6 |
-
import cv2
|
| 7 |
-
from scipy.ndimage import gaussian_filter
|
| 8 |
-
from . import utils, io, transforms
|
| 9 |
-
|
| 10 |
-
try:
|
| 11 |
-
import matplotlib
|
| 12 |
-
MATPLOTLIB_ENABLED = True
|
| 13 |
-
except:
|
| 14 |
-
MATPLOTLIB_ENABLED = False
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
from skimage import color
|
| 18 |
-
from skimage.segmentation import find_boundaries
|
| 19 |
-
SKIMAGE_ENABLED = True
|
| 20 |
-
except:
|
| 21 |
-
SKIMAGE_ENABLED = False
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# modified to use sinebow color
|
| 25 |
-
def dx_to_circ(dP):
|
| 26 |
-
"""Converts the optic flow representation to a circular color representation.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
dP (ndarray): Flow field components [dy, dx].
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
ndarray: The circular color representation of the optic flow.
|
| 33 |
-
|
| 34 |
-
"""
|
| 35 |
-
mag = 255 * np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2, axis=0))), 0, 1.)
|
| 36 |
-
angles = np.arctan2(dP[1], dP[0]) + np.pi
|
| 37 |
-
a = 2
|
| 38 |
-
mag /= a
|
| 39 |
-
rgb = np.zeros((*dP.shape[1:], 3), "uint8")
|
| 40 |
-
rgb[..., 0] = np.clip(mag * (np.cos(angles) + 1), 0, 255).astype("uint8")
|
| 41 |
-
rgb[..., 1] = np.clip(mag * (np.cos(angles + 2 * np.pi / 3) + 1), 0, 255).astype("uint8")
|
| 42 |
-
rgb[..., 2] = np.clip(mag * (np.cos(angles + 4 * np.pi / 3) + 1), 0, 255).astype("uint8")
|
| 43 |
-
|
| 44 |
-
return rgb
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None):
|
| 48 |
-
"""Plot segmentation results (like on website).
|
| 49 |
-
|
| 50 |
-
Can save each panel of figure with file_name option. Use channels option if
|
| 51 |
-
img input is not an RGB image with 3 channels.
|
| 52 |
-
|
| 53 |
-
Args:
|
| 54 |
-
fig (matplotlib.pyplot.figure): Figure in which to make plot.
|
| 55 |
-
img (ndarray): 2D or 3D array. Image input into cellpose.
|
| 56 |
-
maski (int, ndarray): For image k, masks[k] output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
|
| 57 |
-
flowi (int, ndarray): For image k, flows[k][0] output from Cellpose.eval (RGB of flows).
|
| 58 |
-
channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0].
|
| 59 |
-
file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None.
|
| 60 |
-
seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False.
|
| 61 |
-
"""
|
| 62 |
-
if not MATPLOTLIB_ENABLED:
|
| 63 |
-
raise ImportError(
|
| 64 |
-
"matplotlib not installed, install with 'pip install matplotlib'")
|
| 65 |
-
ax = fig.add_subplot(1, 4, 1)
|
| 66 |
-
img0 = img.copy()
|
| 67 |
-
|
| 68 |
-
if img0.shape[0] < 4:
|
| 69 |
-
img0 = np.transpose(img0, (1, 2, 0))
|
| 70 |
-
if img0.shape[-1] < 3 or img0.ndim < 3:
|
| 71 |
-
img0 = image_to_rgb(img0, channels=channels)
|
| 72 |
-
else:
|
| 73 |
-
if img0.max() <= 50.0:
|
| 74 |
-
img0 = np.uint8(np.clip(img0, 0, 1) * 255)
|
| 75 |
-
ax.imshow(img0)
|
| 76 |
-
ax.set_title("original image")
|
| 77 |
-
ax.axis("off")
|
| 78 |
-
|
| 79 |
-
outlines = utils.masks_to_outlines(maski)
|
| 80 |
-
|
| 81 |
-
overlay = mask_overlay(img0, maski)
|
| 82 |
-
|
| 83 |
-
ax = fig.add_subplot(1, 4, 2)
|
| 84 |
-
outX, outY = np.nonzero(outlines)
|
| 85 |
-
imgout = img0.copy()
|
| 86 |
-
imgout[outX, outY] = np.array([255, 0, 0]) # pure red
|
| 87 |
-
|
| 88 |
-
ax.imshow(imgout)
|
| 89 |
-
ax.set_title("predicted outlines")
|
| 90 |
-
ax.axis("off")
|
| 91 |
-
|
| 92 |
-
ax = fig.add_subplot(1, 4, 3)
|
| 93 |
-
ax.imshow(overlay)
|
| 94 |
-
ax.set_title("predicted masks")
|
| 95 |
-
ax.axis("off")
|
| 96 |
-
|
| 97 |
-
ax = fig.add_subplot(1, 4, 4)
|
| 98 |
-
ax.imshow(flowi)
|
| 99 |
-
ax.set_title("predicted cell pose")
|
| 100 |
-
ax.axis("off")
|
| 101 |
-
|
| 102 |
-
if file_name is not None:
|
| 103 |
-
save_path = os.path.splitext(file_name)[0]
|
| 104 |
-
io.imsave(save_path + "_overlay.jpg", overlay)
|
| 105 |
-
io.imsave(save_path + "_outlines.jpg", imgout)
|
| 106 |
-
io.imsave(save_path + "_flows.jpg", flowi)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def mask_rgb(masks, colors=None):
|
| 110 |
-
"""Masks in random RGB colors.
|
| 111 |
-
|
| 112 |
-
Args:
|
| 113 |
-
masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
|
| 114 |
-
colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
|
| 115 |
-
|
| 116 |
-
Returns:
|
| 117 |
-
RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
|
| 118 |
-
"""
|
| 119 |
-
if colors is not None:
|
| 120 |
-
if colors.max() > 1:
|
| 121 |
-
colors = np.float32(colors)
|
| 122 |
-
colors /= 255
|
| 123 |
-
colors = utils.rgb_to_hsv(colors)
|
| 124 |
-
|
| 125 |
-
HSV = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32)
|
| 126 |
-
HSV[:, :, 2] = 1.0
|
| 127 |
-
for n in range(int(masks.max())):
|
| 128 |
-
ipix = (masks == n + 1).nonzero()
|
| 129 |
-
if colors is None:
|
| 130 |
-
HSV[ipix[0], ipix[1], 0] = np.random.rand()
|
| 131 |
-
else:
|
| 132 |
-
HSV[ipix[0], ipix[1], 0] = colors[n, 0]
|
| 133 |
-
HSV[ipix[0], ipix[1], 1] = np.random.rand() * 0.5 + 0.5
|
| 134 |
-
HSV[ipix[0], ipix[1], 2] = np.random.rand() * 0.5 + 0.5
|
| 135 |
-
RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
|
| 136 |
-
return RGB
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def mask_overlay(img, masks, colors=None):
|
| 140 |
-
"""Overlay masks on image (set image to grayscale).
|
| 141 |
-
|
| 142 |
-
Args:
|
| 143 |
-
img (int or float, 2D or 3D array): Image of size [Ly x Lx (x nchan)].
|
| 144 |
-
masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
|
| 145 |
-
colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
|
| 146 |
-
|
| 147 |
-
Returns:
|
| 148 |
-
RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
|
| 149 |
-
"""
|
| 150 |
-
if colors is not None:
|
| 151 |
-
if colors.max() > 1:
|
| 152 |
-
colors = np.float32(colors)
|
| 153 |
-
colors /= 255
|
| 154 |
-
colors = utils.rgb_to_hsv(colors)
|
| 155 |
-
if img.ndim > 2:
|
| 156 |
-
img = img.astype(np.float32).mean(axis=-1)
|
| 157 |
-
else:
|
| 158 |
-
img = img.astype(np.float32)
|
| 159 |
-
|
| 160 |
-
HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
|
| 161 |
-
HSV[:, :, 2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1)
|
| 162 |
-
hues = np.linspace(0, 1, masks.max() + 1)[np.random.permutation(masks.max())]
|
| 163 |
-
for n in range(int(masks.max())):
|
| 164 |
-
ipix = (masks == n + 1).nonzero()
|
| 165 |
-
if colors is None:
|
| 166 |
-
HSV[ipix[0], ipix[1], 0] = hues[n]
|
| 167 |
-
else:
|
| 168 |
-
HSV[ipix[0], ipix[1], 0] = colors[n, 0]
|
| 169 |
-
HSV[ipix[0], ipix[1], 1] = 1.0
|
| 170 |
-
RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
|
| 171 |
-
return RGB
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def image_to_rgb(img0, channels=[0, 0]):
|
| 175 |
-
"""Converts image from 2 x Ly x Lx or Ly x Lx x 2 to RGB Ly x Lx x 3.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
img0 (ndarray): Input image of shape 2 x Ly x Lx or Ly x Lx x 2.
|
| 179 |
-
|
| 180 |
-
Returns:
|
| 181 |
-
ndarray: RGB image of shape Ly x Lx x 3.
|
| 182 |
-
|
| 183 |
-
"""
|
| 184 |
-
img = img0.copy()
|
| 185 |
-
img = img.astype(np.float32)
|
| 186 |
-
if img.ndim < 3:
|
| 187 |
-
img = img[:, :, np.newaxis]
|
| 188 |
-
if img.shape[0] < 5:
|
| 189 |
-
img = np.transpose(img, (1, 2, 0))
|
| 190 |
-
if channels[0] == 0:
|
| 191 |
-
img = img.mean(axis=-1)[:, :, np.newaxis]
|
| 192 |
-
for i in range(img.shape[-1]):
|
| 193 |
-
if np.ptp(img[:, :, i]) > 0:
|
| 194 |
-
img[:, :, i] = np.clip(transforms.normalize99(img[:, :, i]), 0, 1)
|
| 195 |
-
img[:, :, i] = np.clip(img[:, :, i], 0, 1)
|
| 196 |
-
img *= 255
|
| 197 |
-
img = np.uint8(img)
|
| 198 |
-
RGB = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
|
| 199 |
-
if img.shape[-1] == 1:
|
| 200 |
-
RGB = np.tile(img, (1, 1, 3))
|
| 201 |
-
else:
|
| 202 |
-
RGB[:, :, channels[0] - 1] = img[:, :, 0]
|
| 203 |
-
if channels[1] > 0:
|
| 204 |
-
RGB[:, :, channels[1] - 1] = img[:, :, 1]
|
| 205 |
-
return RGB
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def interesting_patch(mask, bsize=130):
|
| 209 |
-
"""
|
| 210 |
-
Get patch of size bsize x bsize with most masks.
|
| 211 |
-
|
| 212 |
-
Args:
|
| 213 |
-
mask (ndarray): Input mask.
|
| 214 |
-
bsize (int): Size of the patch.
|
| 215 |
-
|
| 216 |
-
Returns:
|
| 217 |
-
tuple: Patch coordinates (y, x).
|
| 218 |
-
|
| 219 |
-
"""
|
| 220 |
-
Ly, Lx = mask.shape
|
| 221 |
-
m = np.float32(mask > 0)
|
| 222 |
-
m = gaussian_filter(m, bsize / 2)
|
| 223 |
-
y, x = np.unravel_index(np.argmax(m), m.shape)
|
| 224 |
-
ycent = max(bsize // 2, min(y, Ly - bsize // 2))
|
| 225 |
-
xcent = max(bsize // 2, min(x, Lx - bsize // 2))
|
| 226 |
-
patch = [
|
| 227 |
-
np.arange(ycent - bsize // 2, ycent + bsize // 2, 1, int),
|
| 228 |
-
np.arange(xcent - bsize // 2, xcent + bsize // 2, 1, int)
|
| 229 |
-
]
|
| 230 |
-
return patch
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
def disk(med, r, Ly, Lx):
|
| 234 |
-
"""Returns the pixels of a disk with a given radius and center.
|
| 235 |
-
|
| 236 |
-
Args:
|
| 237 |
-
med (tuple): The center coordinates of the disk.
|
| 238 |
-
r (float): The radius of the disk.
|
| 239 |
-
Ly (int): The height of the image.
|
| 240 |
-
Lx (int): The width of the image.
|
| 241 |
-
|
| 242 |
-
Returns:
|
| 243 |
-
tuple: A tuple containing the y and x coordinates of the pixels within the disk.
|
| 244 |
-
|
| 245 |
-
"""
|
| 246 |
-
yy, xx = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
|
| 247 |
-
indexing="ij")
|
| 248 |
-
inds = ((yy - med[0])**2 + (xx - med[1])**2)**0.5 <= r
|
| 249 |
-
y = yy[inds].flatten()
|
| 250 |
-
x = xx[inds].flatten()
|
| 251 |
-
return y, x
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def outline_view(img0, maski, color=[1, 0, 0], mode="inner"):
|
| 255 |
-
"""
|
| 256 |
-
Generates a red outline overlay onto the image.
|
| 257 |
-
|
| 258 |
-
Args:
|
| 259 |
-
img0 (numpy.ndarray): The input image.
|
| 260 |
-
maski (numpy.ndarray): The mask representing the region of interest.
|
| 261 |
-
color (list, optional): The color of the outline overlay. Defaults to [1, 0, 0] (red).
|
| 262 |
-
mode (str, optional): The mode for generating the outline. Defaults to "inner".
|
| 263 |
-
|
| 264 |
-
Returns:
|
| 265 |
-
numpy.ndarray: The image with the red outline overlay.
|
| 266 |
-
|
| 267 |
-
"""
|
| 268 |
-
if img0.ndim == 2:
|
| 269 |
-
img0 = np.stack([img0] * 3, axis=-1)
|
| 270 |
-
elif img0.ndim != 3:
|
| 271 |
-
raise ValueError("img0 not right size (must have ndim 2 or 3)")
|
| 272 |
-
|
| 273 |
-
if SKIMAGE_ENABLED:
|
| 274 |
-
outlines = find_boundaries(maski, mode=mode)
|
| 275 |
-
else:
|
| 276 |
-
outlines = utils.masks_to_outlines(maski, mode=mode)
|
| 277 |
-
outY, outX = np.nonzero(outlines)
|
| 278 |
-
imgout = img0.copy()
|
| 279 |
-
imgout[outY, outX] = np.array(color)
|
| 280 |
-
|
| 281 |
-
return imgout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/utils.py
DELETED
|
@@ -1,667 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
import logging
|
| 5 |
-
import os, tempfile, shutil, io
|
| 6 |
-
from tqdm import tqdm, trange
|
| 7 |
-
from urllib.request import urlopen
|
| 8 |
-
import cv2
|
| 9 |
-
from scipy.ndimage import find_objects, gaussian_filter, generate_binary_structure, label
|
| 10 |
-
from scipy.spatial import ConvexHull
|
| 11 |
-
import numpy as np
|
| 12 |
-
import colorsys
|
| 13 |
-
import fastremap
|
| 14 |
-
import fill_voids
|
| 15 |
-
from multiprocessing import Pool, cpu_count
|
| 16 |
-
# try:
|
| 17 |
-
# from cellpose import metrics
|
| 18 |
-
# except:
|
| 19 |
-
# import metrics as metrics
|
| 20 |
-
from models.seg_post_model.cellpose import metrics
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
from skimage.morphology import remove_small_holes
|
| 24 |
-
SKIMAGE_ENABLED = True
|
| 25 |
-
except:
|
| 26 |
-
SKIMAGE_ENABLED = False
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class TqdmToLogger(io.StringIO):
|
| 30 |
-
"""
|
| 31 |
-
Output stream for TQDM which will output to logger module instead of
|
| 32 |
-
the StdOut.
|
| 33 |
-
"""
|
| 34 |
-
logger = None
|
| 35 |
-
level = None
|
| 36 |
-
buf = ""
|
| 37 |
-
|
| 38 |
-
def __init__(self, logger, level=None):
|
| 39 |
-
super(TqdmToLogger, self).__init__()
|
| 40 |
-
self.logger = logger
|
| 41 |
-
self.level = level or logging.INFO
|
| 42 |
-
|
| 43 |
-
def write(self, buf):
|
| 44 |
-
self.buf = buf.strip("\r\n\t ")
|
| 45 |
-
|
| 46 |
-
def flush(self):
|
| 47 |
-
self.logger.log(self.level, self.buf)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def rgb_to_hsv(arr):
|
| 51 |
-
rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv)
|
| 52 |
-
r, g, b = np.rollaxis(arr, axis=-1)
|
| 53 |
-
h, s, v = rgb_to_hsv_channels(r, g, b)
|
| 54 |
-
hsv = np.stack((h, s, v), axis=-1)
|
| 55 |
-
return hsv
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def hsv_to_rgb(arr):
|
| 59 |
-
hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb)
|
| 60 |
-
h, s, v = np.rollaxis(arr, axis=-1)
|
| 61 |
-
r, g, b = hsv_to_rgb_channels(h, s, v)
|
| 62 |
-
rgb = np.stack((r, g, b), axis=-1)
|
| 63 |
-
return rgb
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def download_url_to_file(url, dst, progress=True):
|
| 67 |
-
r"""Download object at the given URL to a local path.
|
| 68 |
-
Thanks to torch, slightly modified
|
| 69 |
-
Args:
|
| 70 |
-
url (string): URL of the object to download
|
| 71 |
-
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
|
| 72 |
-
progress (bool, optional): whether or not to display a progress bar to stderr
|
| 73 |
-
Default: True
|
| 74 |
-
"""
|
| 75 |
-
file_size = None
|
| 76 |
-
import ssl
|
| 77 |
-
ssl._create_default_https_context = ssl._create_unverified_context
|
| 78 |
-
u = urlopen(url)
|
| 79 |
-
meta = u.info()
|
| 80 |
-
if hasattr(meta, "getheaders"):
|
| 81 |
-
content_length = meta.getheaders("Content-Length")
|
| 82 |
-
else:
|
| 83 |
-
content_length = meta.get_all("Content-Length")
|
| 84 |
-
if content_length is not None and len(content_length) > 0:
|
| 85 |
-
file_size = int(content_length[0])
|
| 86 |
-
# We deliberately save it in a temp file and move it after
|
| 87 |
-
dst = os.path.expanduser(dst)
|
| 88 |
-
dst_dir = os.path.dirname(dst)
|
| 89 |
-
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
| 90 |
-
try:
|
| 91 |
-
with tqdm(total=file_size, disable=not progress, unit="B", unit_scale=True,
|
| 92 |
-
unit_divisor=1024) as pbar:
|
| 93 |
-
while True:
|
| 94 |
-
buffer = u.read(8192)
|
| 95 |
-
if len(buffer) == 0:
|
| 96 |
-
break
|
| 97 |
-
f.write(buffer)
|
| 98 |
-
pbar.update(len(buffer))
|
| 99 |
-
f.close()
|
| 100 |
-
shutil.move(f.name, dst)
|
| 101 |
-
finally:
|
| 102 |
-
f.close()
|
| 103 |
-
if os.path.exists(f.name):
|
| 104 |
-
os.remove(f.name)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def distance_to_boundary(masks):
|
| 108 |
-
"""Get the distance to the boundary of mask pixels.
|
| 109 |
-
|
| 110 |
-
Args:
|
| 111 |
-
masks (int, 2D or 3D array): The masks array. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
|
| 112 |
-
|
| 113 |
-
Returns:
|
| 114 |
-
dist_to_bound (2D or 3D array): The distance to the boundary. Size [Ly x Lx] or [Lz x Ly x Lx].
|
| 115 |
-
|
| 116 |
-
Raises:
|
| 117 |
-
ValueError: If the masks array is not 2D or 3D.
|
| 118 |
-
|
| 119 |
-
"""
|
| 120 |
-
if masks.ndim > 3 or masks.ndim < 2:
|
| 121 |
-
raise ValueError("distance_to_boundary takes 2D or 3D array, not %dD array" %
|
| 122 |
-
masks.ndim)
|
| 123 |
-
dist_to_bound = np.zeros(masks.shape, np.float64)
|
| 124 |
-
|
| 125 |
-
if masks.ndim == 3:
|
| 126 |
-
for i in range(masks.shape[0]):
|
| 127 |
-
dist_to_bound[i] = distance_to_boundary(masks[i])
|
| 128 |
-
return dist_to_bound
|
| 129 |
-
else:
|
| 130 |
-
slices = find_objects(masks)
|
| 131 |
-
for i, si in enumerate(slices):
|
| 132 |
-
if si is not None:
|
| 133 |
-
sr, sc = si
|
| 134 |
-
mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
|
| 135 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 136 |
-
cv2.CHAIN_APPROX_NONE)
|
| 137 |
-
pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
|
| 138 |
-
ypix, xpix = np.nonzero(mask)
|
| 139 |
-
min_dist = ((ypix[:, np.newaxis] - pvr)**2 +
|
| 140 |
-
(xpix[:, np.newaxis] - pvc)**2).min(axis=1)
|
| 141 |
-
dist_to_bound[ypix + sr.start, xpix + sc.start] = min_dist
|
| 142 |
-
return dist_to_bound
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def masks_to_edges(masks, threshold=1.0):
|
| 146 |
-
"""Get edges of masks as a 0-1 array.
|
| 147 |
-
|
| 148 |
-
Args:
|
| 149 |
-
masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
|
| 150 |
-
threshold (float, optional): Threshold value for distance to boundary. Defaults to 1.0.
|
| 151 |
-
|
| 152 |
-
Returns:
|
| 153 |
-
edges (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are edge pixels.
|
| 154 |
-
"""
|
| 155 |
-
dist_to_bound = distance_to_boundary(masks)
|
| 156 |
-
edges = (dist_to_bound < threshold) * (masks > 0)
|
| 157 |
-
return edges
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def remove_edge_masks(masks, change_index=True):
|
| 161 |
-
"""Removes masks with pixels on the edge of the image.
|
| 162 |
-
|
| 163 |
-
Args:
|
| 164 |
-
masks (int, 2D or 3D array): The masks to be processed. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
|
| 165 |
-
change_index (bool, optional): If True, after removing masks, changes the indexing so that there are no missing label numbers. Defaults to True.
|
| 166 |
-
|
| 167 |
-
Returns:
|
| 168 |
-
outlines (2D or 3D array): The processed masks. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
|
| 169 |
-
"""
|
| 170 |
-
slices = find_objects(masks.astype(int))
|
| 171 |
-
for i, si in enumerate(slices):
|
| 172 |
-
remove = False
|
| 173 |
-
if si is not None:
|
| 174 |
-
for d, sid in enumerate(si):
|
| 175 |
-
if sid.start == 0 or sid.stop == masks.shape[d]:
|
| 176 |
-
remove = True
|
| 177 |
-
break
|
| 178 |
-
if remove:
|
| 179 |
-
masks[si][masks[si] == i + 1] = 0
|
| 180 |
-
shape = masks.shape
|
| 181 |
-
if change_index:
|
| 182 |
-
_, masks = np.unique(masks, return_inverse=True)
|
| 183 |
-
masks = np.reshape(masks, shape).astype(np.int32)
|
| 184 |
-
|
| 185 |
-
return masks
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def masks_to_outlines(masks):
|
| 189 |
-
"""Get outlines of masks as a 0-1 array.
|
| 190 |
-
|
| 191 |
-
Args:
|
| 192 |
-
masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
|
| 193 |
-
|
| 194 |
-
Returns:
|
| 195 |
-
outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines.
|
| 196 |
-
"""
|
| 197 |
-
if masks.ndim > 3 or masks.ndim < 2:
|
| 198 |
-
raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
|
| 199 |
-
masks.ndim)
|
| 200 |
-
outlines = np.zeros(masks.shape, bool)
|
| 201 |
-
|
| 202 |
-
if masks.ndim == 3:
|
| 203 |
-
for i in range(masks.shape[0]):
|
| 204 |
-
outlines[i] = masks_to_outlines(masks[i])
|
| 205 |
-
return outlines
|
| 206 |
-
else:
|
| 207 |
-
slices = find_objects(masks.astype(int))
|
| 208 |
-
for i, si in enumerate(slices):
|
| 209 |
-
if si is not None:
|
| 210 |
-
sr, sc = si
|
| 211 |
-
mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
|
| 212 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 213 |
-
cv2.CHAIN_APPROX_NONE)
|
| 214 |
-
pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
|
| 215 |
-
vr, vc = pvr + sr.start, pvc + sc.start
|
| 216 |
-
outlines[vr, vc] = 1
|
| 217 |
-
return outlines
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None):
|
| 221 |
-
"""Get outlines of masks as a list to loop over for plotting.
|
| 222 |
-
|
| 223 |
-
Args:
|
| 224 |
-
masks (ndarray): Array of masks.
|
| 225 |
-
multiprocessing_threshold (int, optional): Threshold for enabling multiprocessing. Defaults to 1000.
|
| 226 |
-
multiprocessing (bool, optional): Flag to enable multiprocessing. Defaults to None.
|
| 227 |
-
|
| 228 |
-
Returns:
|
| 229 |
-
list: List of outlines.
|
| 230 |
-
|
| 231 |
-
Raises:
|
| 232 |
-
None
|
| 233 |
-
|
| 234 |
-
Notes:
|
| 235 |
-
- This function is a wrapper for outlines_list_single and outlines_list_multi.
|
| 236 |
-
- Multiprocessing is disabled for Windows.
|
| 237 |
-
"""
|
| 238 |
-
# default to use multiprocessing if not few_masks, but allow user to override
|
| 239 |
-
if multiprocessing is None:
|
| 240 |
-
few_masks = np.max(masks) < multiprocessing_threshold
|
| 241 |
-
multiprocessing = not few_masks
|
| 242 |
-
|
| 243 |
-
# disable multiprocessing for Windows
|
| 244 |
-
if os.name == "nt":
|
| 245 |
-
if multiprocessing:
|
| 246 |
-
logging.getLogger(__name__).warning(
|
| 247 |
-
"Multiprocessing is disabled for Windows")
|
| 248 |
-
multiprocessing = False
|
| 249 |
-
|
| 250 |
-
if multiprocessing:
|
| 251 |
-
return outlines_list_multi(masks)
|
| 252 |
-
else:
|
| 253 |
-
return outlines_list_single(masks)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def outlines_list_single(masks):
|
| 257 |
-
"""Get outlines of masks as a list to loop over for plotting.
|
| 258 |
-
|
| 259 |
-
Args:
|
| 260 |
-
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 261 |
-
|
| 262 |
-
Returns:
|
| 263 |
-
list: List of outlines as pixel coordinates.
|
| 264 |
-
|
| 265 |
-
"""
|
| 266 |
-
outpix = []
|
| 267 |
-
for n in np.unique(masks)[1:]:
|
| 268 |
-
mn = masks == n
|
| 269 |
-
if mn.sum() > 0:
|
| 270 |
-
contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
|
| 271 |
-
method=cv2.CHAIN_APPROX_NONE)
|
| 272 |
-
contours = contours[-2]
|
| 273 |
-
cmax = np.argmax([c.shape[0] for c in contours])
|
| 274 |
-
pix = contours[cmax].astype(int).squeeze()
|
| 275 |
-
if len(pix) > 4:
|
| 276 |
-
outpix.append(pix)
|
| 277 |
-
else:
|
| 278 |
-
outpix.append(np.zeros((0, 2)))
|
| 279 |
-
return outpix
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def outlines_list_multi(masks, num_processes=None):
|
| 283 |
-
"""
|
| 284 |
-
Get outlines of masks as a list to loop over for plotting.
|
| 285 |
-
|
| 286 |
-
Args:
|
| 287 |
-
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 288 |
-
|
| 289 |
-
Returns:
|
| 290 |
-
list: List of outlines as pixel coordinates.
|
| 291 |
-
"""
|
| 292 |
-
if num_processes is None:
|
| 293 |
-
num_processes = cpu_count()
|
| 294 |
-
|
| 295 |
-
unique_masks = np.unique(masks)[1:]
|
| 296 |
-
with Pool(processes=num_processes) as pool:
|
| 297 |
-
outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks])
|
| 298 |
-
return outpix
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
def get_outline_multi(args):
|
| 302 |
-
"""Get the outline of a specific mask in a multi-mask image.
|
| 303 |
-
|
| 304 |
-
Args:
|
| 305 |
-
args (tuple): A tuple containing the masks and the mask number.
|
| 306 |
-
|
| 307 |
-
Returns:
|
| 308 |
-
numpy.ndarray: The outline of the specified mask as an array of coordinates.
|
| 309 |
-
|
| 310 |
-
"""
|
| 311 |
-
masks, n = args
|
| 312 |
-
mn = masks == n
|
| 313 |
-
if mn.sum() > 0:
|
| 314 |
-
contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
|
| 315 |
-
method=cv2.CHAIN_APPROX_NONE)
|
| 316 |
-
contours = contours[-2]
|
| 317 |
-
cmax = np.argmax([c.shape[0] for c in contours])
|
| 318 |
-
pix = contours[cmax].astype(int).squeeze()
|
| 319 |
-
return pix if len(pix) > 4 else np.zeros((0, 2))
|
| 320 |
-
return np.zeros((0, 2))
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
def dilate_masks(masks, n_iter=5):
|
| 324 |
-
"""Dilate masks by n_iter pixels.
|
| 325 |
-
|
| 326 |
-
Args:
|
| 327 |
-
masks (ndarray): Array of masks.
|
| 328 |
-
n_iter (int, optional): Number of pixels to dilate the masks. Defaults to 5.
|
| 329 |
-
|
| 330 |
-
Returns:
|
| 331 |
-
ndarray: Dilated masks.
|
| 332 |
-
"""
|
| 333 |
-
dilated_masks = masks.copy()
|
| 334 |
-
for n in range(n_iter):
|
| 335 |
-
# define the structuring element to use for dilation
|
| 336 |
-
kernel = np.ones((3, 3), "uint8")
|
| 337 |
-
# find the distance to each mask (distances are zero within masks)
|
| 338 |
-
dist_transform = cv2.distanceTransform((dilated_masks == 0).astype("uint8"),
|
| 339 |
-
cv2.DIST_L2, 5)
|
| 340 |
-
# dilate each mask and assign to it the pixels along the border of the mask
|
| 341 |
-
# (does not allow dilation into other masks since dist_transform is zero there)
|
| 342 |
-
for i in range(1, np.max(masks) + 1):
|
| 343 |
-
mask = (dilated_masks == i).astype("uint8")
|
| 344 |
-
dilated_mask = cv2.dilate(mask, kernel, iterations=1)
|
| 345 |
-
dilated_mask = np.logical_and(dist_transform < 2, dilated_mask)
|
| 346 |
-
dilated_masks[dilated_mask > 0] = i
|
| 347 |
-
return dilated_masks
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
def get_perimeter(points):
|
| 351 |
-
"""
|
| 352 |
-
Calculate the perimeter of a set of points.
|
| 353 |
-
|
| 354 |
-
Parameters:
|
| 355 |
-
points (ndarray): An array of points with shape (npoints, ndim).
|
| 356 |
-
|
| 357 |
-
Returns:
|
| 358 |
-
float: The perimeter of the points.
|
| 359 |
-
|
| 360 |
-
"""
|
| 361 |
-
if points.shape[0] > 4:
|
| 362 |
-
points = np.append(points, points[:1], axis=0)
|
| 363 |
-
return ((np.diff(points, axis=0)**2).sum(axis=1)**0.5).sum()
|
| 364 |
-
else:
|
| 365 |
-
return 0
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
def get_mask_compactness(masks):
|
| 369 |
-
"""
|
| 370 |
-
Calculate the compactness of masks.
|
| 371 |
-
|
| 372 |
-
Parameters:
|
| 373 |
-
masks (ndarray): Binary masks representing objects.
|
| 374 |
-
|
| 375 |
-
Returns:
|
| 376 |
-
ndarray: Array of compactness values for each mask.
|
| 377 |
-
"""
|
| 378 |
-
perimeters = get_mask_perimeters(masks)
|
| 379 |
-
npoints = np.unique(masks, return_counts=True)[1][1:]
|
| 380 |
-
areas = npoints
|
| 381 |
-
compactness = 4 * np.pi * areas / perimeters**2
|
| 382 |
-
compactness[perimeters == 0] = 0
|
| 383 |
-
compactness[compactness > 1.0] = 1.0
|
| 384 |
-
return compactness
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
def get_mask_perimeters(masks):
|
| 388 |
-
"""
|
| 389 |
-
Calculate the perimeters of the given masks.
|
| 390 |
-
|
| 391 |
-
Parameters:
|
| 392 |
-
masks (numpy.ndarray): Binary masks representing objects.
|
| 393 |
-
|
| 394 |
-
Returns:
|
| 395 |
-
numpy.ndarray: Array containing the perimeters of each mask.
|
| 396 |
-
"""
|
| 397 |
-
perimeters = np.zeros(masks.max())
|
| 398 |
-
for n in range(masks.max()):
|
| 399 |
-
mn = masks == (n + 1)
|
| 400 |
-
if mn.sum() > 0:
|
| 401 |
-
contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
|
| 402 |
-
method=cv2.CHAIN_APPROX_NONE)[-2]
|
| 403 |
-
perimeters[n] = np.array(
|
| 404 |
-
[get_perimeter(c.astype(int).squeeze()) for c in contours]).sum()
|
| 405 |
-
|
| 406 |
-
return perimeters
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
def circleMask(d0):
|
| 410 |
-
"""
|
| 411 |
-
Creates an array with indices which are the radius of that x,y point.
|
| 412 |
-
|
| 413 |
-
Args:
|
| 414 |
-
d0 (tuple): Patch of (-d0, d0+1) over which radius is computed.
|
| 415 |
-
|
| 416 |
-
Returns:
|
| 417 |
-
tuple: A tuple containing:
|
| 418 |
-
- rs (ndarray): Array of radii with shape (2*d0[0]+1, 2*d0[1]+1).
|
| 419 |
-
- dx (ndarray): Indices of the patch along the x-axis.
|
| 420 |
-
- dy (ndarray): Indices of the patch along the y-axis.
|
| 421 |
-
"""
|
| 422 |
-
dx = np.tile(np.arange(-d0[1], d0[1] + 1), (2 * d0[0] + 1, 1))
|
| 423 |
-
dy = np.tile(np.arange(-d0[0], d0[0] + 1), (2 * d0[1] + 1, 1))
|
| 424 |
-
dy = dy.transpose()
|
| 425 |
-
|
| 426 |
-
rs = (dy**2 + dx**2)**0.5
|
| 427 |
-
return rs, dx, dy
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
def get_mask_stats(masks_true):
|
| 431 |
-
"""
|
| 432 |
-
Calculate various statistics for the given binary masks.
|
| 433 |
-
|
| 434 |
-
Parameters:
|
| 435 |
-
masks_true (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 436 |
-
|
| 437 |
-
Returns:
|
| 438 |
-
convexity (ndarray): Convexity values for each mask.
|
| 439 |
-
solidity (ndarray): Solidity values for each mask.
|
| 440 |
-
compactness (ndarray): Compactness values for each mask.
|
| 441 |
-
"""
|
| 442 |
-
mask_perimeters = get_mask_perimeters(masks_true)
|
| 443 |
-
|
| 444 |
-
# disk for compactness
|
| 445 |
-
rs, dy, dx = circleMask(np.array([100, 100]))
|
| 446 |
-
rsort = np.sort(rs.flatten())
|
| 447 |
-
|
| 448 |
-
# area for solidity
|
| 449 |
-
npoints = np.unique(masks_true, return_counts=True)[1][1:]
|
| 450 |
-
areas = npoints - mask_perimeters / 2 - 1
|
| 451 |
-
|
| 452 |
-
compactness = np.zeros(masks_true.max())
|
| 453 |
-
convexity = np.zeros(masks_true.max())
|
| 454 |
-
solidity = np.zeros(masks_true.max())
|
| 455 |
-
convex_perimeters = np.zeros(masks_true.max())
|
| 456 |
-
convex_areas = np.zeros(masks_true.max())
|
| 457 |
-
for ic in range(masks_true.max()):
|
| 458 |
-
points = np.array(np.nonzero(masks_true == (ic + 1))).T
|
| 459 |
-
if len(points) > 15 and mask_perimeters[ic] > 0:
|
| 460 |
-
med = np.median(points, axis=0)
|
| 461 |
-
# compute compactness of ROI
|
| 462 |
-
r2 = ((points - med)**2).sum(axis=1)**0.5
|
| 463 |
-
compactness[ic] = (rsort[:r2.size].mean() + 1e-10) / r2.mean()
|
| 464 |
-
try:
|
| 465 |
-
hull = ConvexHull(points)
|
| 466 |
-
convex_perimeters[ic] = hull.area
|
| 467 |
-
convex_areas[ic] = hull.volume
|
| 468 |
-
except:
|
| 469 |
-
convex_perimeters[ic] = 0
|
| 470 |
-
|
| 471 |
-
convexity[mask_perimeters > 0.0] = (convex_perimeters[mask_perimeters > 0.0] /
|
| 472 |
-
mask_perimeters[mask_perimeters > 0.0])
|
| 473 |
-
solidity[convex_areas > 0.0] = (areas[convex_areas > 0.0] /
|
| 474 |
-
convex_areas[convex_areas > 0.0])
|
| 475 |
-
convexity = np.clip(convexity, 0.0, 1.0)
|
| 476 |
-
solidity = np.clip(solidity, 0.0, 1.0)
|
| 477 |
-
compactness = np.clip(compactness, 0.0, 1.0)
|
| 478 |
-
return convexity, solidity, compactness
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
def get_masks_unet(output, cell_threshold=0, boundary_threshold=0):
|
| 482 |
-
"""Create masks using cell probability and cell boundary.
|
| 483 |
-
|
| 484 |
-
Args:
|
| 485 |
-
output (ndarray): The output array containing cell probability and cell boundary.
|
| 486 |
-
cell_threshold (float, optional): The threshold value for cell probability. Defaults to 0.
|
| 487 |
-
boundary_threshold (float, optional): The threshold value for cell boundary. Defaults to 0.
|
| 488 |
-
|
| 489 |
-
Returns:
|
| 490 |
-
ndarray: The masks representing the segmented cells.
|
| 491 |
-
|
| 492 |
-
"""
|
| 493 |
-
cells = (output[..., 1] - output[..., 0]) > cell_threshold
|
| 494 |
-
selem = generate_binary_structure(cells.ndim, connectivity=1)
|
| 495 |
-
labels, nlabels = label(cells, selem)
|
| 496 |
-
|
| 497 |
-
if output.shape[-1] > 2:
|
| 498 |
-
slices = find_objects(labels)
|
| 499 |
-
dists = 10000 * np.ones(labels.shape, np.float32)
|
| 500 |
-
mins = np.zeros(labels.shape, np.int32)
|
| 501 |
-
borders = np.logical_and(~(labels > 0), output[..., 2] > boundary_threshold)
|
| 502 |
-
pad = 10
|
| 503 |
-
for i, slc in enumerate(slices):
|
| 504 |
-
if slc is not None:
|
| 505 |
-
slc_pad = tuple([
|
| 506 |
-
slice(max(0, sli.start - pad), min(labels.shape[j], sli.stop + pad))
|
| 507 |
-
for j, sli in enumerate(slc)
|
| 508 |
-
])
|
| 509 |
-
msk = (labels[slc_pad] == (i + 1)).astype(np.float32)
|
| 510 |
-
msk = 1 - gaussian_filter(msk, 5)
|
| 511 |
-
dists[slc_pad] = np.minimum(dists[slc_pad], msk)
|
| 512 |
-
mins[slc_pad][dists[slc_pad] == msk] = (i + 1)
|
| 513 |
-
labels[labels == 0] = borders[labels == 0] * mins[labels == 0]
|
| 514 |
-
|
| 515 |
-
masks = labels
|
| 516 |
-
shape0 = masks.shape
|
| 517 |
-
_, masks = np.unique(masks, return_inverse=True)
|
| 518 |
-
masks = np.reshape(masks, shape0)
|
| 519 |
-
return masks
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
def stitch3D(masks, stitch_threshold=0.25):
|
| 523 |
-
"""
|
| 524 |
-
Stitch 2D masks into a 3D volume using a stitch_threshold on IOU.
|
| 525 |
-
|
| 526 |
-
Args:
|
| 527 |
-
masks (list or ndarray): List of 2D masks.
|
| 528 |
-
stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25.
|
| 529 |
-
|
| 530 |
-
Returns:
|
| 531 |
-
list: List of stitched 3D masks.
|
| 532 |
-
"""
|
| 533 |
-
mmax = masks[0].max()
|
| 534 |
-
empty = 0
|
| 535 |
-
for i in trange(len(masks) - 1):
|
| 536 |
-
iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:]
|
| 537 |
-
if not iou.size and empty == 0:
|
| 538 |
-
masks[i + 1] = masks[i + 1]
|
| 539 |
-
mmax = masks[i + 1].max()
|
| 540 |
-
elif not iou.size and not empty == 0:
|
| 541 |
-
icount = masks[i + 1].max()
|
| 542 |
-
istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype)
|
| 543 |
-
mmax += icount
|
| 544 |
-
istitch = np.append(np.array(0), istitch)
|
| 545 |
-
masks[i + 1] = istitch[masks[i + 1]]
|
| 546 |
-
else:
|
| 547 |
-
iou[iou < stitch_threshold] = 0.0
|
| 548 |
-
iou[iou < iou.max(axis=0)] = 0.0
|
| 549 |
-
istitch = iou.argmax(axis=1) + 1
|
| 550 |
-
ino = np.nonzero(iou.max(axis=1) == 0.0)[0]
|
| 551 |
-
istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype)
|
| 552 |
-
mmax += len(ino)
|
| 553 |
-
istitch = np.append(np.array(0), istitch)
|
| 554 |
-
masks[i + 1] = istitch[masks[i + 1]]
|
| 555 |
-
empty = 1
|
| 556 |
-
|
| 557 |
-
return masks
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
def diameters(masks):
|
| 561 |
-
"""
|
| 562 |
-
Calculate the diameters of the objects in the given masks.
|
| 563 |
-
|
| 564 |
-
Parameters:
|
| 565 |
-
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 566 |
-
|
| 567 |
-
Returns:
|
| 568 |
-
tuple: A tuple containing the median diameter and an array of diameters for each object.
|
| 569 |
-
|
| 570 |
-
Examples:
|
| 571 |
-
>>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]])
|
| 572 |
-
>>> diameters(masks)
|
| 573 |
-
(1.0, array([1.41421356, 1.0, 1.0]))
|
| 574 |
-
"""
|
| 575 |
-
uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True)
|
| 576 |
-
counts = counts[1:]
|
| 577 |
-
md = np.median(counts**0.5)
|
| 578 |
-
if np.isnan(md):
|
| 579 |
-
md = 0
|
| 580 |
-
md /= (np.pi**0.5) / 2
|
| 581 |
-
return md, counts**0.5
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
def radius_distribution(masks, bins):
|
| 585 |
-
"""
|
| 586 |
-
Calculate the radius distribution of masks.
|
| 587 |
-
|
| 588 |
-
Args:
|
| 589 |
-
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 590 |
-
bins (int): Number of bins for the histogram.
|
| 591 |
-
|
| 592 |
-
Returns:
|
| 593 |
-
A tuple containing a normalized histogram of radii, median radius, array of radii.
|
| 594 |
-
|
| 595 |
-
"""
|
| 596 |
-
unique, counts = np.unique(masks, return_counts=True)
|
| 597 |
-
counts = counts[unique != 0]
|
| 598 |
-
nb, _ = np.histogram((counts**0.5) * 0.5, bins)
|
| 599 |
-
nb = nb.astype(np.float32)
|
| 600 |
-
if nb.sum() > 0:
|
| 601 |
-
nb = nb / nb.sum()
|
| 602 |
-
md = np.median(counts**0.5) * 0.5
|
| 603 |
-
if np.isnan(md):
|
| 604 |
-
md = 0
|
| 605 |
-
md /= (np.pi**0.5) / 2
|
| 606 |
-
return nb, md, (counts**0.5) / 2
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
def size_distribution(masks):
|
| 610 |
-
"""
|
| 611 |
-
Calculates the size distribution of masks.
|
| 612 |
-
|
| 613 |
-
Args:
|
| 614 |
-
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 615 |
-
|
| 616 |
-
Returns:
|
| 617 |
-
float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes.
|
| 618 |
-
"""
|
| 619 |
-
counts = np.unique(masks, return_counts=True)[1][1:]
|
| 620 |
-
return np.percentile(counts, 25) / np.percentile(counts, 75)
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
def fill_holes_and_remove_small_masks(masks, min_size=15):
|
| 624 |
-
""" Fills holes in masks (2D/3D) and discards masks smaller than min_size.
|
| 625 |
-
|
| 626 |
-
This function fills holes in each mask using fill_voids.fill.
|
| 627 |
-
It also removes masks that are smaller than the specified min_size.
|
| 628 |
-
|
| 629 |
-
Parameters:
|
| 630 |
-
masks (ndarray): Int, 2D or 3D array of labelled masks.
|
| 631 |
-
0 represents no mask, while positive integers represent mask labels.
|
| 632 |
-
The size can be [Ly x Lx] or [Lz x Ly x Lx].
|
| 633 |
-
min_size (int, optional): Minimum number of pixels per mask.
|
| 634 |
-
Masks smaller than min_size will be removed.
|
| 635 |
-
Set to -1 to turn off this functionality. Default is 15.
|
| 636 |
-
|
| 637 |
-
Returns:
|
| 638 |
-
ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
|
| 639 |
-
0 represents no mask, while positive integers represent mask labels.
|
| 640 |
-
The size is [Ly x Lx] or [Lz x Ly x Lx].
|
| 641 |
-
"""
|
| 642 |
-
|
| 643 |
-
if masks.ndim > 3 or masks.ndim < 2:
|
| 644 |
-
raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
|
| 645 |
-
masks.ndim)
|
| 646 |
-
|
| 647 |
-
# Filter small masks
|
| 648 |
-
if min_size > 0:
|
| 649 |
-
counts = fastremap.unique(masks, return_counts=True)[1][1:]
|
| 650 |
-
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
|
| 651 |
-
fastremap.renumber(masks, in_place=True)
|
| 652 |
-
|
| 653 |
-
slices = find_objects(masks)
|
| 654 |
-
j = 0
|
| 655 |
-
for i, slc in enumerate(slices):
|
| 656 |
-
if slc is not None:
|
| 657 |
-
msk = masks[slc] == (i + 1)
|
| 658 |
-
msk = fill_voids.fill(msk)
|
| 659 |
-
masks[slc][msk] = (j + 1)
|
| 660 |
-
j += 1
|
| 661 |
-
|
| 662 |
-
if min_size > 0:
|
| 663 |
-
counts = fastremap.unique(masks, return_counts=True)[1][1:]
|
| 664 |
-
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
|
| 665 |
-
fastremap.renumber(masks, in_place=True)
|
| 666 |
-
|
| 667 |
-
return masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/version.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
from importlib.metadata import PackageNotFoundError, version
|
| 5 |
-
import sys
|
| 6 |
-
from platform import python_version
|
| 7 |
-
import torch
|
| 8 |
-
|
| 9 |
-
try:
|
| 10 |
-
version = version("cellpose")
|
| 11 |
-
except PackageNotFoundError:
|
| 12 |
-
version = "unknown"
|
| 13 |
-
|
| 14 |
-
version_str = f"""
|
| 15 |
-
cellpose version: \t{version}
|
| 16 |
-
platform: \t{sys.platform}
|
| 17 |
-
python version: \t{python_version()}
|
| 18 |
-
torch version: \t{torch.__version__}"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/{cellpose/core.py โ core.py}
RENAMED
|
@@ -109,39 +109,6 @@ def assign_device(use_torch=True, gpu=False, device=0):
|
|
| 109 |
return device, gpu
|
| 110 |
|
| 111 |
|
| 112 |
-
def _to_device(x, device, dtype=torch.float32):
|
| 113 |
-
"""
|
| 114 |
-
Converts the input tensor or numpy array to the specified device.
|
| 115 |
-
|
| 116 |
-
Args:
|
| 117 |
-
x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
|
| 118 |
-
device (torch.device): The target device.
|
| 119 |
-
|
| 120 |
-
Returns:
|
| 121 |
-
torch.Tensor: The converted tensor on the specified device.
|
| 122 |
-
"""
|
| 123 |
-
if not isinstance(x, torch.Tensor):
|
| 124 |
-
X = torch.from_numpy(x).to(device, dtype=dtype)
|
| 125 |
-
return X
|
| 126 |
-
else:
|
| 127 |
-
return x
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def _from_device(X):
|
| 131 |
-
"""
|
| 132 |
-
Converts a PyTorch tensor from the device to a NumPy array on the CPU.
|
| 133 |
-
|
| 134 |
-
Args:
|
| 135 |
-
X (torch.Tensor): The input PyTorch tensor.
|
| 136 |
-
|
| 137 |
-
Returns:
|
| 138 |
-
numpy.ndarray: The converted NumPy array.
|
| 139 |
-
"""
|
| 140 |
-
# The cast is so numpy conversion always works
|
| 141 |
-
x = X.detach().cpu().to(torch.float32).numpy()
|
| 142 |
-
return x
|
| 143 |
-
|
| 144 |
-
|
| 145 |
def _forward(net, x, feat=None):
|
| 146 |
"""Converts images to torch tensors, runs the network model, and returns numpy arrays.
|
| 147 |
|
|
@@ -152,15 +119,19 @@ def _forward(net, x, feat=None):
|
|
| 152 |
Returns:
|
| 153 |
Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
|
| 154 |
"""
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
| 156 |
if feat is not None:
|
| 157 |
-
|
|
|
|
| 158 |
net.eval()
|
| 159 |
with torch.no_grad():
|
| 160 |
y, style = net(X, feat=feat)[:2]
|
| 161 |
del X
|
| 162 |
-
y =
|
| 163 |
-
style =
|
| 164 |
return y, style
|
| 165 |
|
| 166 |
|
|
@@ -269,54 +240,3 @@ def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1,
|
|
| 269 |
yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
|
| 270 |
yf = yf.transpose(0,2,3,1)
|
| 271 |
return yf, np.array(styles)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
def run_3D(net, imgs, batch_size=8, augment=False,
|
| 275 |
-
tile_overlap=0.1, bsize=224, net_ortho=None,
|
| 276 |
-
progress=None):
|
| 277 |
-
"""
|
| 278 |
-
Run network on image z-stack.
|
| 279 |
-
|
| 280 |
-
(faster if augment is False)
|
| 281 |
-
|
| 282 |
-
Args:
|
| 283 |
-
imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
|
| 284 |
-
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
|
| 285 |
-
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
|
| 286 |
-
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
|
| 287 |
-
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
|
| 288 |
-
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 289 |
-
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
|
| 290 |
-
net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
|
| 291 |
-
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
|
| 292 |
-
|
| 293 |
-
Returns:
|
| 294 |
-
Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
|
| 295 |
-
y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
|
| 296 |
-
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
|
| 297 |
-
"""
|
| 298 |
-
sstr = ["YX", "ZY", "ZX"]
|
| 299 |
-
pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
|
| 300 |
-
ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
|
| 301 |
-
cp = [(1, 2), (0, 2), (0, 1)]
|
| 302 |
-
cpy = [(0, 1), (0, 1), (0, 1)]
|
| 303 |
-
shape = imgs.shape[:-1]
|
| 304 |
-
yf = np.zeros((*shape, 4), "float32")
|
| 305 |
-
for p in range(3):
|
| 306 |
-
xsl = imgs.transpose(pm[p])
|
| 307 |
-
# per image
|
| 308 |
-
core_logger.info("running %s: %d planes of size (%d, %d)" %
|
| 309 |
-
(sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
|
| 310 |
-
y, style = run_net(net,
|
| 311 |
-
xsl, batch_size=batch_size, augment=augment,
|
| 312 |
-
bsize=bsize, tile_overlap=tile_overlap,
|
| 313 |
-
rsz=None)
|
| 314 |
-
yf[..., -1] += y[..., -1].transpose(ipm[p])
|
| 315 |
-
for j in range(2):
|
| 316 |
-
yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
|
| 317 |
-
y = None; del y
|
| 318 |
-
|
| 319 |
-
if progress is not None:
|
| 320 |
-
progress.setValue(25 + 15 * p)
|
| 321 |
-
|
| 322 |
-
return yf, style
|
|
|
|
| 109 |
return device, gpu
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def _forward(net, x, feat=None):
|
| 113 |
"""Converts images to torch tensors, runs the network model, and returns numpy arrays.
|
| 114 |
|
|
|
|
| 119 |
Returns:
|
| 120 |
Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
|
| 121 |
"""
|
| 122 |
+
if not isinstance(x, torch.Tensor):
|
| 123 |
+
X = torch.from_numpy(x).to(net.device, dtype=net.dtype)
|
| 124 |
+
else:
|
| 125 |
+
X = x
|
| 126 |
if feat is not None:
|
| 127 |
+
if not isinstance(feat, torch.Tensor):
|
| 128 |
+
feat = torch.from_numpy(feat).to(net.device, dtype=net.dtype)
|
| 129 |
net.eval()
|
| 130 |
with torch.no_grad():
|
| 131 |
y, style = net(X, feat=feat)[:2]
|
| 132 |
del X
|
| 133 |
+
y = y.detach().cpu().to(torch.float32).numpy()
|
| 134 |
+
style = style.detach().cpu().to(torch.float32).numpy()
|
| 135 |
return y, style
|
| 136 |
|
| 137 |
|
|
|
|
| 240 |
yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
|
| 241 |
yf = yf.transpose(0,2,3,1)
|
| 242 |
return yf, np.array(styles)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/{cellpose/dynamics.py โ dynamics.py}
RENAMED
|
@@ -151,126 +151,6 @@ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
|
|
| 151 |
mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
|
| 152 |
return mu0
|
| 153 |
|
| 154 |
-
def masks_to_flows_gpu_3d(masks, device=None, niter=None):
|
| 155 |
-
"""Convert masks to flows using diffusion from center pixel.
|
| 156 |
-
|
| 157 |
-
Args:
|
| 158 |
-
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
|
| 159 |
-
device (torch.device, optional): The device to run the computation on. Defaults to None.
|
| 160 |
-
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
|
| 161 |
-
|
| 162 |
-
Returns:
|
| 163 |
-
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
|
| 164 |
-
|
| 165 |
-
"""
|
| 166 |
-
if device is None:
|
| 167 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 168 |
-
|
| 169 |
-
Lz0, Ly0, Lx0 = masks.shape
|
| 170 |
-
Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
|
| 171 |
-
|
| 172 |
-
masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
|
| 173 |
-
masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
|
| 174 |
-
|
| 175 |
-
# get mask pixel neighbors
|
| 176 |
-
z, y, x = torch.nonzero(masks_padded).T
|
| 177 |
-
neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
|
| 178 |
-
neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
|
| 179 |
-
neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)
|
| 180 |
-
|
| 181 |
-
neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)
|
| 182 |
-
|
| 183 |
-
# get mask centers
|
| 184 |
-
slices = find_objects(masks)
|
| 185 |
-
|
| 186 |
-
centers = np.zeros((masks.max(), 3), "int")
|
| 187 |
-
for i, si in enumerate(slices):
|
| 188 |
-
if si is not None:
|
| 189 |
-
sz, sy, sx = si
|
| 190 |
-
zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1))
|
| 191 |
-
zi = zi.astype(np.int32) + 1 # add padding
|
| 192 |
-
yi = yi.astype(np.int32) + 1 # add padding
|
| 193 |
-
xi = xi.astype(np.int32) + 1 # add padding
|
| 194 |
-
zmed = np.mean(zi)
|
| 195 |
-
ymed = np.mean(yi)
|
| 196 |
-
xmed = np.mean(xi)
|
| 197 |
-
imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2)
|
| 198 |
-
zmed = zi[imin]
|
| 199 |
-
ymed = yi[imin]
|
| 200 |
-
xmed = xi[imin]
|
| 201 |
-
centers[i, 0] = zmed + sz.start
|
| 202 |
-
centers[i, 1] = ymed + sy.start
|
| 203 |
-
centers[i, 2] = xmed + sx.start
|
| 204 |
-
|
| 205 |
-
# get neighbor validator (not all neighbors are in same mask)
|
| 206 |
-
neighbor_masks = masks_padded[tuple(neighbors)]
|
| 207 |
-
isneighbor = neighbor_masks == neighbor_masks[0]
|
| 208 |
-
ext = np.array(
|
| 209 |
-
[[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
|
| 210 |
-
for sz, sy, sx in slices])
|
| 211 |
-
n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter
|
| 212 |
-
|
| 213 |
-
# run diffusion
|
| 214 |
-
shape = masks_padded.shape
|
| 215 |
-
mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter,
|
| 216 |
-
device=device)
|
| 217 |
-
# normalize
|
| 218 |
-
mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
|
| 219 |
-
|
| 220 |
-
# put into original image
|
| 221 |
-
mu0 = np.zeros((3, Lz0, Ly0, Lx0))
|
| 222 |
-
mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
|
| 223 |
-
return mu0
|
| 224 |
-
|
| 225 |
-
def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None,
|
| 226 |
-
return_flows=True):
|
| 227 |
-
"""Converts labels (list of masks or flows) to flows for training model.
|
| 228 |
-
|
| 229 |
-
Args:
|
| 230 |
-
labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx],
|
| 231 |
-
it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D)
|
| 232 |
-
is used to create flows and cell probabilities.
|
| 233 |
-
files (list of str, optional): The files to save the flows to. If provided, flows are saved to
|
| 234 |
-
files to be reused. Defaults to None.
|
| 235 |
-
device (str, optional): The device to use for computation. Defaults to None.
|
| 236 |
-
redo_flows (bool, optional): Whether to recompute the flows. Defaults to False.
|
| 237 |
-
niter (int, optional): The number of iterations for computing flows. Defaults to None.
|
| 238 |
-
|
| 239 |
-
Returns:
|
| 240 |
-
list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k],
|
| 241 |
-
flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow,
|
| 242 |
-
and flows[k][4] is heat distribution.
|
| 243 |
-
"""
|
| 244 |
-
nimg = len(labels)
|
| 245 |
-
if labels[0].ndim < 3:
|
| 246 |
-
labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]
|
| 247 |
-
|
| 248 |
-
flows = []
|
| 249 |
-
# flows need to be recomputed
|
| 250 |
-
if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
|
| 251 |
-
dynamics_logger.info("computing flows for labels")
|
| 252 |
-
|
| 253 |
-
# compute flows; labels are fixed here to be unique, so they need to be passed back
|
| 254 |
-
# make sure labels are unique!
|
| 255 |
-
labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
|
| 256 |
-
iterator = trange if nimg > 1 else range
|
| 257 |
-
for n in iterator(nimg):
|
| 258 |
-
labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
|
| 259 |
-
vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)
|
| 260 |
-
|
| 261 |
-
# concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
|
| 262 |
-
flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
|
| 263 |
-
axis=0).astype(np.float32)
|
| 264 |
-
if files is not None:
|
| 265 |
-
file_name = os.path.splitext(files[n])[0]
|
| 266 |
-
tifffile.imwrite(file_name + "_flows.tif", flow)
|
| 267 |
-
if return_flows:
|
| 268 |
-
flows.append(flow)
|
| 269 |
-
else:
|
| 270 |
-
dynamics_logger.info("flows precomputed")
|
| 271 |
-
if return_flows:
|
| 272 |
-
flows = [labels[n].astype(np.float32) for n in range(nimg)]
|
| 273 |
-
return flows
|
| 274 |
|
| 275 |
|
| 276 |
def flow_error(maski, dP_net, device=None):
|
|
@@ -372,29 +252,6 @@ def steps_interp(dP, inds, niter, device=torch.device("cpu")):
|
|
| 372 |
pt = pt.unsqueeze(0) if pt.ndim==1 else pt
|
| 373 |
return pt.T
|
| 374 |
|
| 375 |
-
def follow_flows(dP, inds, niter=200, device=torch.device("cpu")):
|
| 376 |
-
""" Run dynamics to recover masks in 2D or 3D.
|
| 377 |
-
|
| 378 |
-
Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability
|
| 379 |
-
are used (as defined by inds).
|
| 380 |
-
|
| 381 |
-
Args:
|
| 382 |
-
dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 383 |
-
mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes.
|
| 384 |
-
niter (int, optional): Number of iterations of dynamics to run. Default is 200.
|
| 385 |
-
interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True.
|
| 386 |
-
device (torch.device, optional): Device to use for computation. Default is None.
|
| 387 |
-
|
| 388 |
-
Returns:
|
| 389 |
-
A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
|
| 390 |
-
inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 391 |
-
"""
|
| 392 |
-
shape = np.array(dP.shape[1:]).astype(np.int32)
|
| 393 |
-
ndim = len(inds)
|
| 394 |
-
|
| 395 |
-
p = steps_interp(dP, inds, niter, device=device)
|
| 396 |
-
|
| 397 |
-
return p
|
| 398 |
|
| 399 |
|
| 400 |
def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
|
|
@@ -551,7 +408,6 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
|
|
| 551 |
seed_masks[:,5,5,5] = 1
|
| 552 |
|
| 553 |
for iter in range(5):
|
| 554 |
-
# extend
|
| 555 |
seed_masks = max_pool_nd(seed_masks, kernel_size=3)
|
| 556 |
seed_masks *= h_slc > 2
|
| 557 |
del h_slc
|
|
@@ -580,7 +436,6 @@ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
|
|
| 580 |
fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
|
| 581 |
M0 = M0.reshape(tuple(shape0))
|
| 582 |
|
| 583 |
-
#print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
|
| 584 |
return M0
|
| 585 |
|
| 586 |
|
|
@@ -652,7 +507,7 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
|
|
| 652 |
mask = np.zeros(shape, "uint16")
|
| 653 |
return mask
|
| 654 |
|
| 655 |
-
p_final =
|
| 656 |
inds=inds, niter=niter,
|
| 657 |
device=device)
|
| 658 |
if not torch.is_tensor(p_final):
|
|
|
|
| 151 |
mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
|
| 152 |
return mu0
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
def flow_error(maski, dP_net, device=None):
|
|
|
|
| 252 |
pt = pt.unsqueeze(0) if pt.ndim==1 else pt
|
| 253 |
return pt.T
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
|
| 257 |
def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
|
|
|
|
| 408 |
seed_masks[:,5,5,5] = 1
|
| 409 |
|
| 410 |
for iter in range(5):
|
|
|
|
| 411 |
seed_masks = max_pool_nd(seed_masks, kernel_size=3)
|
| 412 |
seed_masks *= h_slc > 2
|
| 413 |
del h_slc
|
|
|
|
| 436 |
fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
|
| 437 |
M0 = M0.reshape(tuple(shape0))
|
| 438 |
|
|
|
|
| 439 |
return M0
|
| 440 |
|
| 441 |
|
|
|
|
| 507 |
mask = np.zeros(shape, "uint16")
|
| 508 |
return mask
|
| 509 |
|
| 510 |
+
p_final = steps_interp(dP * (cellprob > cellprob_threshold) / 5.,
|
| 511 |
inds=inds, niter=niter,
|
| 512 |
device=device)
|
| 513 |
if not torch.is_tensor(p_final):
|
models/seg_post_model/io.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
import tifffile
|
| 8 |
+
import logging
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import nd2
|
| 14 |
+
ND2 = True
|
| 15 |
+
except:
|
| 16 |
+
ND2 = False
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import nrrd
|
| 20 |
+
NRRD = True
|
| 21 |
+
except:
|
| 22 |
+
NRRD = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
io_logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def load_dax(filename):
|
| 28 |
+
### modified from ZhuangLab github:
|
| 29 |
+
### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
|
| 30 |
+
|
| 31 |
+
inf_filename = os.path.splitext(filename)[0] + ".inf"
|
| 32 |
+
if not os.path.exists(inf_filename):
|
| 33 |
+
io_logger.critical(
|
| 34 |
+
f"ERROR: no inf file found for dax file {filename}, cannot load dax without it"
|
| 35 |
+
)
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
### get metadata
|
| 39 |
+
image_height, image_width = None, None
|
| 40 |
+
# extract the movie information from the associated inf file
|
| 41 |
+
size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)")
|
| 42 |
+
length_re = re.compile(r"number of frames = ([\d]+)")
|
| 43 |
+
endian_re = re.compile(r" (big|little) endian")
|
| 44 |
+
|
| 45 |
+
with open(inf_filename, "r") as inf_file:
|
| 46 |
+
lines = inf_file.read().split("\n")
|
| 47 |
+
for line in lines:
|
| 48 |
+
m = size_re.match(line)
|
| 49 |
+
if m:
|
| 50 |
+
image_height = int(m.group(2))
|
| 51 |
+
image_width = int(m.group(1))
|
| 52 |
+
m = length_re.match(line)
|
| 53 |
+
if m:
|
| 54 |
+
number_frames = int(m.group(1))
|
| 55 |
+
m = endian_re.search(line)
|
| 56 |
+
if m:
|
| 57 |
+
if m.group(1) == "big":
|
| 58 |
+
bigendian = 1
|
| 59 |
+
else:
|
| 60 |
+
bigendian = 0
|
| 61 |
+
# set defaults, warn the user that they couldn"t be determined from the inf file.
|
| 62 |
+
if not image_height:
|
| 63 |
+
io_logger.warning("could not determine dax image size, assuming 256x256")
|
| 64 |
+
image_height = 256
|
| 65 |
+
image_width = 256
|
| 66 |
+
|
| 67 |
+
### load image
|
| 68 |
+
img = np.memmap(filename, dtype="uint16",
|
| 69 |
+
shape=(number_frames, image_height, image_width))
|
| 70 |
+
if bigendian:
|
| 71 |
+
img = img.byteswap()
|
| 72 |
+
img = np.array(img)
|
| 73 |
+
|
| 74 |
+
return img
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def imread(filename):
|
| 78 |
+
"""
|
| 79 |
+
Read in an image file with tif or image file type supported by cv2.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
filename (str): The path to the image file.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
numpy.ndarray: The image data as a NumPy array.
|
| 86 |
+
|
| 87 |
+
Raises:
|
| 88 |
+
None
|
| 89 |
+
|
| 90 |
+
Raises an error if the image file format is not supported.
|
| 91 |
+
|
| 92 |
+
Examples:
|
| 93 |
+
>>> img = imread("image.tif")
|
| 94 |
+
"""
|
| 95 |
+
# ensure that extension check is not case sensitive
|
| 96 |
+
ext = os.path.splitext(filename)[-1].lower()
|
| 97 |
+
if ext == ".tif" or ext == ".tiff" or ext == ".flex":
|
| 98 |
+
with tifffile.TiffFile(filename) as tif:
|
| 99 |
+
ltif = len(tif.pages)
|
| 100 |
+
try:
|
| 101 |
+
full_shape = tif.shaped_metadata[0]["shape"]
|
| 102 |
+
except:
|
| 103 |
+
try:
|
| 104 |
+
page = tif.series[0][0]
|
| 105 |
+
full_shape = tif.series[0].shape
|
| 106 |
+
except:
|
| 107 |
+
ltif = 0
|
| 108 |
+
if ltif < 10:
|
| 109 |
+
img = tif.asarray()
|
| 110 |
+
else:
|
| 111 |
+
page = tif.series[0][0]
|
| 112 |
+
shape, dtype = page.shape, page.dtype
|
| 113 |
+
ltif = int(np.prod(full_shape) / np.prod(shape))
|
| 114 |
+
io_logger.info(f"reading tiff with {ltif} planes")
|
| 115 |
+
img = np.zeros((ltif, *shape), dtype=dtype)
|
| 116 |
+
for i, page in enumerate(tqdm(tif.series[0])):
|
| 117 |
+
img[i] = page.asarray()
|
| 118 |
+
img = img.reshape(full_shape)
|
| 119 |
+
return img
|
| 120 |
+
elif ext == ".dax":
|
| 121 |
+
img = load_dax(filename)
|
| 122 |
+
return img
|
| 123 |
+
elif ext == ".nd2":
|
| 124 |
+
if not ND2:
|
| 125 |
+
io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
|
| 126 |
+
return None
|
| 127 |
+
elif ext == ".nrrd":
|
| 128 |
+
if not NRRD:
|
| 129 |
+
io_logger.critical(
|
| 130 |
+
"ERROR: need to 'pip install pynrrd' to load in .nrrd file")
|
| 131 |
+
return None
|
| 132 |
+
else:
|
| 133 |
+
img, metadata = nrrd.read(filename)
|
| 134 |
+
if img.ndim == 3:
|
| 135 |
+
img = img.transpose(2, 0, 1)
|
| 136 |
+
return img
|
| 137 |
+
elif ext != ".npy":
|
| 138 |
+
try:
|
| 139 |
+
img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH)
|
| 140 |
+
if img.ndim > 2:
|
| 141 |
+
img = img[..., [2, 1, 0]]
|
| 142 |
+
return img
|
| 143 |
+
except Exception as e:
|
| 144 |
+
io_logger.critical("ERROR: could not read file, %s" % e)
|
| 145 |
+
return None
|
| 146 |
+
else:
|
| 147 |
+
try:
|
| 148 |
+
dat = np.load(filename, allow_pickle=True).item()
|
| 149 |
+
masks = dat["masks"]
|
| 150 |
+
return masks
|
| 151 |
+
except Exception as e:
|
| 152 |
+
io_logger.critical("ERROR: could not read masks from file, %s" % e)
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def imsave(filename, arr):
|
| 157 |
+
"""
|
| 158 |
+
Saves an image array to a file.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
filename (str): The name of the file to save the image to.
|
| 162 |
+
arr (numpy.ndarray): The image array to be saved.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
None
|
| 166 |
+
"""
|
| 167 |
+
ext = os.path.splitext(filename)[-1].lower()
|
| 168 |
+
if ext == ".tif" or ext == ".tiff":
|
| 169 |
+
tifffile.imwrite(filename, data=arr, compression="zlib")
|
| 170 |
+
else:
|
| 171 |
+
if len(arr.shape) > 2:
|
| 172 |
+
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
|
| 173 |
+
cv2.imwrite(filename, arr)
|
| 174 |
+
|
models/seg_post_model/{cellpose/metrics.py โ metrics.py}
RENAMED
|
@@ -2,83 +2,10 @@
|
|
| 2 |
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
"""
|
| 4 |
import numpy as np
|
| 5 |
-
from . import utils
|
| 6 |
from scipy.optimize import linear_sum_assignment
|
| 7 |
-
from scipy.ndimage import convolve
|
| 8 |
from scipy.sparse import csr_matrix
|
| 9 |
|
| 10 |
|
| 11 |
-
def mask_ious(masks_true, masks_pred):
|
| 12 |
-
"""Return best-matched masks."""
|
| 13 |
-
iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
|
| 14 |
-
n_min = min(iou.shape[0], iou.shape[1])
|
| 15 |
-
costs = -(iou >= 0.5).astype(float) - iou / (2 * n_min)
|
| 16 |
-
true_ind, pred_ind = linear_sum_assignment(costs)
|
| 17 |
-
iout = np.zeros(masks_true.max())
|
| 18 |
-
iout[true_ind] = iou[true_ind, pred_ind]
|
| 19 |
-
preds = np.zeros(masks_true.max(), "int")
|
| 20 |
-
preds[true_ind] = pred_ind + 1
|
| 21 |
-
return iout, preds
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def boundary_scores(masks_true, masks_pred, scales):
|
| 25 |
-
"""
|
| 26 |
-
Calculate boundary precision, recall, and F-score.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
masks_true (list): List of true masks.
|
| 30 |
-
masks_pred (list): List of predicted masks.
|
| 31 |
-
scales (list): List of scales.
|
| 32 |
-
|
| 33 |
-
Returns:
|
| 34 |
-
tuple: A tuple containing precision, recall, and F-score arrays.
|
| 35 |
-
"""
|
| 36 |
-
diams = [utils.diameters(lbl)[0] for lbl in masks_true]
|
| 37 |
-
precision = np.zeros((len(scales), len(masks_true)))
|
| 38 |
-
recall = np.zeros((len(scales), len(masks_true)))
|
| 39 |
-
fscore = np.zeros((len(scales), len(masks_true)))
|
| 40 |
-
for j, scale in enumerate(scales):
|
| 41 |
-
for n in range(len(masks_true)):
|
| 42 |
-
diam = max(1, scale * diams[n])
|
| 43 |
-
rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))])
|
| 44 |
-
filt = (rs <= diam).astype(np.float32)
|
| 45 |
-
otrue = utils.masks_to_outlines(masks_true[n])
|
| 46 |
-
otrue = convolve(otrue, filt)
|
| 47 |
-
opred = utils.masks_to_outlines(masks_pred[n])
|
| 48 |
-
opred = convolve(opred, filt)
|
| 49 |
-
tp = np.logical_and(otrue == 1, opred == 1).sum()
|
| 50 |
-
fp = np.logical_and(otrue == 0, opred == 1).sum()
|
| 51 |
-
fn = np.logical_and(otrue == 1, opred == 0).sum()
|
| 52 |
-
precision[j, n] = tp / (tp + fp)
|
| 53 |
-
recall[j, n] = tp / (tp + fn)
|
| 54 |
-
fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
|
| 55 |
-
return precision, recall, fscore
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def aggregated_jaccard_index(masks_true, masks_pred):
|
| 59 |
-
"""
|
| 60 |
-
AJI = intersection of all matched masks / union of all masks
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
|
| 64 |
-
where 0=NO masks; 1,2... are mask labels
|
| 65 |
-
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
|
| 66 |
-
np.ndarray (int) where 0=NO masks; 1,2... are mask labels
|
| 67 |
-
|
| 68 |
-
Returns:
|
| 69 |
-
aji (float): aggregated jaccard index for each set of masks
|
| 70 |
-
"""
|
| 71 |
-
aji = np.zeros(len(masks_true))
|
| 72 |
-
for n in range(len(masks_true)):
|
| 73 |
-
iout, preds = mask_ious(masks_true[n], masks_pred[n])
|
| 74 |
-
inds = np.arange(0, masks_true[n].max(), 1, int)
|
| 75 |
-
overlap = _label_overlap(masks_true[n], masks_pred[n])
|
| 76 |
-
union = np.logical_or(masks_true[n] > 0, masks_pred[n] > 0).sum()
|
| 77 |
-
overlap = overlap[inds[preds > 0] + 1, preds[preds > 0].astype(int)]
|
| 78 |
-
aji[n] = overlap.sum() / union
|
| 79 |
-
return aji
|
| 80 |
-
|
| 81 |
-
|
| 82 |
def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
|
| 83 |
"""
|
| 84 |
Average precision estimation: AP = TP / (TP + FP + FN)
|
|
|
|
| 2 |
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
"""
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
from scipy.optimize import linear_sum_assignment
|
|
|
|
| 6 |
from scipy.sparse import csr_matrix
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
|
| 10 |
"""
|
| 11 |
Average precision estimation: AP = TP / (TP + FP + FN)
|
models/seg_post_model/{cellpose/models.py โ models.py}
RENAMED
|
@@ -3,31 +3,29 @@ Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer,
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os, time
|
| 6 |
-
from pathlib import Path
|
| 7 |
import numpy as np
|
| 8 |
from tqdm import trange
|
| 9 |
import torch
|
| 10 |
from scipy.ndimage import gaussian_filter
|
| 11 |
-
import gc
|
| 12 |
import cv2
|
| 13 |
|
| 14 |
import logging
|
| 15 |
|
| 16 |
models_logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
-
from . import transforms, dynamics, utils
|
| 19 |
from .vit_sam import Transformer
|
| 20 |
-
from .core import assign_device, run_net
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
#
|
| 25 |
-
_MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models")
|
| 26 |
-
MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
|
| 27 |
|
| 28 |
-
MODEL_NAMES = ["cpsam"]
|
| 29 |
|
| 30 |
-
MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
|
| 31 |
|
| 32 |
normalize_default = {
|
| 33 |
"lowhigh": None,
|
|
@@ -42,30 +40,17 @@ normalize_default = {
|
|
| 42 |
}
|
| 43 |
|
| 44 |
|
| 45 |
-
# def
|
| 46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
-
|
| 50 |
-
# MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 51 |
-
# cached_file = os.fspath(MODEL_DIR.joinpath('cpsam'))
|
| 52 |
-
# if not os.path.exists(cached_file):
|
| 53 |
-
# models_logger.info('Downloading: "{}" to {}\n'.format(_CPSAM_MODEL_URL, cached_file))
|
| 54 |
-
# utils.download_url_to_file(_CPSAM_MODEL_URL, cached_file, progress=True)
|
| 55 |
-
# return cached_file
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def get_user_models():
|
| 59 |
-
model_strings = []
|
| 60 |
-
if os.path.exists(MODEL_LIST_PATH):
|
| 61 |
-
with open(MODEL_LIST_PATH, "r") as textfile:
|
| 62 |
-
lines = [line.rstrip() for line in textfile]
|
| 63 |
-
if len(lines) > 0:
|
| 64 |
-
model_strings.extend(lines)
|
| 65 |
-
return model_strings
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
class CellposeModel():
|
| 69 |
"""
|
| 70 |
Class representing a Cellpose model.
|
| 71 |
|
|
@@ -102,17 +87,6 @@ class CellposeModel():
|
|
| 102 |
device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
|
| 103 |
use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
|
| 104 |
"""
|
| 105 |
-
# if diam_mean is not None:
|
| 106 |
-
# models_logger.warning(
|
| 107 |
-
# "diam_mean argument are not used in v4.0.1+. Ignoring this argument..."
|
| 108 |
-
# )
|
| 109 |
-
# if model_type is not None:
|
| 110 |
-
# models_logger.warning(
|
| 111 |
-
# "model_type argument is not used in v4.0.1+. Ignoring this argument..."
|
| 112 |
-
# )
|
| 113 |
-
# if nchan is not None:
|
| 114 |
-
# models_logger.warning("nchan argument is deprecated in v4.0.1+. Ignoring this argument")
|
| 115 |
-
|
| 116 |
### assign model device
|
| 117 |
self.device = assign_device(gpu=gpu)[0] if device is None else device
|
| 118 |
if torch.cuda.is_available():
|
|
@@ -127,36 +101,10 @@ class CellposeModel():
|
|
| 127 |
# raise ValueError("Must specify a pretrained model, training from scratch is not implemented")
|
| 128 |
pretrained_model = ""
|
| 129 |
|
| 130 |
-
### create neural network
|
| 131 |
-
if pretrained_model and not os.path.exists(pretrained_model):
|
| 132 |
-
# check if pretrained model is in the models directory
|
| 133 |
-
model_strings = get_user_models()
|
| 134 |
-
all_models = MODEL_NAMES.copy()
|
| 135 |
-
all_models.extend(model_strings)
|
| 136 |
-
if pretrained_model in all_models:
|
| 137 |
-
pretrained_model = os.path.join(MODEL_DIR, pretrained_model)
|
| 138 |
-
else:
|
| 139 |
-
pretrained_model = os.path.join(MODEL_DIR, "cpsam")
|
| 140 |
-
models_logger.warning(
|
| 141 |
-
f"pretrained model {pretrained_model} not found, using default model"
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
self.pretrained_model = pretrained_model
|
| 145 |
dtype = torch.bfloat16 if use_bfloat16 else torch.float32
|
| 146 |
self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device)
|
| 147 |
|
| 148 |
-
if os.path.exists(self.pretrained_model):
|
| 149 |
-
models_logger.info(f">>>> loading model {self.pretrained_model}")
|
| 150 |
-
self.net.load_model(self.pretrained_model, device=self.device)
|
| 151 |
-
# else:
|
| 152 |
-
# try:
|
| 153 |
-
# if os.path.split(self.pretrained_model)[-1] != 'cpsam':
|
| 154 |
-
# raise FileNotFoundError('model file not recognized')
|
| 155 |
-
# cache_CPSAM_model_path()
|
| 156 |
-
# self.net.load_model(self.pretrained_model, device=self.device)
|
| 157 |
-
# except:
|
| 158 |
-
# print("ViT not initialized")
|
| 159 |
-
|
| 160 |
|
| 161 |
def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None,
|
| 162 |
z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
|
|
@@ -166,12 +114,6 @@ class CellposeModel():
|
|
| 166 |
augment=False, tile_overlap=0.1, bsize=256,
|
| 167 |
compute_masks=True, progress=None):
|
| 168 |
|
| 169 |
-
|
| 170 |
-
# if rescale is not None:
|
| 171 |
-
# models_logger.warning("rescaling deprecated in v4.0.1+")
|
| 172 |
-
# if channels is not None:
|
| 173 |
-
# models_logger.warning("channels deprecated in v4.0.1+. If data contain more than 3 channels, only the first 3 channels will be used")
|
| 174 |
-
|
| 175 |
if isinstance(x, list) or x.squeeze().ndim == 5:
|
| 176 |
self.timing = []
|
| 177 |
masks, styles, flows = [], [], []
|
|
@@ -230,7 +172,7 @@ class CellposeModel():
|
|
| 230 |
Ly_0 = x.shape[1]
|
| 231 |
Lx_0 = x.shape[2]
|
| 232 |
Lz_0 = None
|
| 233 |
-
if
|
| 234 |
Lz_0 = x.shape[0]
|
| 235 |
if diameter is not None:
|
| 236 |
image_scaling = 30. / diameter
|
|
@@ -273,7 +215,7 @@ class CellposeModel():
|
|
| 273 |
# transpose feat to have channels last
|
| 274 |
feat = np.moveaxis(feat, 1, -1)
|
| 275 |
|
| 276 |
-
#
|
| 277 |
if isinstance(anisotropy, (float, int)) and image_scaling:
|
| 278 |
anisotropy = image_scaling * anisotropy
|
| 279 |
|
|
@@ -287,12 +229,6 @@ class CellposeModel():
|
|
| 287 |
do_3D=do_3D,
|
| 288 |
anisotropy=anisotropy)
|
| 289 |
|
| 290 |
-
if do_3D:
|
| 291 |
-
if flow3D_smooth > 0:
|
| 292 |
-
models_logger.info(f"smoothing flows with sigma={flow3D_smooth}")
|
| 293 |
-
dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth))
|
| 294 |
-
torch.cuda.empty_cache()
|
| 295 |
-
gc.collect()
|
| 296 |
|
| 297 |
if resample:
|
| 298 |
# upsample flows before computing them:
|
|
@@ -310,28 +246,16 @@ class CellposeModel():
|
|
| 310 |
else:
|
| 311 |
masks = np.zeros(0) #pass back zeros if not compute_masks
|
| 312 |
|
| 313 |
-
masks
|
| 314 |
|
| 315 |
# undo resizing:
|
| 316 |
if image_scaling is not None or anisotropy is not None:
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
if do_3D:
|
| 322 |
-
if compute_masks:
|
| 323 |
-
# Rescale xy then xz:
|
| 324 |
-
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
|
| 325 |
-
masks = masks.transpose(1, 0, 2)
|
| 326 |
-
masks = transforms.resize_image(masks, Ly=Lz_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
|
| 327 |
-
masks = masks.transpose(1, 0, 2)
|
| 328 |
|
| 329 |
-
|
| 330 |
-
# 2D or 3D stitching case:
|
| 331 |
-
if compute_masks:
|
| 332 |
-
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
|
| 333 |
|
| 334 |
-
return masks, [plot.dx_to_circ(dP), dP, cellprob], styles
|
| 335 |
|
| 336 |
|
| 337 |
def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
|
|
@@ -428,29 +352,15 @@ class CellposeModel():
|
|
| 428 |
nimg = shape[0]
|
| 429 |
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
tile_overlap=tile_overlap,
|
| 441 |
-
bsize=bsize
|
| 442 |
-
)
|
| 443 |
-
cellprob = yf[..., -1]
|
| 444 |
-
dP = yf[..., :-1].transpose((3, 0, 1, 2))
|
| 445 |
-
else:
|
| 446 |
-
yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment,
|
| 447 |
-
batch_size=batch_size,
|
| 448 |
-
tile_overlap=tile_overlap,
|
| 449 |
-
)
|
| 450 |
-
cellprob = yf[..., -1]
|
| 451 |
-
dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
|
| 452 |
-
if yf.shape[-1] > 3:
|
| 453 |
-
styles = yf[..., :-3]
|
| 454 |
|
| 455 |
styles = styles.squeeze()
|
| 456 |
|
|
@@ -471,48 +381,48 @@ class CellposeModel():
|
|
| 471 |
changed_device_from = "mps"
|
| 472 |
Lz, Ly, Lx = shape[:3]
|
| 473 |
tic = time.time()
|
| 474 |
-
if do_3D:
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
device=self.device)
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
if nimg > 1:
|
| 501 |
-
masks[i] = outputs
|
| 502 |
-
else:
|
| 503 |
-
masks = outputs
|
| 504 |
-
|
| 505 |
-
if stitch_threshold > 0 and nimg > 1:
|
| 506 |
-
models_logger.info(
|
| 507 |
-
f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
|
| 508 |
-
)
|
| 509 |
-
masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
|
| 510 |
-
masks = utils.fill_holes_and_remove_small_masks(
|
| 511 |
-
masks, min_size=min_size)
|
| 512 |
-
elif nimg > 1:
|
| 513 |
-
models_logger.warning(
|
| 514 |
-
"3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
|
| 515 |
-
)
|
| 516 |
|
| 517 |
flow_time = time.time() - tic
|
| 518 |
if shape[0] > 1:
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os, time
|
| 6 |
+
# from pathlib import Path
|
| 7 |
import numpy as np
|
| 8 |
from tqdm import trange
|
| 9 |
import torch
|
| 10 |
from scipy.ndimage import gaussian_filter
|
| 11 |
+
# import gc
|
| 12 |
import cv2
|
| 13 |
|
| 14 |
import logging
|
| 15 |
|
| 16 |
models_logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
from . import transforms, dynamics, utils
|
| 19 |
from .vit_sam import Transformer
|
| 20 |
+
from .core import assign_device, run_net
|
| 21 |
|
| 22 |
+
# _MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
|
| 23 |
+
# _MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models")
|
| 24 |
+
# MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# MODEL_NAMES = ["cpsam"]
|
| 27 |
|
| 28 |
+
# MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
|
| 29 |
|
| 30 |
normalize_default = {
|
| 31 |
"lowhigh": None,
|
|
|
|
| 40 |
}
|
| 41 |
|
| 42 |
|
| 43 |
+
# def get_user_models():
|
| 44 |
+
# model_strings = []
|
| 45 |
+
# if os.path.exists(MODEL_LIST_PATH):
|
| 46 |
+
# with open(MODEL_LIST_PATH, "r") as textfile:
|
| 47 |
+
# lines = [line.rstrip() for line in textfile]
|
| 48 |
+
# if len(lines) > 0:
|
| 49 |
+
# model_strings.extend(lines)
|
| 50 |
+
# return model_strings
|
| 51 |
|
| 52 |
|
| 53 |
+
class SegModel():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
Class representing a Cellpose model.
|
| 56 |
|
|
|
|
| 87 |
device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
|
| 88 |
use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
|
| 89 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
### assign model device
|
| 91 |
self.device = assign_device(gpu=gpu)[0] if device is None else device
|
| 92 |
if torch.cuda.is_available():
|
|
|
|
| 101 |
# raise ValueError("Must specify a pretrained model, training from scratch is not implemented")
|
| 102 |
pretrained_model = ""
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
self.pretrained_model = pretrained_model
|
| 105 |
dtype = torch.bfloat16 if use_bfloat16 else torch.float32
|
| 106 |
self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None,
|
| 110 |
z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
|
|
|
|
| 114 |
augment=False, tile_overlap=0.1, bsize=256,
|
| 115 |
compute_masks=True, progress=None):
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
if isinstance(x, list) or x.squeeze().ndim == 5:
|
| 118 |
self.timing = []
|
| 119 |
masks, styles, flows = [], [], []
|
|
|
|
| 172 |
Ly_0 = x.shape[1]
|
| 173 |
Lx_0 = x.shape[2]
|
| 174 |
Lz_0 = None
|
| 175 |
+
if stitch_threshold > 0:
|
| 176 |
Lz_0 = x.shape[0]
|
| 177 |
if diameter is not None:
|
| 178 |
image_scaling = 30. / diameter
|
|
|
|
| 215 |
# transpose feat to have channels last
|
| 216 |
feat = np.moveaxis(feat, 1, -1)
|
| 217 |
|
| 218 |
+
# adjust the anisotropy when diameter is specified and images are resized:
|
| 219 |
if isinstance(anisotropy, (float, int)) and image_scaling:
|
| 220 |
anisotropy = image_scaling * anisotropy
|
| 221 |
|
|
|
|
| 229 |
do_3D=do_3D,
|
| 230 |
anisotropy=anisotropy)
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
if resample:
|
| 234 |
# upsample flows before computing them:
|
|
|
|
| 246 |
else:
|
| 247 |
masks = np.zeros(0) #pass back zeros if not compute_masks
|
| 248 |
|
| 249 |
+
masks = masks.squeeze()
|
| 250 |
|
| 251 |
# undo resizing:
|
| 252 |
if image_scaling is not None or anisotropy is not None:
|
| 253 |
|
| 254 |
+
if compute_masks:
|
| 255 |
+
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
return masks
|
|
|
|
|
|
|
|
|
|
| 258 |
|
|
|
|
| 259 |
|
| 260 |
|
| 261 |
def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
|
|
|
|
| 352 |
nimg = shape[0]
|
| 353 |
|
| 354 |
|
| 355 |
+
|
| 356 |
+
yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment,
|
| 357 |
+
batch_size=batch_size,
|
| 358 |
+
tile_overlap=tile_overlap,
|
| 359 |
+
)
|
| 360 |
+
cellprob = yf[..., -1]
|
| 361 |
+
dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
|
| 362 |
+
if yf.shape[-1] > 3:
|
| 363 |
+
styles = yf[..., :-3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
styles = styles.squeeze()
|
| 366 |
|
|
|
|
| 381 |
changed_device_from = "mps"
|
| 382 |
Lz, Ly, Lx = shape[:3]
|
| 383 |
tic = time.time()
|
| 384 |
+
# if do_3D:
|
| 385 |
+
# masks = dynamics.resize_and_compute_masks(
|
| 386 |
+
# dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
|
| 387 |
+
# flow_threshold=flow_threshold, do_3D=do_3D,
|
| 388 |
+
# min_size=min_size, max_size_fraction=max_size_fraction,
|
| 389 |
+
# resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum()
|
| 390 |
+
# else None,
|
| 391 |
+
# device=self.device)
|
| 392 |
+
# else:
|
| 393 |
+
nimg = shape[0]
|
| 394 |
+
Ly0, Lx0 = cellprob[0].shape
|
| 395 |
+
resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx]
|
| 396 |
+
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
|
| 397 |
+
iterator = trange(nimg, file=tqdm_out,
|
| 398 |
+
mininterval=30) if nimg > 1 else range(nimg)
|
| 399 |
+
for i in iterator:
|
| 400 |
+
# turn off min_size for 3D stitching
|
| 401 |
+
min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1
|
| 402 |
+
outputs = dynamics.resize_and_compute_masks(
|
| 403 |
+
dP[:, i], cellprob[i],
|
| 404 |
+
niter=niter, cellprob_threshold=cellprob_threshold,
|
| 405 |
+
flow_threshold=flow_threshold, resize=resize,
|
| 406 |
+
min_size=min_size0, max_size_fraction=max_size_fraction,
|
| 407 |
device=self.device)
|
| 408 |
+
if i==0 and nimg > 1:
|
| 409 |
+
masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype)
|
| 410 |
+
if nimg > 1:
|
| 411 |
+
masks[i] = outputs
|
| 412 |
+
else:
|
| 413 |
+
masks = outputs
|
| 414 |
+
|
| 415 |
+
if stitch_threshold > 0 and nimg > 1:
|
| 416 |
+
models_logger.info(
|
| 417 |
+
f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
|
| 418 |
+
)
|
| 419 |
+
masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
|
| 420 |
+
masks = utils.fill_holes_and_remove_small_masks(
|
| 421 |
+
masks, min_size=min_size)
|
| 422 |
+
elif nimg > 1:
|
| 423 |
+
models_logger.warning(
|
| 424 |
+
"3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
|
| 425 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
flow_time = time.time() - tic
|
| 428 |
if shape[0] > 1:
|
models/seg_post_model/{cellpose/transforms.py โ transforms.py}
RENAMED
|
@@ -398,145 +398,6 @@ def move_axis(img, m_axis=-1, first=True):
|
|
| 398 |
return img
|
| 399 |
|
| 400 |
|
| 401 |
-
def move_min_dim(img, force=False):
|
| 402 |
-
"""Move the minimum dimension last as channels if it is less than 10 or force is True.
|
| 403 |
-
|
| 404 |
-
Args:
|
| 405 |
-
img (ndarray): The input image.
|
| 406 |
-
force (bool, optional): If True, the minimum dimension will always be moved.
|
| 407 |
-
Defaults to False.
|
| 408 |
-
|
| 409 |
-
Returns:
|
| 410 |
-
ndarray: The image with the minimum dimension moved to the last axis as channels.
|
| 411 |
-
"""
|
| 412 |
-
if len(img.shape) > 2:
|
| 413 |
-
min_dim = min(img.shape)
|
| 414 |
-
if min_dim < 10 or force:
|
| 415 |
-
if img.shape[-1] == min_dim:
|
| 416 |
-
channel_axis = -1
|
| 417 |
-
else:
|
| 418 |
-
channel_axis = (img.shape).index(min_dim)
|
| 419 |
-
img = move_axis(img, m_axis=channel_axis, first=False)
|
| 420 |
-
return img
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
def update_axis(m_axis, to_squeeze, ndim):
|
| 424 |
-
"""
|
| 425 |
-
Squeeze the axis value based on the given parameters.
|
| 426 |
-
|
| 427 |
-
Args:
|
| 428 |
-
m_axis (int): The current axis value.
|
| 429 |
-
to_squeeze (numpy.ndarray): An array of indices to squeeze.
|
| 430 |
-
ndim (int): The number of dimensions.
|
| 431 |
-
|
| 432 |
-
Returns:
|
| 433 |
-
m_axis (int or None): The updated axis value.
|
| 434 |
-
"""
|
| 435 |
-
if m_axis == -1:
|
| 436 |
-
m_axis = ndim - 1
|
| 437 |
-
if (to_squeeze == m_axis).sum() == 1:
|
| 438 |
-
m_axis = None
|
| 439 |
-
else:
|
| 440 |
-
inds = np.ones(ndim, bool)
|
| 441 |
-
inds[to_squeeze] = False
|
| 442 |
-
m_axis = np.nonzero(np.arange(0, ndim)[inds] == m_axis)[0]
|
| 443 |
-
if len(m_axis) > 0:
|
| 444 |
-
m_axis = m_axis[0]
|
| 445 |
-
else:
|
| 446 |
-
m_axis = None
|
| 447 |
-
return m_axis
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
def _convert_image_3d(x, channel_axis=None, z_axis=None):
|
| 451 |
-
"""
|
| 452 |
-
Convert a 3D or 4D image array to have dimensions ordered as (Z, X, Y, C).
|
| 453 |
-
|
| 454 |
-
Arrays of ndim=3 are assumed to be grayscale and must be specified with z_axis.
|
| 455 |
-
Arrays of ndim=4 must have both `channel_axis` and `z_axis` specified.
|
| 456 |
-
|
| 457 |
-
Args:
|
| 458 |
-
x (numpy.ndarray): Input image array. Must be either 3D (assumed to be grayscale 3D) or 4D.
|
| 459 |
-
channel_axis (int): The axis index corresponding to the channel dimension in the input array. \
|
| 460 |
-
Must be specified for 4D images.
|
| 461 |
-
z_axis (int): The axis index corresponding to the depth (Z) dimension in the input array. \
|
| 462 |
-
Must be specified for both 3D and 4D images.
|
| 463 |
-
|
| 464 |
-
Returns:
|
| 465 |
-
numpy.ndarray: A 4D image array with dimensions ordered as (Z, X, Y, C), where C is the channel
|
| 466 |
-
dimension. If the input has fewer than 3 channels, the output will be padded with zeros to \
|
| 467 |
-
have 3 channels. If the input has more than 3 channels, only the first 3 channels will be retained.
|
| 468 |
-
|
| 469 |
-
Raises:
|
| 470 |
-
ValueError: If `z_axis` is not specified for 3D images. If either `channel_axis` or `z_axis` \
|
| 471 |
-
is not specified for 4D images. If the input image does not have 3 or 4 dimensions.
|
| 472 |
-
|
| 473 |
-
Notes:
|
| 474 |
-
- For 3D images (ndim=3), the function assumes the input is grayscale and adds a singleton channel dimension.
|
| 475 |
-
- The function reorders the dimensions of the input array to ensure the output has the desired (Z, X, Y, C) order.
|
| 476 |
-
- If the number of channels is not equal to 3, the function either truncates or pads the \
|
| 477 |
-
channels to ensure the output has exactly 3 channels.
|
| 478 |
-
"""
|
| 479 |
-
|
| 480 |
-
if x.ndim < 3:
|
| 481 |
-
raise ValueError(f"Input image must have at least 3 dimensions, input shape: {x.shape}, ndim={x.ndim}")
|
| 482 |
-
|
| 483 |
-
if z_axis is not None and z_axis < 0:
|
| 484 |
-
z_axis += x.ndim
|
| 485 |
-
|
| 486 |
-
# if image is ndim==3, assume it is greyscale 3D and use provided z_axis
|
| 487 |
-
if x.ndim == 3 and z_axis is not None:
|
| 488 |
-
# add in channel axis
|
| 489 |
-
x = x[..., np.newaxis]
|
| 490 |
-
channel_axis = 3
|
| 491 |
-
elif x.ndim == 3 and z_axis is None:
|
| 492 |
-
raise ValueError("z_axis must be specified when segmenting 3D images of ndim=3")
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
if channel_axis is None or z_axis is None:
|
| 496 |
-
raise ValueError("For 4D images, both `channel_axis` and `z_axis` must be explicitly specified. Please provide values for both parameters.")
|
| 497 |
-
if channel_axis is not None and channel_axis < 0:
|
| 498 |
-
channel_axis += x.ndim
|
| 499 |
-
if channel_axis is None or channel_axis >= x.ndim:
|
| 500 |
-
raise IndexError(f"channel_axis {channel_axis} is out of bounds for input array with {x.ndim} dimensions")
|
| 501 |
-
assert x.ndim == 4, f"input image must have ndim == 4, ndim={x.ndim}"
|
| 502 |
-
|
| 503 |
-
x_dim_shapes = list(x.shape)
|
| 504 |
-
num_z_layers = x_dim_shapes[z_axis]
|
| 505 |
-
num_channels = x_dim_shapes[channel_axis]
|
| 506 |
-
x_xy_axes = [i for i in range(x.ndim)]
|
| 507 |
-
|
| 508 |
-
# need to remove the z and channels from the shapes:
|
| 509 |
-
# delete the one with the bigger index first
|
| 510 |
-
if z_axis > channel_axis:
|
| 511 |
-
del x_dim_shapes[z_axis]
|
| 512 |
-
del x_dim_shapes[channel_axis]
|
| 513 |
-
|
| 514 |
-
del x_xy_axes[z_axis]
|
| 515 |
-
del x_xy_axes[channel_axis]
|
| 516 |
-
|
| 517 |
-
else:
|
| 518 |
-
del x_dim_shapes[channel_axis]
|
| 519 |
-
del x_dim_shapes[z_axis]
|
| 520 |
-
|
| 521 |
-
del x_xy_axes[channel_axis]
|
| 522 |
-
del x_xy_axes[z_axis]
|
| 523 |
-
|
| 524 |
-
x = x.transpose((z_axis, x_xy_axes[0], x_xy_axes[1], channel_axis))
|
| 525 |
-
|
| 526 |
-
# Handle cases with not 3 channels:
|
| 527 |
-
if num_channels != 3:
|
| 528 |
-
x_chans_to_copy = min(3, num_channels)
|
| 529 |
-
|
| 530 |
-
if num_channels > 3:
|
| 531 |
-
transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels")
|
| 532 |
-
x = x[..., :x_chans_to_copy]
|
| 533 |
-
else:
|
| 534 |
-
# less than 3 channels: pad up to
|
| 535 |
-
pad_width = [(0, 0), (0, 0), (0, 0), (0, 3 - x_chans_to_copy)]
|
| 536 |
-
x = np.pad(x, pad_width, mode='constant', constant_values=0)
|
| 537 |
-
|
| 538 |
-
return x
|
| 539 |
-
|
| 540 |
|
| 541 |
def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
|
| 542 |
"""Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
|
|
@@ -571,13 +432,9 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
|
|
| 571 |
if z_axis is not None and not do_3D:
|
| 572 |
raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
|
| 573 |
|
| 574 |
-
# make sure that channel_axis and z_axis are specified if 3D
|
| 575 |
if ndim == 4 and not do_3D:
|
| 576 |
raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
|
| 577 |
|
| 578 |
-
# make sure that channel_axis and z_axis are specified if 3D
|
| 579 |
-
if do_3D:
|
| 580 |
-
return _convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis)
|
| 581 |
|
| 582 |
######################## 2D reshaping ########################
|
| 583 |
# if user specifies channel axis, return early
|
|
@@ -956,10 +813,10 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal
|
|
| 956 |
nchan = X[0].shape[0]
|
| 957 |
else:
|
| 958 |
nchan = 1
|
| 959 |
-
if do_3D and X[0].ndim > 3:
|
| 960 |
-
|
| 961 |
-
else:
|
| 962 |
-
|
| 963 |
imgi = np.zeros((nimg, nchan, *shape), "float32")
|
| 964 |
|
| 965 |
lbl = []
|
|
@@ -1008,15 +865,6 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal
|
|
| 1008 |
if labels.ndim < 3:
|
| 1009 |
labels = labels[np.newaxis, :, :]
|
| 1010 |
|
| 1011 |
-
if do_3D:
|
| 1012 |
-
Lz = X[n].shape[-3]
|
| 1013 |
-
flip_z = np.random.rand() > .5
|
| 1014 |
-
lz = int(np.round(zcrop / scale[n]))
|
| 1015 |
-
iz = np.random.randint(0, Lz - lz)
|
| 1016 |
-
img = img[:,iz:iz + lz,:,:]
|
| 1017 |
-
if Y is not None:
|
| 1018 |
-
labels = labels[:,iz:iz + lz,:,:]
|
| 1019 |
-
|
| 1020 |
if do_flip:
|
| 1021 |
if flip:
|
| 1022 |
img = img[..., ::-1]
|
|
@@ -1024,47 +872,15 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal
|
|
| 1024 |
labels = labels[..., ::-1]
|
| 1025 |
if nt > 1 and not unet:
|
| 1026 |
labels[-1] = -labels[-1]
|
| 1027 |
-
if do_3D and flip_z:
|
| 1028 |
-
img = img[:, ::-1]
|
| 1029 |
-
if Y is not None:
|
| 1030 |
-
labels = labels[:,::-1]
|
| 1031 |
-
if nt > 1 and not unet:
|
| 1032 |
-
labels[-3] = -labels[-3]
|
| 1033 |
|
| 1034 |
for k in range(nchan):
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
for z in range(lz):
|
| 1038 |
-
I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
|
| 1039 |
-
flags=cv2.INTER_LINEAR)
|
| 1040 |
-
img0[z] = I
|
| 1041 |
-
if scale[n] != 1.0:
|
| 1042 |
-
for y in range(imgi.shape[-2]):
|
| 1043 |
-
imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
|
| 1044 |
-
interpolation=cv2.INTER_LINEAR)
|
| 1045 |
-
else:
|
| 1046 |
-
imgi[n, k] = img0
|
| 1047 |
-
else:
|
| 1048 |
-
I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 1049 |
-
imgi[n, k] = I
|
| 1050 |
|
| 1051 |
if Y is not None:
|
| 1052 |
for k in range(nt):
|
| 1053 |
flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
|
| 1054 |
-
|
| 1055 |
-
lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1056 |
-
for z in range(lz):
|
| 1057 |
-
I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
|
| 1058 |
-
flags=flag)
|
| 1059 |
-
lbl0[z] = I
|
| 1060 |
-
if scale[n] != 1.0:
|
| 1061 |
-
for y in range(lbl.shape[-2]):
|
| 1062 |
-
lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
|
| 1063 |
-
interpolation=flag)
|
| 1064 |
-
else:
|
| 1065 |
-
lbl[n, k] = lbl0
|
| 1066 |
-
else:
|
| 1067 |
-
lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
|
| 1068 |
|
| 1069 |
if nt > 1 and not unet:
|
| 1070 |
v1 = lbl[n, -1].copy()
|
|
@@ -1106,10 +922,7 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
|
|
| 1106 |
nchan = X[0].shape[0]
|
| 1107 |
else:
|
| 1108 |
nchan = 1
|
| 1109 |
-
|
| 1110 |
-
shape = (zcrop, xy[0], xy[1])
|
| 1111 |
-
else:
|
| 1112 |
-
shape = (xy[0], xy[1])
|
| 1113 |
imgi = np.zeros((nimg, nchan, *shape), "float32")
|
| 1114 |
|
| 1115 |
lbl = []
|
|
@@ -1169,17 +982,6 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
|
|
| 1169 |
if feats.ndim < 3:
|
| 1170 |
feats = feats[np.newaxis, :, :]
|
| 1171 |
|
| 1172 |
-
if do_3D:
|
| 1173 |
-
Lz = X[n].shape[-3]
|
| 1174 |
-
flip_z = np.random.rand() > .5
|
| 1175 |
-
lz = int(np.round(zcrop / scale[n]))
|
| 1176 |
-
iz = np.random.randint(0, Lz - lz)
|
| 1177 |
-
img = img[:,iz:iz + lz,:,:]
|
| 1178 |
-
if Y is not None:
|
| 1179 |
-
labels = labels[:,iz:iz + lz,:,:]
|
| 1180 |
-
if feat is not None:
|
| 1181 |
-
feats = feats[:,iz:iz + lz,:,:]
|
| 1182 |
-
|
| 1183 |
if do_flip:
|
| 1184 |
if flip:
|
| 1185 |
img = img[..., ::-1]
|
|
@@ -1189,49 +991,16 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
|
|
| 1189 |
labels[-1] = -labels[-1]
|
| 1190 |
if feat is not None:
|
| 1191 |
feats = feats[..., ::-1]
|
| 1192 |
-
|
| 1193 |
-
img = img[:, ::-1]
|
| 1194 |
-
if Y is not None:
|
| 1195 |
-
labels = labels[:,::-1]
|
| 1196 |
-
if nt > 1 and not unet:
|
| 1197 |
-
labels[-3] = -labels[-3]
|
| 1198 |
-
if feat is not None:
|
| 1199 |
-
feats = feats[:, ::-1]
|
| 1200 |
|
| 1201 |
for k in range(nchan):
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
-
for z in range(lz):
|
| 1205 |
-
I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
|
| 1206 |
-
flags=cv2.INTER_LINEAR)
|
| 1207 |
-
img0[z] = I
|
| 1208 |
-
if scale[n] != 1.0:
|
| 1209 |
-
for y in range(imgi.shape[-2]):
|
| 1210 |
-
imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
|
| 1211 |
-
interpolation=cv2.INTER_LINEAR)
|
| 1212 |
-
else:
|
| 1213 |
-
imgi[n, k] = img0
|
| 1214 |
-
else:
|
| 1215 |
-
I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 1216 |
-
imgi[n, k] = I
|
| 1217 |
|
| 1218 |
if Y is not None:
|
| 1219 |
for k in range(nt):
|
| 1220 |
flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
|
| 1221 |
-
|
| 1222 |
-
lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1223 |
-
for z in range(lz):
|
| 1224 |
-
I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
|
| 1225 |
-
flags=flag)
|
| 1226 |
-
lbl0[z] = I
|
| 1227 |
-
if scale[n] != 1.0:
|
| 1228 |
-
for y in range(lbl.shape[-2]):
|
| 1229 |
-
lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
|
| 1230 |
-
interpolation=flag)
|
| 1231 |
-
else:
|
| 1232 |
-
lbl[n, k] = lbl0
|
| 1233 |
-
else:
|
| 1234 |
-
lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
|
| 1235 |
|
| 1236 |
if nt > 1 and not unet:
|
| 1237 |
v1 = lbl[n, -1].copy()
|
|
@@ -1241,20 +1010,7 @@ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=
|
|
| 1241 |
|
| 1242 |
if feat is not None:
|
| 1243 |
for k in range(nf):
|
| 1244 |
-
|
| 1245 |
-
feat0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1246 |
-
for z in range(lz):
|
| 1247 |
-
I = cv2.warpAffine(feats[k, z], M, (xy[1], xy[0]),
|
| 1248 |
-
flags=cv2.INTER_LINEAR)
|
| 1249 |
-
feat0[z] = I
|
| 1250 |
-
if scale[n] != 1.0:
|
| 1251 |
-
for y in range(feat_out.shape[-2]):
|
| 1252 |
-
feat_out[n, k, :, y] = cv2.resize(feat0[:, y], (xy[1], zcrop),
|
| 1253 |
-
interpolation=cv2.INTER_LINEAR)
|
| 1254 |
-
else:
|
| 1255 |
-
feat_out[n, k] = feat0
|
| 1256 |
-
else:
|
| 1257 |
-
feat_out[n, k] = cv2.warpAffine(feats[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 1258 |
|
| 1259 |
|
| 1260 |
|
|
|
|
| 398 |
return img
|
| 399 |
|
| 400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
|
| 403 |
"""Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
|
|
|
|
| 432 |
if z_axis is not None and not do_3D:
|
| 433 |
raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
|
| 434 |
|
|
|
|
| 435 |
if ndim == 4 and not do_3D:
|
| 436 |
raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
|
| 437 |
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
######################## 2D reshaping ########################
|
| 440 |
# if user specifies channel axis, return early
|
|
|
|
| 813 |
nchan = X[0].shape[0]
|
| 814 |
else:
|
| 815 |
nchan = 1
|
| 816 |
+
# if do_3D and X[0].ndim > 3:
|
| 817 |
+
# shape = (zcrop, xy[0], xy[1])
|
| 818 |
+
# else:
|
| 819 |
+
shape = (xy[0], xy[1])
|
| 820 |
imgi = np.zeros((nimg, nchan, *shape), "float32")
|
| 821 |
|
| 822 |
lbl = []
|
|
|
|
| 865 |
if labels.ndim < 3:
|
| 866 |
labels = labels[np.newaxis, :, :]
|
| 867 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
if do_flip:
|
| 869 |
if flip:
|
| 870 |
img = img[..., ::-1]
|
|
|
|
| 872 |
labels = labels[..., ::-1]
|
| 873 |
if nt > 1 and not unet:
|
| 874 |
labels[-1] = -labels[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 875 |
|
| 876 |
for k in range(nchan):
|
| 877 |
+
I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 878 |
+
imgi[n, k] = I
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
|
| 880 |
if Y is not None:
|
| 881 |
for k in range(nt):
|
| 882 |
flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
|
| 883 |
+
lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 884 |
|
| 885 |
if nt > 1 and not unet:
|
| 886 |
v1 = lbl[n, -1].copy()
|
|
|
|
| 922 |
nchan = X[0].shape[0]
|
| 923 |
else:
|
| 924 |
nchan = 1
|
| 925 |
+
shape = (xy[0], xy[1])
|
|
|
|
|
|
|
|
|
|
| 926 |
imgi = np.zeros((nimg, nchan, *shape), "float32")
|
| 927 |
|
| 928 |
lbl = []
|
|
|
|
| 982 |
if feats.ndim < 3:
|
| 983 |
feats = feats[np.newaxis, :, :]
|
| 984 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
if do_flip:
|
| 986 |
if flip:
|
| 987 |
img = img[..., ::-1]
|
|
|
|
| 991 |
labels[-1] = -labels[-1]
|
| 992 |
if feat is not None:
|
| 993 |
feats = feats[..., ::-1]
|
| 994 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
|
| 996 |
for k in range(nchan):
|
| 997 |
+
I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 998 |
+
imgi[n, k] = I
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
|
| 1000 |
if Y is not None:
|
| 1001 |
for k in range(nt):
|
| 1002 |
flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
|
| 1003 |
+
lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
|
| 1005 |
if nt > 1 and not unet:
|
| 1006 |
v1 = lbl[n, -1].copy()
|
|
|
|
| 1010 |
|
| 1011 |
if feat is not None:
|
| 1012 |
for k in range(nf):
|
| 1013 |
+
feat_out[n, k] = cv2.warpAffine(feats[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1014 |
|
| 1015 |
|
| 1016 |
|
models/seg_post_model/utils.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright ยฉ 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import io
|
| 6 |
+
from tqdm import tqdm, trange
|
| 7 |
+
import cv2
|
| 8 |
+
from scipy.ndimage import find_objects
|
| 9 |
+
import numpy as np
|
| 10 |
+
import fastremap
|
| 11 |
+
import fill_voids
|
| 12 |
+
from models.seg_post_model import metrics
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TqdmToLogger(io.StringIO):
|
| 16 |
+
"""
|
| 17 |
+
Output stream for TQDM which will output to logger module instead of
|
| 18 |
+
the StdOut.
|
| 19 |
+
"""
|
| 20 |
+
logger = None
|
| 21 |
+
level = None
|
| 22 |
+
buf = ""
|
| 23 |
+
|
| 24 |
+
def __init__(self, logger, level=None):
|
| 25 |
+
super(TqdmToLogger, self).__init__()
|
| 26 |
+
self.logger = logger
|
| 27 |
+
self.level = level or logging.INFO
|
| 28 |
+
|
| 29 |
+
def write(self, buf):
|
| 30 |
+
self.buf = buf.strip("\r\n\t ")
|
| 31 |
+
|
| 32 |
+
def flush(self):
|
| 33 |
+
self.logger.log(self.level, self.buf)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# def masks_to_outlines(masks):
|
| 38 |
+
# """Get outlines of masks as a 0-1 array.
|
| 39 |
+
|
| 40 |
+
# Args:
|
| 41 |
+
# masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
|
| 42 |
+
|
| 43 |
+
# Returns:
|
| 44 |
+
# outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines.
|
| 45 |
+
# """
|
| 46 |
+
# if masks.ndim > 3 or masks.ndim < 2:
|
| 47 |
+
# raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
|
| 48 |
+
# masks.ndim)
|
| 49 |
+
# outlines = np.zeros(masks.shape, bool)
|
| 50 |
+
|
| 51 |
+
# if masks.ndim == 3:
|
| 52 |
+
# for i in range(masks.shape[0]):
|
| 53 |
+
# outlines[i] = masks_to_outlines(masks[i])
|
| 54 |
+
# return outlines
|
| 55 |
+
# else:
|
| 56 |
+
# slices = find_objects(masks.astype(int))
|
| 57 |
+
# for i, si in enumerate(slices):
|
| 58 |
+
# if si is not None:
|
| 59 |
+
# sr, sc = si
|
| 60 |
+
# mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
|
| 61 |
+
# contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 62 |
+
# cv2.CHAIN_APPROX_NONE)
|
| 63 |
+
# pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
|
| 64 |
+
# vr, vc = pvr + sr.start, pvc + sc.start
|
| 65 |
+
# outlines[vr, vc] = 1
|
| 66 |
+
# return outlines
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def stitch3D(masks, stitch_threshold=0.25):
|
| 70 |
+
"""
|
| 71 |
+
Stitch 2D masks into a 3D volume using a stitch_threshold on IOU.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
masks (list or ndarray): List of 2D masks.
|
| 75 |
+
stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
list: List of stitched 3D masks.
|
| 79 |
+
"""
|
| 80 |
+
mmax = masks[0].max()
|
| 81 |
+
empty = 0
|
| 82 |
+
for i in trange(len(masks) - 1):
|
| 83 |
+
iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:]
|
| 84 |
+
if not iou.size and empty == 0:
|
| 85 |
+
masks[i + 1] = masks[i + 1]
|
| 86 |
+
mmax = masks[i + 1].max()
|
| 87 |
+
elif not iou.size and not empty == 0:
|
| 88 |
+
icount = masks[i + 1].max()
|
| 89 |
+
istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype)
|
| 90 |
+
mmax += icount
|
| 91 |
+
istitch = np.append(np.array(0), istitch)
|
| 92 |
+
masks[i + 1] = istitch[masks[i + 1]]
|
| 93 |
+
else:
|
| 94 |
+
iou[iou < stitch_threshold] = 0.0
|
| 95 |
+
iou[iou < iou.max(axis=0)] = 0.0
|
| 96 |
+
istitch = iou.argmax(axis=1) + 1
|
| 97 |
+
ino = np.nonzero(iou.max(axis=1) == 0.0)[0]
|
| 98 |
+
istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype)
|
| 99 |
+
mmax += len(ino)
|
| 100 |
+
istitch = np.append(np.array(0), istitch)
|
| 101 |
+
masks[i + 1] = istitch[masks[i + 1]]
|
| 102 |
+
empty = 1
|
| 103 |
+
|
| 104 |
+
return masks
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# def diameters(masks):
|
| 108 |
+
# """
|
| 109 |
+
# Calculate the diameters of the objects in the given masks.
|
| 110 |
+
|
| 111 |
+
# Parameters:
|
| 112 |
+
# masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 113 |
+
|
| 114 |
+
# Returns:
|
| 115 |
+
# tuple: A tuple containing the median diameter and an array of diameters for each object.
|
| 116 |
+
|
| 117 |
+
# Examples:
|
| 118 |
+
# >>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]])
|
| 119 |
+
# >>> diameters(masks)
|
| 120 |
+
# (1.0, array([1.41421356, 1.0, 1.0]))
|
| 121 |
+
# """
|
| 122 |
+
# uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True)
|
| 123 |
+
# counts = counts[1:]
|
| 124 |
+
# md = np.median(counts**0.5)
|
| 125 |
+
# if np.isnan(md):
|
| 126 |
+
# md = 0
|
| 127 |
+
# md /= (np.pi**0.5) / 2
|
| 128 |
+
# return md, counts**0.5
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# def radius_distribution(masks, bins):
|
| 132 |
+
# """
|
| 133 |
+
# Calculate the radius distribution of masks.
|
| 134 |
+
|
| 135 |
+
# Args:
|
| 136 |
+
# masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 137 |
+
# bins (int): Number of bins for the histogram.
|
| 138 |
+
|
| 139 |
+
# Returns:
|
| 140 |
+
# A tuple containing a normalized histogram of radii, median radius, array of radii.
|
| 141 |
+
|
| 142 |
+
# """
|
| 143 |
+
# unique, counts = np.unique(masks, return_counts=True)
|
| 144 |
+
# counts = counts[unique != 0]
|
| 145 |
+
# nb, _ = np.histogram((counts**0.5) * 0.5, bins)
|
| 146 |
+
# nb = nb.astype(np.float32)
|
| 147 |
+
# if nb.sum() > 0:
|
| 148 |
+
# nb = nb / nb.sum()
|
| 149 |
+
# md = np.median(counts**0.5) * 0.5
|
| 150 |
+
# if np.isnan(md):
|
| 151 |
+
# md = 0
|
| 152 |
+
# md /= (np.pi**0.5) / 2
|
| 153 |
+
# return nb, md, (counts**0.5) / 2
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# def size_distribution(masks):
|
| 157 |
+
# """
|
| 158 |
+
# Calculates the size distribution of masks.
|
| 159 |
+
|
| 160 |
+
# Args:
|
| 161 |
+
# masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 162 |
+
|
| 163 |
+
# Returns:
|
| 164 |
+
# float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes.
|
| 165 |
+
# """
|
| 166 |
+
# counts = np.unique(masks, return_counts=True)[1][1:]
|
| 167 |
+
# return np.percentile(counts, 25) / np.percentile(counts, 75)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def fill_holes_and_remove_small_masks(masks, min_size=15):
|
| 171 |
+
""" Fills holes in masks (2D/3D) and discards masks smaller than min_size.
|
| 172 |
+
|
| 173 |
+
This function fills holes in each mask using fill_voids.fill.
|
| 174 |
+
It also removes masks that are smaller than the specified min_size.
|
| 175 |
+
|
| 176 |
+
Parameters:
|
| 177 |
+
masks (ndarray): Int, 2D or 3D array of labelled masks.
|
| 178 |
+
0 represents no mask, while positive integers represent mask labels.
|
| 179 |
+
The size can be [Ly x Lx] or [Lz x Ly x Lx].
|
| 180 |
+
min_size (int, optional): Minimum number of pixels per mask.
|
| 181 |
+
Masks smaller than min_size will be removed.
|
| 182 |
+
Set to -1 to turn off this functionality. Default is 15.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
|
| 186 |
+
0 represents no mask, while positive integers represent mask labels.
|
| 187 |
+
The size is [Ly x Lx] or [Lz x Ly x Lx].
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
if masks.ndim > 3 or masks.ndim < 2:
|
| 191 |
+
raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
|
| 192 |
+
masks.ndim)
|
| 193 |
+
|
| 194 |
+
# Filter small masks
|
| 195 |
+
if min_size > 0:
|
| 196 |
+
counts = fastremap.unique(masks, return_counts=True)[1][1:]
|
| 197 |
+
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
|
| 198 |
+
fastremap.renumber(masks, in_place=True)
|
| 199 |
+
|
| 200 |
+
slices = find_objects(masks)
|
| 201 |
+
j = 0
|
| 202 |
+
for i, slc in enumerate(slices):
|
| 203 |
+
if slc is not None:
|
| 204 |
+
msk = masks[slc] == (i + 1)
|
| 205 |
+
msk = fill_voids.fill(msk)
|
| 206 |
+
masks[slc][msk] = (j + 1)
|
| 207 |
+
j += 1
|
| 208 |
+
|
| 209 |
+
if min_size > 0:
|
| 210 |
+
counts = fastremap.unique(masks, return_counts=True)[1][1:]
|
| 211 |
+
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
|
| 212 |
+
fastremap.renumber(masks, in_place=True)
|
| 213 |
+
|
| 214 |
+
return masks
|
models/seg_post_model/{cellpose/vit_sam.py โ vit_sam.py}
RENAMED
|
@@ -130,66 +130,3 @@ class Transformer(nn.Module):
|
|
| 130 |
"""
|
| 131 |
torch.save(self.state_dict(), filename)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
class CPnetBioImageIO(Transformer):
|
| 136 |
-
"""
|
| 137 |
-
A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
|
| 138 |
-
|
| 139 |
-
This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
|
| 140 |
-
allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
|
| 141 |
-
"""
|
| 142 |
-
|
| 143 |
-
def forward(self, x):
|
| 144 |
-
"""
|
| 145 |
-
Perform a forward pass of the CPnet model and return unpacked tensors.
|
| 146 |
-
|
| 147 |
-
Args:
|
| 148 |
-
x (torch.Tensor): Input tensor.
|
| 149 |
-
|
| 150 |
-
Returns:
|
| 151 |
-
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
|
| 152 |
-
"""
|
| 153 |
-
output_tensor, style_tensor, downsampled_tensors = super().forward(x)
|
| 154 |
-
return output_tensor, style_tensor, *downsampled_tensors
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def load_model(self, filename, device=None):
|
| 158 |
-
"""
|
| 159 |
-
Load the model from a file.
|
| 160 |
-
|
| 161 |
-
Args:
|
| 162 |
-
filename (str): The path to the file where the model is saved.
|
| 163 |
-
device (torch.device, optional): The device to load the model on. Defaults to None.
|
| 164 |
-
"""
|
| 165 |
-
if (device is not None) and (device.type != "cpu"):
|
| 166 |
-
state_dict = torch.load(filename, map_location=device, weights_only=True)
|
| 167 |
-
else:
|
| 168 |
-
self.__init__(self.nout)
|
| 169 |
-
state_dict = torch.load(filename, map_location=torch.device("cpu"),
|
| 170 |
-
weights_only=True)
|
| 171 |
-
|
| 172 |
-
self.load_state_dict(state_dict)
|
| 173 |
-
|
| 174 |
-
def load_state_dict(self, state_dict):
|
| 175 |
-
"""
|
| 176 |
-
Load the state dictionary into the model.
|
| 177 |
-
|
| 178 |
-
This method overrides the default `load_state_dict` to handle Cellpose's custom
|
| 179 |
-
loading mechanism and ensures compatibility with BioImage.IO Core.
|
| 180 |
-
|
| 181 |
-
Args:
|
| 182 |
-
state_dict (Mapping[str, Any]): A state dictionary to load into the model
|
| 183 |
-
"""
|
| 184 |
-
if state_dict["output.2.weight"].shape[0] != self.nout:
|
| 185 |
-
for name in self.state_dict():
|
| 186 |
-
if "output" not in name:
|
| 187 |
-
self.state_dict()[name].copy_(state_dict[name])
|
| 188 |
-
else:
|
| 189 |
-
super().load_state_dict(
|
| 190 |
-
{name: param for name, param in state_dict.items()},
|
| 191 |
-
strict=False)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
| 130 |
"""
|
| 131 |
torch.save(self.state_dict(), filename)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
segmentation.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
-
import pprint
|
| 3 |
from typing import Any, List, Optional
|
| 4 |
-
import argparse
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
-
import pyrallis
|
| 7 |
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 8 |
import torch
|
| 9 |
import os
|
|
@@ -27,7 +24,7 @@ from models.enc_model.loca_args import get_argparser as loca_get_argparser
|
|
| 27 |
from models.enc_model.loca import build_model as build_loca_model
|
| 28 |
import time
|
| 29 |
from _utils.seg_eval import *
|
| 30 |
-
from models.seg_post_model
|
| 31 |
from datetime import datetime
|
| 32 |
import json
|
| 33 |
import logging
|
|
@@ -35,7 +32,7 @@ from PIL import Image
|
|
| 35 |
import torchvision.transforms as T
|
| 36 |
import cv2
|
| 37 |
from skimage import io, measure
|
| 38 |
-
logging.getLogger('models.seg_post_model.
|
| 39 |
|
| 40 |
SCALE = 1
|
| 41 |
|
|
@@ -76,8 +73,6 @@ class SegmentationModule(pl.LightningModule):
|
|
| 76 |
" `placeholder_token` that is not already in the tokenizer."
|
| 77 |
)
|
| 78 |
try:
|
| 79 |
-
# print("loading pretrained task embedding from {}".format("pretrained/task_embed.pth"))
|
| 80 |
-
# task_embed_from_pretrain = torch.load("pretrained/task_embed.pth")
|
| 81 |
task_embed_from_pretrain = hf_hub_download(
|
| 82 |
repo_id="phoebe777777/111",
|
| 83 |
filename="task_embed.pth",
|
|
@@ -190,11 +185,6 @@ class SegmentationModule(pl.LightningModule):
|
|
| 190 |
exemplar_attention_maps2 = []
|
| 191 |
exemplar_attention_maps3 = []
|
| 192 |
|
| 193 |
-
cross_self_task_attn_maps = []
|
| 194 |
-
cross_self_exe_attn_maps1 = []
|
| 195 |
-
cross_self_exe_attn_maps2 = []
|
| 196 |
-
cross_self_exe_attn_maps3 = []
|
| 197 |
-
|
| 198 |
# only use 64x64 self-attention
|
| 199 |
self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 200 |
prompts=[self.config.prompt for i in range(bsz)], # ่ฟ้่ฆๆนไน
|
|
@@ -323,7 +313,7 @@ def inference(data_path, box=None, save_path="./example_imgs", visualize=False):
|
|
| 323 |
ax[0].axis("off")
|
| 324 |
ax[1].imshow(img_show)
|
| 325 |
for inst_id in np.unique(mask_show):
|
| 326 |
-
if inst_id == 0: # 0
|
| 327 |
continue
|
| 328 |
# ็ๆไบๅผ mask
|
| 329 |
binary_mask = (mask_show == inst_id).astype(np.uint8)
|
|
@@ -352,12 +342,6 @@ def main():
|
|
| 352 |
from matplotlib import cm
|
| 353 |
|
| 354 |
def overlay_instances(img, mask, alpha=0.5, cmap_name="tab20"):
|
| 355 |
-
"""
|
| 356 |
-
img: ๅๅพ (H, W, 3)๏ผ่ๅด [0,255] ๆ [0,1]
|
| 357 |
-
mask: ๅฎไพๅๅฒ็ปๆ (H, W)๏ผ่ๆฏ=0๏ผๅฎไพ=1,2,...
|
| 358 |
-
alpha: ้ๆๅบฆ
|
| 359 |
-
cmap_name: ้ข่ฒๆ ๅฐ่กจ
|
| 360 |
-
"""
|
| 361 |
img = img.astype(np.float32)
|
| 362 |
if len(img.shape) == 2:
|
| 363 |
img = np.stack([img]*3, axis=-1)
|
|
@@ -369,7 +353,7 @@ def overlay_instances(img, mask, alpha=0.5, cmap_name="tab20"):
|
|
| 369 |
cmap = cm.get_cmap(cmap_name, np.max(mask)+1)
|
| 370 |
|
| 371 |
for inst_id in np.unique(mask):
|
| 372 |
-
if inst_id == 0:
|
| 373 |
continue
|
| 374 |
color = np.array(cmap(inst_id)[:3]) # RGB
|
| 375 |
overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
from typing import Any, List, Optional
|
|
|
|
| 3 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 4 |
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 5 |
import torch
|
| 6 |
import os
|
|
|
|
| 24 |
from models.enc_model.loca import build_model as build_loca_model
|
| 25 |
import time
|
| 26 |
from _utils.seg_eval import *
|
| 27 |
+
from models.seg_post_model import metrics
|
| 28 |
from datetime import datetime
|
| 29 |
import json
|
| 30 |
import logging
|
|
|
|
| 32 |
import torchvision.transforms as T
|
| 33 |
import cv2
|
| 34 |
from skimage import io, measure
|
| 35 |
+
logging.getLogger('models.seg_post_model.models').setLevel(logging.ERROR)
|
| 36 |
|
| 37 |
SCALE = 1
|
| 38 |
|
|
|
|
| 73 |
" `placeholder_token` that is not already in the tokenizer."
|
| 74 |
)
|
| 75 |
try:
|
|
|
|
|
|
|
| 76 |
task_embed_from_pretrain = hf_hub_download(
|
| 77 |
repo_id="phoebe777777/111",
|
| 78 |
filename="task_embed.pth",
|
|
|
|
| 185 |
exemplar_attention_maps2 = []
|
| 186 |
exemplar_attention_maps3 = []
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
# only use 64x64 self-attention
|
| 189 |
self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 190 |
prompts=[self.config.prompt for i in range(bsz)], # ่ฟ้่ฆๆนไน
|
|
|
|
| 313 |
ax[0].axis("off")
|
| 314 |
ax[1].imshow(img_show)
|
| 315 |
for inst_id in np.unique(mask_show):
|
| 316 |
+
if inst_id == 0: # 0 background
|
| 317 |
continue
|
| 318 |
# ็ๆไบๅผ mask
|
| 319 |
binary_mask = (mask_show == inst_id).astype(np.uint8)
|
|
|
|
| 342 |
from matplotlib import cm
|
| 343 |
|
| 344 |
def overlay_instances(img, mask, alpha=0.5, cmap_name="tab20"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
img = img.astype(np.float32)
|
| 346 |
if len(img.shape) == 2:
|
| 347 |
img = np.stack([img]*3, axis=-1)
|
|
|
|
| 353 |
cmap = cm.get_cmap(cmap_name, np.max(mask)+1)
|
| 354 |
|
| 355 |
for inst_id in np.unique(mask):
|
| 356 |
+
if inst_id == 0:
|
| 357 |
continue
|
| 358 |
color = np.array(cmap(inst_id)[:3]) # RGB
|
| 359 |
overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color
|