Spaces:
Runtime error
Runtime error
Commit ·
102cd7d
1
Parent(s): 25ea9b7
update
Browse files- models/seg_post_model/cellpose/__init__.py +1 -1
- models/seg_post_model/cellpose/__main__.py +0 -272
- models/seg_post_model/cellpose/cli.py +0 -240
- models/seg_post_model/cellpose/denoise.py +0 -1474
- models/seg_post_model/cellpose/export.py +0 -405
- models/seg_post_model/cellpose/gui/gui.py +0 -2007
- models/seg_post_model/cellpose/gui/gui3d.py +0 -667
- models/seg_post_model/cellpose/gui/guihelpwindowtext.html +0 -143
- models/seg_post_model/cellpose/gui/guiparts.py +0 -793
- models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html +0 -25
- models/seg_post_model/cellpose/gui/io.py +0 -634
- models/seg_post_model/cellpose/gui/make_train.py +0 -107
- models/seg_post_model/cellpose/gui/menus.py +0 -145
- models/seg_post_model/cellpose/vit_sam_new.py +0 -197
models/seg_post_model/cellpose/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
from .version import version, version_str
|
|
|
|
| 1 |
+
# from .version import version, version_str
|
models/seg_post_model/cellpose/__main__.py
DELETED
|
@@ -1,272 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
import os, time
|
| 5 |
-
import numpy as np
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
from cellpose import utils, models, io, train
|
| 8 |
-
from .version import version_str
|
| 9 |
-
from cellpose.cli import get_arg_parser
|
| 10 |
-
|
| 11 |
-
try:
|
| 12 |
-
from cellpose.gui import gui3d, gui
|
| 13 |
-
GUI_ENABLED = True
|
| 14 |
-
except ImportError as err:
|
| 15 |
-
GUI_ERROR = err
|
| 16 |
-
GUI_ENABLED = False
|
| 17 |
-
GUI_IMPORT = True
|
| 18 |
-
except Exception as err:
|
| 19 |
-
GUI_ENABLED = False
|
| 20 |
-
GUI_ERROR = err
|
| 21 |
-
GUI_IMPORT = False
|
| 22 |
-
raise
|
| 23 |
-
|
| 24 |
-
import logging
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def main():
|
| 28 |
-
""" Run cellpose from command line
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work
|
| 32 |
-
|
| 33 |
-
if args.version:
|
| 34 |
-
print(version_str)
|
| 35 |
-
return
|
| 36 |
-
|
| 37 |
-
######## if no image arguments are provided, run GUI or add model and exit ########
|
| 38 |
-
if len(args.dir) == 0 and len(args.image_path) == 0:
|
| 39 |
-
if args.add_model:
|
| 40 |
-
io.add_model(args.add_model)
|
| 41 |
-
return
|
| 42 |
-
else:
|
| 43 |
-
if not GUI_ENABLED:
|
| 44 |
-
print("GUI ERROR: %s" % GUI_ERROR)
|
| 45 |
-
if GUI_IMPORT:
|
| 46 |
-
print(
|
| 47 |
-
"GUI FAILED: GUI dependencies may not be installed, to install, run"
|
| 48 |
-
)
|
| 49 |
-
print(" pip install 'cellpose[gui]'")
|
| 50 |
-
else:
|
| 51 |
-
if args.Zstack:
|
| 52 |
-
gui3d.run()
|
| 53 |
-
else:
|
| 54 |
-
gui.run()
|
| 55 |
-
return
|
| 56 |
-
|
| 57 |
-
############################## run cellpose on images ##############################
|
| 58 |
-
if args.verbose:
|
| 59 |
-
from .io import logger_setup
|
| 60 |
-
logger, log_file = logger_setup()
|
| 61 |
-
else:
|
| 62 |
-
print(
|
| 63 |
-
">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
|
| 64 |
-
print("No --verbose => no progress or info printed")
|
| 65 |
-
logger = logging.getLogger(__name__)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
# find images
|
| 69 |
-
if len(args.img_filter) > 0:
|
| 70 |
-
image_filter = args.img_filter
|
| 71 |
-
else:
|
| 72 |
-
image_filter = None
|
| 73 |
-
|
| 74 |
-
device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
|
| 75 |
-
device=args.gpu_device)
|
| 76 |
-
|
| 77 |
-
if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
|
| 78 |
-
pretrained_model = "cpsam"
|
| 79 |
-
logger.warning("training from scratch is disabled, using 'cpsam' model")
|
| 80 |
-
else:
|
| 81 |
-
pretrained_model = args.pretrained_model
|
| 82 |
-
|
| 83 |
-
# Warn users about old arguments from CP3:
|
| 84 |
-
if args.pretrained_model_ortho:
|
| 85 |
-
logger.warning(
|
| 86 |
-
"the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
|
| 87 |
-
if args.train_size:
|
| 88 |
-
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
|
| 89 |
-
if args.chan or args.chan2:
|
| 90 |
-
logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
|
| 91 |
-
if args.all_channels:
|
| 92 |
-
logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
|
| 93 |
-
if args.restore_type:
|
| 94 |
-
logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
|
| 95 |
-
if args.transformer:
|
| 96 |
-
logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
|
| 97 |
-
if args.invert:
|
| 98 |
-
logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
|
| 99 |
-
if args.chan2_restore:
|
| 100 |
-
logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
|
| 101 |
-
if args.diam_mean:
|
| 102 |
-
logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
|
| 103 |
-
if args.train_size:
|
| 104 |
-
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
|
| 105 |
-
|
| 106 |
-
if args.norm_percentile is not None:
|
| 107 |
-
value1, value2 = args.norm_percentile
|
| 108 |
-
normalize = {'percentile': (float(value1), float(value2))}
|
| 109 |
-
else:
|
| 110 |
-
normalize = (not args.no_norm)
|
| 111 |
-
|
| 112 |
-
if args.save_each:
|
| 113 |
-
if not args.save_every:
|
| 114 |
-
raise ValueError("ERROR: --save_each requires --save_every")
|
| 115 |
-
|
| 116 |
-
if len(args.image_path) > 0 and args.train:
|
| 117 |
-
raise ValueError("ERROR: cannot train model with single image input")
|
| 118 |
-
|
| 119 |
-
## Run evaluation on images
|
| 120 |
-
if not args.train:
|
| 121 |
-
_evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
|
| 122 |
-
|
| 123 |
-
## Train a model ##
|
| 124 |
-
else:
|
| 125 |
-
_train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
|
| 129 |
-
test_dir = None if len(args.test_dir) == 0 else args.test_dir
|
| 130 |
-
images, labels, image_names, train_probs = None, None, None, None
|
| 131 |
-
test_images, test_labels, image_names_test, test_probs = None, None, None, None
|
| 132 |
-
compute_flows = False
|
| 133 |
-
if len(args.file_list) > 0:
|
| 134 |
-
if os.path.exists(args.file_list):
|
| 135 |
-
dat = np.load(args.file_list, allow_pickle=True).item()
|
| 136 |
-
image_names = dat["train_files"]
|
| 137 |
-
image_names_test = dat.get("test_files", None)
|
| 138 |
-
train_probs = dat.get("train_probs", None)
|
| 139 |
-
test_probs = dat.get("test_probs", None)
|
| 140 |
-
compute_flows = dat.get("compute_flows", False)
|
| 141 |
-
load_files = False
|
| 142 |
-
else:
|
| 143 |
-
logger.critical(f"ERROR: {args.file_list} does not exist")
|
| 144 |
-
else:
|
| 145 |
-
output = io.load_train_test_data(args.dir, test_dir, image_filter,
|
| 146 |
-
args.mask_filter,
|
| 147 |
-
args.look_one_level_down)
|
| 148 |
-
images, labels, image_names, test_images, test_labels, image_names_test = output
|
| 149 |
-
load_files = True
|
| 150 |
-
|
| 151 |
-
# initialize model
|
| 152 |
-
model = models.CellposeModel(device=device, pretrained_model=pretrained_model)
|
| 153 |
-
|
| 154 |
-
# train segmentation model
|
| 155 |
-
cpmodel_path = train.train_seg(
|
| 156 |
-
model.net, images, labels, train_files=image_names,
|
| 157 |
-
test_data=test_images, test_labels=test_labels,
|
| 158 |
-
test_files=image_names_test, train_probs=train_probs,
|
| 159 |
-
test_probs=test_probs, compute_flows=compute_flows,
|
| 160 |
-
load_files=load_files, normalize=normalize,
|
| 161 |
-
channel_axis=args.channel_axis,
|
| 162 |
-
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
|
| 163 |
-
SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
|
| 164 |
-
min_train_masks=args.min_train_masks,
|
| 165 |
-
nimg_per_epoch=args.nimg_per_epoch,
|
| 166 |
-
nimg_test_per_epoch=args.nimg_test_per_epoch,
|
| 167 |
-
save_path=os.path.realpath(args.dir),
|
| 168 |
-
save_every=args.save_every,
|
| 169 |
-
save_each=args.save_each,
|
| 170 |
-
model_name=args.model_name_out)[0]
|
| 171 |
-
model.pretrained_model = cpmodel_path
|
| 172 |
-
logger.info(">>>> model trained and saved to %s" % cpmodel_path)
|
| 173 |
-
return model
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
|
| 177 |
-
# Check with user if they REALLY mean to run without saving anything
|
| 178 |
-
if not args.train:
|
| 179 |
-
saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
|
| 180 |
-
|
| 181 |
-
tic = time.time()
|
| 182 |
-
if len(args.dir) > 0:
|
| 183 |
-
image_names = io.get_image_files(
|
| 184 |
-
args.dir, args.mask_filter, imf=imf,
|
| 185 |
-
look_one_level_down=args.look_one_level_down)
|
| 186 |
-
else:
|
| 187 |
-
if os.path.exists(args.image_path):
|
| 188 |
-
image_names = [args.image_path]
|
| 189 |
-
else:
|
| 190 |
-
raise ValueError(f"ERROR: no file found at {args.image_path}")
|
| 191 |
-
nimg = len(image_names)
|
| 192 |
-
|
| 193 |
-
if args.savedir:
|
| 194 |
-
if not os.path.exists(args.savedir):
|
| 195 |
-
raise FileExistsError(f"--savedir {args.savedir} does not exist")
|
| 196 |
-
|
| 197 |
-
logger.info(
|
| 198 |
-
">>>> running cellpose on %d images using all channels" % nimg)
|
| 199 |
-
|
| 200 |
-
# handle built-in model exceptions
|
| 201 |
-
model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)
|
| 202 |
-
|
| 203 |
-
tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
|
| 204 |
-
|
| 205 |
-
channel_axis = args.channel_axis
|
| 206 |
-
z_axis = args.z_axis
|
| 207 |
-
|
| 208 |
-
for image_name in tqdm(image_names, file=tqdm_out):
|
| 209 |
-
if args.do_3D or args.stitch_threshold > 0.:
|
| 210 |
-
logger.info('loading image as 3D zstack')
|
| 211 |
-
image = io.imread_3D(image_name)
|
| 212 |
-
if channel_axis is None:
|
| 213 |
-
channel_axis = 3
|
| 214 |
-
if z_axis is None:
|
| 215 |
-
z_axis = 0
|
| 216 |
-
|
| 217 |
-
else:
|
| 218 |
-
image = io.imread_2D(image_name)
|
| 219 |
-
out = model.eval(
|
| 220 |
-
image,
|
| 221 |
-
diameter=args.diameter,
|
| 222 |
-
do_3D=args.do_3D,
|
| 223 |
-
augment=args.augment,
|
| 224 |
-
flow_threshold=args.flow_threshold,
|
| 225 |
-
cellprob_threshold=args.cellprob_threshold,
|
| 226 |
-
stitch_threshold=args.stitch_threshold,
|
| 227 |
-
min_size=args.min_size,
|
| 228 |
-
batch_size=args.batch_size,
|
| 229 |
-
bsize=args.bsize,
|
| 230 |
-
resample=not args.no_resample,
|
| 231 |
-
normalize=normalize,
|
| 232 |
-
channel_axis=channel_axis,
|
| 233 |
-
z_axis=z_axis,
|
| 234 |
-
anisotropy=args.anisotropy,
|
| 235 |
-
niter=args.niter,
|
| 236 |
-
flow3D_smooth=args.flow3D_smooth)
|
| 237 |
-
masks, flows = out[:2]
|
| 238 |
-
|
| 239 |
-
if args.exclude_on_edges:
|
| 240 |
-
masks = utils.remove_edge_masks(masks)
|
| 241 |
-
if not args.no_npy:
|
| 242 |
-
io.masks_flows_to_seg(image, masks, flows, image_name,
|
| 243 |
-
imgs_restore=None,
|
| 244 |
-
restore_type=None,
|
| 245 |
-
ratio=1.)
|
| 246 |
-
if saving_something:
|
| 247 |
-
suffix = "_cp_masks"
|
| 248 |
-
if args.output_name is not None:
|
| 249 |
-
# (1) If `savedir` is not defined, then must have a non-zero `suffix`
|
| 250 |
-
if args.savedir is None and len(args.output_name) > 0:
|
| 251 |
-
suffix = args.output_name
|
| 252 |
-
elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
|
| 253 |
-
# (2) If `savedir` is defined, and different from `dir` then
|
| 254 |
-
# takes the value passed as a param. (which can be empty string)
|
| 255 |
-
suffix = args.output_name
|
| 256 |
-
|
| 257 |
-
io.save_masks(image, masks, flows, image_name,
|
| 258 |
-
suffix=suffix, png=args.save_png,
|
| 259 |
-
tif=args.save_tif, save_flows=args.save_flows,
|
| 260 |
-
save_outlines=args.save_outlines,
|
| 261 |
-
dir_above=args.dir_above, savedir=args.savedir,
|
| 262 |
-
save_txt=args.save_txt, in_folders=args.in_folders,
|
| 263 |
-
save_mpl=args.save_mpl)
|
| 264 |
-
if args.save_rois:
|
| 265 |
-
io.save_rois(masks, image_name)
|
| 266 |
-
logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
|
| 267 |
-
|
| 268 |
-
return model
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
if __name__ == "__main__":
|
| 272 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/cli.py
DELETED
|
@@ -1,240 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import argparse
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def get_arg_parser():
|
| 9 |
-
""" Parses command line arguments for cellpose main function
|
| 10 |
-
|
| 11 |
-
Note: this function has to be in a separate file to allow autodoc to work for CLI.
|
| 12 |
-
The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
|
| 13 |
-
see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
|
| 17 |
-
|
| 18 |
-
# misc settings
|
| 19 |
-
parser.add_argument("--version", action="store_true",
|
| 20 |
-
help="show cellpose version info")
|
| 21 |
-
parser.add_argument(
|
| 22 |
-
"--verbose", action="store_true",
|
| 23 |
-
help="show information about running and settings and save to log")
|
| 24 |
-
parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
|
| 25 |
-
|
| 26 |
-
# settings for CPU vs GPU
|
| 27 |
-
hardware_args = parser.add_argument_group("Hardware Arguments")
|
| 28 |
-
hardware_args.add_argument("--use_gpu", action="store_true",
|
| 29 |
-
help="use gpu if torch with cuda installed")
|
| 30 |
-
hardware_args.add_argument(
|
| 31 |
-
"--gpu_device", required=False, default="0", type=str,
|
| 32 |
-
help="which gpu device to use, use an integer for torch, or mps for M1")
|
| 33 |
-
|
| 34 |
-
# settings for locating and formatting images
|
| 35 |
-
input_img_args = parser.add_argument_group("Input Image Arguments")
|
| 36 |
-
input_img_args.add_argument("--dir", default=[], type=str,
|
| 37 |
-
help="folder containing data to run or train on.")
|
| 38 |
-
input_img_args.add_argument(
|
| 39 |
-
"--image_path", default=[], type=str, help=
|
| 40 |
-
"if given and --dir not given, run on single image instead of folder (cannot train with this option)"
|
| 41 |
-
)
|
| 42 |
-
input_img_args.add_argument(
|
| 43 |
-
"--look_one_level_down", action="store_true",
|
| 44 |
-
help="run processing on all subdirectories of current folder")
|
| 45 |
-
input_img_args.add_argument("--img_filter", default=[], type=str,
|
| 46 |
-
help="end string for images to run on")
|
| 47 |
-
input_img_args.add_argument(
|
| 48 |
-
"--channel_axis", default=None, type=int,
|
| 49 |
-
help="axis of image which corresponds to image channels")
|
| 50 |
-
input_img_args.add_argument("--z_axis", default=None, type=int,
|
| 51 |
-
help="axis of image which corresponds to Z dimension")
|
| 52 |
-
|
| 53 |
-
# TODO: remove deprecated in future version
|
| 54 |
-
input_img_args.add_argument(
|
| 55 |
-
"--chan", default=0, type=int, help=
|
| 56 |
-
"Deprecated in v4.0.1+, not used. ")
|
| 57 |
-
input_img_args.add_argument(
|
| 58 |
-
"--chan2", default=0, type=int, help=
|
| 59 |
-
'Deprecated in v4.0.1+, not used. ')
|
| 60 |
-
input_img_args.add_argument("--invert", action="store_true", help=
|
| 61 |
-
'Deprecated in v4.0.1+, not used. ')
|
| 62 |
-
input_img_args.add_argument(
|
| 63 |
-
"--all_channels", action="store_true", help=
|
| 64 |
-
'Deprecated in v4.0.1+, not used. ')
|
| 65 |
-
|
| 66 |
-
# model settings
|
| 67 |
-
model_args = parser.add_argument_group("Model Arguments")
|
| 68 |
-
model_args.add_argument("--pretrained_model", required=False, default="cpsam",
|
| 69 |
-
type=str,
|
| 70 |
-
help="model to use for running or starting training")
|
| 71 |
-
model_args.add_argument(
|
| 72 |
-
"--add_model", required=False, default=None, type=str,
|
| 73 |
-
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
|
| 74 |
-
model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
|
| 75 |
-
type=str,
|
| 76 |
-
help="Deprecated in v4.0.1+, not used. ")
|
| 77 |
-
|
| 78 |
-
# TODO: remove deprecated in future version
|
| 79 |
-
model_args.add_argument("--restore_type", required=False, default=None, type=str, help=
|
| 80 |
-
'Deprecated in v4.0.1+, not used. ')
|
| 81 |
-
model_args.add_argument("--chan2_restore", action="store_true", help=
|
| 82 |
-
'Deprecated in v4.0.1+, not used. ')
|
| 83 |
-
model_args.add_argument(
|
| 84 |
-
"--transformer", action="store_true", help=
|
| 85 |
-
"use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
|
| 86 |
-
|
| 87 |
-
# algorithm settings
|
| 88 |
-
algorithm_args = parser.add_argument_group("Algorithm Arguments")
|
| 89 |
-
algorithm_args.add_argument("--no_norm", action="store_true",
|
| 90 |
-
help="do not normalize images (normalize=False)")
|
| 91 |
-
algorithm_args.add_argument(
|
| 92 |
-
'--norm_percentile',
|
| 93 |
-
nargs=2, # Require exactly two values
|
| 94 |
-
metavar=('VALUE1', 'VALUE2'),
|
| 95 |
-
help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
|
| 96 |
-
)
|
| 97 |
-
algorithm_args.add_argument(
|
| 98 |
-
"--do_3D", action="store_true",
|
| 99 |
-
help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
|
| 100 |
-
algorithm_args.add_argument(
|
| 101 |
-
"--diameter", required=False, default=None, type=float, help=
|
| 102 |
-
"use to resize cells to the training diameter (30 pixels)"
|
| 103 |
-
)
|
| 104 |
-
algorithm_args.add_argument(
|
| 105 |
-
"--stitch_threshold", required=False, default=0.0, type=float,
|
| 106 |
-
help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
|
| 107 |
-
)
|
| 108 |
-
algorithm_args.add_argument(
|
| 109 |
-
"--min_size", required=False, default=15, type=int,
|
| 110 |
-
help="minimum number of pixels per mask, can turn off with -1")
|
| 111 |
-
algorithm_args.add_argument(
|
| 112 |
-
"--flow3D_smooth", required=False, default=0, type=float,
|
| 113 |
-
help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
|
| 114 |
-
algorithm_args.add_argument(
|
| 115 |
-
"--flow_threshold", default=0.4, type=float, help=
|
| 116 |
-
"flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
|
| 117 |
-
algorithm_args.add_argument(
|
| 118 |
-
"--cellprob_threshold", default=0, type=float,
|
| 119 |
-
help="cellprob threshold, default is 0, decrease to find more and larger masks")
|
| 120 |
-
algorithm_args.add_argument(
|
| 121 |
-
"--niter", default=0, type=int, help=
|
| 122 |
-
"niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
|
| 123 |
-
)
|
| 124 |
-
algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
|
| 125 |
-
help="anisotropy of volume in 3D")
|
| 126 |
-
algorithm_args.add_argument("--exclude_on_edges", action="store_true",
|
| 127 |
-
help="discard masks which touch edges of image")
|
| 128 |
-
algorithm_args.add_argument(
|
| 129 |
-
"--augment", action="store_true",
|
| 130 |
-
help="tiles image with overlapping tiles and flips overlapped regions to augment"
|
| 131 |
-
)
|
| 132 |
-
algorithm_args.add_argument("--batch_size", default=8, type=int,
|
| 133 |
-
help="inference batch size. Default: %(default)s")
|
| 134 |
-
|
| 135 |
-
# TODO: remove deprecated in future version
|
| 136 |
-
algorithm_args.add_argument(
|
| 137 |
-
"--no_resample", action="store_true",
|
| 138 |
-
help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.")
|
| 139 |
-
algorithm_args.add_argument(
|
| 140 |
-
"--no_interp", action="store_true",
|
| 141 |
-
help="do not interpolate when running dynamics (was default)")
|
| 142 |
-
|
| 143 |
-
# output settings
|
| 144 |
-
output_args = parser.add_argument_group("Output Arguments")
|
| 145 |
-
output_args.add_argument(
|
| 146 |
-
"--save_png", action="store_true",
|
| 147 |
-
help="save masks as png")
|
| 148 |
-
output_args.add_argument(
|
| 149 |
-
"--save_tif", action="store_true",
|
| 150 |
-
help="save masks as tif")
|
| 151 |
-
output_args.add_argument(
|
| 152 |
-
"--output_name", default=None, type=str,
|
| 153 |
-
help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
|
| 154 |
-
output_args.add_argument("--no_npy", action="store_true",
|
| 155 |
-
help="suppress saving of npy")
|
| 156 |
-
output_args.add_argument(
|
| 157 |
-
"--savedir", default=None, type=str, help=
|
| 158 |
-
"folder to which segmentation results will be saved (defaults to input image directory)"
|
| 159 |
-
)
|
| 160 |
-
output_args.add_argument(
|
| 161 |
-
"--dir_above", action="store_true", help=
|
| 162 |
-
"save output folders adjacent to image folder instead of inside it (off by default)"
|
| 163 |
-
)
|
| 164 |
-
output_args.add_argument("--in_folders", action="store_true",
|
| 165 |
-
help="flag to save output in folders (off by default)")
|
| 166 |
-
output_args.add_argument(
|
| 167 |
-
"--save_flows", action="store_true", help=
|
| 168 |
-
"whether or not to save RGB images of flows when masks are saved (disabled by default)"
|
| 169 |
-
)
|
| 170 |
-
output_args.add_argument(
|
| 171 |
-
"--save_outlines", action="store_true", help=
|
| 172 |
-
"whether or not to save RGB outline images when masks are saved (disabled by default)"
|
| 173 |
-
)
|
| 174 |
-
output_args.add_argument(
|
| 175 |
-
"--save_rois", action="store_true",
|
| 176 |
-
help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
|
| 177 |
-
)
|
| 178 |
-
output_args.add_argument(
|
| 179 |
-
"--save_txt", action="store_true",
|
| 180 |
-
help="flag to enable txt outlines for ImageJ (disabled by default)")
|
| 181 |
-
output_args.add_argument(
|
| 182 |
-
"--save_mpl", action="store_true",
|
| 183 |
-
help="save a figure of image/mask/flows using matplotlib (disabled by default). "
|
| 184 |
-
"This is slow, especially with large images.")
|
| 185 |
-
|
| 186 |
-
# training settings
|
| 187 |
-
training_args = parser.add_argument_group("Training Arguments")
|
| 188 |
-
training_args.add_argument("--train", action="store_true",
|
| 189 |
-
help="train network using images in dir")
|
| 190 |
-
training_args.add_argument("--test_dir", default=[], type=str,
|
| 191 |
-
help="folder containing test data (optional)")
|
| 192 |
-
training_args.add_argument(
|
| 193 |
-
"--file_list", default=[], type=str, help=
|
| 194 |
-
"path to list of files for training and testing and probabilities for each image (optional)"
|
| 195 |
-
)
|
| 196 |
-
training_args.add_argument(
|
| 197 |
-
"--mask_filter", default="_masks", type=str, help=
|
| 198 |
-
"end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
|
| 199 |
-
)
|
| 200 |
-
training_args.add_argument("--learning_rate", default=1e-5, type=float,
|
| 201 |
-
help="learning rate. Default: %(default)s")
|
| 202 |
-
training_args.add_argument("--weight_decay", default=0.1, type=float,
|
| 203 |
-
help="weight decay. Default: %(default)s")
|
| 204 |
-
training_args.add_argument("--n_epochs", default=100, type=int,
|
| 205 |
-
help="number of epochs. Default: %(default)s")
|
| 206 |
-
training_args.add_argument("--train_batch_size", default=1, type=int,
|
| 207 |
-
help="training batch size. Default: %(default)s")
|
| 208 |
-
training_args.add_argument("--bsize", default=256, type=int,
|
| 209 |
-
help="block size for tiles. Default: %(default)s")
|
| 210 |
-
training_args.add_argument(
|
| 211 |
-
"--nimg_per_epoch", default=None, type=int,
|
| 212 |
-
help="number of train images per epoch. Default is to use all train images.")
|
| 213 |
-
training_args.add_argument(
|
| 214 |
-
"--nimg_test_per_epoch", default=None, type=int,
|
| 215 |
-
help="number of test images per epoch. Default is to use all test images.")
|
| 216 |
-
training_args.add_argument(
|
| 217 |
-
"--min_train_masks", default=5, type=int, help=
|
| 218 |
-
"minimum number of masks a training image must have to be used. Default: %(default)s"
|
| 219 |
-
)
|
| 220 |
-
training_args.add_argument("--SGD", default=0, type=int,
|
| 221 |
-
help="Deprecated in v4.0.1+, not used - AdamW used instead. ")
|
| 222 |
-
training_args.add_argument(
|
| 223 |
-
"--save_every", default=100, type=int,
|
| 224 |
-
help="number of epochs to skip between saves. Default: %(default)s")
|
| 225 |
-
training_args.add_argument(
|
| 226 |
-
"--save_each", action="store_true",
|
| 227 |
-
help="wether or not to save each epoch. Must also use --save_every. (default: False)")
|
| 228 |
-
training_args.add_argument(
|
| 229 |
-
"--model_name_out", default=None, type=str,
|
| 230 |
-
help="Name of model to save as, defaults to name describing model architecture. "
|
| 231 |
-
"Model is saved in the folder specified by --dir in models subfolder.")
|
| 232 |
-
|
| 233 |
-
# TODO: remove deprecated in future version
|
| 234 |
-
training_args.add_argument(
|
| 235 |
-
"--diam_mean", default=30., type=float, help=
|
| 236 |
-
'Deprecated in v4.0.1+, not used. ')
|
| 237 |
-
training_args.add_argument("--train_size", action="store_true", help=
|
| 238 |
-
'Deprecated in v4.0.1+, not used. ')
|
| 239 |
-
|
| 240 |
-
return parser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/denoise.py
DELETED
|
@@ -1,1474 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
import os, time, datetime
|
| 5 |
-
import numpy as np
|
| 6 |
-
from scipy.stats import mode
|
| 7 |
-
import cv2
|
| 8 |
-
import torch
|
| 9 |
-
from torch import nn
|
| 10 |
-
from torch.nn.functional import conv2d, interpolate
|
| 11 |
-
from tqdm import trange
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
|
| 14 |
-
import logging
|
| 15 |
-
|
| 16 |
-
denoise_logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
from cellpose import transforms, utils, io
|
| 19 |
-
from cellpose.core import run_net
|
| 20 |
-
from cellpose.models import CellposeModel, model_path, normalize_default, assign_device
|
| 21 |
-
|
| 22 |
-
MODEL_NAMES = []
|
| 23 |
-
for ctype in ["cyto3", "cyto2", "nuclei"]:
|
| 24 |
-
for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
|
| 25 |
-
MODEL_NAMES.append(f"{ntype}_{ctype}")
|
| 26 |
-
if ctype != "cyto3":
|
| 27 |
-
for ltype in ["per", "seg", "rec"]:
|
| 28 |
-
MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}")
|
| 29 |
-
if ctype != "cyto3":
|
| 30 |
-
MODEL_NAMES.append(f"aniso_{ctype}")
|
| 31 |
-
|
| 32 |
-
criterion = nn.MSELoss(reduction="mean")
|
| 33 |
-
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def deterministic(seed=0):
|
| 37 |
-
""" set random seeds to create test data """
|
| 38 |
-
import random
|
| 39 |
-
torch.manual_seed(seed)
|
| 40 |
-
torch.cuda.manual_seed(seed)
|
| 41 |
-
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
| 42 |
-
np.random.seed(seed) # Numpy module.
|
| 43 |
-
random.seed(seed) # Python random module.
|
| 44 |
-
torch.manual_seed(seed)
|
| 45 |
-
torch.backends.cudnn.benchmark = False
|
| 46 |
-
torch.backends.cudnn.deterministic = True
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def loss_fn_rec(lbl, y):
|
| 50 |
-
""" loss function between true labels lbl and prediction y """
|
| 51 |
-
loss = 80. * criterion(y, lbl)
|
| 52 |
-
return loss
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def loss_fn_seg(lbl, y):
|
| 56 |
-
""" loss function between true labels lbl and prediction y """
|
| 57 |
-
veci = 5. * lbl[:, 1:]
|
| 58 |
-
lbl = (lbl[:, 0] > .5).float()
|
| 59 |
-
loss = criterion(y[:, :2], veci)
|
| 60 |
-
loss /= 2.
|
| 61 |
-
loss2 = criterion2(y[:, 2], lbl)
|
| 62 |
-
loss = loss + loss2
|
| 63 |
-
return loss
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def get_sigma(Tdown):
|
| 67 |
-
""" Calculates the correlation matrices across channels for the perceptual loss.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
Tdown (list): List of tensors output by each downsampling block of network.
|
| 71 |
-
|
| 72 |
-
Returns:
|
| 73 |
-
list: List of correlations for each input tensor.
|
| 74 |
-
"""
|
| 75 |
-
Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown]
|
| 76 |
-
Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm]
|
| 77 |
-
Sigma = [
|
| 78 |
-
torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1])
|
| 79 |
-
for x in Tnorm
|
| 80 |
-
]
|
| 81 |
-
return Sigma
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def imstats(X, net1):
|
| 85 |
-
"""
|
| 86 |
-
Calculates the image correlation matrices for the perceptual loss.
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
X (torch.Tensor): Input image tensor.
|
| 90 |
-
net1: Cellpose net.
|
| 91 |
-
|
| 92 |
-
Returns:
|
| 93 |
-
list: A list of tensors of correlation matrices.
|
| 94 |
-
"""
|
| 95 |
-
_, _, Tdown = net1(X)
|
| 96 |
-
Sigma = get_sigma(Tdown)
|
| 97 |
-
Sigma = [x.detach() for x in Sigma]
|
| 98 |
-
return Sigma
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def loss_fn_per(img, net1, yl):
|
| 102 |
-
"""
|
| 103 |
-
Calculates the perceptual loss function for image restoration.
|
| 104 |
-
|
| 105 |
-
Args:
|
| 106 |
-
img (torch.Tensor): Input image tensor (noisy/blurry/downsampled).
|
| 107 |
-
net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net).
|
| 108 |
-
yl (torch.Tensor): Clean image tensor.
|
| 109 |
-
|
| 110 |
-
Returns:
|
| 111 |
-
torch.Tensor: Mean perceptual loss.
|
| 112 |
-
"""
|
| 113 |
-
Sigma = imstats(img, net1)
|
| 114 |
-
sd = [x.std((1, 2)) + 1e-6 for x in Sigma]
|
| 115 |
-
Sigma_test = get_sigma(yl)
|
| 116 |
-
losses = torch.zeros(len(Sigma[0]), device=img.device)
|
| 117 |
-
for k in range(len(Sigma)):
|
| 118 |
-
losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2)
|
| 119 |
-
return losses.mean()
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
|
| 123 |
-
"""
|
| 124 |
-
Calculates the test loss for image restoration tasks.
|
| 125 |
-
|
| 126 |
-
Args:
|
| 127 |
-
net0 (torch.nn.Module): The image restoration network.
|
| 128 |
-
X (torch.Tensor): The input image tensor.
|
| 129 |
-
net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
|
| 130 |
-
img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
|
| 131 |
-
lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
|
| 132 |
-
lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
|
| 133 |
-
|
| 134 |
-
Returns:
|
| 135 |
-
tuple: A tuple containing the total loss and the perceptual loss.
|
| 136 |
-
"""
|
| 137 |
-
net0.eval()
|
| 138 |
-
if net1 is not None:
|
| 139 |
-
net1.eval()
|
| 140 |
-
loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
|
| 141 |
-
|
| 142 |
-
with torch.no_grad():
|
| 143 |
-
img_dn = net0(X)[0]
|
| 144 |
-
if lam[2] > 0.:
|
| 145 |
-
loss += lam[2] * loss_fn_rec(img, img_dn)
|
| 146 |
-
if lam[1] > 0. or lam[0] > 0.:
|
| 147 |
-
y, _, ydown = net1(img_dn)
|
| 148 |
-
if lam[1] > 0.:
|
| 149 |
-
loss += lam[1] * loss_fn_seg(lbl, y)
|
| 150 |
-
if lam[0] > 0.:
|
| 151 |
-
loss_per = loss_fn_per(img, net1, ydown)
|
| 152 |
-
loss += lam[0] * loss_per
|
| 153 |
-
return loss, loss_per
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
|
| 157 |
-
"""
|
| 158 |
-
Calculates the train loss for image restoration tasks.
|
| 159 |
-
|
| 160 |
-
Args:
|
| 161 |
-
net0 (torch.nn.Module): The image restoration network.
|
| 162 |
-
X (torch.Tensor): The input image tensor.
|
| 163 |
-
net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
|
| 164 |
-
img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
|
| 165 |
-
lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
|
| 166 |
-
lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
|
| 167 |
-
|
| 168 |
-
Returns:
|
| 169 |
-
tuple: A tuple containing the total loss and the perceptual loss.
|
| 170 |
-
"""
|
| 171 |
-
net0.train()
|
| 172 |
-
if net1 is not None:
|
| 173 |
-
net1.eval()
|
| 174 |
-
loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
|
| 175 |
-
|
| 176 |
-
img_dn = net0(X)[0]
|
| 177 |
-
if lam[2] > 0.:
|
| 178 |
-
loss += lam[2] * loss_fn_rec(img, img_dn)
|
| 179 |
-
if lam[1] > 0. or lam[0] > 0.:
|
| 180 |
-
y, _, ydown = net1(img_dn)
|
| 181 |
-
if lam[1] > 0.:
|
| 182 |
-
loss += lam[1] * loss_fn_seg(lbl, y)
|
| 183 |
-
if lam[0] > 0.:
|
| 184 |
-
loss_per = loss_fn_per(img, net1, ydown)
|
| 185 |
-
loss += lam[0] * loss_per
|
| 186 |
-
return loss, loss_per
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def img_norm(imgi):
|
| 190 |
-
"""
|
| 191 |
-
Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles.
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
imgi (torch.Tensor): Input image tensor.
|
| 195 |
-
|
| 196 |
-
Returns:
|
| 197 |
-
torch.Tensor: Normalized image tensor.
|
| 198 |
-
"""
|
| 199 |
-
shape = imgi.shape
|
| 200 |
-
imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1)
|
| 201 |
-
perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1,
|
| 202 |
-
keepdim=True)
|
| 203 |
-
for k in range(imgi.shape[1]):
|
| 204 |
-
hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3
|
| 205 |
-
imgi[hask, k] -= perc[0, hask, k]
|
| 206 |
-
imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k])
|
| 207 |
-
imgi = imgi.reshape(shape)
|
| 208 |
-
return imgi
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7,
|
| 212 |
-
ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None,
|
| 213 |
-
ds=None, uniform_blur=False, partial_blur=False):
|
| 214 |
-
"""Adds noise to the input image.
|
| 215 |
-
|
| 216 |
-
Args:
|
| 217 |
-
lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx).
|
| 218 |
-
alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4.
|
| 219 |
-
beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7.
|
| 220 |
-
poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7.
|
| 221 |
-
blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7.
|
| 222 |
-
gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0.
|
| 223 |
-
downsample (float, optional): The probability of downsampling the image. Defaults to 0.7.
|
| 224 |
-
ds_max (int, optional): The maximum downsampling factor. Defaults to 7.
|
| 225 |
-
diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None.
|
| 226 |
-
pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None.
|
| 227 |
-
iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True.
|
| 228 |
-
sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None.
|
| 229 |
-
sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None.
|
| 230 |
-
ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None.
|
| 231 |
-
|
| 232 |
-
Returns:
|
| 233 |
-
torch.Tensor: The noisy image tensor of the same shape as the input image.
|
| 234 |
-
"""
|
| 235 |
-
device = lbl.device
|
| 236 |
-
imgi = torch.zeros_like(lbl)
|
| 237 |
-
Ly, Lx = lbl.shape[-2:]
|
| 238 |
-
|
| 239 |
-
diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device)
|
| 240 |
-
#ds0 = 1 if ds is None else ds.item()
|
| 241 |
-
ds = ds * torch.ones(
|
| 242 |
-
(len(lbl),), device=device, dtype=torch.long) if ds is not None else ds
|
| 243 |
-
|
| 244 |
-
# downsample
|
| 245 |
-
ii = []
|
| 246 |
-
idownsample = np.random.rand(len(lbl)) < downsample
|
| 247 |
-
if (ds is None and idownsample.sum() > 0.) or not iso:
|
| 248 |
-
ds = torch.ones(len(lbl), dtype=torch.long, device=device)
|
| 249 |
-
ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),),
|
| 250 |
-
device=device)
|
| 251 |
-
ii = torch.nonzero(ds > 1).flatten()
|
| 252 |
-
elif ds is not None and (ds > 1).sum():
|
| 253 |
-
ii = torch.nonzero(ds > 1).flatten()
|
| 254 |
-
|
| 255 |
-
# add gaussian blur
|
| 256 |
-
iblur = torch.rand(len(lbl), device=device) < blur
|
| 257 |
-
iblur[ii] = True
|
| 258 |
-
if iblur.sum() > 0:
|
| 259 |
-
if sigma0 is None:
|
| 260 |
-
if uniform_blur and iso:
|
| 261 |
-
xr = torch.rand(len(lbl), device=device)
|
| 262 |
-
if len(ii) > 0:
|
| 263 |
-
xr[ii] = ds[ii].float() / 2. / gblur
|
| 264 |
-
sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur])
|
| 265 |
-
sigma1 = sigma0.clone()
|
| 266 |
-
elif not iso:
|
| 267 |
-
xr = torch.rand(len(lbl), device=device)
|
| 268 |
-
if len(ii) > 0:
|
| 269 |
-
xr[ii] = (ds[ii].float()) / gblur
|
| 270 |
-
xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35
|
| 271 |
-
xr[ii] = torch.clip(xr[ii], 0.05, 1.5)
|
| 272 |
-
sigma0 = diams[iblur] / 30. * gblur * xr[iblur]
|
| 273 |
-
sigma1 = sigma0.clone() / 10.
|
| 274 |
-
else:
|
| 275 |
-
xrand = np.random.exponential(1, size=iblur.sum())
|
| 276 |
-
xrand = np.clip(xrand * 0.5, 0.1, 1.0)
|
| 277 |
-
xrand *= gblur
|
| 278 |
-
sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to(
|
| 279 |
-
device)
|
| 280 |
-
sigma1 = sigma0.clone()
|
| 281 |
-
else:
|
| 282 |
-
sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device)
|
| 283 |
-
sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device)
|
| 284 |
-
|
| 285 |
-
# create gaussian filter
|
| 286 |
-
xr = max(8, sigma0.max().long() * 2)
|
| 287 |
-
gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 /
|
| 288 |
-
(2 * sigma0.unsqueeze(-1)**2))
|
| 289 |
-
gfilt0 /= gfilt0.sum(axis=-1, keepdims=True)
|
| 290 |
-
gfilt1 = torch.zeros_like(gfilt0)
|
| 291 |
-
gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0]
|
| 292 |
-
gfilt1[sigma1 != sigma0] = torch.exp(
|
| 293 |
-
-torch.arange(-xr + 1, xr, device=device)**2 /
|
| 294 |
-
(2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2))
|
| 295 |
-
gfilt1[sigma1 == 0] = 0.
|
| 296 |
-
gfilt1[sigma1 == 0, xr] = 1.
|
| 297 |
-
gfilt1 /= gfilt1.sum(axis=-1, keepdims=True)
|
| 298 |
-
gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1)
|
| 299 |
-
gfilt /= gfilt.sum(axis=(1, 2), keepdims=True)
|
| 300 |
-
|
| 301 |
-
lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1),
|
| 302 |
-
padding=gfilt.shape[-1] // 2,
|
| 303 |
-
groups=gfilt.shape[0]).transpose(1, 0)
|
| 304 |
-
if partial_blur:
|
| 305 |
-
#yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100)
|
| 306 |
-
imgi[iblur] = lbl[iblur].clone()
|
| 307 |
-
Lxc = int(Lx * 0.85)
|
| 308 |
-
ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32),
|
| 309 |
-
torch.arange(0, Lxc, dtype=torch.float32),
|
| 310 |
-
indexing="ij")
|
| 311 |
-
mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2))
|
| 312 |
-
mask -= mask.min()
|
| 313 |
-
mask /= mask.max()
|
| 314 |
-
lbl_blur_crop = lbl_blur[:, :, :, :Lxc]
|
| 315 |
-
imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask +
|
| 316 |
-
(1-mask) * imgi[iblur, :, :, :Lxc])
|
| 317 |
-
else:
|
| 318 |
-
imgi[iblur] = lbl_blur
|
| 319 |
-
|
| 320 |
-
imgi[~iblur] = lbl[~iblur]
|
| 321 |
-
|
| 322 |
-
# apply downsample
|
| 323 |
-
for k in ii:
|
| 324 |
-
i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]]
|
| 325 |
-
imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear")
|
| 326 |
-
|
| 327 |
-
# add poisson noise
|
| 328 |
-
ipoisson = np.random.rand(len(lbl)) < poisson
|
| 329 |
-
if ipoisson.sum() > 0:
|
| 330 |
-
if pscale is None:
|
| 331 |
-
pscale = torch.zeros(len(lbl))
|
| 332 |
-
m = torch.distributions.gamma.Gamma(alpha, beta)
|
| 333 |
-
pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.)
|
| 334 |
-
#pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5)
|
| 335 |
-
pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
|
| 336 |
-
else:
|
| 337 |
-
pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device)
|
| 338 |
-
imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson])
|
| 339 |
-
imgi[~ipoisson] = imgi[~ipoisson]
|
| 340 |
-
|
| 341 |
-
# renormalize
|
| 342 |
-
imgi = img_norm(imgi)
|
| 343 |
-
|
| 344 |
-
return imgi
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
|
| 348 |
-
downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
|
| 349 |
-
ds_max=7, uniform_blur=False, iso=True, rotate=True,
|
| 350 |
-
device=torch.device("cuda"), xy=(224, 224),
|
| 351 |
-
nchan_noise=1, keep_raw=True):
|
| 352 |
-
"""
|
| 353 |
-
Applies random rotation, resizing, and noise to the input data.
|
| 354 |
-
|
| 355 |
-
Args:
|
| 356 |
-
data (numpy.ndarray): The input data.
|
| 357 |
-
labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None.
|
| 358 |
-
diams (float, optional): The diameter of the objects. Defaults to None.
|
| 359 |
-
poisson (float, optional): The Poisson noise probability. Defaults to 0.7.
|
| 360 |
-
blur (float, optional): The blur probability. Defaults to 0.7.
|
| 361 |
-
downsample (float, optional): The downsample probability. Defaults to 0.0.
|
| 362 |
-
beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7.
|
| 363 |
-
gblur (float, optional): The Gaussian blur level. Defaults to 1.0.
|
| 364 |
-
diam_mean (float, optional): The mean diameter. Defaults to 30.
|
| 365 |
-
ds_max (int, optional): The maximum downsample value. Defaults to 7.
|
| 366 |
-
iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True.
|
| 367 |
-
rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True.
|
| 368 |
-
device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
|
| 369 |
-
xy (tuple, optional): The size of the output image. Defaults to (224, 224).
|
| 370 |
-
nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1.
|
| 371 |
-
keep_raw (bool, optional): Whether to keep the raw image. Defaults to True.
|
| 372 |
-
|
| 373 |
-
Returns:
|
| 374 |
-
torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image.
|
| 375 |
-
torch.Tensor: The augmented labels.
|
| 376 |
-
float: The scale factor applied to the image.
|
| 377 |
-
"""
|
| 378 |
-
if device == None:
|
| 379 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 380 |
-
|
| 381 |
-
diams = 30 if diams is None else diams
|
| 382 |
-
random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
|
| 383 |
-
random_rsc = diams / random_diam #/ random_diam
|
| 384 |
-
#rsc /= random_scale
|
| 385 |
-
xy0 = (340, 340)
|
| 386 |
-
nchan = data[0].shape[0]
|
| 387 |
-
data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32")
|
| 388 |
-
labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32")
|
| 389 |
-
for i in range(
|
| 390 |
-
len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)):
|
| 391 |
-
sc = random_rsc[i]
|
| 392 |
-
img = data[i]
|
| 393 |
-
lbl = labels[i] if labels is not None else None
|
| 394 |
-
# create affine transform to resize
|
| 395 |
-
Ly, Lx = img.shape[-2:]
|
| 396 |
-
dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]]))
|
| 397 |
-
dxy = (np.random.rand(2,) - .5) * dxy
|
| 398 |
-
cc = np.array([Lx / 2, Ly / 2])
|
| 399 |
-
cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy
|
| 400 |
-
pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
|
| 401 |
-
pts2 = np.float32(
|
| 402 |
-
[cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc])
|
| 403 |
-
M = cv2.getAffineTransform(pts1, pts2)
|
| 404 |
-
|
| 405 |
-
# apply to image
|
| 406 |
-
for c in range(nchan):
|
| 407 |
-
img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR)
|
| 408 |
-
#img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0)
|
| 409 |
-
data_new[i, c] = img_rsz
|
| 410 |
-
if keep_raw:
|
| 411 |
-
data_new[i, c + nchan] = img_rsz
|
| 412 |
-
|
| 413 |
-
if lbl is not None:
|
| 414 |
-
# apply to labels
|
| 415 |
-
labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST)
|
| 416 |
-
labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR)
|
| 417 |
-
labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR)
|
| 418 |
-
|
| 419 |
-
rsc = random_diam / diam_mean
|
| 420 |
-
|
| 421 |
-
# add noise before augmentations
|
| 422 |
-
img = torch.from_numpy(data_new).to(device)
|
| 423 |
-
img = torch.clamp(img, 0.)
|
| 424 |
-
# just add noise to cyto if nchan_noise=1
|
| 425 |
-
img[:, :nchan_noise] = add_noise(
|
| 426 |
-
img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso,
|
| 427 |
-
downsample=downsample, beta=beta, gblur=gblur,
|
| 428 |
-
diams=torch.from_numpy(random_diam).to(device).float())
|
| 429 |
-
# img -= img.mean(dim=(-2,-1), keepdim=True)
|
| 430 |
-
# img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3
|
| 431 |
-
img = img.cpu().numpy()
|
| 432 |
-
|
| 433 |
-
# augmentations
|
| 434 |
-
img, lbl, scale = transforms.random_rotate_and_resize(
|
| 435 |
-
img,
|
| 436 |
-
Y=labels_new,
|
| 437 |
-
xy=xy,
|
| 438 |
-
rotate=False if not iso else rotate,
|
| 439 |
-
#(iso and downsample==0),
|
| 440 |
-
rescale=rsc,
|
| 441 |
-
scale_range=0.5)
|
| 442 |
-
img = torch.from_numpy(img).to(device)
|
| 443 |
-
lbl = torch.from_numpy(lbl).to(device)
|
| 444 |
-
|
| 445 |
-
return img, lbl, scale
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
|
| 449 |
-
"""
|
| 450 |
-
Creates a Cellpose network with a single input channel.
|
| 451 |
-
|
| 452 |
-
Args:
|
| 453 |
-
device (str): The device to run the network on.
|
| 454 |
-
model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2".
|
| 455 |
-
pretrained_model (str, optional): The path to a pretrained model file. Defaults to None.
|
| 456 |
-
|
| 457 |
-
Returns:
|
| 458 |
-
torch.nn.Module: The Cellpose network with a single input channel.
|
| 459 |
-
"""
|
| 460 |
-
if pretrained_model is not None and not os.path.exists(pretrained_model):
|
| 461 |
-
model_type = pretrained_model
|
| 462 |
-
pretrained_model = None
|
| 463 |
-
nbase = [32, 64, 128, 256]
|
| 464 |
-
nchan = 1
|
| 465 |
-
net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
|
| 466 |
-
filename = model_path(model_type,
|
| 467 |
-
0) if pretrained_model is None else pretrained_model
|
| 468 |
-
weights = torch.load(filename, weights_only=True)
|
| 469 |
-
zp = 0
|
| 470 |
-
print(filename)
|
| 471 |
-
for name in net1.state_dict():
|
| 472 |
-
if ("res_down_0.conv.conv_0" not in name and
|
| 473 |
-
#"output" not in name and
|
| 474 |
-
"res_down_0.proj" not in name and name != "diam_mean" and
|
| 475 |
-
name != "diam_labels"):
|
| 476 |
-
net1.state_dict()[name].copy_(weights[name])
|
| 477 |
-
elif "res_down_0" in name:
|
| 478 |
-
if len(weights[name].shape) > 0:
|
| 479 |
-
new_weight = torch.zeros_like(net1.state_dict()[name])
|
| 480 |
-
if weights[name].shape[0] == 2:
|
| 481 |
-
new_weight[:] = weights[name][0]
|
| 482 |
-
elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2:
|
| 483 |
-
new_weight[:, zp] = weights[name][:, 0]
|
| 484 |
-
else:
|
| 485 |
-
new_weight = weights[name]
|
| 486 |
-
else:
|
| 487 |
-
new_weight = weights[name]
|
| 488 |
-
net1.state_dict()[name].copy_(new_weight)
|
| 489 |
-
return net1
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
class CellposeDenoiseModel():
|
| 493 |
-
""" model to run Cellpose and Image restoration """
|
| 494 |
-
|
| 495 |
-
def __init__(self, gpu=False, pretrained_model=False, model_type=None,
|
| 496 |
-
restore_type="denoise_cyto3", nchan=2,
|
| 497 |
-
chan2_restore=False, device=None):
|
| 498 |
-
|
| 499 |
-
self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
|
| 500 |
-
device=device)
|
| 501 |
-
self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
|
| 502 |
-
pretrained_model=pretrained_model, device=device)
|
| 503 |
-
|
| 504 |
-
def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 505 |
-
normalize=True, rescale=None, diameter=None, tile_overlap=0.1,
|
| 506 |
-
augment=False, resample=True, invert=False, flow_threshold=0.4,
|
| 507 |
-
cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
|
| 508 |
-
min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0):
|
| 509 |
-
"""
|
| 510 |
-
Restore array or list of images using the image restoration model, and then segment.
|
| 511 |
-
|
| 512 |
-
Args:
|
| 513 |
-
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 514 |
-
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 515 |
-
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 516 |
-
channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
|
| 517 |
-
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
|
| 518 |
-
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
|
| 519 |
-
For instance, to segment grayscale images, input [0,0]. To segment images with cells
|
| 520 |
-
in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
|
| 521 |
-
image with cells in green and nuclei in blue, input [[0,0], [2,3]].
|
| 522 |
-
Defaults to None.
|
| 523 |
-
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
|
| 524 |
-
if None, channels dimension is attempted to be automatically determined. Defaults to None.
|
| 525 |
-
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
|
| 526 |
-
if None, z dimension is attempted to be automatically determined. Defaults to None.
|
| 527 |
-
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 528 |
-
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 529 |
-
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 530 |
-
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 531 |
-
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 532 |
-
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 533 |
-
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 534 |
-
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 535 |
-
Defaults to True.
|
| 536 |
-
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 537 |
-
(only used if diameter is None). Defaults to None.
|
| 538 |
-
diameter (float, optional): diameter for each image,
|
| 539 |
-
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
|
| 540 |
-
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 541 |
-
augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
|
| 542 |
-
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
|
| 543 |
-
invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
|
| 544 |
-
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
|
| 545 |
-
cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
|
| 546 |
-
do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
|
| 547 |
-
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
|
| 548 |
-
stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
|
| 549 |
-
min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
|
| 550 |
-
flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
|
| 551 |
-
niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
|
| 552 |
-
interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
|
| 553 |
-
|
| 554 |
-
Returns:
|
| 555 |
-
A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
|
| 556 |
-
flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
|
| 557 |
-
styles: style vector summarizing each image of size 256;
|
| 558 |
-
imgs: Restored images.
|
| 559 |
-
"""
|
| 560 |
-
|
| 561 |
-
if isinstance(normalize, dict):
|
| 562 |
-
normalize_params = {**normalize_default, **normalize}
|
| 563 |
-
elif not isinstance(normalize, bool):
|
| 564 |
-
raise ValueError("normalize parameter must be a bool or a dict")
|
| 565 |
-
else:
|
| 566 |
-
normalize_params = normalize_default
|
| 567 |
-
normalize_params["normalize"] = normalize
|
| 568 |
-
normalize_params["invert"] = invert
|
| 569 |
-
|
| 570 |
-
img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels,
|
| 571 |
-
channel_axis=channel_axis, z_axis=z_axis,
|
| 572 |
-
do_3D=do_3D,
|
| 573 |
-
normalize=normalize_params, rescale=rescale,
|
| 574 |
-
diameter=diameter,
|
| 575 |
-
tile_overlap=tile_overlap, bsize=bsize)
|
| 576 |
-
|
| 577 |
-
# turn off special normalization for segmentation
|
| 578 |
-
normalize_params = normalize_default
|
| 579 |
-
|
| 580 |
-
# change channels for segmentation
|
| 581 |
-
if channels is not None:
|
| 582 |
-
channels_new = [0, 0] if channels[0] == 0 else [1, 2]
|
| 583 |
-
else:
|
| 584 |
-
channels_new = None
|
| 585 |
-
# change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
|
| 586 |
-
diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
|
| 587 |
-
masks, flows, styles = self.cp.eval(
|
| 588 |
-
img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
|
| 589 |
-
z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None,
|
| 590 |
-
normalize=normalize_params, rescale=rescale, diameter=diameter,
|
| 591 |
-
tile_overlap=tile_overlap, augment=augment, resample=resample,
|
| 592 |
-
invert=invert, flow_threshold=flow_threshold,
|
| 593 |
-
cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy,
|
| 594 |
-
stitch_threshold=stitch_threshold, min_size=min_size, niter=niter,
|
| 595 |
-
interp=interp, bsize=bsize)
|
| 596 |
-
|
| 597 |
-
return masks, flows, styles, img_restore
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
class DenoiseModel():
|
| 601 |
-
"""
|
| 602 |
-
DenoiseModel class for denoising images using Cellpose denoising model.
|
| 603 |
-
|
| 604 |
-
Args:
|
| 605 |
-
gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
|
| 606 |
-
pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising.
|
| 607 |
-
Can be a string or path. Defaults to False.
|
| 608 |
-
nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1.
|
| 609 |
-
model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None.
|
| 610 |
-
chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False.
|
| 611 |
-
diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0.
|
| 612 |
-
device (torch.device, optional): Device to use for computation. Defaults to None.
|
| 613 |
-
|
| 614 |
-
Attributes:
|
| 615 |
-
nchan (int): Number of channels in the input images.
|
| 616 |
-
diam_mean (float): Mean diameter of the objects in the images.
|
| 617 |
-
net (CPnet): Cellpose network for denoising.
|
| 618 |
-
pretrained_model (bool or str or Path): Pretrained model path to use for denoising.
|
| 619 |
-
net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable.
|
| 620 |
-
net_type (str): Type of the denoising network.
|
| 621 |
-
|
| 622 |
-
Methods:
|
| 623 |
-
eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 624 |
-
normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1)
|
| 625 |
-
Denoise array or list of images using the denoising model.
|
| 626 |
-
|
| 627 |
-
_eval(net, x, normalize=True, rescale=None, diameter=None, tile=True,
|
| 628 |
-
tile_overlap=0.1)
|
| 629 |
-
Run denoising model on a single channel.
|
| 630 |
-
"""
|
| 631 |
-
|
| 632 |
-
def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
|
| 633 |
-
chan2=False, diam_mean=30., device=None):
|
| 634 |
-
self.nchan = nchan
|
| 635 |
-
if pretrained_model and (not isinstance(pretrained_model, str) and
|
| 636 |
-
not isinstance(pretrained_model, Path)):
|
| 637 |
-
raise ValueError("pretrained_model must be a string or path")
|
| 638 |
-
|
| 639 |
-
self.diam_mean = diam_mean
|
| 640 |
-
builtin = True
|
| 641 |
-
if model_type is not None or (pretrained_model and
|
| 642 |
-
not os.path.exists(pretrained_model)):
|
| 643 |
-
pretrained_model_string = model_type if model_type is not None else "denoise_cyto3"
|
| 644 |
-
if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
|
| 645 |
-
pretrained_model_string = "denoise_cyto3"
|
| 646 |
-
pretrained_model = model_path(pretrained_model_string)
|
| 647 |
-
if (pretrained_model and not os.path.exists(pretrained_model)):
|
| 648 |
-
denoise_logger.warning("pretrained model has incorrect path")
|
| 649 |
-
denoise_logger.info(f">> {pretrained_model_string} << model set to be used")
|
| 650 |
-
self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30.
|
| 651 |
-
else:
|
| 652 |
-
if pretrained_model:
|
| 653 |
-
builtin = False
|
| 654 |
-
pretrained_model_string = pretrained_model
|
| 655 |
-
denoise_logger.info(f">>>> loading model {pretrained_model_string}")
|
| 656 |
-
|
| 657 |
-
# assign network device
|
| 658 |
-
if device is None:
|
| 659 |
-
sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
|
| 660 |
-
self.device = device if device is not None else sdevice
|
| 661 |
-
if device is not None:
|
| 662 |
-
device_gpu = self.device.type == "cuda"
|
| 663 |
-
self.gpu = gpu if device is None else device_gpu
|
| 664 |
-
|
| 665 |
-
# create network
|
| 666 |
-
self.nchan = nchan
|
| 667 |
-
self.nclasses = 1
|
| 668 |
-
nbase = [32, 64, 128, 256]
|
| 669 |
-
self.nchan = nchan
|
| 670 |
-
self.nbase = [nchan, *nbase]
|
| 671 |
-
|
| 672 |
-
self.net = CPnet(self.nbase, self.nclasses, sz=3,
|
| 673 |
-
max_pool=True, diam_mean=diam_mean).to(self.device)
|
| 674 |
-
|
| 675 |
-
self.pretrained_model = pretrained_model
|
| 676 |
-
self.net_chan2 = None
|
| 677 |
-
if self.pretrained_model:
|
| 678 |
-
self.net.load_model(self.pretrained_model, device=self.device)
|
| 679 |
-
denoise_logger.info(
|
| 680 |
-
f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
|
| 681 |
-
)
|
| 682 |
-
if chan2 and builtin:
|
| 683 |
-
chan2_path = model_path(
|
| 684 |
-
os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
|
| 685 |
-
print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
|
| 686 |
-
self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
|
| 687 |
-
max_pool=True,
|
| 688 |
-
diam_mean=17.).to(self.device)
|
| 689 |
-
self.net_chan2.load_model(chan2_path, device=self.device)
|
| 690 |
-
self.net_type = "cellpose_denoise"
|
| 691 |
-
|
| 692 |
-
def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 693 |
-
normalize=True, rescale=None, diameter=None, tile=True, do_3D=False,
|
| 694 |
-
tile_overlap=0.1, bsize=224):
|
| 695 |
-
"""
|
| 696 |
-
Restore array or list of images using the image restoration model.
|
| 697 |
-
|
| 698 |
-
Args:
|
| 699 |
-
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 700 |
-
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 701 |
-
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 702 |
-
channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
|
| 703 |
-
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
|
| 704 |
-
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
|
| 705 |
-
For instance, to segment grayscale images, input [0,0]. To segment images with cells
|
| 706 |
-
in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
|
| 707 |
-
image with cells in green and nuclei in blue, input [[0,0], [2,3]].
|
| 708 |
-
Defaults to None.
|
| 709 |
-
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
|
| 710 |
-
if None, channels dimension is attempted to be automatically determined. Defaults to None.
|
| 711 |
-
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
|
| 712 |
-
if None, z dimension is attempted to be automatically determined. Defaults to None.
|
| 713 |
-
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 714 |
-
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 715 |
-
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 716 |
-
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 717 |
-
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 718 |
-
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 719 |
-
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 720 |
-
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 721 |
-
Defaults to True.
|
| 722 |
-
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 723 |
-
(only used if diameter is None). Defaults to None.
|
| 724 |
-
diameter (float, optional): diameter for each image,
|
| 725 |
-
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
|
| 726 |
-
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 727 |
-
|
| 728 |
-
Returns:
|
| 729 |
-
list: A list of 2D/3D arrays of restored images
|
| 730 |
-
|
| 731 |
-
"""
|
| 732 |
-
if isinstance(x, list) or x.squeeze().ndim == 5:
|
| 733 |
-
tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO)
|
| 734 |
-
nimg = len(x)
|
| 735 |
-
iterator = trange(nimg, file=tqdm_out,
|
| 736 |
-
mininterval=30) if nimg > 1 else range(nimg)
|
| 737 |
-
imgs = []
|
| 738 |
-
for i in iterator:
|
| 739 |
-
imgi = self.eval(
|
| 740 |
-
x[i], batch_size=batch_size,
|
| 741 |
-
channels=channels[i] if channels is not None and
|
| 742 |
-
((len(channels) == len(x) and
|
| 743 |
-
(isinstance(channels[i], list) or
|
| 744 |
-
isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
|
| 745 |
-
else channels, channel_axis=channel_axis, z_axis=z_axis,
|
| 746 |
-
normalize=normalize,
|
| 747 |
-
do_3D=do_3D,
|
| 748 |
-
rescale=rescale[i] if isinstance(rescale, list) or
|
| 749 |
-
isinstance(rescale, np.ndarray) else rescale,
|
| 750 |
-
diameter=diameter[i] if isinstance(diameter, list) or
|
| 751 |
-
isinstance(diameter, np.ndarray) else diameter,
|
| 752 |
-
tile_overlap=tile_overlap, bsize=bsize)
|
| 753 |
-
imgs.append(imgi)
|
| 754 |
-
if isinstance(x, np.ndarray):
|
| 755 |
-
imgs = np.array(imgs)
|
| 756 |
-
return imgs
|
| 757 |
-
|
| 758 |
-
else:
|
| 759 |
-
# reshape image
|
| 760 |
-
x = transforms.convert_image(x, channels, channel_axis=channel_axis,
|
| 761 |
-
z_axis=z_axis, do_3D=do_3D, nchan=None)
|
| 762 |
-
if x.ndim < 4:
|
| 763 |
-
squeeze = True
|
| 764 |
-
x = x[np.newaxis, ...]
|
| 765 |
-
else:
|
| 766 |
-
squeeze = False
|
| 767 |
-
|
| 768 |
-
# may need to interpolate image before running upsampling
|
| 769 |
-
self.ratio = 1.
|
| 770 |
-
if "upsample" in self.pretrained_model:
|
| 771 |
-
Ly, Lx = x.shape[-3:-1]
|
| 772 |
-
if diameter is not None and 3 <= diameter < self.diam_mean:
|
| 773 |
-
self.ratio = self.diam_mean / diameter
|
| 774 |
-
denoise_logger.info(
|
| 775 |
-
f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)"
|
| 776 |
-
)
|
| 777 |
-
Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio)
|
| 778 |
-
x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr)
|
| 779 |
-
else:
|
| 780 |
-
denoise_logger.warning(
|
| 781 |
-
f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}"
|
| 782 |
-
)
|
| 783 |
-
#raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}")
|
| 784 |
-
|
| 785 |
-
self.batch_size = batch_size
|
| 786 |
-
|
| 787 |
-
if diameter is not None and diameter > 0:
|
| 788 |
-
rescale = self.diam_mean / diameter
|
| 789 |
-
elif rescale is None:
|
| 790 |
-
rescale = 1.0
|
| 791 |
-
|
| 792 |
-
if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
|
| 793 |
-
x = x[..., :1]
|
| 794 |
-
|
| 795 |
-
for c in range(x.shape[-1]):
|
| 796 |
-
rescale0 = rescale * 30. / 17. if c == 1 else rescale
|
| 797 |
-
if c == 0 or self.net_chan2 is None:
|
| 798 |
-
x[...,
|
| 799 |
-
c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size,
|
| 800 |
-
normalize=normalize, rescale=rescale0,
|
| 801 |
-
tile_overlap=tile_overlap, bsize=bsize)[...,0]
|
| 802 |
-
else:
|
| 803 |
-
x[...,
|
| 804 |
-
c] = self._eval(self.net_chan2, x[...,
|
| 805 |
-
c:c + 1], batch_size=batch_size,
|
| 806 |
-
normalize=normalize, rescale=rescale0,
|
| 807 |
-
tile_overlap=tile_overlap, bsize=bsize)[...,0]
|
| 808 |
-
x = x[0] if squeeze else x
|
| 809 |
-
return x
|
| 810 |
-
|
| 811 |
-
def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
|
| 812 |
-
tile_overlap=0.1, bsize=224):
|
| 813 |
-
"""
|
| 814 |
-
Run image restoration model on a single channel.
|
| 815 |
-
|
| 816 |
-
Args:
|
| 817 |
-
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 818 |
-
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 819 |
-
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 820 |
-
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 821 |
-
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 822 |
-
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 823 |
-
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 824 |
-
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 825 |
-
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 826 |
-
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 827 |
-
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 828 |
-
Defaults to True.
|
| 829 |
-
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 830 |
-
(only used if diameter is None). Defaults to None.
|
| 831 |
-
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 832 |
-
|
| 833 |
-
Returns:
|
| 834 |
-
list: A list of 2D/3D arrays of restored images
|
| 835 |
-
|
| 836 |
-
"""
|
| 837 |
-
if isinstance(normalize, dict):
|
| 838 |
-
normalize_params = {**normalize_default, **normalize}
|
| 839 |
-
elif not isinstance(normalize, bool):
|
| 840 |
-
raise ValueError("normalize parameter must be a bool or a dict")
|
| 841 |
-
else:
|
| 842 |
-
normalize_params = normalize_default
|
| 843 |
-
normalize_params["normalize"] = normalize
|
| 844 |
-
|
| 845 |
-
tic = time.time()
|
| 846 |
-
shape = x.shape
|
| 847 |
-
nimg = shape[0]
|
| 848 |
-
|
| 849 |
-
do_normalization = True if normalize_params["normalize"] else False
|
| 850 |
-
|
| 851 |
-
img = np.asarray(x)
|
| 852 |
-
if do_normalization:
|
| 853 |
-
img = transforms.normalize_img(img, **normalize_params)
|
| 854 |
-
if rescale != 1.0:
|
| 855 |
-
img = transforms.resize_image(img, rsz=rescale)
|
| 856 |
-
yf, style = run_net(self.net, img, bsize=bsize,
|
| 857 |
-
tile_overlap=tile_overlap)
|
| 858 |
-
yf = transforms.resize_image(yf, shape[1], shape[2])
|
| 859 |
-
imgs = yf
|
| 860 |
-
del yf, style
|
| 861 |
-
|
| 862 |
-
# imgs = np.zeros((*x.shape[:-1], 1), np.float32)
|
| 863 |
-
# for i in iterator:
|
| 864 |
-
# img = np.asarray(x[i])
|
| 865 |
-
# if do_normalization:
|
| 866 |
-
# img = transforms.normalize_img(img, **normalize_params)
|
| 867 |
-
# if rescale != 1.0:
|
| 868 |
-
# img = transforms.resize_image(img, rsz=[rescale, rescale])
|
| 869 |
-
# if img.ndim == 2:
|
| 870 |
-
# img = img[:, :, np.newaxis]
|
| 871 |
-
# yf, style = run_net(net, img, batch_size=batch_size, augment=False,
|
| 872 |
-
# tile=tile, tile_overlap=tile_overlap, bsize=bsize)
|
| 873 |
-
# img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2])
|
| 874 |
-
|
| 875 |
-
# if img.ndim == 2:
|
| 876 |
-
# img = img[:, :, np.newaxis]
|
| 877 |
-
# imgs[i] = img
|
| 878 |
-
# del yf, style
|
| 879 |
-
net_time = time.time() - tic
|
| 880 |
-
if nimg > 1:
|
| 881 |
-
denoise_logger.info("imgs denoised in %2.2fs" % (net_time))
|
| 882 |
-
|
| 883 |
-
return imgs
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
def train(net, train_data=None, train_labels=None, train_files=None, test_data=None,
|
| 887 |
-
test_labels=None, test_files=None, train_probs=None, test_probs=None,
|
| 888 |
-
lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None,
|
| 889 |
-
save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0,
|
| 890 |
-
iso=True, uniform_blur=False, downsample=0., ds_max=7,
|
| 891 |
-
learning_rate=0.005, n_epochs=500,
|
| 892 |
-
weight_decay=0.00001, batch_size=8, nimg_per_epoch=None,
|
| 893 |
-
nimg_test_per_epoch=None, model_name=None):
|
| 894 |
-
|
| 895 |
-
# net properties
|
| 896 |
-
device = net.device
|
| 897 |
-
nchan = net.nchan
|
| 898 |
-
diam_mean = net.diam_mean.item()
|
| 899 |
-
|
| 900 |
-
args = np.array([poisson, beta, blur, gblur, downsample])
|
| 901 |
-
if args.ndim == 1:
|
| 902 |
-
args = args[:, np.newaxis]
|
| 903 |
-
poisson, beta, blur, gblur, downsample = args
|
| 904 |
-
nnoise = len(poisson)
|
| 905 |
-
|
| 906 |
-
d = datetime.datetime.now()
|
| 907 |
-
if save_path is not None:
|
| 908 |
-
if model_name is None:
|
| 909 |
-
filename = ""
|
| 910 |
-
lstrs = ["per", "seg", "rec"]
|
| 911 |
-
for k, (l, s) in enumerate(zip(lam, lstrs)):
|
| 912 |
-
filename += f"{s}_{l:.2f}_"
|
| 913 |
-
if not iso:
|
| 914 |
-
filename += "aniso_"
|
| 915 |
-
if poisson.sum() > 0:
|
| 916 |
-
filename += "poisson_"
|
| 917 |
-
if blur.sum() > 0:
|
| 918 |
-
filename += "blur_"
|
| 919 |
-
if downsample.sum() > 0:
|
| 920 |
-
filename += "downsample_"
|
| 921 |
-
filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f")
|
| 922 |
-
filename = os.path.join(save_path, filename)
|
| 923 |
-
else:
|
| 924 |
-
filename = os.path.join(save_path, model_name)
|
| 925 |
-
print(filename)
|
| 926 |
-
for i in range(len(poisson)):
|
| 927 |
-
denoise_logger.info(
|
| 928 |
-
f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}"
|
| 929 |
-
)
|
| 930 |
-
net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type)
|
| 931 |
-
|
| 932 |
-
learning_rate_const = learning_rate
|
| 933 |
-
LR = np.linspace(0, learning_rate_const, 10)
|
| 934 |
-
LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100))
|
| 935 |
-
for i in range(10):
|
| 936 |
-
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
|
| 937 |
-
learning_rate = LR
|
| 938 |
-
|
| 939 |
-
batch_size = 8
|
| 940 |
-
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0],
|
| 941 |
-
weight_decay=weight_decay)
|
| 942 |
-
if train_data is not None:
|
| 943 |
-
nimg = len(train_data)
|
| 944 |
-
diam_train = np.array(
|
| 945 |
-
[utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))])
|
| 946 |
-
diam_train[diam_train < 5] = 5.
|
| 947 |
-
if test_data is not None:
|
| 948 |
-
diam_test = np.array(
|
| 949 |
-
[utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))])
|
| 950 |
-
diam_test[diam_test < 5] = 5.
|
| 951 |
-
nimg_test = len(test_data)
|
| 952 |
-
else:
|
| 953 |
-
nimg = len(train_files)
|
| 954 |
-
denoise_logger.info(">>> using files instead of loading dataset")
|
| 955 |
-
train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files]
|
| 956 |
-
denoise_logger.info(">>> computing diameters")
|
| 957 |
-
diam_train = np.array([
|
| 958 |
-
utils.diameters(io.imread(train_labels_files[k])[0])[0]
|
| 959 |
-
for k in trange(len(train_labels_files))
|
| 960 |
-
])
|
| 961 |
-
diam_train[diam_train < 5] = 5.
|
| 962 |
-
if test_files is not None:
|
| 963 |
-
nimg_test = len(test_files)
|
| 964 |
-
test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files]
|
| 965 |
-
diam_test = np.array([
|
| 966 |
-
utils.diameters(io.imread(test_labels_files[k])[0])[0]
|
| 967 |
-
for k in trange(len(test_labels_files))
|
| 968 |
-
])
|
| 969 |
-
diam_test[diam_test < 5] = 5.
|
| 970 |
-
train_probs = 1. / nimg * np.ones(nimg,
|
| 971 |
-
"float64") if train_probs is None else train_probs
|
| 972 |
-
if test_files is not None or test_data is not None:
|
| 973 |
-
test_probs = 1. / nimg_test * np.ones(
|
| 974 |
-
nimg_test, "float64") if test_probs is None else test_probs
|
| 975 |
-
|
| 976 |
-
tic = time.time()
|
| 977 |
-
|
| 978 |
-
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
|
| 979 |
-
if test_files is not None or test_data is not None:
|
| 980 |
-
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
|
| 981 |
-
|
| 982 |
-
nbatch = 0
|
| 983 |
-
train_losses, test_losses = [], []
|
| 984 |
-
for iepoch in range(n_epochs):
|
| 985 |
-
np.random.seed(iepoch)
|
| 986 |
-
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
|
| 987 |
-
p=train_probs)
|
| 988 |
-
torch.manual_seed(iepoch)
|
| 989 |
-
np.random.seed(iepoch)
|
| 990 |
-
for param_group in optimizer.param_groups:
|
| 991 |
-
param_group["lr"] = learning_rate[iepoch]
|
| 992 |
-
lavg, lavg_per, nsum = 0, 0, 0
|
| 993 |
-
for ibatch in range(0, nimg_per_epoch, batch_size * nnoise):
|
| 994 |
-
inds = rperm[ibatch : ibatch + batch_size * nnoise]
|
| 995 |
-
if train_data is None:
|
| 996 |
-
imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds]
|
| 997 |
-
lbls = [io.imread(train_labels_files[i])[1:] for i in inds]
|
| 998 |
-
else:
|
| 999 |
-
imgs = [train_data[i][:nchan] for i in inds]
|
| 1000 |
-
lbls = [train_labels[i][1:] for i in inds]
|
| 1001 |
-
#inoise = nbatch % nnoise
|
| 1002 |
-
rnoise = np.random.permutation(nnoise)
|
| 1003 |
-
for i, inoise in enumerate(rnoise):
|
| 1004 |
-
if i * batch_size < len(imgs):
|
| 1005 |
-
imgi, lbli, scale = random_rotate_and_resize_noise(
|
| 1006 |
-
imgs[i * batch_size : (i + 1) * batch_size],
|
| 1007 |
-
lbls[i * batch_size : (i + 1) * batch_size],
|
| 1008 |
-
diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(),
|
| 1009 |
-
poisson=poisson[inoise],
|
| 1010 |
-
beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso,
|
| 1011 |
-
downsample=downsample[inoise], uniform_blur=uniform_blur,
|
| 1012 |
-
diam_mean=diam_mean, ds_max=ds_max,
|
| 1013 |
-
device=device)
|
| 1014 |
-
if i == 0:
|
| 1015 |
-
img = imgi
|
| 1016 |
-
lbl = lbli
|
| 1017 |
-
else:
|
| 1018 |
-
img = torch.cat((img, imgi), axis=0)
|
| 1019 |
-
lbl = torch.cat((lbl, lbli), axis=0)
|
| 1020 |
-
|
| 1021 |
-
if nnoise > 0:
|
| 1022 |
-
iperm = np.random.permutation(img.shape[0])
|
| 1023 |
-
img, lbl = img[iperm], lbl[iperm]
|
| 1024 |
-
|
| 1025 |
-
for i in range(nnoise):
|
| 1026 |
-
optimizer.zero_grad()
|
| 1027 |
-
imgi = img[i * batch_size: (i + 1) * batch_size]
|
| 1028 |
-
lbli = lbl[i * batch_size: (i + 1) * batch_size]
|
| 1029 |
-
if imgi.shape[0] > 0:
|
| 1030 |
-
loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1,
|
| 1031 |
-
img=imgi[:, nchan:], lbl=lbli, lam=lam)
|
| 1032 |
-
loss.backward()
|
| 1033 |
-
optimizer.step()
|
| 1034 |
-
lavg += loss.item() * imgi.shape[0]
|
| 1035 |
-
lavg_per += loss_per.item() * imgi.shape[0]
|
| 1036 |
-
|
| 1037 |
-
nsum += len(img)
|
| 1038 |
-
nbatch += 1
|
| 1039 |
-
|
| 1040 |
-
if iepoch % 5 == 0 or iepoch < 10:
|
| 1041 |
-
lavg = lavg / nsum
|
| 1042 |
-
lavg_per = lavg_per / nsum
|
| 1043 |
-
if test_data is not None or test_files is not None:
|
| 1044 |
-
lavgt, nsum = 0., 0
|
| 1045 |
-
np.random.seed(42)
|
| 1046 |
-
rperm = np.random.choice(np.arange(0, nimg_test),
|
| 1047 |
-
size=(nimg_test_per_epoch,), p=test_probs)
|
| 1048 |
-
inoise = iepoch % nnoise
|
| 1049 |
-
torch.manual_seed(inoise)
|
| 1050 |
-
for ibatch in range(0, nimg_test_per_epoch, batch_size):
|
| 1051 |
-
inds = rperm[ibatch:ibatch + batch_size]
|
| 1052 |
-
if test_data is None:
|
| 1053 |
-
imgs = [
|
| 1054 |
-
np.maximum(0,
|
| 1055 |
-
io.imread(test_files[i])[:nchan]) for i in inds
|
| 1056 |
-
]
|
| 1057 |
-
lbls = [io.imread(test_labels_files[i])[1:] for i in inds]
|
| 1058 |
-
else:
|
| 1059 |
-
imgs = [test_data[i][:nchan] for i in inds]
|
| 1060 |
-
lbls = [test_labels[i][1:] for i in inds]
|
| 1061 |
-
img, lbl, scale = random_rotate_and_resize_noise(
|
| 1062 |
-
imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise],
|
| 1063 |
-
beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise],
|
| 1064 |
-
iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur,
|
| 1065 |
-
diam_mean=diam_mean, ds_max=ds_max, device=device)
|
| 1066 |
-
loss, loss_per = test_loss(net, img[:, :nchan], net1=net1,
|
| 1067 |
-
img=img[:, nchan:], lbl=lbl, lam=lam)
|
| 1068 |
-
|
| 1069 |
-
lavgt += loss.item() * img.shape[0]
|
| 1070 |
-
nsum += len(img)
|
| 1071 |
-
lavgt = lavgt / nsum
|
| 1072 |
-
denoise_logger.info(
|
| 1073 |
-
"Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f"
|
| 1074 |
-
% (iepoch, time.time() - tic, lavg, lavg_per, lavgt,
|
| 1075 |
-
learning_rate[iepoch]))
|
| 1076 |
-
test_losses.append(lavgt)
|
| 1077 |
-
else:
|
| 1078 |
-
denoise_logger.info(
|
| 1079 |
-
"Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" %
|
| 1080 |
-
(iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch]))
|
| 1081 |
-
train_losses.append(lavg)
|
| 1082 |
-
|
| 1083 |
-
if save_path is not None:
|
| 1084 |
-
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
|
| 1085 |
-
if save_each: #separate files as model progresses
|
| 1086 |
-
filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
|
| 1087 |
-
else:
|
| 1088 |
-
filename0 = filename
|
| 1089 |
-
denoise_logger.info(f"saving network parameters to {filename0}")
|
| 1090 |
-
net.save_model(filename0)
|
| 1091 |
-
else:
|
| 1092 |
-
filename = save_path
|
| 1093 |
-
|
| 1094 |
-
return filename, train_losses, test_losses
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
if __name__ == "__main__":
|
| 1098 |
-
import argparse
|
| 1099 |
-
parser = argparse.ArgumentParser(description="cellpose parameters")
|
| 1100 |
-
|
| 1101 |
-
input_img_args = parser.add_argument_group("input image arguments")
|
| 1102 |
-
input_img_args.add_argument("--dir", default=[], type=str,
|
| 1103 |
-
help="folder containing data to run or train on.")
|
| 1104 |
-
input_img_args.add_argument("--img_filter", default=[], type=str,
|
| 1105 |
-
help="end string for images to run on")
|
| 1106 |
-
|
| 1107 |
-
model_args = parser.add_argument_group("model arguments")
|
| 1108 |
-
model_args.add_argument("--pretrained_model", default=[], type=str,
|
| 1109 |
-
help="pretrained denoising model")
|
| 1110 |
-
|
| 1111 |
-
training_args = parser.add_argument_group("training arguments")
|
| 1112 |
-
training_args.add_argument("--test_dir", default=[], type=str,
|
| 1113 |
-
help="folder containing test data (optional)")
|
| 1114 |
-
training_args.add_argument("--file_list", default=[], type=str,
|
| 1115 |
-
help="npy file containing list of train and test files")
|
| 1116 |
-
training_args.add_argument("--seg_model_type", default="cyto2", type=str,
|
| 1117 |
-
help="model to use for seg training loss")
|
| 1118 |
-
training_args.add_argument(
|
| 1119 |
-
"--noise_type", default=[], type=str,
|
| 1120 |
-
help="noise type to use (if input, then other noise params are ignored)")
|
| 1121 |
-
training_args.add_argument("--poisson", default=0.8, type=float,
|
| 1122 |
-
help="fraction of images to add poisson noise to")
|
| 1123 |
-
training_args.add_argument("--beta", default=0.7, type=float,
|
| 1124 |
-
help="scale of poisson noise")
|
| 1125 |
-
training_args.add_argument("--blur", default=0., type=float,
|
| 1126 |
-
help="fraction of images to blur")
|
| 1127 |
-
training_args.add_argument("--gblur", default=1.0, type=float,
|
| 1128 |
-
help="scale of gaussian blurring stddev")
|
| 1129 |
-
training_args.add_argument("--downsample", default=0., type=float,
|
| 1130 |
-
help="fraction of images to downsample")
|
| 1131 |
-
training_args.add_argument("--ds_max", default=7, type=int,
|
| 1132 |
-
help="max downsampling factor")
|
| 1133 |
-
training_args.add_argument("--lam_per", default=1.0, type=float,
|
| 1134 |
-
help="weighting of perceptual loss")
|
| 1135 |
-
training_args.add_argument("--lam_seg", default=1.5, type=float,
|
| 1136 |
-
help="weighting of segmentation loss")
|
| 1137 |
-
training_args.add_argument("--lam_rec", default=0., type=float,
|
| 1138 |
-
help="weighting of reconstruction loss")
|
| 1139 |
-
training_args.add_argument(
|
| 1140 |
-
"--diam_mean", default=30., type=float, help=
|
| 1141 |
-
"mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
|
| 1142 |
-
)
|
| 1143 |
-
training_args.add_argument("--learning_rate", default=0.001, type=float,
|
| 1144 |
-
help="learning rate. Default: %(default)s")
|
| 1145 |
-
training_args.add_argument("--n_epochs", default=2000, type=int,
|
| 1146 |
-
help="number of epochs. Default: %(default)s")
|
| 1147 |
-
training_args.add_argument(
|
| 1148 |
-
"--save_each", default=False, action="store_true",
|
| 1149 |
-
help="save each epoch as separate model")
|
| 1150 |
-
training_args.add_argument(
|
| 1151 |
-
"--nimg_per_epoch", default=0, type=int,
|
| 1152 |
-
help="number of images per epoch. Default is length of training images")
|
| 1153 |
-
training_args.add_argument(
|
| 1154 |
-
"--nimg_test_per_epoch", default=0, type=int,
|
| 1155 |
-
help="number of test images per epoch. Default is length of testing images")
|
| 1156 |
-
|
| 1157 |
-
io.logger_setup()
|
| 1158 |
-
|
| 1159 |
-
args = parser.parse_args()
|
| 1160 |
-
lams = [args.lam_per, args.lam_seg, args.lam_rec]
|
| 1161 |
-
print("lam", lams)
|
| 1162 |
-
|
| 1163 |
-
if len(args.noise_type) > 0:
|
| 1164 |
-
noise_type = args.noise_type
|
| 1165 |
-
uniform_blur = False
|
| 1166 |
-
iso = True
|
| 1167 |
-
if noise_type == "poisson":
|
| 1168 |
-
poisson = 0.8
|
| 1169 |
-
blur = 0.
|
| 1170 |
-
downsample = 0.
|
| 1171 |
-
beta = 0.7
|
| 1172 |
-
gblur = 1.0
|
| 1173 |
-
elif noise_type == "blur_expr":
|
| 1174 |
-
poisson = 0.8
|
| 1175 |
-
blur = 0.8
|
| 1176 |
-
downsample = 0.
|
| 1177 |
-
beta = 0.1
|
| 1178 |
-
gblur = 0.5
|
| 1179 |
-
elif noise_type == "blur":
|
| 1180 |
-
poisson = 0.8
|
| 1181 |
-
blur = 0.8
|
| 1182 |
-
downsample = 0.
|
| 1183 |
-
beta = 0.1
|
| 1184 |
-
gblur = 10.0
|
| 1185 |
-
uniform_blur = True
|
| 1186 |
-
elif noise_type == "downsample_expr":
|
| 1187 |
-
poisson = 0.8
|
| 1188 |
-
blur = 0.8
|
| 1189 |
-
downsample = 0.8
|
| 1190 |
-
beta = 0.03
|
| 1191 |
-
gblur = 1.0
|
| 1192 |
-
elif noise_type == "downsample":
|
| 1193 |
-
poisson = 0.8
|
| 1194 |
-
blur = 0.8
|
| 1195 |
-
downsample = 0.8
|
| 1196 |
-
beta = 0.03
|
| 1197 |
-
gblur = 5.0
|
| 1198 |
-
uniform_blur = True
|
| 1199 |
-
elif noise_type == "all":
|
| 1200 |
-
poisson = [0.8, 0.8, 0.8]
|
| 1201 |
-
blur = [0., 0.8, 0.8]
|
| 1202 |
-
downsample = [0., 0., 0.8]
|
| 1203 |
-
beta = [0.7, 0.1, 0.03]
|
| 1204 |
-
gblur = [0., 10.0, 5.0]
|
| 1205 |
-
uniform_blur = True
|
| 1206 |
-
elif noise_type == "aniso":
|
| 1207 |
-
poisson = 0.8
|
| 1208 |
-
blur = 0.8
|
| 1209 |
-
downsample = 0.8
|
| 1210 |
-
beta = 0.1
|
| 1211 |
-
gblur = args.ds_max * 1.5
|
| 1212 |
-
iso = False
|
| 1213 |
-
else:
|
| 1214 |
-
raise ValueError(f"{noise_type} noise_type is not supported")
|
| 1215 |
-
else:
|
| 1216 |
-
poisson, beta = args.poisson, args.beta
|
| 1217 |
-
blur, gblur = args.blur, args.gblur
|
| 1218 |
-
downsample = args.downsample
|
| 1219 |
-
|
| 1220 |
-
pretrained_model = None if len(
|
| 1221 |
-
args.pretrained_model) == 0 else args.pretrained_model
|
| 1222 |
-
model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean,
|
| 1223 |
-
pretrained_model=pretrained_model)
|
| 1224 |
-
|
| 1225 |
-
train_data, labels, train_files, train_probs = None, None, None, None
|
| 1226 |
-
test_data, test_labels, test_files, test_probs = None, None, None, None
|
| 1227 |
-
if len(args.file_list) == 0:
|
| 1228 |
-
output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0)
|
| 1229 |
-
images, labels, image_names, test_images, test_labels, image_names_test = output
|
| 1230 |
-
train_data = []
|
| 1231 |
-
for i in range(len(images)):
|
| 1232 |
-
img = images[i].astype("float32")
|
| 1233 |
-
if img.ndim > 2:
|
| 1234 |
-
img = img[0]
|
| 1235 |
-
train_data.append(
|
| 1236 |
-
np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
|
| 1237 |
-
if len(args.test_dir) > 0:
|
| 1238 |
-
test_data = []
|
| 1239 |
-
for i in range(len(test_images)):
|
| 1240 |
-
img = test_images[i].astype("float32")
|
| 1241 |
-
if img.ndim > 2:
|
| 1242 |
-
img = img[0]
|
| 1243 |
-
test_data.append(
|
| 1244 |
-
np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
|
| 1245 |
-
save_path = os.path.join(args.dir, "../models/")
|
| 1246 |
-
else:
|
| 1247 |
-
root = args.dir
|
| 1248 |
-
denoise_logger.info(
|
| 1249 |
-
">>> using file_list (assumes images are normalized and have flows!)")
|
| 1250 |
-
dat = np.load(args.file_list, allow_pickle=True).item()
|
| 1251 |
-
train_files = dat["train_files"]
|
| 1252 |
-
test_files = dat["test_files"]
|
| 1253 |
-
train_probs = dat["train_probs"] if "train_probs" in dat else None
|
| 1254 |
-
test_probs = dat["test_probs"] if "test_probs" in dat else None
|
| 1255 |
-
if str(train_files[0])[:len(str(root))] != str(root):
|
| 1256 |
-
for i in range(len(train_files)):
|
| 1257 |
-
new_path = root / Path(*train_files[i].parts[-3:])
|
| 1258 |
-
if i == 0:
|
| 1259 |
-
print(f"changing path from {train_files[i]} to {new_path}")
|
| 1260 |
-
train_files[i] = new_path
|
| 1261 |
-
|
| 1262 |
-
for i in range(len(test_files)):
|
| 1263 |
-
new_path = root / Path(*test_files[i].parts[-3:])
|
| 1264 |
-
test_files[i] = new_path
|
| 1265 |
-
save_path = os.path.join(args.dir, "models/")
|
| 1266 |
-
|
| 1267 |
-
os.makedirs(save_path, exist_ok=True)
|
| 1268 |
-
|
| 1269 |
-
nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch
|
| 1270 |
-
nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch
|
| 1271 |
-
|
| 1272 |
-
model_path = train(
|
| 1273 |
-
model.net, train_data=train_data, train_labels=labels, train_files=train_files,
|
| 1274 |
-
test_data=test_data, test_labels=test_labels, test_files=test_files,
|
| 1275 |
-
train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta,
|
| 1276 |
-
blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max,
|
| 1277 |
-
iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs,
|
| 1278 |
-
learning_rate=args.learning_rate,
|
| 1279 |
-
lam=lams,
|
| 1280 |
-
seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch,
|
| 1281 |
-
nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path)
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None,
|
| 1285 |
-
poisson=0.8, blur=0.0, downsample=0.0, save_path=None,
|
| 1286 |
-
save_every=100, save_each=False, learning_rate=0.2, n_epochs=500,
|
| 1287 |
-
momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8,
|
| 1288 |
-
nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False,
|
| 1289 |
-
model_name=None):
|
| 1290 |
-
""" train function uses loss function model.loss_fn in models.py
|
| 1291 |
-
|
| 1292 |
-
(data should already be normalized)
|
| 1293 |
-
|
| 1294 |
-
"""
|
| 1295 |
-
|
| 1296 |
-
d = datetime.datetime.now()
|
| 1297 |
-
|
| 1298 |
-
model.n_epochs = n_epochs
|
| 1299 |
-
if isinstance(learning_rate, (list, np.ndarray)):
|
| 1300 |
-
if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1:
|
| 1301 |
-
raise ValueError("learning_rate.ndim must equal 1")
|
| 1302 |
-
elif len(learning_rate) != n_epochs:
|
| 1303 |
-
raise ValueError(
|
| 1304 |
-
"if learning_rate given as list or np.ndarray it must have length n_epochs"
|
| 1305 |
-
)
|
| 1306 |
-
model.learning_rate = learning_rate
|
| 1307 |
-
model.learning_rate_const = mode(learning_rate)[0][0]
|
| 1308 |
-
else:
|
| 1309 |
-
model.learning_rate_const = learning_rate
|
| 1310 |
-
# set learning rate schedule
|
| 1311 |
-
if SGD:
|
| 1312 |
-
LR = np.linspace(0, model.learning_rate_const, 10)
|
| 1313 |
-
if model.n_epochs > 250:
|
| 1314 |
-
LR = np.append(
|
| 1315 |
-
LR, model.learning_rate_const * np.ones(model.n_epochs - 100))
|
| 1316 |
-
for i in range(10):
|
| 1317 |
-
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
|
| 1318 |
-
else:
|
| 1319 |
-
LR = np.append(
|
| 1320 |
-
LR,
|
| 1321 |
-
model.learning_rate_const * np.ones(max(0, model.n_epochs - 10)))
|
| 1322 |
-
else:
|
| 1323 |
-
LR = model.learning_rate_const * np.ones(model.n_epochs)
|
| 1324 |
-
model.learning_rate = LR
|
| 1325 |
-
|
| 1326 |
-
model.batch_size = batch_size
|
| 1327 |
-
model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD)
|
| 1328 |
-
model._set_criterion()
|
| 1329 |
-
|
| 1330 |
-
nimg = len(train_data)
|
| 1331 |
-
|
| 1332 |
-
# compute average cell diameter
|
| 1333 |
-
if diameter is None:
|
| 1334 |
-
diam_train = np.array(
|
| 1335 |
-
[utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))])
|
| 1336 |
-
diam_train_mean = diam_train[diam_train > 0].mean()
|
| 1337 |
-
model.diam_labels = diam_train_mean
|
| 1338 |
-
if rescale:
|
| 1339 |
-
diam_train[diam_train < 5] = 5.
|
| 1340 |
-
if test_data is not None:
|
| 1341 |
-
diam_test = np.array([
|
| 1342 |
-
utils.diameters(test_labels[k][0])[0]
|
| 1343 |
-
for k in range(len(test_labels))
|
| 1344 |
-
])
|
| 1345 |
-
diam_test[diam_test < 5] = 5.
|
| 1346 |
-
denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
|
| 1347 |
-
elif rescale:
|
| 1348 |
-
diam_train_mean = diameter
|
| 1349 |
-
model.diam_labels = diameter
|
| 1350 |
-
denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
|
| 1351 |
-
diam_train = diameter * np.ones(len(train_labels), "float32")
|
| 1352 |
-
if test_data is not None:
|
| 1353 |
-
diam_test = diameter * np.ones(len(test_labels), "float32")
|
| 1354 |
-
|
| 1355 |
-
denoise_logger.info(
|
| 1356 |
-
f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}"
|
| 1357 |
-
)
|
| 1358 |
-
model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean
|
| 1359 |
-
|
| 1360 |
-
nchan = train_data[0].shape[0]
|
| 1361 |
-
denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan)
|
| 1362 |
-
denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" %
|
| 1363 |
-
(model.learning_rate_const, model.batch_size, weight_decay))
|
| 1364 |
-
|
| 1365 |
-
if test_data is not None:
|
| 1366 |
-
denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}")
|
| 1367 |
-
else:
|
| 1368 |
-
denoise_logger.info(f">>>> ntrain = {nimg}")
|
| 1369 |
-
|
| 1370 |
-
tic = time.time()
|
| 1371 |
-
|
| 1372 |
-
lavg, nsum = 0, 0
|
| 1373 |
-
|
| 1374 |
-
if save_path is not None:
|
| 1375 |
-
_, file_label = os.path.split(save_path)
|
| 1376 |
-
file_path = os.path.join(save_path, "models/")
|
| 1377 |
-
|
| 1378 |
-
if not os.path.exists(file_path):
|
| 1379 |
-
os.makedirs(file_path)
|
| 1380 |
-
else:
|
| 1381 |
-
denoise_logger.warning("WARNING: no save_path given, model not saving")
|
| 1382 |
-
|
| 1383 |
-
ksave = 0
|
| 1384 |
-
|
| 1385 |
-
# get indices for each epoch for training
|
| 1386 |
-
np.random.seed(0)
|
| 1387 |
-
inds_all = np.zeros((0,), "int32")
|
| 1388 |
-
if nimg_per_epoch is None or nimg > nimg_per_epoch:
|
| 1389 |
-
nimg_per_epoch = nimg
|
| 1390 |
-
denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}")
|
| 1391 |
-
while len(inds_all) < n_epochs * nimg_per_epoch:
|
| 1392 |
-
rperm = np.random.permutation(nimg)
|
| 1393 |
-
inds_all = np.hstack((inds_all, rperm))
|
| 1394 |
-
|
| 1395 |
-
for iepoch in range(model.n_epochs):
|
| 1396 |
-
if SGD:
|
| 1397 |
-
model._set_learning_rate(model.learning_rate[iepoch])
|
| 1398 |
-
np.random.seed(iepoch)
|
| 1399 |
-
rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch]
|
| 1400 |
-
for ibatch in range(0, nimg_per_epoch, batch_size):
|
| 1401 |
-
inds = rperm[ibatch:ibatch + batch_size]
|
| 1402 |
-
imgi, lbl, scale = random_rotate_and_resize_noise(
|
| 1403 |
-
[train_data[i] for i in inds], [train_labels[i][1:] for i in inds],
|
| 1404 |
-
poisson=poisson, blur=blur, downsample=downsample,
|
| 1405 |
-
diams=diam_train[inds], diam_mean=model.diam_mean)
|
| 1406 |
-
imgi = imgi[:, :1] # keep noisy only
|
| 1407 |
-
if z_masking:
|
| 1408 |
-
nc = imgi.shape[1]
|
| 1409 |
-
nb = imgi.shape[0]
|
| 1410 |
-
ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint(
|
| 1411 |
-
nc // 2 - 1, size=nb))
|
| 1412 |
-
ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint(
|
| 1413 |
-
nc // 2 - 1, size=nb))
|
| 1414 |
-
for b in range(nb):
|
| 1415 |
-
imgi[b, :ncmin[b]] = 0
|
| 1416 |
-
imgi[b, ncmax[b]:] = 0
|
| 1417 |
-
|
| 1418 |
-
train_loss = model._train_step(imgi, lbl)
|
| 1419 |
-
lavg += train_loss
|
| 1420 |
-
nsum += len(imgi)
|
| 1421 |
-
|
| 1422 |
-
if iepoch % 10 == 0 or iepoch == 5:
|
| 1423 |
-
lavg = lavg / nsum
|
| 1424 |
-
if test_data is not None:
|
| 1425 |
-
lavgt, nsum = 0., 0
|
| 1426 |
-
np.random.seed(42)
|
| 1427 |
-
rperm = np.arange(0, len(test_data), 1, int)
|
| 1428 |
-
for ibatch in range(0, len(test_data), batch_size):
|
| 1429 |
-
inds = rperm[ibatch:ibatch + batch_size]
|
| 1430 |
-
imgi, lbl, scale = random_rotate_and_resize_noise(
|
| 1431 |
-
[test_data[i] for i in inds],
|
| 1432 |
-
[test_labels[i][1:] for i in inds], poisson=poisson, blur=blur,
|
| 1433 |
-
downsample=downsample, diams=diam_test[inds],
|
| 1434 |
-
diam_mean=model.diam_mean)
|
| 1435 |
-
imgi = imgi[:, :1] # keep noisy only
|
| 1436 |
-
test_loss = model._test_eval(imgi, lbl)
|
| 1437 |
-
lavgt += test_loss
|
| 1438 |
-
nsum += len(imgi)
|
| 1439 |
-
|
| 1440 |
-
denoise_logger.info(
|
| 1441 |
-
"Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" %
|
| 1442 |
-
(iepoch, time.time() - tic, lavg, lavgt / nsum,
|
| 1443 |
-
model.learning_rate[iepoch]))
|
| 1444 |
-
else:
|
| 1445 |
-
denoise_logger.info(
|
| 1446 |
-
"Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" %
|
| 1447 |
-
(iepoch, time.time() - tic, lavg, model.learning_rate[iepoch]))
|
| 1448 |
-
|
| 1449 |
-
lavg, nsum = 0, 0
|
| 1450 |
-
|
| 1451 |
-
if save_path is not None:
|
| 1452 |
-
if iepoch == model.n_epochs - 1 or iepoch % save_every == 1:
|
| 1453 |
-
# save model at the end
|
| 1454 |
-
if save_each: #separate files as model progresses
|
| 1455 |
-
if model_name is None:
|
| 1456 |
-
filename = "{}_{}_{}_{}".format(
|
| 1457 |
-
model.net_type, file_label,
|
| 1458 |
-
d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))
|
| 1459 |
-
else:
|
| 1460 |
-
filename = "{}_{}".format(model_name, "epoch_" + str(iepoch))
|
| 1461 |
-
else:
|
| 1462 |
-
if model_name is None:
|
| 1463 |
-
filename = "{}_{}_{}".format(model.net_type, file_label,
|
| 1464 |
-
d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
|
| 1465 |
-
else:
|
| 1466 |
-
filename = model_name
|
| 1467 |
-
filename = os.path.join(file_path, filename)
|
| 1468 |
-
ksave += 1
|
| 1469 |
-
denoise_logger.info(f"saving network parameters to {filename}")
|
| 1470 |
-
model.net.save_model(filename)
|
| 1471 |
-
else:
|
| 1472 |
-
filename = save_path
|
| 1473 |
-
|
| 1474 |
-
return filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/export.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
"""Auxiliary module for bioimageio format export
|
| 2 |
-
|
| 3 |
-
Example usage:
|
| 4 |
-
|
| 5 |
-
```bash
|
| 6 |
-
#!/bin/bash
|
| 7 |
-
|
| 8 |
-
# Define default paths and parameters
|
| 9 |
-
DEFAULT_CHANNELS="1 0"
|
| 10 |
-
DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995"
|
| 11 |
-
DEFAULT_PATH_README="/home/qinyu/models/cp/README.md"
|
| 12 |
-
DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg"
|
| 13 |
-
DEFAULT_MODEL_ID="philosophical-panda"
|
| 14 |
-
DEFAULT_MODEL_ICON="🐼"
|
| 15 |
-
DEFAULT_MODEL_VERSION="0.1.0"
|
| 16 |
-
DEFAULT_MODEL_NAME="My Cool Cellpose"
|
| 17 |
-
DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset."
|
| 18 |
-
DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]'
|
| 19 |
-
DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]'
|
| 20 |
-
DEFAULT_MODEL_TAGS="cellpose 3d 2d"
|
| 21 |
-
DEFAULT_MODEL_LICENSE="MIT"
|
| 22 |
-
DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear"
|
| 23 |
-
|
| 24 |
-
# Run the Python script with default parameters
|
| 25 |
-
python export.py \
|
| 26 |
-
--channels $DEFAULT_CHANNELS \
|
| 27 |
-
--path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \
|
| 28 |
-
--path_readme "$DEFAULT_PATH_README" \
|
| 29 |
-
--list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \
|
| 30 |
-
--model_version "$DEFAULT_MODEL_VERSION" \
|
| 31 |
-
--model_name "$DEFAULT_MODEL_NAME" \
|
| 32 |
-
--model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \
|
| 33 |
-
--model_authors "$DEFAULT_MODEL_AUTHORS" \
|
| 34 |
-
--model_cite "$DEFAULT_MODEL_CITE" \
|
| 35 |
-
--model_tags $DEFAULT_MODEL_TAGS \
|
| 36 |
-
--model_license "$DEFAULT_MODEL_LICENSE" \
|
| 37 |
-
--model_repo "$DEFAULT_MODEL_REPO"
|
| 38 |
-
```
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
import os
|
| 42 |
-
import sys
|
| 43 |
-
import json
|
| 44 |
-
import argparse
|
| 45 |
-
from pathlib import Path
|
| 46 |
-
from urllib.parse import urlparse
|
| 47 |
-
|
| 48 |
-
import torch
|
| 49 |
-
import numpy as np
|
| 50 |
-
|
| 51 |
-
from cellpose.io import imread
|
| 52 |
-
from cellpose.utils import download_url_to_file
|
| 53 |
-
from cellpose.transforms import pad_image_ND, normalize_img, convert_image
|
| 54 |
-
from cellpose.vit_sam import CPnetBioImageIO
|
| 55 |
-
|
| 56 |
-
from bioimageio.spec.model.v0_5 import (
|
| 57 |
-
ArchitectureFromFileDescr,
|
| 58 |
-
Author,
|
| 59 |
-
AxisId,
|
| 60 |
-
ChannelAxis,
|
| 61 |
-
CiteEntry,
|
| 62 |
-
Doi,
|
| 63 |
-
FileDescr,
|
| 64 |
-
Identifier,
|
| 65 |
-
InputTensorDescr,
|
| 66 |
-
IntervalOrRatioDataDescr,
|
| 67 |
-
LicenseId,
|
| 68 |
-
ModelDescr,
|
| 69 |
-
ModelId,
|
| 70 |
-
OrcidId,
|
| 71 |
-
OutputTensorDescr,
|
| 72 |
-
ParameterizedSize,
|
| 73 |
-
PytorchStateDictWeightsDescr,
|
| 74 |
-
SizeReference,
|
| 75 |
-
SpaceInputAxis,
|
| 76 |
-
SpaceOutputAxis,
|
| 77 |
-
TensorId,
|
| 78 |
-
TorchscriptWeightsDescr,
|
| 79 |
-
Version,
|
| 80 |
-
WeightsDescr,
|
| 81 |
-
)
|
| 82 |
-
# Define ARBITRARY_SIZE if it is not available in the module
|
| 83 |
-
try:
|
| 84 |
-
from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE
|
| 85 |
-
except ImportError:
|
| 86 |
-
ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
|
| 87 |
-
|
| 88 |
-
from bioimageio.spec.common import HttpUrl
|
| 89 |
-
from bioimageio.spec import save_bioimageio_package
|
| 90 |
-
from bioimageio.core import test_model
|
| 91 |
-
|
| 92 |
-
DEFAULT_CHANNELS = [2, 1]
|
| 93 |
-
DEFAULT_NORMALIZE_PARAMS = {
|
| 94 |
-
"axis": -1,
|
| 95 |
-
"lowhigh": None,
|
| 96 |
-
"percentile": None,
|
| 97 |
-
"normalize": True,
|
| 98 |
-
"norm3D": False,
|
| 99 |
-
"sharpen_radius": 0,
|
| 100 |
-
"smooth_radius": 0,
|
| 101 |
-
"tile_norm_blocksize": 0,
|
| 102 |
-
"tile_norm_smooth3D": 1,
|
| 103 |
-
"invert": False,
|
| 104 |
-
}
|
| 105 |
-
IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif"
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS):
|
| 109 |
-
"""
|
| 110 |
-
Download and normalize image.
|
| 111 |
-
"""
|
| 112 |
-
filename = os.path.basename(urlparse(IMAGE_URL).path)
|
| 113 |
-
path_image = path_dir_temp / filename
|
| 114 |
-
if not path_image.exists():
|
| 115 |
-
sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n')
|
| 116 |
-
download_url_to_file(IMAGE_URL, path_image)
|
| 117 |
-
img = imread(path_image).astype(np.float32)
|
| 118 |
-
img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2)
|
| 119 |
-
img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS)
|
| 120 |
-
img = np.transpose(img, (0, 3, 1, 2))
|
| 121 |
-
img, _, _ = pad_image_ND(img)
|
| 122 |
-
return img
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
|
| 126 |
-
cpnet_kwargs = {
|
| 127 |
-
"nout": 3,
|
| 128 |
-
}
|
| 129 |
-
cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
|
| 130 |
-
state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
|
| 131 |
-
cpnet_biio.load_state_dict(state_dict_cuda)
|
| 132 |
-
cpnet_biio.eval() # crucial for the prediction results
|
| 133 |
-
return cpnet_biio, cpnet_kwargs
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def descr_gen_input(path_test_input, nchan=2):
|
| 137 |
-
input_axes = [
|
| 138 |
-
SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE),
|
| 139 |
-
ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]),
|
| 140 |
-
SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)),
|
| 141 |
-
SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)),
|
| 142 |
-
]
|
| 143 |
-
data_descr = IntervalOrRatioDataDescr(type="float32")
|
| 144 |
-
path_test_input = Path(path_test_input)
|
| 145 |
-
descr_input = InputTensorDescr(
|
| 146 |
-
id=TensorId("raw"),
|
| 147 |
-
axes=input_axes,
|
| 148 |
-
test_tensor=FileDescr(source=path_test_input),
|
| 149 |
-
data=data_descr,
|
| 150 |
-
)
|
| 151 |
-
return descr_input
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def descr_gen_output_flow(path_test_output):
|
| 155 |
-
output_axes_output_tensor = [
|
| 156 |
-
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 157 |
-
ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]),
|
| 158 |
-
SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
|
| 159 |
-
SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))),
|
| 160 |
-
]
|
| 161 |
-
path_test_output = Path(path_test_output)
|
| 162 |
-
descr_output = OutputTensorDescr(
|
| 163 |
-
id=TensorId("flow"),
|
| 164 |
-
axes=output_axes_output_tensor,
|
| 165 |
-
test_tensor=FileDescr(source=path_test_output),
|
| 166 |
-
)
|
| 167 |
-
return descr_output
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def descr_gen_output_downsampled(path_dir_temp, nbase=None):
|
| 171 |
-
if nbase is None:
|
| 172 |
-
nbase = [32, 64, 128, 256]
|
| 173 |
-
|
| 174 |
-
output_axes_downsampled_tensors = [
|
| 175 |
-
[
|
| 176 |
-
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 177 |
-
ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]),
|
| 178 |
-
SpaceOutputAxis(
|
| 179 |
-
id=AxisId("y"),
|
| 180 |
-
size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")),
|
| 181 |
-
scale=2**offset,
|
| 182 |
-
),
|
| 183 |
-
SpaceOutputAxis(
|
| 184 |
-
id=AxisId("x"),
|
| 185 |
-
size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")),
|
| 186 |
-
scale=2**offset,
|
| 187 |
-
),
|
| 188 |
-
]
|
| 189 |
-
for offset, base in enumerate(nbase)
|
| 190 |
-
]
|
| 191 |
-
path_downsampled_tensors = [
|
| 192 |
-
Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors))
|
| 193 |
-
]
|
| 194 |
-
descr_output_downsampled_tensors = [
|
| 195 |
-
OutputTensorDescr(
|
| 196 |
-
id=TensorId(f"downsampled_{i}"),
|
| 197 |
-
axes=axes,
|
| 198 |
-
test_tensor=FileDescr(source=path),
|
| 199 |
-
)
|
| 200 |
-
for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors))
|
| 201 |
-
]
|
| 202 |
-
return descr_output_downsampled_tensors
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def descr_gen_output_style(path_test_style, nchannel=256):
|
| 206 |
-
output_axes_style_tensor = [
|
| 207 |
-
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 208 |
-
ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]),
|
| 209 |
-
]
|
| 210 |
-
path_style_tensor = Path(path_test_style)
|
| 211 |
-
descr_output_style_tensor = OutputTensorDescr(
|
| 212 |
-
id=TensorId("style"),
|
| 213 |
-
axes=output_axes_style_tensor,
|
| 214 |
-
test_tensor=FileDescr(source=path_style_tensor),
|
| 215 |
-
)
|
| 216 |
-
return descr_output_style_tensor
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None):
|
| 220 |
-
if path_cpnet_wrapper is None:
|
| 221 |
-
path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py"
|
| 222 |
-
pytorch_architecture = ArchitectureFromFileDescr(
|
| 223 |
-
callable=Identifier("CPnetBioImageIO"),
|
| 224 |
-
source=Path(path_cpnet_wrapper),
|
| 225 |
-
kwargs=cpnet_kwargs,
|
| 226 |
-
)
|
| 227 |
-
return pytorch_architecture
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def descr_gen_documentation(path_doc, markdown_text):
|
| 231 |
-
with open(path_doc, "w") as f:
|
| 232 |
-
f.write(markdown_text)
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
def package_to_bioimageio(
|
| 236 |
-
path_pretrained_model,
|
| 237 |
-
path_save_trace,
|
| 238 |
-
path_readme,
|
| 239 |
-
list_path_cover_images,
|
| 240 |
-
descr_input,
|
| 241 |
-
descr_output,
|
| 242 |
-
descr_output_downsampled_tensors,
|
| 243 |
-
descr_output_style_tensor,
|
| 244 |
-
pytorch_version,
|
| 245 |
-
pytorch_architecture,
|
| 246 |
-
model_id,
|
| 247 |
-
model_icon,
|
| 248 |
-
model_version,
|
| 249 |
-
model_name,
|
| 250 |
-
model_documentation,
|
| 251 |
-
model_authors,
|
| 252 |
-
model_cite,
|
| 253 |
-
model_tags,
|
| 254 |
-
model_license,
|
| 255 |
-
model_repo,
|
| 256 |
-
):
|
| 257 |
-
"""Package model description to BioImage.IO format."""
|
| 258 |
-
my_model_descr = ModelDescr(
|
| 259 |
-
id=ModelId(model_id) if model_id is not None else None,
|
| 260 |
-
id_emoji=model_icon,
|
| 261 |
-
version=Version(model_version),
|
| 262 |
-
name=model_name,
|
| 263 |
-
description=model_documentation,
|
| 264 |
-
authors=[
|
| 265 |
-
Author(
|
| 266 |
-
name=author["name"],
|
| 267 |
-
affiliation=author["affiliation"],
|
| 268 |
-
github_user=author["github_user"],
|
| 269 |
-
orcid=OrcidId(author["orcid"]),
|
| 270 |
-
)
|
| 271 |
-
for author in model_authors
|
| 272 |
-
],
|
| 273 |
-
cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite],
|
| 274 |
-
covers=[Path(img) for img in list_path_cover_images],
|
| 275 |
-
license=LicenseId(model_license),
|
| 276 |
-
tags=model_tags,
|
| 277 |
-
documentation=Path(path_readme),
|
| 278 |
-
git_repo=HttpUrl(model_repo),
|
| 279 |
-
inputs=[descr_input],
|
| 280 |
-
outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors,
|
| 281 |
-
weights=WeightsDescr(
|
| 282 |
-
pytorch_state_dict=PytorchStateDictWeightsDescr(
|
| 283 |
-
source=Path(path_pretrained_model),
|
| 284 |
-
architecture=pytorch_architecture,
|
| 285 |
-
pytorch_version=pytorch_version,
|
| 286 |
-
),
|
| 287 |
-
torchscript=TorchscriptWeightsDescr(
|
| 288 |
-
source=Path(path_save_trace),
|
| 289 |
-
pytorch_version=pytorch_version,
|
| 290 |
-
parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights.
|
| 291 |
-
),
|
| 292 |
-
),
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
return my_model_descr
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def parse_args():
|
| 299 |
-
# fmt: off
|
| 300 |
-
parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose")
|
| 301 |
-
parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]")
|
| 302 |
-
parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995")
|
| 303 |
-
parser.add_argument("--path_readme", required=True, type=str, help="Path to README file")
|
| 304 |
-
parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images")
|
| 305 |
-
parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None)
|
| 306 |
-
parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None)
|
| 307 |
-
parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0")
|
| 308 |
-
parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose")
|
| 309 |
-
parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.")
|
| 310 |
-
parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'")
|
| 311 |
-
parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'")
|
| 312 |
-
parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d")
|
| 313 |
-
parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT")
|
| 314 |
-
parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL")
|
| 315 |
-
return parser.parse_args()
|
| 316 |
-
# fmt: on
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
def main():
|
| 320 |
-
args = parse_args()
|
| 321 |
-
|
| 322 |
-
# Parse user-provided paths and arguments
|
| 323 |
-
channels = args.channels
|
| 324 |
-
model_cite = json.loads(args.model_cite)
|
| 325 |
-
model_authors = json.loads(args.model_authors)
|
| 326 |
-
|
| 327 |
-
path_readme = Path(args.path_readme)
|
| 328 |
-
path_pretrained_model = Path(args.path_pretrained_model)
|
| 329 |
-
list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images]
|
| 330 |
-
|
| 331 |
-
# Auto-generated paths
|
| 332 |
-
path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py"
|
| 333 |
-
path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem
|
| 334 |
-
path_dir_temp.mkdir(parents=True, exist_ok=True)
|
| 335 |
-
|
| 336 |
-
path_save_trace = path_dir_temp / "cp_traced.pt"
|
| 337 |
-
path_test_input = path_dir_temp / "test_input.npy"
|
| 338 |
-
path_test_output = path_dir_temp / "test_output.npy"
|
| 339 |
-
path_test_style = path_dir_temp / "test_style.npy"
|
| 340 |
-
path_bioimageio_package = path_dir_temp / "cellpose_model.zip"
|
| 341 |
-
|
| 342 |
-
# Download test input image
|
| 343 |
-
img_np = download_and_normalize_image(path_dir_temp, channels=channels)
|
| 344 |
-
np.save(path_test_input, img_np)
|
| 345 |
-
img = torch.tensor(img_np).float()
|
| 346 |
-
|
| 347 |
-
# Load model
|
| 348 |
-
cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model)
|
| 349 |
-
|
| 350 |
-
# Test model and save output
|
| 351 |
-
tuple_output_tensor = cpnet_biio(img)
|
| 352 |
-
np.save(path_test_output, tuple_output_tensor[0].detach().numpy())
|
| 353 |
-
np.save(path_test_style, tuple_output_tensor[1].detach().numpy())
|
| 354 |
-
for i, t in enumerate(tuple_output_tensor[2:]):
|
| 355 |
-
np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy())
|
| 356 |
-
|
| 357 |
-
# Save traced model
|
| 358 |
-
model_traced = torch.jit.trace(cpnet_biio, img)
|
| 359 |
-
model_traced.save(path_save_trace)
|
| 360 |
-
|
| 361 |
-
# Generate model description
|
| 362 |
-
descr_input = descr_gen_input(path_test_input)
|
| 363 |
-
descr_output = descr_gen_output_flow(path_test_output)
|
| 364 |
-
descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:])
|
| 365 |
-
descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1])
|
| 366 |
-
pytorch_version = Version(torch.__version__)
|
| 367 |
-
pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper)
|
| 368 |
-
|
| 369 |
-
# Package model
|
| 370 |
-
my_model_descr = package_to_bioimageio(
|
| 371 |
-
path_pretrained_model,
|
| 372 |
-
path_save_trace,
|
| 373 |
-
path_readme,
|
| 374 |
-
list_path_cover_images,
|
| 375 |
-
descr_input,
|
| 376 |
-
descr_output,
|
| 377 |
-
descr_output_downsampled_tensors,
|
| 378 |
-
descr_output_style_tensor,
|
| 379 |
-
pytorch_version,
|
| 380 |
-
pytorch_architecture,
|
| 381 |
-
args.model_id,
|
| 382 |
-
args.model_icon,
|
| 383 |
-
args.model_version,
|
| 384 |
-
args.model_name,
|
| 385 |
-
args.model_documentation,
|
| 386 |
-
model_authors,
|
| 387 |
-
model_cite,
|
| 388 |
-
args.model_tags,
|
| 389 |
-
args.model_license,
|
| 390 |
-
args.model_repo,
|
| 391 |
-
)
|
| 392 |
-
|
| 393 |
-
# Test model
|
| 394 |
-
summary = test_model(my_model_descr, weight_format="pytorch_state_dict")
|
| 395 |
-
summary.display()
|
| 396 |
-
summary = test_model(my_model_descr, weight_format="torchscript")
|
| 397 |
-
summary.display()
|
| 398 |
-
|
| 399 |
-
# Save BioImage.IO package
|
| 400 |
-
package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package))
|
| 401 |
-
print("package path:", package_path)
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
if __name__ == "__main__":
|
| 405 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/gui.py
DELETED
|
@@ -1,2007 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import sys, os, pathlib, warnings, datetime, time, copy
|
| 6 |
-
|
| 7 |
-
from qtpy import QtGui, QtCore
|
| 8 |
-
from superqt import QRangeSlider, QCollapsible
|
| 9 |
-
from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \
|
| 10 |
-
QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \
|
| 11 |
-
QLineEdit, QMessageBox, QGroupBox, QMenu, QAction
|
| 12 |
-
import pyqtgraph as pg
|
| 13 |
-
|
| 14 |
-
import numpy as np
|
| 15 |
-
from scipy.stats import mode
|
| 16 |
-
import cv2
|
| 17 |
-
|
| 18 |
-
from . import guiparts, menus, io
|
| 19 |
-
from .. import models, core, dynamics, version, train
|
| 20 |
-
from ..utils import download_url_to_file, masks_to_outlines, diameters
|
| 21 |
-
from ..io import get_image_files, imsave, imread
|
| 22 |
-
from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img
|
| 23 |
-
from ..models import normalize_default
|
| 24 |
-
from ..plot import disk
|
| 25 |
-
|
| 26 |
-
try:
|
| 27 |
-
import matplotlib.pyplot as plt
|
| 28 |
-
MATPLOTLIB = True
|
| 29 |
-
except:
|
| 30 |
-
MATPLOTLIB = False
|
| 31 |
-
|
| 32 |
-
Horizontal = QtCore.Qt.Orientation.Horizontal
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class Slider(QRangeSlider):
|
| 36 |
-
|
| 37 |
-
def __init__(self, parent, name, color):
|
| 38 |
-
super().__init__(Horizontal)
|
| 39 |
-
self.setEnabled(False)
|
| 40 |
-
self.valueChanged.connect(lambda: self.levelChanged(parent))
|
| 41 |
-
self.name = name
|
| 42 |
-
|
| 43 |
-
self.setStyleSheet(""" QSlider{
|
| 44 |
-
background-color: transparent;
|
| 45 |
-
}
|
| 46 |
-
""")
|
| 47 |
-
self.show()
|
| 48 |
-
|
| 49 |
-
def levelChanged(self, parent):
|
| 50 |
-
parent.level_change(self.name)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
class QHLine(QFrame):
|
| 54 |
-
|
| 55 |
-
def __init__(self):
|
| 56 |
-
super(QHLine, self).__init__()
|
| 57 |
-
self.setFrameShape(QFrame.HLine)
|
| 58 |
-
self.setLineWidth(8)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def make_bwr():
|
| 62 |
-
# make a bwr colormap
|
| 63 |
-
b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
|
| 64 |
-
r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis]
|
| 65 |
-
g = np.append(np.linspace(0, 255, 128),
|
| 66 |
-
np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
|
| 67 |
-
color = np.concatenate((r, g, b), axis=-1).astype(np.uint8)
|
| 68 |
-
bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
|
| 69 |
-
return bwr
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def make_spectral():
|
| 73 |
-
# make spectral colormap
|
| 74 |
-
r = np.array([
|
| 75 |
-
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80,
|
| 76 |
-
84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128,
|
| 77 |
-
128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88,
|
| 78 |
-
80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 79 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23,
|
| 80 |
-
27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103,
|
| 81 |
-
107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167,
|
| 82 |
-
171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231,
|
| 83 |
-
235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 84 |
-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 85 |
-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 86 |
-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 87 |
-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 88 |
-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 89 |
-
255, 255, 255, 255, 255
|
| 90 |
-
])
|
| 91 |
-
g = np.array([
|
| 92 |
-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3,
|
| 93 |
-
2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111,
|
| 94 |
-
119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239,
|
| 95 |
-
247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143,
|
| 96 |
-
135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150,
|
| 97 |
-
151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175,
|
| 98 |
-
177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201,
|
| 99 |
-
202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226,
|
| 100 |
-
228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251,
|
| 101 |
-
253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199,
|
| 102 |
-
195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135,
|
| 103 |
-
131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63,
|
| 104 |
-
59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41,
|
| 105 |
-
49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180,
|
| 106 |
-
189, 197, 205, 213, 222, 230, 238, 246, 254
|
| 107 |
-
])
|
| 108 |
-
b = np.array([
|
| 109 |
-
0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143,
|
| 110 |
-
151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255,
|
| 111 |
-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 112 |
-
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247,
|
| 113 |
-
243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183,
|
| 114 |
-
179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124,
|
| 115 |
-
122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90,
|
| 116 |
-
88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50,
|
| 117 |
-
48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10,
|
| 118 |
-
8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 119 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 120 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74,
|
| 121 |
-
82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205,
|
| 122 |
-
213, 222, 230, 238, 246, 254
|
| 123 |
-
])
|
| 124 |
-
color = (np.vstack((r, g, b)).T).astype(np.uint8)
|
| 125 |
-
spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
|
| 126 |
-
return spectral
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def make_cmap(cm=0):
|
| 130 |
-
# make a single channel colormap
|
| 131 |
-
r = np.arange(0, 256)
|
| 132 |
-
color = np.zeros((256, 3))
|
| 133 |
-
color[:, cm] = r
|
| 134 |
-
color = color.astype(np.uint8)
|
| 135 |
-
cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
|
| 136 |
-
return cmap
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def run(image=None):
|
| 140 |
-
from ..io import logger_setup
|
| 141 |
-
logger, log_file = logger_setup()
|
| 142 |
-
# Always start by initializing Qt (only once per application)
|
| 143 |
-
warnings.filterwarnings("ignore")
|
| 144 |
-
app = QApplication(sys.argv)
|
| 145 |
-
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
|
| 146 |
-
guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
|
| 147 |
-
if not icon_path.is_file():
|
| 148 |
-
cp_dir = pathlib.Path.home().joinpath(".cellpose")
|
| 149 |
-
cp_dir.mkdir(exist_ok=True)
|
| 150 |
-
print("downloading logo")
|
| 151 |
-
download_url_to_file(
|
| 152 |
-
"https://www.cellpose.org/static/images/cellpose_transparent.png",
|
| 153 |
-
icon_path, progress=True)
|
| 154 |
-
if not guip_path.is_file():
|
| 155 |
-
print("downloading help window image")
|
| 156 |
-
download_url_to_file("https://www.cellpose.org/static/images/cellposeSAM_gui.png",
|
| 157 |
-
guip_path, progress=True)
|
| 158 |
-
icon_path = str(icon_path.resolve())
|
| 159 |
-
app_icon = QtGui.QIcon()
|
| 160 |
-
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
|
| 161 |
-
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
|
| 162 |
-
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
|
| 163 |
-
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
|
| 164 |
-
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
|
| 165 |
-
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
|
| 166 |
-
app.setWindowIcon(app_icon)
|
| 167 |
-
app.setStyle("Fusion")
|
| 168 |
-
app.setPalette(guiparts.DarkPalette())
|
| 169 |
-
MainW(image=image, logger=logger)
|
| 170 |
-
ret = app.exec_()
|
| 171 |
-
sys.exit(ret)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
class MainW(QMainWindow):
|
| 175 |
-
|
| 176 |
-
def __init__(self, image=None, logger=None):
|
| 177 |
-
super(MainW, self).__init__()
|
| 178 |
-
|
| 179 |
-
self.logger = logger
|
| 180 |
-
pg.setConfigOptions(imageAxisOrder="row-major")
|
| 181 |
-
self.setGeometry(50, 50, 1200, 1000)
|
| 182 |
-
self.setWindowTitle(f"cellpose v{version}")
|
| 183 |
-
self.cp_path = os.path.dirname(os.path.realpath(__file__))
|
| 184 |
-
app_icon = QtGui.QIcon()
|
| 185 |
-
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
|
| 186 |
-
icon_path = str(icon_path.resolve())
|
| 187 |
-
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
|
| 188 |
-
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
|
| 189 |
-
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
|
| 190 |
-
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
|
| 191 |
-
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
|
| 192 |
-
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
|
| 193 |
-
self.setWindowIcon(app_icon)
|
| 194 |
-
# rgb(150,255,150)
|
| 195 |
-
self.setStyleSheet(guiparts.stylesheet())
|
| 196 |
-
|
| 197 |
-
menus.mainmenu(self)
|
| 198 |
-
menus.editmenu(self)
|
| 199 |
-
menus.modelmenu(self)
|
| 200 |
-
menus.helpmenu(self)
|
| 201 |
-
|
| 202 |
-
self.stylePressed = """QPushButton {Text-align: center;
|
| 203 |
-
background-color: rgb(150,50,150);
|
| 204 |
-
border-color: white;
|
| 205 |
-
color:white;}
|
| 206 |
-
QToolTip {
|
| 207 |
-
background-color: black;
|
| 208 |
-
color: white;
|
| 209 |
-
border: black solid 1px
|
| 210 |
-
}"""
|
| 211 |
-
self.styleUnpressed = """QPushButton {Text-align: center;
|
| 212 |
-
background-color: rgb(50,50,50);
|
| 213 |
-
border-color: white;
|
| 214 |
-
color:white;}
|
| 215 |
-
QToolTip {
|
| 216 |
-
background-color: black;
|
| 217 |
-
color: white;
|
| 218 |
-
border: black solid 1px
|
| 219 |
-
}"""
|
| 220 |
-
self.loaded = False
|
| 221 |
-
|
| 222 |
-
# ---- MAIN WIDGET LAYOUT ---- #
|
| 223 |
-
self.cwidget = QWidget(self)
|
| 224 |
-
self.lmain = QGridLayout()
|
| 225 |
-
self.cwidget.setLayout(self.lmain)
|
| 226 |
-
self.setCentralWidget(self.cwidget)
|
| 227 |
-
self.lmain.setVerticalSpacing(0)
|
| 228 |
-
self.lmain.setContentsMargins(0, 0, 0, 10)
|
| 229 |
-
|
| 230 |
-
self.imask = 0
|
| 231 |
-
self.scrollarea = QScrollArea()
|
| 232 |
-
self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
|
| 233 |
-
self.scrollarea.setStyleSheet("""QScrollArea { border: none }""")
|
| 234 |
-
self.scrollarea.setWidgetResizable(True)
|
| 235 |
-
self.swidget = QWidget(self)
|
| 236 |
-
self.scrollarea.setWidget(self.swidget)
|
| 237 |
-
self.l0 = QGridLayout()
|
| 238 |
-
self.swidget.setLayout(self.l0)
|
| 239 |
-
b = self.make_buttons()
|
| 240 |
-
self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9)
|
| 241 |
-
|
| 242 |
-
# ---- drawing area ---- #
|
| 243 |
-
self.win = pg.GraphicsLayoutWidget()
|
| 244 |
-
|
| 245 |
-
self.lmain.addWidget(self.win, 0, 9, 40, 30)
|
| 246 |
-
|
| 247 |
-
self.win.scene().sigMouseClicked.connect(self.plot_clicked)
|
| 248 |
-
self.win.scene().sigMouseMoved.connect(self.mouse_moved)
|
| 249 |
-
self.make_viewbox()
|
| 250 |
-
self.lmain.setColumnStretch(10, 1)
|
| 251 |
-
bwrmap = make_bwr()
|
| 252 |
-
self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False)
|
| 253 |
-
self.cmap = []
|
| 254 |
-
# spectral colormap
|
| 255 |
-
self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0,
|
| 256 |
-
alpha=False))
|
| 257 |
-
# single channel colormaps
|
| 258 |
-
for i in range(3):
|
| 259 |
-
self.cmap.append(
|
| 260 |
-
make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False))
|
| 261 |
-
|
| 262 |
-
if MATPLOTLIB:
|
| 263 |
-
self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) *
|
| 264 |
-
255).astype(np.uint8)
|
| 265 |
-
np.random.seed(42) # make colors stable
|
| 266 |
-
self.colormap = self.colormap[np.random.permutation(1000000)]
|
| 267 |
-
else:
|
| 268 |
-
np.random.seed(42) # make colors stable
|
| 269 |
-
self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype(
|
| 270 |
-
np.uint8)
|
| 271 |
-
self.NZ = 1
|
| 272 |
-
self.restore = None
|
| 273 |
-
self.ratio = 1.
|
| 274 |
-
self.reset()
|
| 275 |
-
|
| 276 |
-
# This needs to go after .reset() is called to get state fully set up:
|
| 277 |
-
self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked)
|
| 278 |
-
|
| 279 |
-
self.load_3D = False
|
| 280 |
-
|
| 281 |
-
# if called with image, load it
|
| 282 |
-
if image is not None:
|
| 283 |
-
self.filename = image
|
| 284 |
-
io._load_image(self, self.filename)
|
| 285 |
-
|
| 286 |
-
# training settings
|
| 287 |
-
d = datetime.datetime.now()
|
| 288 |
-
self.training_params = {
|
| 289 |
-
"model_index": 0,
|
| 290 |
-
"learning_rate": 1e-5,
|
| 291 |
-
"weight_decay": 0.1,
|
| 292 |
-
"n_epochs": 100,
|
| 293 |
-
"model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"),
|
| 294 |
-
}
|
| 295 |
-
|
| 296 |
-
self.stitch_threshold = 0.
|
| 297 |
-
self.flow3D_smooth = 0.
|
| 298 |
-
self.anisotropy = 1.
|
| 299 |
-
self.min_size = 15
|
| 300 |
-
|
| 301 |
-
self.setAcceptDrops(True)
|
| 302 |
-
self.win.show()
|
| 303 |
-
self.show()
|
| 304 |
-
|
| 305 |
-
def help_window(self):
|
| 306 |
-
HW = guiparts.HelpWindow(self)
|
| 307 |
-
HW.show()
|
| 308 |
-
|
| 309 |
-
def train_help_window(self):
|
| 310 |
-
THW = guiparts.TrainHelpWindow(self)
|
| 311 |
-
THW.show()
|
| 312 |
-
|
| 313 |
-
def gui_window(self):
|
| 314 |
-
EG = guiparts.ExampleGUI(self)
|
| 315 |
-
EG.show()
|
| 316 |
-
|
| 317 |
-
def make_buttons(self):
|
| 318 |
-
self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold)
|
| 319 |
-
self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold)
|
| 320 |
-
self.medfont = QtGui.QFont("Arial", 9)
|
| 321 |
-
self.smallfont = QtGui.QFont("Arial", 8)
|
| 322 |
-
|
| 323 |
-
b = 0
|
| 324 |
-
self.satBox = QGroupBox("Views")
|
| 325 |
-
self.satBox.setFont(self.boldfont)
|
| 326 |
-
self.satBoxG = QGridLayout()
|
| 327 |
-
self.satBox.setLayout(self.satBoxG)
|
| 328 |
-
self.l0.addWidget(self.satBox, b, 0, 1, 9)
|
| 329 |
-
|
| 330 |
-
widget_row = 0
|
| 331 |
-
self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob
|
| 332 |
-
self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B
|
| 333 |
-
self.RGBDropDown = QComboBox()
|
| 334 |
-
self.RGBDropDown.addItems(
|
| 335 |
-
["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"])
|
| 336 |
-
self.RGBDropDown.setFont(self.medfont)
|
| 337 |
-
self.RGBDropDown.currentIndexChanged.connect(self.color_choose)
|
| 338 |
-
self.satBoxG.addWidget(self.RGBDropDown, widget_row, 0, 1, 3)
|
| 339 |
-
|
| 340 |
-
label = QLabel("<p>[↑ / ↓ or W/S]</p>")
|
| 341 |
-
label.setFont(self.smallfont)
|
| 342 |
-
self.satBoxG.addWidget(label, widget_row, 3, 1, 3)
|
| 343 |
-
label = QLabel("[R / G / B \n toggles color ]")
|
| 344 |
-
label.setFont(self.smallfont)
|
| 345 |
-
self.satBoxG.addWidget(label, widget_row, 6, 1, 3)
|
| 346 |
-
|
| 347 |
-
widget_row += 1
|
| 348 |
-
self.ViewDropDown = QComboBox()
|
| 349 |
-
self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"])
|
| 350 |
-
self.ViewDropDown.setFont(self.medfont)
|
| 351 |
-
self.ViewDropDown.model().item(3).setEnabled(False)
|
| 352 |
-
self.ViewDropDown.currentIndexChanged.connect(self.update_plot)
|
| 353 |
-
self.satBoxG.addWidget(self.ViewDropDown, widget_row, 0, 2, 3)
|
| 354 |
-
|
| 355 |
-
label = QLabel("[pageup / pagedown]")
|
| 356 |
-
label.setFont(self.smallfont)
|
| 357 |
-
self.satBoxG.addWidget(label, widget_row, 3, 1, 5)
|
| 358 |
-
|
| 359 |
-
widget_row += 2
|
| 360 |
-
label = QLabel("")
|
| 361 |
-
label.setToolTip(
|
| 362 |
-
"NOTE: manually changing the saturation bars does not affect normalization in segmentation"
|
| 363 |
-
)
|
| 364 |
-
self.satBoxG.addWidget(label, widget_row, 0, 1, 5)
|
| 365 |
-
|
| 366 |
-
self.autobtn = QCheckBox("auto-adjust saturation")
|
| 367 |
-
self.autobtn.setToolTip("sets scale-bars as normalized for segmentation")
|
| 368 |
-
self.autobtn.setFont(self.medfont)
|
| 369 |
-
self.autobtn.setChecked(True)
|
| 370 |
-
self.satBoxG.addWidget(self.autobtn, widget_row, 1, 1, 8)
|
| 371 |
-
|
| 372 |
-
widget_row += 1
|
| 373 |
-
self.sliders = []
|
| 374 |
-
colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]]
|
| 375 |
-
colornames = ["red", "Chartreuse", "DodgerBlue"]
|
| 376 |
-
names = ["red", "green", "blue"]
|
| 377 |
-
for r in range(3):
|
| 378 |
-
widget_row += 1
|
| 379 |
-
if r == 0:
|
| 380 |
-
label = QLabel('<font color="gray">gray/</font><br>red')
|
| 381 |
-
else:
|
| 382 |
-
label = QLabel(names[r] + ":")
|
| 383 |
-
label.setStyleSheet(f"color: {colornames[r]}")
|
| 384 |
-
label.setFont(self.boldmedfont)
|
| 385 |
-
self.satBoxG.addWidget(label, widget_row, 0, 1, 2)
|
| 386 |
-
self.sliders.append(Slider(self, names[r], colors[r]))
|
| 387 |
-
self.sliders[-1].setMinimum(-.1)
|
| 388 |
-
self.sliders[-1].setMaximum(255.1)
|
| 389 |
-
self.sliders[-1].setValue([0, 255])
|
| 390 |
-
self.sliders[-1].setToolTip(
|
| 391 |
-
"NOTE: manually changing the saturation bars does not affect normalization in segmentation"
|
| 392 |
-
)
|
| 393 |
-
self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7)
|
| 394 |
-
|
| 395 |
-
b += 1
|
| 396 |
-
self.drawBox = QGroupBox("Drawing")
|
| 397 |
-
self.drawBox.setFont(self.boldfont)
|
| 398 |
-
self.drawBoxG = QGridLayout()
|
| 399 |
-
self.drawBox.setLayout(self.drawBoxG)
|
| 400 |
-
self.l0.addWidget(self.drawBox, b, 0, 1, 9)
|
| 401 |
-
self.autosave = True
|
| 402 |
-
|
| 403 |
-
widget_row = 0
|
| 404 |
-
self.brush_size = 3
|
| 405 |
-
self.BrushChoose = QComboBox()
|
| 406 |
-
self.BrushChoose.addItems(["1", "3", "5", "7", "9"])
|
| 407 |
-
self.BrushChoose.currentIndexChanged.connect(self.brush_choose)
|
| 408 |
-
self.BrushChoose.setFixedWidth(40)
|
| 409 |
-
self.BrushChoose.setFont(self.medfont)
|
| 410 |
-
self.drawBoxG.addWidget(self.BrushChoose, widget_row, 3, 1, 2)
|
| 411 |
-
label = QLabel("brush size:")
|
| 412 |
-
label.setFont(self.medfont)
|
| 413 |
-
self.drawBoxG.addWidget(label, widget_row, 0, 1, 3)
|
| 414 |
-
|
| 415 |
-
widget_row += 1
|
| 416 |
-
# turn off masks
|
| 417 |
-
self.layer_off = False
|
| 418 |
-
self.masksOn = True
|
| 419 |
-
self.MCheckBox = QCheckBox("MASKS ON [X]")
|
| 420 |
-
self.MCheckBox.setFont(self.medfont)
|
| 421 |
-
self.MCheckBox.setChecked(True)
|
| 422 |
-
self.MCheckBox.toggled.connect(self.toggle_masks)
|
| 423 |
-
self.drawBoxG.addWidget(self.MCheckBox, widget_row, 0, 1, 5)
|
| 424 |
-
|
| 425 |
-
widget_row += 1
|
| 426 |
-
# turn off outlines
|
| 427 |
-
self.outlinesOn = False # turn off by default
|
| 428 |
-
self.OCheckBox = QCheckBox("outlines on [Z]")
|
| 429 |
-
self.OCheckBox.setFont(self.medfont)
|
| 430 |
-
self.drawBoxG.addWidget(self.OCheckBox, widget_row, 0, 1, 5)
|
| 431 |
-
self.OCheckBox.setChecked(False)
|
| 432 |
-
self.OCheckBox.toggled.connect(self.toggle_masks)
|
| 433 |
-
|
| 434 |
-
widget_row += 1
|
| 435 |
-
self.SCheckBox = QCheckBox("single stroke")
|
| 436 |
-
self.SCheckBox.setFont(self.medfont)
|
| 437 |
-
self.SCheckBox.setChecked(True)
|
| 438 |
-
self.SCheckBox.toggled.connect(self.autosave_on)
|
| 439 |
-
self.SCheckBox.setEnabled(True)
|
| 440 |
-
self.drawBoxG.addWidget(self.SCheckBox, widget_row, 0, 1, 5)
|
| 441 |
-
|
| 442 |
-
# buttons for deleting multiple cells
|
| 443 |
-
self.deleteBox = QGroupBox("delete multiple ROIs")
|
| 444 |
-
self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)")
|
| 445 |
-
self.deleteBox.setFont(self.medfont)
|
| 446 |
-
self.deleteBoxG = QGridLayout()
|
| 447 |
-
self.deleteBox.setLayout(self.deleteBoxG)
|
| 448 |
-
self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4)
|
| 449 |
-
self.MakeDeletionRegionButton = QPushButton("region-select")
|
| 450 |
-
self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells)
|
| 451 |
-
self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4)
|
| 452 |
-
self.MakeDeletionRegionButton.setFont(self.smallfont)
|
| 453 |
-
self.MakeDeletionRegionButton.setFixedWidth(70)
|
| 454 |
-
self.DeleteMultipleROIButton = QPushButton("click-select")
|
| 455 |
-
self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells)
|
| 456 |
-
self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4)
|
| 457 |
-
self.DeleteMultipleROIButton.setFont(self.smallfont)
|
| 458 |
-
self.DeleteMultipleROIButton.setFixedWidth(70)
|
| 459 |
-
self.DoneDeleteMultipleROIButton = QPushButton("done")
|
| 460 |
-
self.DoneDeleteMultipleROIButton.clicked.connect(
|
| 461 |
-
self.done_remove_multiple_cells)
|
| 462 |
-
self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2)
|
| 463 |
-
self.DoneDeleteMultipleROIButton.setFont(self.smallfont)
|
| 464 |
-
self.DoneDeleteMultipleROIButton.setFixedWidth(35)
|
| 465 |
-
self.CancelDeleteMultipleROIButton = QPushButton("cancel")
|
| 466 |
-
self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple)
|
| 467 |
-
self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2)
|
| 468 |
-
self.CancelDeleteMultipleROIButton.setFont(self.smallfont)
|
| 469 |
-
self.CancelDeleteMultipleROIButton.setFixedWidth(35)
|
| 470 |
-
|
| 471 |
-
b += 1
|
| 472 |
-
widget_row = 0
|
| 473 |
-
self.segBox = QGroupBox("Segmentation")
|
| 474 |
-
self.segBoxG = QGridLayout()
|
| 475 |
-
self.segBox.setLayout(self.segBoxG)
|
| 476 |
-
self.l0.addWidget(self.segBox, b, 0, 1, 9)
|
| 477 |
-
self.segBox.setFont(self.boldfont)
|
| 478 |
-
|
| 479 |
-
widget_row += 1
|
| 480 |
-
|
| 481 |
-
# use GPU
|
| 482 |
-
self.useGPU = QCheckBox("use GPU")
|
| 483 |
-
self.useGPU.setToolTip(
|
| 484 |
-
"if you have specially installed the <i>cuda</i> version of torch, then you can activate this"
|
| 485 |
-
)
|
| 486 |
-
self.useGPU.setFont(self.medfont)
|
| 487 |
-
self.check_gpu()
|
| 488 |
-
self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)
|
| 489 |
-
|
| 490 |
-
# compute segmentation with general models
|
| 491 |
-
self.net_text = ["run CPSAM"]
|
| 492 |
-
nett = ["cellpose super-generalist model"]
|
| 493 |
-
|
| 494 |
-
self.StyleButtons = []
|
| 495 |
-
jj = 4
|
| 496 |
-
for j in range(len(self.net_text)):
|
| 497 |
-
self.StyleButtons.append(
|
| 498 |
-
guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
|
| 499 |
-
w = 5
|
| 500 |
-
self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w)
|
| 501 |
-
jj += w
|
| 502 |
-
self.StyleButtons[-1].setToolTip(nett[j])
|
| 503 |
-
|
| 504 |
-
widget_row += 1
|
| 505 |
-
self.ncells = guiparts.ObservableVariable(0)
|
| 506 |
-
self.roi_count = QLabel()
|
| 507 |
-
self.roi_count.setFont(self.boldfont)
|
| 508 |
-
self.roi_count.setAlignment(QtCore.Qt.AlignLeft)
|
| 509 |
-
self.ncells.valueChanged.connect(
|
| 510 |
-
lambda n: self.roi_count.setText(f'{str(n)} ROIs')
|
| 511 |
-
)
|
| 512 |
-
|
| 513 |
-
self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4)
|
| 514 |
-
|
| 515 |
-
self.progress = QProgressBar(self)
|
| 516 |
-
self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5)
|
| 517 |
-
|
| 518 |
-
widget_row += 1
|
| 519 |
-
|
| 520 |
-
############################### Segmentation settings ###############################
|
| 521 |
-
self.additional_seg_settings_qcollapsible = QCollapsible("additional settings")
|
| 522 |
-
self.additional_seg_settings_qcollapsible.setFont(self.medfont)
|
| 523 |
-
self.additional_seg_settings_qcollapsible._toggle_btn.setFont(self.medfont)
|
| 524 |
-
self.segmentation_settings = guiparts.SegmentationSettings(self.medfont)
|
| 525 |
-
self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings)
|
| 526 |
-
self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9)
|
| 527 |
-
|
| 528 |
-
# connect edits to image processing steps:
|
| 529 |
-
self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale)
|
| 530 |
-
self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob)
|
| 531 |
-
self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob)
|
| 532 |
-
self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob)
|
| 533 |
-
|
| 534 |
-
# Needed to do this for the drop down to not be open on startup
|
| 535 |
-
self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True)
|
| 536 |
-
self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(False)
|
| 537 |
-
|
| 538 |
-
b += 1
|
| 539 |
-
self.modelBox = QGroupBox("user-trained models")
|
| 540 |
-
self.modelBoxG = QGridLayout()
|
| 541 |
-
self.modelBox.setLayout(self.modelBoxG)
|
| 542 |
-
self.l0.addWidget(self.modelBox, b, 0, 1, 9)
|
| 543 |
-
self.modelBox.setFont(self.boldfont)
|
| 544 |
-
# choose models
|
| 545 |
-
self.ModelChooseC = QComboBox()
|
| 546 |
-
self.ModelChooseC.setFont(self.medfont)
|
| 547 |
-
current_index = 0
|
| 548 |
-
self.ModelChooseC.addItems(["custom models"])
|
| 549 |
-
if len(self.model_strings) > 0:
|
| 550 |
-
self.ModelChooseC.addItems(self.model_strings)
|
| 551 |
-
self.ModelChooseC.setFixedWidth(175)
|
| 552 |
-
self.ModelChooseC.setCurrentIndex(current_index)
|
| 553 |
-
tipstr = 'add or train your own models in the "Models" file menu and choose model here'
|
| 554 |
-
self.ModelChooseC.setToolTip(tipstr)
|
| 555 |
-
self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True))
|
| 556 |
-
self.modelBoxG.addWidget(self.ModelChooseC, widget_row, 0, 1, 8)
|
| 557 |
-
|
| 558 |
-
# compute segmentation w/ custom model
|
| 559 |
-
self.ModelButtonC = QPushButton(u"run")
|
| 560 |
-
self.ModelButtonC.setFont(self.medfont)
|
| 561 |
-
self.ModelButtonC.setFixedWidth(35)
|
| 562 |
-
self.ModelButtonC.clicked.connect(
|
| 563 |
-
lambda: self.compute_segmentation(custom=True))
|
| 564 |
-
self.modelBoxG.addWidget(self.ModelButtonC, widget_row, 8, 1, 1)
|
| 565 |
-
self.ModelButtonC.setEnabled(False)
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
b += 1
|
| 569 |
-
self.filterBox = QGroupBox("Image filtering")
|
| 570 |
-
self.filterBox.setFont(self.boldfont)
|
| 571 |
-
self.filterBox_grid_layout = QGridLayout()
|
| 572 |
-
self.filterBox.setLayout(self.filterBox_grid_layout)
|
| 573 |
-
self.l0.addWidget(self.filterBox, b, 0, 1, 9)
|
| 574 |
-
|
| 575 |
-
widget_row = 0
|
| 576 |
-
|
| 577 |
-
# Filtering
|
| 578 |
-
self.FilterButtons = []
|
| 579 |
-
nett = [
|
| 580 |
-
"clear restore/filter",
|
| 581 |
-
"filter image (settings below)",
|
| 582 |
-
]
|
| 583 |
-
self.filter_text = ["none",
|
| 584 |
-
"filter",
|
| 585 |
-
]
|
| 586 |
-
self.restore = None
|
| 587 |
-
self.ratio = 1.
|
| 588 |
-
jj = 0
|
| 589 |
-
w = 3
|
| 590 |
-
for j in range(len(self.filter_text)):
|
| 591 |
-
self.FilterButtons.append(
|
| 592 |
-
guiparts.FilterButton(self, self.filter_text[j]))
|
| 593 |
-
self.filterBox_grid_layout.addWidget(self.FilterButtons[-1], widget_row, jj, 1, w)
|
| 594 |
-
self.FilterButtons[-1].setFixedWidth(75)
|
| 595 |
-
self.FilterButtons[-1].setToolTip(nett[j])
|
| 596 |
-
self.FilterButtons[-1].setFont(self.medfont)
|
| 597 |
-
widget_row += 1 if j%2==1 else 0
|
| 598 |
-
jj = 0 if j%2==1 else jj + w
|
| 599 |
-
|
| 600 |
-
self.save_norm = QCheckBox("save restored/filtered image")
|
| 601 |
-
self.save_norm.setFont(self.medfont)
|
| 602 |
-
self.save_norm.setToolTip("save restored/filtered image in _seg.npy file")
|
| 603 |
-
self.save_norm.setChecked(True)
|
| 604 |
-
|
| 605 |
-
widget_row += 2
|
| 606 |
-
|
| 607 |
-
self.filtBox = QCollapsible("custom filter settings")
|
| 608 |
-
self.filtBox._toggle_btn.setFont(self.medfont)
|
| 609 |
-
self.filtBoxG = QGridLayout()
|
| 610 |
-
_content = QWidget()
|
| 611 |
-
_content.setLayout(self.filtBoxG)
|
| 612 |
-
_content.setMaximumHeight(0)
|
| 613 |
-
_content.setMinimumHeight(0)
|
| 614 |
-
self.filtBox.setContent(_content)
|
| 615 |
-
self.filterBox_grid_layout.addWidget(self.filtBox, widget_row, 0, 1, 9)
|
| 616 |
-
|
| 617 |
-
self.filt_vals = [0., 0., 0., 0.]
|
| 618 |
-
self.filt_edits = []
|
| 619 |
-
labels = [
|
| 620 |
-
"sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize",
|
| 621 |
-
"tile_norm\nsmooth3D"
|
| 622 |
-
]
|
| 623 |
-
tooltips = [
|
| 624 |
-
"set size of surround-subtraction filter for sharpening image",
|
| 625 |
-
"set size of gaussian filter for smoothing image",
|
| 626 |
-
"set size of tiles to use to normalize image",
|
| 627 |
-
"set amount of smoothing of normalization values across planes"
|
| 628 |
-
]
|
| 629 |
-
|
| 630 |
-
for p in range(4):
|
| 631 |
-
label = QLabel(f"{labels[p]}:")
|
| 632 |
-
label.setToolTip(tooltips[p])
|
| 633 |
-
label.setFont(self.medfont)
|
| 634 |
-
self.filtBoxG.addWidget(label, widget_row + p // 2, 4 * (p % 2), 1, 2)
|
| 635 |
-
self.filt_edits.append(QLineEdit())
|
| 636 |
-
self.filt_edits[p].setText(str(self.filt_vals[p]))
|
| 637 |
-
self.filt_edits[p].setFixedWidth(40)
|
| 638 |
-
self.filt_edits[p].setFont(self.medfont)
|
| 639 |
-
self.filtBoxG.addWidget(self.filt_edits[p], widget_row + p // 2, 4 * (p % 2) + 2, 1,
|
| 640 |
-
2)
|
| 641 |
-
self.filt_edits[p].setToolTip(tooltips[p])
|
| 642 |
-
|
| 643 |
-
widget_row += 3
|
| 644 |
-
self.norm3D_cb = QCheckBox("norm3D")
|
| 645 |
-
self.norm3D_cb.setFont(self.medfont)
|
| 646 |
-
self.norm3D_cb.setChecked(True)
|
| 647 |
-
self.norm3D_cb.setToolTip("run same normalization across planes")
|
| 648 |
-
self.filtBoxG.addWidget(self.norm3D_cb, widget_row, 0, 1, 3)
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
return b
|
| 652 |
-
|
| 653 |
-
def level_change(self, r):
|
| 654 |
-
r = ["red", "green", "blue"].index(r)
|
| 655 |
-
if self.loaded:
|
| 656 |
-
sval = self.sliders[r].value()
|
| 657 |
-
self.saturation[r][self.currentZ] = sval
|
| 658 |
-
if not self.autobtn.isChecked():
|
| 659 |
-
for r in range(3):
|
| 660 |
-
for i in range(len(self.saturation[r])):
|
| 661 |
-
self.saturation[r][i] = self.saturation[r][self.currentZ]
|
| 662 |
-
self.update_plot()
|
| 663 |
-
|
| 664 |
-
def keyPressEvent(self, event):
|
| 665 |
-
if self.loaded:
|
| 666 |
-
if not (event.modifiers() &
|
| 667 |
-
(QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
|
| 668 |
-
QtCore.Qt.AltModifier) or self.in_stroke):
|
| 669 |
-
updated = False
|
| 670 |
-
if len(self.current_point_set) > 0:
|
| 671 |
-
if event.key() == QtCore.Qt.Key_Return:
|
| 672 |
-
self.add_set()
|
| 673 |
-
else:
|
| 674 |
-
nviews = self.ViewDropDown.count() - 1
|
| 675 |
-
nviews += int(
|
| 676 |
-
self.ViewDropDown.model().item(self.ViewDropDown.count() -
|
| 677 |
-
1).isEnabled())
|
| 678 |
-
if event.key() == QtCore.Qt.Key_X:
|
| 679 |
-
self.MCheckBox.toggle()
|
| 680 |
-
if event.key() == QtCore.Qt.Key_Z:
|
| 681 |
-
self.OCheckBox.toggle()
|
| 682 |
-
if event.key() == QtCore.Qt.Key_Left or event.key(
|
| 683 |
-
) == QtCore.Qt.Key_A:
|
| 684 |
-
self.get_prev_image()
|
| 685 |
-
elif event.key() == QtCore.Qt.Key_Right or event.key(
|
| 686 |
-
) == QtCore.Qt.Key_D:
|
| 687 |
-
self.get_next_image()
|
| 688 |
-
elif event.key() == QtCore.Qt.Key_PageDown:
|
| 689 |
-
self.view = (self.view + 1) % (nviews)
|
| 690 |
-
self.ViewDropDown.setCurrentIndex(self.view)
|
| 691 |
-
elif event.key() == QtCore.Qt.Key_PageUp:
|
| 692 |
-
self.view = (self.view - 1) % (nviews)
|
| 693 |
-
self.ViewDropDown.setCurrentIndex(self.view)
|
| 694 |
-
|
| 695 |
-
# can change background or stroke size if cell not finished
|
| 696 |
-
if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
|
| 697 |
-
self.color = (self.color - 1) % (6)
|
| 698 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 699 |
-
elif event.key() == QtCore.Qt.Key_Down or event.key(
|
| 700 |
-
) == QtCore.Qt.Key_S:
|
| 701 |
-
self.color = (self.color + 1) % (6)
|
| 702 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 703 |
-
elif event.key() == QtCore.Qt.Key_R:
|
| 704 |
-
if self.color != 1:
|
| 705 |
-
self.color = 1
|
| 706 |
-
else:
|
| 707 |
-
self.color = 0
|
| 708 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 709 |
-
elif event.key() == QtCore.Qt.Key_G:
|
| 710 |
-
if self.color != 2:
|
| 711 |
-
self.color = 2
|
| 712 |
-
else:
|
| 713 |
-
self.color = 0
|
| 714 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 715 |
-
elif event.key() == QtCore.Qt.Key_B:
|
| 716 |
-
if self.color != 3:
|
| 717 |
-
self.color = 3
|
| 718 |
-
else:
|
| 719 |
-
self.color = 0
|
| 720 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 721 |
-
elif (event.key() == QtCore.Qt.Key_Comma or
|
| 722 |
-
event.key() == QtCore.Qt.Key_Period):
|
| 723 |
-
count = self.BrushChoose.count()
|
| 724 |
-
gci = self.BrushChoose.currentIndex()
|
| 725 |
-
if event.key() == QtCore.Qt.Key_Comma:
|
| 726 |
-
gci = max(0, gci - 1)
|
| 727 |
-
else:
|
| 728 |
-
gci = min(count - 1, gci + 1)
|
| 729 |
-
self.BrushChoose.setCurrentIndex(gci)
|
| 730 |
-
self.brush_choose()
|
| 731 |
-
if not updated:
|
| 732 |
-
self.update_plot()
|
| 733 |
-
if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
|
| 734 |
-
self.p0.keyPressEvent(event)
|
| 735 |
-
|
| 736 |
-
def autosave_on(self):
|
| 737 |
-
if self.SCheckBox.isChecked():
|
| 738 |
-
self.autosave = True
|
| 739 |
-
else:
|
| 740 |
-
self.autosave = False
|
| 741 |
-
|
| 742 |
-
def check_gpu(self, torch=True):
|
| 743 |
-
# also decide whether or not to use torch
|
| 744 |
-
self.useGPU.setChecked(False)
|
| 745 |
-
self.useGPU.setEnabled(False)
|
| 746 |
-
if core.use_gpu(use_torch=True):
|
| 747 |
-
self.useGPU.setEnabled(True)
|
| 748 |
-
self.useGPU.setChecked(True)
|
| 749 |
-
else:
|
| 750 |
-
self.useGPU.setStyleSheet("color: rgb(80,80,80);")
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
def model_choose(self, custom=False):
|
| 754 |
-
index = self.ModelChooseC.currentIndex(
|
| 755 |
-
) if custom else self.ModelChooseB.currentIndex()
|
| 756 |
-
if index > 0:
|
| 757 |
-
if custom:
|
| 758 |
-
model_name = self.ModelChooseC.currentText()
|
| 759 |
-
else:
|
| 760 |
-
model_name = self.net_names[index - 1]
|
| 761 |
-
print(f"GUI_INFO: selected model {model_name}, loading now")
|
| 762 |
-
self.initialize_model(model_name=model_name, custom=custom)
|
| 763 |
-
|
| 764 |
-
def toggle_scale(self):
|
| 765 |
-
if self.scale_on:
|
| 766 |
-
self.p0.removeItem(self.scale)
|
| 767 |
-
self.scale_on = False
|
| 768 |
-
else:
|
| 769 |
-
self.p0.addItem(self.scale)
|
| 770 |
-
self.scale_on = True
|
| 771 |
-
|
| 772 |
-
def enable_buttons(self):
|
| 773 |
-
if len(self.model_strings) > 0:
|
| 774 |
-
self.ModelButtonC.setEnabled(True)
|
| 775 |
-
for i in range(len(self.StyleButtons)):
|
| 776 |
-
self.StyleButtons[i].setEnabled(True)
|
| 777 |
-
|
| 778 |
-
for i in range(len(self.FilterButtons)):
|
| 779 |
-
self.FilterButtons[i].setEnabled(True)
|
| 780 |
-
if self.load_3D:
|
| 781 |
-
self.FilterButtons[-2].setEnabled(False)
|
| 782 |
-
|
| 783 |
-
self.newmodel.setEnabled(True)
|
| 784 |
-
self.loadMasks.setEnabled(True)
|
| 785 |
-
|
| 786 |
-
for n in range(self.nchan):
|
| 787 |
-
self.sliders[n].setEnabled(True)
|
| 788 |
-
for n in range(self.nchan, 3):
|
| 789 |
-
self.sliders[n].setEnabled(True)
|
| 790 |
-
|
| 791 |
-
self.toggle_mask_ops()
|
| 792 |
-
|
| 793 |
-
self.update_plot()
|
| 794 |
-
self.setWindowTitle(self.filename)
|
| 795 |
-
|
| 796 |
-
def disable_buttons_removeROIs(self):
|
| 797 |
-
if len(self.model_strings) > 0:
|
| 798 |
-
self.ModelButtonC.setEnabled(False)
|
| 799 |
-
for i in range(len(self.StyleButtons)):
|
| 800 |
-
self.StyleButtons[i].setEnabled(False)
|
| 801 |
-
self.newmodel.setEnabled(False)
|
| 802 |
-
self.loadMasks.setEnabled(False)
|
| 803 |
-
self.saveSet.setEnabled(False)
|
| 804 |
-
self.savePNG.setEnabled(False)
|
| 805 |
-
self.saveFlows.setEnabled(False)
|
| 806 |
-
self.saveOutlines.setEnabled(False)
|
| 807 |
-
self.saveROIs.setEnabled(False)
|
| 808 |
-
|
| 809 |
-
self.MakeDeletionRegionButton.setEnabled(False)
|
| 810 |
-
self.DeleteMultipleROIButton.setEnabled(False)
|
| 811 |
-
self.DoneDeleteMultipleROIButton.setEnabled(True)
|
| 812 |
-
self.CancelDeleteMultipleROIButton.setEnabled(True)
|
| 813 |
-
|
| 814 |
-
def toggle_mask_ops(self):
|
| 815 |
-
self.update_layer()
|
| 816 |
-
self.toggle_saving()
|
| 817 |
-
self.toggle_removals()
|
| 818 |
-
|
| 819 |
-
def toggle_saving(self):
|
| 820 |
-
if self.ncells > 0:
|
| 821 |
-
self.saveSet.setEnabled(True)
|
| 822 |
-
self.savePNG.setEnabled(True)
|
| 823 |
-
self.saveFlows.setEnabled(True)
|
| 824 |
-
self.saveOutlines.setEnabled(True)
|
| 825 |
-
self.saveROIs.setEnabled(True)
|
| 826 |
-
else:
|
| 827 |
-
self.saveSet.setEnabled(False)
|
| 828 |
-
self.savePNG.setEnabled(False)
|
| 829 |
-
self.saveFlows.setEnabled(False)
|
| 830 |
-
self.saveOutlines.setEnabled(False)
|
| 831 |
-
self.saveROIs.setEnabled(False)
|
| 832 |
-
|
| 833 |
-
def toggle_removals(self):
|
| 834 |
-
if self.ncells > 0:
|
| 835 |
-
self.ClearButton.setEnabled(True)
|
| 836 |
-
self.remcell.setEnabled(True)
|
| 837 |
-
self.undo.setEnabled(True)
|
| 838 |
-
self.MakeDeletionRegionButton.setEnabled(True)
|
| 839 |
-
self.DeleteMultipleROIButton.setEnabled(True)
|
| 840 |
-
self.DoneDeleteMultipleROIButton.setEnabled(False)
|
| 841 |
-
self.CancelDeleteMultipleROIButton.setEnabled(False)
|
| 842 |
-
else:
|
| 843 |
-
self.ClearButton.setEnabled(False)
|
| 844 |
-
self.remcell.setEnabled(False)
|
| 845 |
-
self.undo.setEnabled(False)
|
| 846 |
-
self.MakeDeletionRegionButton.setEnabled(False)
|
| 847 |
-
self.DeleteMultipleROIButton.setEnabled(False)
|
| 848 |
-
self.DoneDeleteMultipleROIButton.setEnabled(False)
|
| 849 |
-
self.CancelDeleteMultipleROIButton.setEnabled(False)
|
| 850 |
-
|
| 851 |
-
def remove_action(self):
|
| 852 |
-
if self.selected > 0:
|
| 853 |
-
self.remove_cell(self.selected)
|
| 854 |
-
|
| 855 |
-
def undo_action(self):
|
| 856 |
-
if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ):
|
| 857 |
-
self.remove_stroke()
|
| 858 |
-
else:
|
| 859 |
-
# remove previous cell
|
| 860 |
-
if self.ncells > 0:
|
| 861 |
-
self.remove_cell(self.ncells.get())
|
| 862 |
-
|
| 863 |
-
def undo_remove_action(self):
|
| 864 |
-
self.undo_remove_cell()
|
| 865 |
-
|
| 866 |
-
def get_files(self):
|
| 867 |
-
folder = os.path.dirname(self.filename)
|
| 868 |
-
mask_filter = "_masks"
|
| 869 |
-
images = get_image_files(folder, mask_filter)
|
| 870 |
-
fnames = [os.path.split(images[k])[-1] for k in range(len(images))]
|
| 871 |
-
f0 = os.path.split(self.filename)[-1]
|
| 872 |
-
idx = np.nonzero(np.array(fnames) == f0)[0][0]
|
| 873 |
-
return images, idx
|
| 874 |
-
|
| 875 |
-
def get_prev_image(self):
|
| 876 |
-
images, idx = self.get_files()
|
| 877 |
-
idx = (idx - 1) % len(images)
|
| 878 |
-
io._load_image(self, filename=images[idx])
|
| 879 |
-
|
| 880 |
-
def get_next_image(self, load_seg=True):
|
| 881 |
-
images, idx = self.get_files()
|
| 882 |
-
idx = (idx + 1) % len(images)
|
| 883 |
-
io._load_image(self, filename=images[idx], load_seg=load_seg)
|
| 884 |
-
|
| 885 |
-
def dragEnterEvent(self, event):
|
| 886 |
-
if event.mimeData().hasUrls():
|
| 887 |
-
event.accept()
|
| 888 |
-
else:
|
| 889 |
-
event.ignore()
|
| 890 |
-
|
| 891 |
-
def dropEvent(self, event):
|
| 892 |
-
files = [u.toLocalFile() for u in event.mimeData().urls()]
|
| 893 |
-
if os.path.splitext(files[0])[-1] == ".npy":
|
| 894 |
-
io._load_seg(self, filename=files[0], load_3D=self.load_3D)
|
| 895 |
-
else:
|
| 896 |
-
io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D)
|
| 897 |
-
|
| 898 |
-
def toggle_masks(self):
|
| 899 |
-
if self.MCheckBox.isChecked():
|
| 900 |
-
self.masksOn = True
|
| 901 |
-
else:
|
| 902 |
-
self.masksOn = False
|
| 903 |
-
if self.OCheckBox.isChecked():
|
| 904 |
-
self.outlinesOn = True
|
| 905 |
-
else:
|
| 906 |
-
self.outlinesOn = False
|
| 907 |
-
if not self.masksOn and not self.outlinesOn:
|
| 908 |
-
self.p0.removeItem(self.layer)
|
| 909 |
-
self.layer_off = True
|
| 910 |
-
else:
|
| 911 |
-
if self.layer_off:
|
| 912 |
-
self.p0.addItem(self.layer)
|
| 913 |
-
self.draw_layer()
|
| 914 |
-
self.update_layer()
|
| 915 |
-
if self.loaded:
|
| 916 |
-
self.update_plot()
|
| 917 |
-
self.update_layer()
|
| 918 |
-
|
| 919 |
-
def make_viewbox(self):
|
| 920 |
-
self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True,
|
| 921 |
-
name="plot1", border=[100, 100,
|
| 922 |
-
100], invertY=True)
|
| 923 |
-
self.p0.setCursor(QtCore.Qt.CrossCursor)
|
| 924 |
-
self.brush_size = 3
|
| 925 |
-
self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1)
|
| 926 |
-
self.p0.setMenuEnabled(False)
|
| 927 |
-
self.p0.setMouseEnabled(x=True, y=True)
|
| 928 |
-
self.img = pg.ImageItem(viewbox=self.p0, parent=self)
|
| 929 |
-
self.img.autoDownsample = False
|
| 930 |
-
self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self)
|
| 931 |
-
self.layer.setLevels([0, 255])
|
| 932 |
-
self.scale = pg.ImageItem(viewbox=self.p0, parent=self)
|
| 933 |
-
self.scale.setLevels([0, 255])
|
| 934 |
-
self.p0.scene().contextMenuItem = self.p0
|
| 935 |
-
self.Ly, self.Lx = 512, 512
|
| 936 |
-
self.p0.addItem(self.img)
|
| 937 |
-
self.p0.addItem(self.layer)
|
| 938 |
-
self.p0.addItem(self.scale)
|
| 939 |
-
|
| 940 |
-
def reset(self):
|
| 941 |
-
# ---- start sets of points ---- #
|
| 942 |
-
self.selected = 0
|
| 943 |
-
self.nchan = 3
|
| 944 |
-
self.loaded = False
|
| 945 |
-
self.channel = [0, 1]
|
| 946 |
-
self.current_point_set = []
|
| 947 |
-
self.in_stroke = False
|
| 948 |
-
self.strokes = []
|
| 949 |
-
self.stroke_appended = True
|
| 950 |
-
self.resize = False
|
| 951 |
-
self.ncells.reset()
|
| 952 |
-
self.zdraw = []
|
| 953 |
-
self.removed_cell = []
|
| 954 |
-
self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
|
| 955 |
-
|
| 956 |
-
# -- zero out image stack -- #
|
| 957 |
-
self.opacity = 128 # how opaque masks should be
|
| 958 |
-
self.outcolor = [200, 200, 255, 200]
|
| 959 |
-
self.NZ, self.Ly, self.Lx = 1, 256, 256
|
| 960 |
-
self.saturation = self.saturation if hasattr(self, 'saturation') else []
|
| 961 |
-
|
| 962 |
-
# only adjust the saturation if auto-adjust is on:
|
| 963 |
-
if self.autobtn.isChecked():
|
| 964 |
-
for r in range(3):
|
| 965 |
-
self.saturation.append([[0, 255] for n in range(self.NZ)])
|
| 966 |
-
self.sliders[r].setValue([0, 255])
|
| 967 |
-
self.sliders[r].setEnabled(False)
|
| 968 |
-
self.sliders[r].show()
|
| 969 |
-
self.currentZ = 0
|
| 970 |
-
self.flows = [[], [], [], [], [[]]]
|
| 971 |
-
# masks matrix
|
| 972 |
-
# image matrix with a scale disk
|
| 973 |
-
self.stack = np.zeros((1, self.Ly, self.Lx, 3))
|
| 974 |
-
self.Lyr, self.Lxr = self.Ly, self.Lx
|
| 975 |
-
self.Ly0, self.Lx0 = self.Ly, self.Lx
|
| 976 |
-
self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
|
| 977 |
-
self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
|
| 978 |
-
self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
|
| 979 |
-
self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
|
| 980 |
-
self.ismanual = np.zeros(0, "bool")
|
| 981 |
-
|
| 982 |
-
# -- set menus to default -- #
|
| 983 |
-
self.color = 0
|
| 984 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 985 |
-
self.view = 0
|
| 986 |
-
self.ViewDropDown.setCurrentIndex(0)
|
| 987 |
-
self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
|
| 988 |
-
self.delete_restore()
|
| 989 |
-
|
| 990 |
-
self.clear_all()
|
| 991 |
-
|
| 992 |
-
self.filename = []
|
| 993 |
-
self.loaded = False
|
| 994 |
-
self.recompute_masks = False
|
| 995 |
-
|
| 996 |
-
self.deleting_multiple = False
|
| 997 |
-
self.removing_cells_list = []
|
| 998 |
-
self.removing_region = False
|
| 999 |
-
self.remove_roi_obj = None
|
| 1000 |
-
|
| 1001 |
-
def delete_restore(self):
|
| 1002 |
-
""" delete restored imgs but don't reset settings """
|
| 1003 |
-
if hasattr(self, "stack_filtered"):
|
| 1004 |
-
del self.stack_filtered
|
| 1005 |
-
if hasattr(self, "cellpix_orig"):
|
| 1006 |
-
self.cellpix = self.cellpix_orig.copy()
|
| 1007 |
-
self.outpix = self.outpix_orig.copy()
|
| 1008 |
-
del self.outpix_orig, self.outpix_resize
|
| 1009 |
-
del self.cellpix_orig, self.cellpix_resize
|
| 1010 |
-
|
| 1011 |
-
def clear_restore(self):
|
| 1012 |
-
""" delete restored imgs and reset settings """
|
| 1013 |
-
print("GUI_INFO: clearing restored image")
|
| 1014 |
-
self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
|
| 1015 |
-
if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1:
|
| 1016 |
-
self.ViewDropDown.setCurrentIndex(0)
|
| 1017 |
-
self.delete_restore()
|
| 1018 |
-
self.restore = None
|
| 1019 |
-
self.ratio = 1.
|
| 1020 |
-
self.set_normalize_params(self.get_normalize_params())
|
| 1021 |
-
|
| 1022 |
-
def brush_choose(self):
|
| 1023 |
-
self.brush_size = self.BrushChoose.currentIndex() * 2 + 1
|
| 1024 |
-
if self.loaded:
|
| 1025 |
-
self.layer.setDrawKernel(kernel_size=self.brush_size)
|
| 1026 |
-
self.update_layer()
|
| 1027 |
-
|
| 1028 |
-
def clear_all(self):
|
| 1029 |
-
self.prev_selected = 0
|
| 1030 |
-
self.selected = 0
|
| 1031 |
-
if self.restore and "upsample" in self.restore:
|
| 1032 |
-
self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8)
|
| 1033 |
-
self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
|
| 1034 |
-
self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
|
| 1035 |
-
self.cellpix_resize = self.cellpix.copy()
|
| 1036 |
-
self.outpix_resize = self.outpix.copy()
|
| 1037 |
-
self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
|
| 1038 |
-
self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
|
| 1039 |
-
else:
|
| 1040 |
-
self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
|
| 1041 |
-
self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
|
| 1042 |
-
self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
|
| 1043 |
-
|
| 1044 |
-
self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
|
| 1045 |
-
self.ncells.reset()
|
| 1046 |
-
self.toggle_removals()
|
| 1047 |
-
self.update_scale()
|
| 1048 |
-
self.update_layer()
|
| 1049 |
-
|
| 1050 |
-
def select_cell(self, idx):
|
| 1051 |
-
self.prev_selected = self.selected
|
| 1052 |
-
self.selected = idx
|
| 1053 |
-
if self.selected > 0:
|
| 1054 |
-
z = self.currentZ
|
| 1055 |
-
self.layerz[self.cellpix[z] == idx] = np.array(
|
| 1056 |
-
[255, 255, 255, self.opacity])
|
| 1057 |
-
self.update_layer()
|
| 1058 |
-
|
| 1059 |
-
def select_cell_multi(self, idx):
|
| 1060 |
-
if idx > 0:
|
| 1061 |
-
z = self.currentZ
|
| 1062 |
-
self.layerz[self.cellpix[z] == idx] = np.array(
|
| 1063 |
-
[255, 255, 255, self.opacity])
|
| 1064 |
-
self.update_layer()
|
| 1065 |
-
|
| 1066 |
-
def unselect_cell(self):
|
| 1067 |
-
if self.selected > 0:
|
| 1068 |
-
idx = self.selected
|
| 1069 |
-
if idx < (self.ncells.get() + 1):
|
| 1070 |
-
z = self.currentZ
|
| 1071 |
-
self.layerz[self.cellpix[z] == idx] = np.append(
|
| 1072 |
-
self.cellcolors[idx], self.opacity)
|
| 1073 |
-
if self.outlinesOn:
|
| 1074 |
-
self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
|
| 1075 |
-
np.uint8)
|
| 1076 |
-
#[0,0,0,self.opacity])
|
| 1077 |
-
self.update_layer()
|
| 1078 |
-
self.selected = 0
|
| 1079 |
-
|
| 1080 |
-
def unselect_cell_multi(self, idx):
|
| 1081 |
-
z = self.currentZ
|
| 1082 |
-
self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx],
|
| 1083 |
-
self.opacity)
|
| 1084 |
-
if self.outlinesOn:
|
| 1085 |
-
self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
|
| 1086 |
-
np.uint8)
|
| 1087 |
-
# [0,0,0,self.opacity])
|
| 1088 |
-
self.update_layer()
|
| 1089 |
-
|
| 1090 |
-
def remove_cell(self, idx):
|
| 1091 |
-
if isinstance(idx, (int, np.integer)):
|
| 1092 |
-
idx = [idx]
|
| 1093 |
-
# because the function remove_single_cell updates the state of the cellpix and outpix arrays
|
| 1094 |
-
# by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
|
| 1095 |
-
# so that the indices are correct
|
| 1096 |
-
idx.sort(reverse=True)
|
| 1097 |
-
for i in idx:
|
| 1098 |
-
self.remove_single_cell(i)
|
| 1099 |
-
self.ncells -= len(idx) # _save_sets uses ncells
|
| 1100 |
-
self.update_layer()
|
| 1101 |
-
|
| 1102 |
-
if self.ncells == 0:
|
| 1103 |
-
self.ClearButton.setEnabled(False)
|
| 1104 |
-
if self.NZ == 1:
|
| 1105 |
-
io._save_sets_with_check(self)
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
def remove_single_cell(self, idx):
|
| 1109 |
-
# remove from manual array
|
| 1110 |
-
self.selected = 0
|
| 1111 |
-
if self.NZ > 1:
|
| 1112 |
-
zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0]
|
| 1113 |
-
else:
|
| 1114 |
-
zextent = [0]
|
| 1115 |
-
for z in zextent:
|
| 1116 |
-
cp = self.cellpix[z] == idx
|
| 1117 |
-
op = self.outpix[z] == idx
|
| 1118 |
-
# remove from self.cellpix and self.outpix
|
| 1119 |
-
self.cellpix[z, cp] = 0
|
| 1120 |
-
self.outpix[z, op] = 0
|
| 1121 |
-
if z == self.currentZ:
|
| 1122 |
-
# remove from mask layer
|
| 1123 |
-
self.layerz[cp] = np.array([0, 0, 0, 0])
|
| 1124 |
-
|
| 1125 |
-
# reduce other pixels by -1
|
| 1126 |
-
self.cellpix[self.cellpix > idx] -= 1
|
| 1127 |
-
self.outpix[self.outpix > idx] -= 1
|
| 1128 |
-
|
| 1129 |
-
if self.NZ == 1:
|
| 1130 |
-
self.removed_cell = [
|
| 1131 |
-
self.ismanual[idx - 1], self.cellcolors[idx],
|
| 1132 |
-
np.nonzero(cp),
|
| 1133 |
-
np.nonzero(op)
|
| 1134 |
-
]
|
| 1135 |
-
self.redo.setEnabled(True)
|
| 1136 |
-
ar, ac = self.removed_cell[2]
|
| 1137 |
-
d = datetime.datetime.now()
|
| 1138 |
-
self.track_changes.append(
|
| 1139 |
-
[d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
|
| 1140 |
-
# remove cell from lists
|
| 1141 |
-
self.ismanual = np.delete(self.ismanual, idx - 1)
|
| 1142 |
-
self.cellcolors = np.delete(self.cellcolors, [idx], axis=0)
|
| 1143 |
-
del self.zdraw[idx - 1]
|
| 1144 |
-
print("GUI_INFO: removed cell %d" % (idx - 1))
|
| 1145 |
-
|
| 1146 |
-
def remove_region_cells(self):
|
| 1147 |
-
if self.removing_cells_list:
|
| 1148 |
-
for idx in self.removing_cells_list:
|
| 1149 |
-
self.unselect_cell_multi(idx)
|
| 1150 |
-
self.removing_cells_list.clear()
|
| 1151 |
-
self.disable_buttons_removeROIs()
|
| 1152 |
-
self.removing_region = True
|
| 1153 |
-
|
| 1154 |
-
self.clear_multi_selected_cells()
|
| 1155 |
-
|
| 1156 |
-
# make roi region here in center of view, making ROI half the size of the view
|
| 1157 |
-
roi_width = self.p0.viewRect().width() / 2
|
| 1158 |
-
x_loc = self.p0.viewRect().x() + (roi_width / 2)
|
| 1159 |
-
roi_height = self.p0.viewRect().height() / 2
|
| 1160 |
-
y_loc = self.p0.viewRect().y() + (roi_height / 2)
|
| 1161 |
-
|
| 1162 |
-
pos = [x_loc, y_loc]
|
| 1163 |
-
roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2),
|
| 1164 |
-
removable=True)
|
| 1165 |
-
roi.sigRemoveRequested.connect(self.remove_roi)
|
| 1166 |
-
roi.sigRegionChangeFinished.connect(self.roi_changed)
|
| 1167 |
-
self.p0.addItem(roi)
|
| 1168 |
-
self.remove_roi_obj = roi
|
| 1169 |
-
self.roi_changed(roi)
|
| 1170 |
-
|
| 1171 |
-
def delete_multiple_cells(self):
|
| 1172 |
-
self.unselect_cell()
|
| 1173 |
-
self.disable_buttons_removeROIs()
|
| 1174 |
-
self.DoneDeleteMultipleROIButton.setEnabled(True)
|
| 1175 |
-
self.MakeDeletionRegionButton.setEnabled(True)
|
| 1176 |
-
self.CancelDeleteMultipleROIButton.setEnabled(True)
|
| 1177 |
-
self.deleting_multiple = True
|
| 1178 |
-
|
| 1179 |
-
def done_remove_multiple_cells(self):
|
| 1180 |
-
self.deleting_multiple = False
|
| 1181 |
-
self.removing_region = False
|
| 1182 |
-
self.DoneDeleteMultipleROIButton.setEnabled(False)
|
| 1183 |
-
self.MakeDeletionRegionButton.setEnabled(False)
|
| 1184 |
-
self.CancelDeleteMultipleROIButton.setEnabled(False)
|
| 1185 |
-
|
| 1186 |
-
if self.removing_cells_list:
|
| 1187 |
-
self.removing_cells_list = list(set(self.removing_cells_list))
|
| 1188 |
-
display_remove_list = [i - 1 for i in self.removing_cells_list]
|
| 1189 |
-
print(f"GUI_INFO: removing cells: {display_remove_list}")
|
| 1190 |
-
self.remove_cell(self.removing_cells_list)
|
| 1191 |
-
self.removing_cells_list.clear()
|
| 1192 |
-
self.unselect_cell()
|
| 1193 |
-
self.enable_buttons()
|
| 1194 |
-
|
| 1195 |
-
if self.remove_roi_obj is not None:
|
| 1196 |
-
self.remove_roi(self.remove_roi_obj)
|
| 1197 |
-
|
| 1198 |
-
def merge_cells(self, idx):
|
| 1199 |
-
self.prev_selected = self.selected
|
| 1200 |
-
self.selected = idx
|
| 1201 |
-
if self.selected != self.prev_selected:
|
| 1202 |
-
for z in range(self.NZ):
|
| 1203 |
-
ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected)
|
| 1204 |
-
ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected)
|
| 1205 |
-
touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3,
|
| 1206 |
-
(ac0[:, np.newaxis] - ac1) < 3).sum()
|
| 1207 |
-
ar = np.hstack((ar0, ar1))
|
| 1208 |
-
ac = np.hstack((ac0, ac1))
|
| 1209 |
-
vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected)
|
| 1210 |
-
vr1, vc1 = np.nonzero(self.outpix[z] == self.selected)
|
| 1211 |
-
self.outpix[z, vr0, vc0] = 0
|
| 1212 |
-
self.outpix[z, vr1, vc1] = 0
|
| 1213 |
-
if touching > 0:
|
| 1214 |
-
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
|
| 1215 |
-
mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
|
| 1216 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 1217 |
-
cv2.CHAIN_APPROX_NONE)
|
| 1218 |
-
pvc, pvr = contours[-2][0].squeeze().T
|
| 1219 |
-
vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
|
| 1220 |
-
|
| 1221 |
-
else:
|
| 1222 |
-
vr = np.hstack((vr0, vr1))
|
| 1223 |
-
vc = np.hstack((vc0, vc1))
|
| 1224 |
-
color = self.cellcolors[self.prev_selected]
|
| 1225 |
-
self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected)
|
| 1226 |
-
self.remove_cell(self.selected)
|
| 1227 |
-
print("GUI_INFO: merged two cells")
|
| 1228 |
-
self.update_layer()
|
| 1229 |
-
io._save_sets_with_check(self)
|
| 1230 |
-
self.undo.setEnabled(False)
|
| 1231 |
-
self.redo.setEnabled(False)
|
| 1232 |
-
|
| 1233 |
-
def undo_remove_cell(self):
|
| 1234 |
-
if len(self.removed_cell) > 0:
|
| 1235 |
-
z = 0
|
| 1236 |
-
ar, ac = self.removed_cell[2]
|
| 1237 |
-
vr, vc = self.removed_cell[3]
|
| 1238 |
-
color = self.removed_cell[1]
|
| 1239 |
-
self.draw_mask(z, ar, ac, vr, vc, color)
|
| 1240 |
-
self.toggle_mask_ops()
|
| 1241 |
-
self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0)
|
| 1242 |
-
self.ncells += 1
|
| 1243 |
-
self.ismanual = np.append(self.ismanual, self.removed_cell[0])
|
| 1244 |
-
self.zdraw.append([])
|
| 1245 |
-
print(">>> added back removed cell")
|
| 1246 |
-
self.update_layer()
|
| 1247 |
-
io._save_sets_with_check(self)
|
| 1248 |
-
self.removed_cell = []
|
| 1249 |
-
self.redo.setEnabled(False)
|
| 1250 |
-
|
| 1251 |
-
def remove_stroke(self, delete_points=True, stroke_ind=-1):
|
| 1252 |
-
stroke = np.array(self.strokes[stroke_ind])
|
| 1253 |
-
cZ = self.currentZ
|
| 1254 |
-
inZ = stroke[0, 0] == cZ
|
| 1255 |
-
if inZ:
|
| 1256 |
-
outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0
|
| 1257 |
-
self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0])
|
| 1258 |
-
cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]]
|
| 1259 |
-
ccol = self.cellcolors.copy()
|
| 1260 |
-
if self.selected > 0:
|
| 1261 |
-
ccol[self.selected] = np.array([255, 255, 255])
|
| 1262 |
-
col2mask = ccol[cellpix]
|
| 1263 |
-
if self.masksOn:
|
| 1264 |
-
col2mask = np.concatenate(
|
| 1265 |
-
(col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1)
|
| 1266 |
-
else:
|
| 1267 |
-
col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)),
|
| 1268 |
-
axis=-1)
|
| 1269 |
-
self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask
|
| 1270 |
-
if self.outlinesOn:
|
| 1271 |
-
self.layerz[stroke[outpix, 1], stroke[outpix,
|
| 1272 |
-
2]] = np.array(self.outcolor)
|
| 1273 |
-
if delete_points:
|
| 1274 |
-
del self.current_point_set[stroke_ind]
|
| 1275 |
-
self.update_layer()
|
| 1276 |
-
|
| 1277 |
-
del self.strokes[stroke_ind]
|
| 1278 |
-
|
| 1279 |
-
def plot_clicked(self, event):
|
| 1280 |
-
if event.button()==QtCore.Qt.LeftButton \
|
| 1281 |
-
and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
|
| 1282 |
-
and not self.removing_region:
|
| 1283 |
-
if event.double():
|
| 1284 |
-
try:
|
| 1285 |
-
self.p0.setYRange(0, self.Ly + self.pr)
|
| 1286 |
-
except:
|
| 1287 |
-
self.p0.setYRange(0, self.Ly)
|
| 1288 |
-
self.p0.setXRange(0, self.Lx)
|
| 1289 |
-
|
| 1290 |
-
def cancel_remove_multiple(self):
|
| 1291 |
-
self.clear_multi_selected_cells()
|
| 1292 |
-
self.done_remove_multiple_cells()
|
| 1293 |
-
|
| 1294 |
-
def clear_multi_selected_cells(self):
|
| 1295 |
-
# unselect all previously selected cells:
|
| 1296 |
-
for idx in self.removing_cells_list:
|
| 1297 |
-
self.unselect_cell_multi(idx)
|
| 1298 |
-
self.removing_cells_list.clear()
|
| 1299 |
-
|
| 1300 |
-
def add_roi(self, roi):
|
| 1301 |
-
self.p0.addItem(roi)
|
| 1302 |
-
self.remove_roi_obj = roi
|
| 1303 |
-
|
| 1304 |
-
def remove_roi(self, roi):
|
| 1305 |
-
self.clear_multi_selected_cells()
|
| 1306 |
-
assert roi == self.remove_roi_obj
|
| 1307 |
-
self.remove_roi_obj = None
|
| 1308 |
-
self.p0.removeItem(roi)
|
| 1309 |
-
self.removing_region = False
|
| 1310 |
-
|
| 1311 |
-
def roi_changed(self, roi):
|
| 1312 |
-
# find the overlapping cells and make them selected
|
| 1313 |
-
pos = roi.pos()
|
| 1314 |
-
size = roi.size()
|
| 1315 |
-
x0 = int(pos.x())
|
| 1316 |
-
y0 = int(pos.y())
|
| 1317 |
-
x1 = int(pos.x() + size.x())
|
| 1318 |
-
y1 = int(pos.y() + size.y())
|
| 1319 |
-
if x0 < 0:
|
| 1320 |
-
x0 = 0
|
| 1321 |
-
if y0 < 0:
|
| 1322 |
-
y0 = 0
|
| 1323 |
-
if x1 > self.Lx:
|
| 1324 |
-
x1 = self.Lx
|
| 1325 |
-
if y1 > self.Ly:
|
| 1326 |
-
y1 = self.Ly
|
| 1327 |
-
|
| 1328 |
-
# find cells in that region
|
| 1329 |
-
cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1])
|
| 1330 |
-
cell_idxs = np.trim_zeros(cell_idxs)
|
| 1331 |
-
# deselect cells not in region by deselecting all and then selecting the ones in the region
|
| 1332 |
-
self.clear_multi_selected_cells()
|
| 1333 |
-
|
| 1334 |
-
for idx in cell_idxs:
|
| 1335 |
-
self.select_cell_multi(idx)
|
| 1336 |
-
self.removing_cells_list.append(idx)
|
| 1337 |
-
|
| 1338 |
-
self.update_layer()
|
| 1339 |
-
|
| 1340 |
-
def mouse_moved(self, pos):
|
| 1341 |
-
items = self.win.scene().items(pos)
|
| 1342 |
-
|
| 1343 |
-
def color_choose(self):
|
| 1344 |
-
self.color = self.RGBDropDown.currentIndex()
|
| 1345 |
-
self.view = 0
|
| 1346 |
-
self.ViewDropDown.setCurrentIndex(self.view)
|
| 1347 |
-
self.update_plot()
|
| 1348 |
-
|
| 1349 |
-
def update_plot(self):
|
| 1350 |
-
self.view = self.ViewDropDown.currentIndex()
|
| 1351 |
-
self.Ly, self.Lx, _ = self.stack[self.currentZ].shape
|
| 1352 |
-
|
| 1353 |
-
if self.view == 0 or self.view == self.ViewDropDown.count() - 1:
|
| 1354 |
-
image = self.stack[
|
| 1355 |
-
self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ]
|
| 1356 |
-
if self.color == 0:
|
| 1357 |
-
self.img.setImage(image, autoLevels=False, lut=None)
|
| 1358 |
-
if self.nchan > 1:
|
| 1359 |
-
levels = np.array([
|
| 1360 |
-
self.saturation[0][self.currentZ],
|
| 1361 |
-
self.saturation[1][self.currentZ],
|
| 1362 |
-
self.saturation[2][self.currentZ]
|
| 1363 |
-
])
|
| 1364 |
-
self.img.setLevels(levels)
|
| 1365 |
-
else:
|
| 1366 |
-
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1367 |
-
elif self.color > 0 and self.color < 4:
|
| 1368 |
-
if self.nchan > 1:
|
| 1369 |
-
image = image[:, :, self.color - 1]
|
| 1370 |
-
self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color])
|
| 1371 |
-
if self.nchan > 1:
|
| 1372 |
-
self.img.setLevels(self.saturation[self.color - 1][self.currentZ])
|
| 1373 |
-
else:
|
| 1374 |
-
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1375 |
-
elif self.color == 4:
|
| 1376 |
-
if self.nchan > 1:
|
| 1377 |
-
image = image.mean(axis=-1)
|
| 1378 |
-
self.img.setImage(image, autoLevels=False, lut=None)
|
| 1379 |
-
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1380 |
-
elif self.color == 5:
|
| 1381 |
-
if self.nchan > 1:
|
| 1382 |
-
image = image.mean(axis=-1)
|
| 1383 |
-
self.img.setImage(image, autoLevels=False, lut=self.cmap[0])
|
| 1384 |
-
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1385 |
-
else:
|
| 1386 |
-
image = np.zeros((self.Ly, self.Lx), np.uint8)
|
| 1387 |
-
if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0:
|
| 1388 |
-
image = self.flows[self.view - 1][self.currentZ]
|
| 1389 |
-
if self.view > 1:
|
| 1390 |
-
self.img.setImage(image, autoLevels=False, lut=self.bwr)
|
| 1391 |
-
else:
|
| 1392 |
-
self.img.setImage(image, autoLevels=False, lut=None)
|
| 1393 |
-
self.img.setLevels([0.0, 255.0])
|
| 1394 |
-
|
| 1395 |
-
for r in range(3):
|
| 1396 |
-
self.sliders[r].setValue([
|
| 1397 |
-
self.saturation[r][self.currentZ][0],
|
| 1398 |
-
self.saturation[r][self.currentZ][1]
|
| 1399 |
-
])
|
| 1400 |
-
self.win.show()
|
| 1401 |
-
self.show()
|
| 1402 |
-
|
| 1403 |
-
|
| 1404 |
-
def update_layer(self):
|
| 1405 |
-
if self.masksOn or self.outlinesOn:
|
| 1406 |
-
self.layer.setImage(self.layerz, autoLevels=False)
|
| 1407 |
-
self.win.show()
|
| 1408 |
-
self.show()
|
| 1409 |
-
|
| 1410 |
-
|
| 1411 |
-
def add_set(self):
|
| 1412 |
-
if len(self.current_point_set) > 0:
|
| 1413 |
-
while len(self.strokes) > 0:
|
| 1414 |
-
self.remove_stroke(delete_points=False)
|
| 1415 |
-
if len(self.current_point_set[0]) > 8:
|
| 1416 |
-
color = self.colormap[self.ncells.get(), :3]
|
| 1417 |
-
median = self.add_mask(points=self.current_point_set, color=color)
|
| 1418 |
-
if median is not None:
|
| 1419 |
-
self.removed_cell = []
|
| 1420 |
-
self.toggle_mask_ops()
|
| 1421 |
-
self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :],
|
| 1422 |
-
axis=0)
|
| 1423 |
-
self.ncells += 1
|
| 1424 |
-
self.ismanual = np.append(self.ismanual, True)
|
| 1425 |
-
if self.NZ == 1:
|
| 1426 |
-
# only save after each cell if single image
|
| 1427 |
-
io._save_sets_with_check(self)
|
| 1428 |
-
else:
|
| 1429 |
-
print("GUI_ERROR: cell too small, not drawn")
|
| 1430 |
-
self.current_stroke = []
|
| 1431 |
-
self.strokes = []
|
| 1432 |
-
self.current_point_set = []
|
| 1433 |
-
self.update_layer()
|
| 1434 |
-
|
| 1435 |
-
def add_mask(self, points=None, color=(100, 200, 50), dense=True):
|
| 1436 |
-
# points is list of strokes
|
| 1437 |
-
points_all = np.concatenate(points, axis=0)
|
| 1438 |
-
|
| 1439 |
-
# loop over z values
|
| 1440 |
-
median = []
|
| 1441 |
-
zdraw = np.unique(points_all[:, 0])
|
| 1442 |
-
z = 0
|
| 1443 |
-
ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
|
| 1444 |
-
0, "int"), np.zeros(0, "int")
|
| 1445 |
-
for stroke in points:
|
| 1446 |
-
stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
|
| 1447 |
-
vr = stroke[:, 1]
|
| 1448 |
-
vc = stroke[:, 2]
|
| 1449 |
-
# get points inside drawn points
|
| 1450 |
-
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
|
| 1451 |
-
pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
|
| 1452 |
-
axis=-1)[:, np.newaxis, :]
|
| 1453 |
-
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
|
| 1454 |
-
ar, ac = np.nonzero(mask)
|
| 1455 |
-
ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
|
| 1456 |
-
# get dense outline
|
| 1457 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 1458 |
-
pvc, pvr = contours[-2][0][:,0].T
|
| 1459 |
-
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
|
| 1460 |
-
# concatenate all points
|
| 1461 |
-
ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
|
| 1462 |
-
# if these pixels are overlapping with another cell, reassign them
|
| 1463 |
-
ioverlap = self.cellpix[z][ar, ac] > 0
|
| 1464 |
-
if (~ioverlap).sum() < 10:
|
| 1465 |
-
print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn")
|
| 1466 |
-
return None
|
| 1467 |
-
elif ioverlap.sum() > 0:
|
| 1468 |
-
ar, ac = ar[~ioverlap], ac[~ioverlap]
|
| 1469 |
-
# compute outline of new mask
|
| 1470 |
-
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
|
| 1471 |
-
mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1
|
| 1472 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 1473 |
-
cv2.CHAIN_APPROX_NONE)
|
| 1474 |
-
pvc, pvr = contours[-2][0][:,0].T
|
| 1475 |
-
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
|
| 1476 |
-
ars = np.concatenate((ars, ar), axis=0)
|
| 1477 |
-
acs = np.concatenate((acs, ac), axis=0)
|
| 1478 |
-
vrs = np.concatenate((vrs, vr), axis=0)
|
| 1479 |
-
vcs = np.concatenate((vcs, vc), axis=0)
|
| 1480 |
-
|
| 1481 |
-
self.draw_mask(z, ars, acs, vrs, vcs, color)
|
| 1482 |
-
median.append(np.array([np.median(ars), np.median(acs)]))
|
| 1483 |
-
|
| 1484 |
-
self.zdraw.append(zdraw)
|
| 1485 |
-
d = datetime.datetime.now()
|
| 1486 |
-
self.track_changes.append(
|
| 1487 |
-
[d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]])
|
| 1488 |
-
return median
|
| 1489 |
-
|
| 1490 |
-
def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
|
| 1491 |
-
""" draw single mask using outlines and area """
|
| 1492 |
-
if idx is None:
|
| 1493 |
-
idx = self.ncells + 1
|
| 1494 |
-
self.cellpix[z, vr, vc] = idx
|
| 1495 |
-
self.cellpix[z, ar, ac] = idx
|
| 1496 |
-
self.outpix[z, vr, vc] = idx
|
| 1497 |
-
if self.restore and "upsample" in self.restore:
|
| 1498 |
-
if self.resize:
|
| 1499 |
-
self.cellpix_resize[z, vr, vc] = idx
|
| 1500 |
-
self.cellpix_resize[z, ar, ac] = idx
|
| 1501 |
-
self.outpix_resize[z, vr, vc] = idx
|
| 1502 |
-
self.cellpix_orig[z, (vr / self.ratio).astype(int),
|
| 1503 |
-
(vc / self.ratio).astype(int)] = idx
|
| 1504 |
-
self.cellpix_orig[z, (ar / self.ratio).astype(int),
|
| 1505 |
-
(ac / self.ratio).astype(int)] = idx
|
| 1506 |
-
self.outpix_orig[z, (vr / self.ratio).astype(int),
|
| 1507 |
-
(vc / self.ratio).astype(int)] = idx
|
| 1508 |
-
else:
|
| 1509 |
-
self.cellpix_orig[z, vr, vc] = idx
|
| 1510 |
-
self.cellpix_orig[z, ar, ac] = idx
|
| 1511 |
-
self.outpix_orig[z, vr, vc] = idx
|
| 1512 |
-
|
| 1513 |
-
# get upsampled mask
|
| 1514 |
-
vrr = (vr.copy() * self.ratio).astype(int)
|
| 1515 |
-
vcr = (vc.copy() * self.ratio).astype(int)
|
| 1516 |
-
mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8)
|
| 1517 |
-
pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2),
|
| 1518 |
-
axis=-1)[:, np.newaxis, :]
|
| 1519 |
-
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
|
| 1520 |
-
arr, acr = np.nonzero(mask)
|
| 1521 |
-
arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2
|
| 1522 |
-
# get dense outline
|
| 1523 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 1524 |
-
cv2.CHAIN_APPROX_NONE)
|
| 1525 |
-
pvc, pvr = contours[-2][0].squeeze().T
|
| 1526 |
-
vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2
|
| 1527 |
-
# concatenate all points
|
| 1528 |
-
arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr))))
|
| 1529 |
-
self.cellpix_resize[z, vrr, vcr] = idx
|
| 1530 |
-
self.cellpix_resize[z, arr, acr] = idx
|
| 1531 |
-
self.outpix_resize[z, vrr, vcr] = idx
|
| 1532 |
-
|
| 1533 |
-
if z == self.currentZ:
|
| 1534 |
-
self.layerz[ar, ac, :3] = color
|
| 1535 |
-
if self.masksOn:
|
| 1536 |
-
self.layerz[ar, ac, -1] = self.opacity
|
| 1537 |
-
if self.outlinesOn:
|
| 1538 |
-
self.layerz[vr, vc] = np.array(self.outcolor)
|
| 1539 |
-
|
| 1540 |
-
def compute_scale(self):
|
| 1541 |
-
# get diameter from gui
|
| 1542 |
-
diameter = self.segmentation_settings.diameter
|
| 1543 |
-
if not diameter:
|
| 1544 |
-
diameter = 30
|
| 1545 |
-
|
| 1546 |
-
self.pr = int(diameter)
|
| 1547 |
-
self.radii_padding = int(self.pr * 1.25)
|
| 1548 |
-
self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8)
|
| 1549 |
-
yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1],
|
| 1550 |
-
self.pr / 2, self.Ly + self.radii_padding, self.Lx)
|
| 1551 |
-
# rgb(150,50,150)
|
| 1552 |
-
self.radii[yy, xx, 0] = 150
|
| 1553 |
-
self.radii[yy, xx, 1] = 50
|
| 1554 |
-
self.radii[yy, xx, 2] = 150
|
| 1555 |
-
self.radii[yy, xx, 3] = 255
|
| 1556 |
-
self.p0.setYRange(0, self.Ly + self.radii_padding)
|
| 1557 |
-
self.p0.setXRange(0, self.Lx)
|
| 1558 |
-
|
| 1559 |
-
def update_scale(self):
|
| 1560 |
-
self.compute_scale()
|
| 1561 |
-
self.scale.setImage(self.radii, autoLevels=False)
|
| 1562 |
-
self.scale.setLevels([0.0, 255.0])
|
| 1563 |
-
self.win.show()
|
| 1564 |
-
self.show()
|
| 1565 |
-
|
| 1566 |
-
|
| 1567 |
-
def draw_layer(self):
|
| 1568 |
-
if self.resize:
|
| 1569 |
-
self.Ly, self.Lx = self.Lyr, self.Lxr
|
| 1570 |
-
else:
|
| 1571 |
-
self.Ly, self.Lx = self.Ly0, self.Lx0
|
| 1572 |
-
|
| 1573 |
-
if self.masksOn or self.outlinesOn:
|
| 1574 |
-
if self.restore and "upsample" in self.restore:
|
| 1575 |
-
if self.resize:
|
| 1576 |
-
self.cellpix = self.cellpix_resize.copy()
|
| 1577 |
-
self.outpix = self.outpix_resize.copy()
|
| 1578 |
-
else:
|
| 1579 |
-
self.cellpix = self.cellpix_orig.copy()
|
| 1580 |
-
self.outpix = self.outpix_orig.copy()
|
| 1581 |
-
|
| 1582 |
-
self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8)
|
| 1583 |
-
if self.masksOn:
|
| 1584 |
-
self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :]
|
| 1585 |
-
self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ]
|
| 1586 |
-
> 0).astype(np.uint8)
|
| 1587 |
-
if self.selected > 0:
|
| 1588 |
-
self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array(
|
| 1589 |
-
[255, 255, 255, self.opacity])
|
| 1590 |
-
cZ = self.currentZ
|
| 1591 |
-
stroke_z = np.array([s[0][0] for s in self.strokes])
|
| 1592 |
-
inZ = np.nonzero(stroke_z == cZ)[0]
|
| 1593 |
-
if len(inZ) > 0:
|
| 1594 |
-
for i in inZ:
|
| 1595 |
-
stroke = np.array(self.strokes[i])
|
| 1596 |
-
self.layerz[stroke[:, 1], stroke[:,
|
| 1597 |
-
2]] = np.array([255, 0, 255, 100])
|
| 1598 |
-
else:
|
| 1599 |
-
self.layerz[..., 3] = 0
|
| 1600 |
-
|
| 1601 |
-
if self.outlinesOn:
|
| 1602 |
-
self.layerz[self.outpix[self.currentZ] > 0] = np.array(
|
| 1603 |
-
self.outcolor).astype(np.uint8)
|
| 1604 |
-
|
| 1605 |
-
|
| 1606 |
-
def set_normalize_params(self, normalize_params):
|
| 1607 |
-
from cellpose.models import normalize_default
|
| 1608 |
-
if self.restore != "filter":
|
| 1609 |
-
keys = list(normalize_params.keys()).copy()
|
| 1610 |
-
for key in keys:
|
| 1611 |
-
if key != "percentile":
|
| 1612 |
-
normalize_params[key] = normalize_default[key]
|
| 1613 |
-
normalize_params = {**normalize_default, **normalize_params}
|
| 1614 |
-
out = self.check_filter_params(normalize_params["sharpen_radius"],
|
| 1615 |
-
normalize_params["smooth_radius"],
|
| 1616 |
-
normalize_params["tile_norm_blocksize"],
|
| 1617 |
-
normalize_params["tile_norm_smooth3D"],
|
| 1618 |
-
normalize_params["norm3D"],
|
| 1619 |
-
normalize_params["invert"])
|
| 1620 |
-
|
| 1621 |
-
|
| 1622 |
-
def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert):
|
| 1623 |
-
tile_norm = 0 if tile_norm < 0 else tile_norm
|
| 1624 |
-
sharpen = 0 if sharpen < 0 else sharpen
|
| 1625 |
-
smooth = 0 if smooth < 0 else smooth
|
| 1626 |
-
smooth3D = 0 if smooth3D < 0 else smooth3D
|
| 1627 |
-
norm3D = bool(norm3D)
|
| 1628 |
-
invert = bool(invert)
|
| 1629 |
-
if tile_norm > self.Ly and tile_norm > self.Lx:
|
| 1630 |
-
print(
|
| 1631 |
-
"GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling"
|
| 1632 |
-
)
|
| 1633 |
-
tile_norm = 0
|
| 1634 |
-
self.filt_edits[0].setText(str(sharpen))
|
| 1635 |
-
self.filt_edits[1].setText(str(smooth))
|
| 1636 |
-
self.filt_edits[2].setText(str(tile_norm))
|
| 1637 |
-
self.filt_edits[3].setText(str(smooth3D))
|
| 1638 |
-
self.norm3D_cb.setChecked(norm3D)
|
| 1639 |
-
return sharpen, smooth, tile_norm, smooth3D, norm3D, invert
|
| 1640 |
-
|
| 1641 |
-
def get_normalize_params(self):
|
| 1642 |
-
percentile = [
|
| 1643 |
-
self.segmentation_settings.low_percentile,
|
| 1644 |
-
self.segmentation_settings.high_percentile,
|
| 1645 |
-
]
|
| 1646 |
-
normalize_params = {"percentile": percentile}
|
| 1647 |
-
norm3D = self.norm3D_cb.isChecked()
|
| 1648 |
-
normalize_params["norm3D"] = norm3D
|
| 1649 |
-
sharpen = float(self.filt_edits[0].text())
|
| 1650 |
-
smooth = float(self.filt_edits[1].text())
|
| 1651 |
-
tile_norm = float(self.filt_edits[2].text())
|
| 1652 |
-
smooth3D = float(self.filt_edits[3].text())
|
| 1653 |
-
invert = False
|
| 1654 |
-
out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D,
|
| 1655 |
-
invert)
|
| 1656 |
-
sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out
|
| 1657 |
-
normalize_params["sharpen_radius"] = sharpen
|
| 1658 |
-
normalize_params["smooth_radius"] = smooth
|
| 1659 |
-
normalize_params["tile_norm_blocksize"] = tile_norm
|
| 1660 |
-
normalize_params["tile_norm_smooth3D"] = smooth3D
|
| 1661 |
-
normalize_params["invert"] = invert
|
| 1662 |
-
|
| 1663 |
-
from cellpose.models import normalize_default
|
| 1664 |
-
normalize_params = {**normalize_default, **normalize_params}
|
| 1665 |
-
|
| 1666 |
-
return normalize_params
|
| 1667 |
-
|
| 1668 |
-
def compute_saturation_if_checked(self):
|
| 1669 |
-
if self.autobtn.isChecked():
|
| 1670 |
-
self.compute_saturation()
|
| 1671 |
-
|
| 1672 |
-
def compute_saturation(self, return_img=False):
|
| 1673 |
-
norm = self.get_normalize_params()
|
| 1674 |
-
print(norm)
|
| 1675 |
-
sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"]
|
| 1676 |
-
percentile = norm["percentile"]
|
| 1677 |
-
tile_norm = norm["tile_norm_blocksize"]
|
| 1678 |
-
invert = norm["invert"]
|
| 1679 |
-
norm3D = norm["norm3D"]
|
| 1680 |
-
smooth3D = norm["tile_norm_smooth3D"]
|
| 1681 |
-
tile_norm = norm["tile_norm_blocksize"]
|
| 1682 |
-
|
| 1683 |
-
if sharpen > 0 or smooth > 0 or tile_norm > 0:
|
| 1684 |
-
img_norm = self.stack.copy()
|
| 1685 |
-
else:
|
| 1686 |
-
img_norm = self.stack
|
| 1687 |
-
|
| 1688 |
-
if sharpen > 0 or smooth > 0 or tile_norm > 0:
|
| 1689 |
-
self.restore = "filter"
|
| 1690 |
-
print(
|
| 1691 |
-
"GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0"
|
| 1692 |
-
)
|
| 1693 |
-
print(
|
| 1694 |
-
"GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this"
|
| 1695 |
-
)
|
| 1696 |
-
img_norm = self.stack.copy()
|
| 1697 |
-
if sharpen > 0 or smooth > 0:
|
| 1698 |
-
img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen,
|
| 1699 |
-
smooth_radius=smooth)
|
| 1700 |
-
|
| 1701 |
-
if tile_norm > 0:
|
| 1702 |
-
img_norm = normalize99_tile(img_norm, blocksize=tile_norm,
|
| 1703 |
-
lower=percentile[0], upper=percentile[1],
|
| 1704 |
-
smooth3D=smooth3D, norm3D=norm3D)
|
| 1705 |
-
# convert to 0->255
|
| 1706 |
-
img_norm_min = img_norm.min()
|
| 1707 |
-
img_norm_max = img_norm.max()
|
| 1708 |
-
for c in range(img_norm.shape[-1]):
|
| 1709 |
-
if np.ptp(img_norm[..., c]) > 1e-3:
|
| 1710 |
-
img_norm[..., c] -= img_norm_min
|
| 1711 |
-
img_norm[..., c] /= (img_norm_max - img_norm_min)
|
| 1712 |
-
img_norm *= 255
|
| 1713 |
-
self.stack_filtered = img_norm
|
| 1714 |
-
self.ViewDropDown.model().item(self.ViewDropDown.count() -
|
| 1715 |
-
1).setEnabled(True)
|
| 1716 |
-
self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
|
| 1717 |
-
else:
|
| 1718 |
-
img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered
|
| 1719 |
-
|
| 1720 |
-
if self.autobtn.isChecked():
|
| 1721 |
-
self.saturation = []
|
| 1722 |
-
for c in range(img_norm.shape[-1]):
|
| 1723 |
-
self.saturation.append([])
|
| 1724 |
-
if np.ptp(img_norm[..., c]) > 1e-3:
|
| 1725 |
-
if norm3D:
|
| 1726 |
-
x01 = np.percentile(img_norm[..., c], percentile[0])
|
| 1727 |
-
x99 = np.percentile(img_norm[..., c], percentile[1])
|
| 1728 |
-
if invert:
|
| 1729 |
-
x01i = 255. - x99
|
| 1730 |
-
x99i = 255. - x01
|
| 1731 |
-
x01, x99 = x01i, x99i
|
| 1732 |
-
for n in range(self.NZ):
|
| 1733 |
-
self.saturation[-1].append([x01, x99])
|
| 1734 |
-
else:
|
| 1735 |
-
for z in range(self.NZ):
|
| 1736 |
-
if self.NZ > 1:
|
| 1737 |
-
x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
|
| 1738 |
-
x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
|
| 1739 |
-
else:
|
| 1740 |
-
x01 = np.percentile(img_norm[..., c], percentile[0])
|
| 1741 |
-
x99 = np.percentile(img_norm[..., c], percentile[1])
|
| 1742 |
-
if invert:
|
| 1743 |
-
x01i = 255. - x99
|
| 1744 |
-
x99i = 255. - x01
|
| 1745 |
-
x01, x99 = x01i, x99i
|
| 1746 |
-
self.saturation[-1].append([x01, x99])
|
| 1747 |
-
else:
|
| 1748 |
-
for n in range(self.NZ):
|
| 1749 |
-
self.saturation[-1].append([0, 255.])
|
| 1750 |
-
print(self.saturation[2][self.currentZ])
|
| 1751 |
-
|
| 1752 |
-
if img_norm.shape[-1] == 1:
|
| 1753 |
-
self.saturation.append(self.saturation[0])
|
| 1754 |
-
self.saturation.append(self.saturation[0])
|
| 1755 |
-
|
| 1756 |
-
# self.autobtn.setChecked(True)
|
| 1757 |
-
self.update_plot()
|
| 1758 |
-
|
| 1759 |
-
|
| 1760 |
-
def get_model_path(self, custom=False):
|
| 1761 |
-
if custom:
|
| 1762 |
-
self.current_model = self.ModelChooseC.currentText()
|
| 1763 |
-
self.current_model_path = os.fspath(
|
| 1764 |
-
models.MODEL_DIR.joinpath(self.current_model))
|
| 1765 |
-
else:
|
| 1766 |
-
self.current_model = "cpsam"
|
| 1767 |
-
self.current_model_path = models.model_path(self.current_model)
|
| 1768 |
-
|
| 1769 |
-
def initialize_model(self, model_name=None, custom=False):
|
| 1770 |
-
if model_name is None or custom:
|
| 1771 |
-
self.get_model_path(custom=custom)
|
| 1772 |
-
if not os.path.exists(self.current_model_path):
|
| 1773 |
-
raise ValueError("need to specify model (use dropdown)")
|
| 1774 |
-
|
| 1775 |
-
if model_name is None or not isinstance(model_name, str):
|
| 1776 |
-
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
|
| 1777 |
-
pretrained_model=self.current_model_path)
|
| 1778 |
-
else:
|
| 1779 |
-
self.current_model = model_name
|
| 1780 |
-
self.current_model_path = os.fspath(
|
| 1781 |
-
models.MODEL_DIR.joinpath(self.current_model))
|
| 1782 |
-
|
| 1783 |
-
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
|
| 1784 |
-
pretrained_model=self.current_model)
|
| 1785 |
-
|
| 1786 |
-
def add_model(self):
|
| 1787 |
-
io._add_model(self)
|
| 1788 |
-
return
|
| 1789 |
-
|
| 1790 |
-
def remove_model(self):
|
| 1791 |
-
io._remove_model(self)
|
| 1792 |
-
return
|
| 1793 |
-
|
| 1794 |
-
def new_model(self):
|
| 1795 |
-
if self.NZ != 1:
|
| 1796 |
-
print("ERROR: cannot train model on 3D data")
|
| 1797 |
-
return
|
| 1798 |
-
|
| 1799 |
-
# train model
|
| 1800 |
-
image_names = self.get_files()[0]
|
| 1801 |
-
self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
|
| 1802 |
-
image_names)
|
| 1803 |
-
TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
|
| 1804 |
-
train = TW.exec_()
|
| 1805 |
-
if train:
|
| 1806 |
-
self.logger.info(
|
| 1807 |
-
f"training with {[os.path.split(f)[1] for f in self.train_files]}")
|
| 1808 |
-
self.train_model(restore=restore, normalize_params=normalize_params)
|
| 1809 |
-
else:
|
| 1810 |
-
print("GUI_INFO: training cancelled")
|
| 1811 |
-
|
| 1812 |
-
def train_model(self, restore=None, normalize_params=None):
|
| 1813 |
-
from cellpose.models import normalize_default
|
| 1814 |
-
if normalize_params is None:
|
| 1815 |
-
normalize_params = copy.deepcopy(normalize_default)
|
| 1816 |
-
model_type = models.MODEL_NAMES[self.training_params["model_index"]]
|
| 1817 |
-
self.logger.info(f"training new model starting at model {model_type}")
|
| 1818 |
-
self.current_model = model_type
|
| 1819 |
-
|
| 1820 |
-
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
|
| 1821 |
-
model_type=model_type)
|
| 1822 |
-
save_path = os.path.dirname(self.filename)
|
| 1823 |
-
|
| 1824 |
-
print("GUI_INFO: name of new model: " + self.training_params["model_name"])
|
| 1825 |
-
self.new_model_path, train_losses = train.train_seg(
|
| 1826 |
-
self.model.net, train_data=self.train_data, train_labels=self.train_labels,
|
| 1827 |
-
normalize=normalize_params, min_train_masks=0,
|
| 1828 |
-
save_path=save_path, nimg_per_epoch=max(2, len(self.train_data)),
|
| 1829 |
-
learning_rate=self.training_params["learning_rate"],
|
| 1830 |
-
weight_decay=self.training_params["weight_decay"],
|
| 1831 |
-
n_epochs=self.training_params["n_epochs"],
|
| 1832 |
-
model_name=self.training_params["model_name"])[:2]
|
| 1833 |
-
# save train losses
|
| 1834 |
-
np.save(str(self.new_model_path) + "_train_losses.npy", train_losses)
|
| 1835 |
-
# run model on next image
|
| 1836 |
-
io._add_model(self, self.new_model_path)
|
| 1837 |
-
diam_labels = self.model.net.diam_labels.item() #.copy()
|
| 1838 |
-
self.new_model_ind = len(self.model_strings)
|
| 1839 |
-
self.autorun = True
|
| 1840 |
-
self.clear_all()
|
| 1841 |
-
self.restore = restore
|
| 1842 |
-
self.set_normalize_params(normalize_params)
|
| 1843 |
-
self.get_next_image(load_seg=False)
|
| 1844 |
-
|
| 1845 |
-
self.compute_segmentation(custom=True)
|
| 1846 |
-
self.logger.info(
|
| 1847 |
-
f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!"
|
| 1848 |
-
)
|
| 1849 |
-
|
| 1850 |
-
|
| 1851 |
-
def compute_cprob(self):
|
| 1852 |
-
if self.recompute_masks:
|
| 1853 |
-
flow_threshold = self.segmentation_settings.flow_threshold
|
| 1854 |
-
cellprob_threshold = self.segmentation_settings.cellprob_threshold
|
| 1855 |
-
niter = self.segmentation_settings.niter
|
| 1856 |
-
min_size = int(self.min_size.text()) if not isinstance(
|
| 1857 |
-
self.min_size, int) else self.min_size
|
| 1858 |
-
|
| 1859 |
-
self.logger.info(
|
| 1860 |
-
"computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
|
| 1861 |
-
(cellprob_threshold, flow_threshold))
|
| 1862 |
-
|
| 1863 |
-
try:
|
| 1864 |
-
dP = self.flows[2].squeeze()
|
| 1865 |
-
cellprob = self.flows[3].squeeze()
|
| 1866 |
-
except IndexError:
|
| 1867 |
-
self.logger.error("Flows don't exist, try running model again.")
|
| 1868 |
-
return
|
| 1869 |
-
|
| 1870 |
-
maski = dynamics.resize_and_compute_masks(
|
| 1871 |
-
dP=dP,
|
| 1872 |
-
cellprob=cellprob,
|
| 1873 |
-
niter=niter,
|
| 1874 |
-
do_3D=self.load_3D,
|
| 1875 |
-
min_size=min_size,
|
| 1876 |
-
# max_size_fraction=min_size_fraction, # Leave as default
|
| 1877 |
-
cellprob_threshold=cellprob_threshold,
|
| 1878 |
-
flow_threshold=flow_threshold)
|
| 1879 |
-
|
| 1880 |
-
self.masksOn = True
|
| 1881 |
-
if not self.OCheckBox.isChecked():
|
| 1882 |
-
self.MCheckBox.setChecked(True)
|
| 1883 |
-
if maski.ndim < 3:
|
| 1884 |
-
maski = maski[np.newaxis, ...]
|
| 1885 |
-
self.logger.info("%d cells found" % (len(np.unique(maski)[1:])))
|
| 1886 |
-
io._masks_to_gui(self, maski, outlines=None)
|
| 1887 |
-
self.show()
|
| 1888 |
-
|
| 1889 |
-
|
| 1890 |
-
def compute_segmentation(self, custom=False, model_name=None, load_model=True):
|
| 1891 |
-
self.progress.setValue(0)
|
| 1892 |
-
try:
|
| 1893 |
-
tic = time.time()
|
| 1894 |
-
self.clear_all()
|
| 1895 |
-
self.flows = [[], [], []]
|
| 1896 |
-
if load_model:
|
| 1897 |
-
self.initialize_model(model_name=model_name, custom=custom)
|
| 1898 |
-
self.progress.setValue(10)
|
| 1899 |
-
do_3D = self.load_3D
|
| 1900 |
-
stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
|
| 1901 |
-
self.stitch_threshold, float) else self.stitch_threshold
|
| 1902 |
-
anisotropy = float(self.anisotropy.text()) if not isinstance(
|
| 1903 |
-
self.anisotropy, float) else self.anisotropy
|
| 1904 |
-
flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance(
|
| 1905 |
-
self.flow3D_smooth, float) else self.flow3D_smooth
|
| 1906 |
-
min_size = int(self.min_size.text()) if not isinstance(
|
| 1907 |
-
self.min_size, int) else self.min_size
|
| 1908 |
-
|
| 1909 |
-
do_3D = False if stitch_threshold > 0. else do_3D
|
| 1910 |
-
|
| 1911 |
-
if self.restore == "filter":
|
| 1912 |
-
data = self.stack_filtered.copy().squeeze()
|
| 1913 |
-
else:
|
| 1914 |
-
data = self.stack.copy().squeeze()
|
| 1915 |
-
|
| 1916 |
-
flow_threshold = self.segmentation_settings.flow_threshold
|
| 1917 |
-
cellprob_threshold = self.segmentation_settings.cellprob_threshold
|
| 1918 |
-
diameter = self.segmentation_settings.diameter
|
| 1919 |
-
niter = self.segmentation_settings.niter
|
| 1920 |
-
|
| 1921 |
-
normalize_params = self.get_normalize_params()
|
| 1922 |
-
print(normalize_params)
|
| 1923 |
-
try:
|
| 1924 |
-
masks, flows = self.model.eval(
|
| 1925 |
-
data,
|
| 1926 |
-
diameter=diameter,
|
| 1927 |
-
cellprob_threshold=cellprob_threshold,
|
| 1928 |
-
flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
|
| 1929 |
-
normalize=normalize_params, stitch_threshold=stitch_threshold,
|
| 1930 |
-
anisotropy=anisotropy, flow3D_smooth=flow3D_smooth,
|
| 1931 |
-
min_size=min_size, channel_axis=-1,
|
| 1932 |
-
progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2]
|
| 1933 |
-
except Exception as e:
|
| 1934 |
-
print("NET ERROR: %s" % e)
|
| 1935 |
-
self.progress.setValue(0)
|
| 1936 |
-
return
|
| 1937 |
-
|
| 1938 |
-
self.progress.setValue(75)
|
| 1939 |
-
|
| 1940 |
-
# convert flows to uint8 and resize to original image size
|
| 1941 |
-
flows_new = []
|
| 1942 |
-
flows_new.append(flows[0].copy()) # RGB flow
|
| 1943 |
-
flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
|
| 1944 |
-
255).astype("uint8")) # cellprob
|
| 1945 |
-
flows_new.append(flows[1].copy()) # XY flows
|
| 1946 |
-
flows_new.append(flows[2].copy()) # original cellprob
|
| 1947 |
-
|
| 1948 |
-
if self.load_3D:
|
| 1949 |
-
if stitch_threshold == 0.:
|
| 1950 |
-
flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
|
| 1951 |
-
else:
|
| 1952 |
-
flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))
|
| 1953 |
-
|
| 1954 |
-
if not self.load_3D:
|
| 1955 |
-
if self.restore and "upsample" in self.restore:
|
| 1956 |
-
self.Ly, self.Lx = self.Lyr, self.Lxr
|
| 1957 |
-
|
| 1958 |
-
if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
|
| 1959 |
-
self.flows = []
|
| 1960 |
-
for j in range(len(flows_new)):
|
| 1961 |
-
self.flows.append(
|
| 1962 |
-
resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
|
| 1963 |
-
interpolation=cv2.INTER_NEAREST))
|
| 1964 |
-
else:
|
| 1965 |
-
self.flows = flows_new
|
| 1966 |
-
else:
|
| 1967 |
-
self.flows = []
|
| 1968 |
-
Lz, Ly, Lx = self.NZ, self.Ly, self.Lx
|
| 1969 |
-
Lz0, Ly0, Lx0 = flows_new[0].shape[:3]
|
| 1970 |
-
print("GUI_INFO: resizing flows to original image size")
|
| 1971 |
-
for j in range(len(flows_new)):
|
| 1972 |
-
flow0 = flows_new[j]
|
| 1973 |
-
if Ly0 != Ly:
|
| 1974 |
-
flow0 = resize_image(flow0, Ly=Ly, Lx=Lx,
|
| 1975 |
-
no_channels=flow0.ndim==3,
|
| 1976 |
-
interpolation=cv2.INTER_NEAREST)
|
| 1977 |
-
if Lz0 != Lz:
|
| 1978 |
-
flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1),
|
| 1979 |
-
Ly=Lz, Lx=Lx,
|
| 1980 |
-
no_channels=flow0.ndim==3,
|
| 1981 |
-
interpolation=cv2.INTER_NEAREST), 0, 1)
|
| 1982 |
-
self.flows.append(flow0)
|
| 1983 |
-
|
| 1984 |
-
# add first axis
|
| 1985 |
-
if self.NZ == 1:
|
| 1986 |
-
masks = masks[np.newaxis, ...]
|
| 1987 |
-
self.flows = [
|
| 1988 |
-
self.flows[n][np.newaxis, ...] for n in range(len(self.flows))
|
| 1989 |
-
]
|
| 1990 |
-
|
| 1991 |
-
self.logger.info("%d cells found with model in %0.3f sec" %
|
| 1992 |
-
(len(np.unique(masks)[1:]), time.time() - tic))
|
| 1993 |
-
self.progress.setValue(80)
|
| 1994 |
-
z = 0
|
| 1995 |
-
|
| 1996 |
-
io._masks_to_gui(self, masks, outlines=None)
|
| 1997 |
-
self.masksOn = True
|
| 1998 |
-
self.MCheckBox.setChecked(True)
|
| 1999 |
-
self.progress.setValue(100)
|
| 2000 |
-
if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked():
|
| 2001 |
-
self.compute_saturation()
|
| 2002 |
-
if not do_3D and not stitch_threshold > 0:
|
| 2003 |
-
self.recompute_masks = True
|
| 2004 |
-
else:
|
| 2005 |
-
self.recompute_masks = False
|
| 2006 |
-
except Exception as e:
|
| 2007 |
-
print("ERROR: %s" % e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/gui3d.py
DELETED
|
@@ -1,667 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import sys, pathlib, warnings
|
| 6 |
-
|
| 7 |
-
from qtpy import QtGui, QtCore
|
| 8 |
-
from qtpy.QtWidgets import QApplication, QScrollBar, QCheckBox, QLabel, QLineEdit
|
| 9 |
-
import pyqtgraph as pg
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
from scipy.stats import mode
|
| 13 |
-
import cv2
|
| 14 |
-
|
| 15 |
-
from . import guiparts, io
|
| 16 |
-
from ..utils import download_url_to_file, masks_to_outlines
|
| 17 |
-
from .gui import MainW
|
| 18 |
-
|
| 19 |
-
try:
|
| 20 |
-
import matplotlib.pyplot as plt
|
| 21 |
-
MATPLOTLIB = True
|
| 22 |
-
except:
|
| 23 |
-
MATPLOTLIB = False
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def avg3d(C):
|
| 27 |
-
""" smooth value of c across nearby points
|
| 28 |
-
(c is center of grid directly below point)
|
| 29 |
-
b -- a -- b
|
| 30 |
-
a -- c -- a
|
| 31 |
-
b -- a -- b
|
| 32 |
-
"""
|
| 33 |
-
Ly, Lx = C.shape
|
| 34 |
-
# pad T by 2
|
| 35 |
-
T = np.zeros((Ly + 2, Lx + 2), "float32")
|
| 36 |
-
M = np.zeros((Ly, Lx), "float32")
|
| 37 |
-
T[1:-1, 1:-1] = C.copy()
|
| 38 |
-
y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
|
| 39 |
-
indexing="ij")
|
| 40 |
-
y += 1
|
| 41 |
-
x += 1
|
| 42 |
-
a = 1. / 2 #/(z**2 + 1)**0.5
|
| 43 |
-
b = 1. / (1 + 2**0.5) #(z**2 + 2)**0.5
|
| 44 |
-
c = 1.
|
| 45 |
-
M = (b * T[y - 1, x - 1] + a * T[y - 1, x] + b * T[y - 1, x + 1] + a * T[y, x - 1] +
|
| 46 |
-
c * T[y, x] + a * T[y, x + 1] + b * T[y + 1, x - 1] + a * T[y + 1, x] +
|
| 47 |
-
b * T[y + 1, x + 1])
|
| 48 |
-
M /= 4 * a + 4 * b + c
|
| 49 |
-
return M
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def interpZ(mask, zdraw):
|
| 53 |
-
""" find nearby planes and average their values using grid of points
|
| 54 |
-
zfill is in ascending order
|
| 55 |
-
"""
|
| 56 |
-
ifill = np.ones(mask.shape[0], "bool")
|
| 57 |
-
zall = np.arange(0, mask.shape[0], 1, int)
|
| 58 |
-
ifill[zdraw] = False
|
| 59 |
-
zfill = zall[ifill]
|
| 60 |
-
zlower = zdraw[np.searchsorted(zdraw, zfill, side="left") - 1]
|
| 61 |
-
zupper = zdraw[np.searchsorted(zdraw, zfill, side="right")]
|
| 62 |
-
for k, z in enumerate(zfill):
|
| 63 |
-
Z = zupper[k] - zlower[k]
|
| 64 |
-
zl = (z - zlower[k]) / Z
|
| 65 |
-
plower = avg3d(mask[zlower[k]]) * (1 - zl)
|
| 66 |
-
pupper = avg3d(mask[zupper[k]]) * zl
|
| 67 |
-
mask[z] = (plower + pupper) > 0.33
|
| 68 |
-
return mask, zfill
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def run(image=None):
|
| 72 |
-
from ..io import logger_setup
|
| 73 |
-
logger, log_file = logger_setup()
|
| 74 |
-
# Always start by initializing Qt (only once per application)
|
| 75 |
-
warnings.filterwarnings("ignore")
|
| 76 |
-
app = QApplication(sys.argv)
|
| 77 |
-
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
|
| 78 |
-
guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
|
| 79 |
-
style_path = pathlib.Path.home().joinpath(".cellpose", "style_choice.npy")
|
| 80 |
-
if not icon_path.is_file():
|
| 81 |
-
cp_dir = pathlib.Path.home().joinpath(".cellpose")
|
| 82 |
-
cp_dir.mkdir(exist_ok=True)
|
| 83 |
-
print("downloading logo")
|
| 84 |
-
download_url_to_file(
|
| 85 |
-
"https://www.cellpose.org/static/images/cellpose_transparent.png",
|
| 86 |
-
icon_path, progress=True)
|
| 87 |
-
if not guip_path.is_file():
|
| 88 |
-
print("downloading help window image")
|
| 89 |
-
download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png",
|
| 90 |
-
guip_path, progress=True)
|
| 91 |
-
icon_path = str(icon_path.resolve())
|
| 92 |
-
app_icon = QtGui.QIcon()
|
| 93 |
-
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
|
| 94 |
-
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
|
| 95 |
-
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
|
| 96 |
-
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
|
| 97 |
-
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
|
| 98 |
-
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
|
| 99 |
-
app.setWindowIcon(app_icon)
|
| 100 |
-
app.setStyle("Fusion")
|
| 101 |
-
app.setPalette(guiparts.DarkPalette())
|
| 102 |
-
MainW_3d(image=image, logger=logger)
|
| 103 |
-
ret = app.exec_()
|
| 104 |
-
sys.exit(ret)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
class MainW_3d(MainW):
|
| 108 |
-
|
| 109 |
-
def __init__(self, image=None, logger=None):
|
| 110 |
-
# MainW init
|
| 111 |
-
MainW.__init__(self, image=image, logger=logger)
|
| 112 |
-
|
| 113 |
-
# add gradZ view
|
| 114 |
-
self.ViewDropDown.insertItem(3, "gradZ")
|
| 115 |
-
|
| 116 |
-
# turn off single stroke
|
| 117 |
-
self.SCheckBox.setChecked(False)
|
| 118 |
-
|
| 119 |
-
### add orthoviews and z-bar
|
| 120 |
-
# ortho crosshair lines
|
| 121 |
-
self.vLine = pg.InfiniteLine(angle=90, movable=False)
|
| 122 |
-
self.hLine = pg.InfiniteLine(angle=0, movable=False)
|
| 123 |
-
self.vLineOrtho = [
|
| 124 |
-
pg.InfiniteLine(angle=90, movable=False),
|
| 125 |
-
pg.InfiniteLine(angle=90, movable=False)
|
| 126 |
-
]
|
| 127 |
-
self.hLineOrtho = [
|
| 128 |
-
pg.InfiniteLine(angle=0, movable=False),
|
| 129 |
-
pg.InfiniteLine(angle=0, movable=False)
|
| 130 |
-
]
|
| 131 |
-
self.make_orthoviews()
|
| 132 |
-
|
| 133 |
-
# z scrollbar underneath
|
| 134 |
-
self.scroll = QScrollBar(QtCore.Qt.Horizontal)
|
| 135 |
-
self.scroll.setMaximum(10)
|
| 136 |
-
self.scroll.valueChanged.connect(self.move_in_Z)
|
| 137 |
-
self.lmain.addWidget(self.scroll, 40, 9, 1, 30)
|
| 138 |
-
|
| 139 |
-
b = 22
|
| 140 |
-
|
| 141 |
-
label = QLabel("stitch\nthreshold:")
|
| 142 |
-
label.setToolTip(
|
| 143 |
-
"for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
|
| 144 |
-
)
|
| 145 |
-
label.setFont(self.medfont)
|
| 146 |
-
self.segBoxG.addWidget(label, b, 0, 1, 4)
|
| 147 |
-
self.stitch_threshold = QLineEdit()
|
| 148 |
-
self.stitch_threshold.setText("0.0")
|
| 149 |
-
self.stitch_threshold.setFixedWidth(30)
|
| 150 |
-
self.stitch_threshold.setFont(self.medfont)
|
| 151 |
-
self.stitch_threshold.setToolTip(
|
| 152 |
-
"for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
|
| 153 |
-
)
|
| 154 |
-
self.segBoxG.addWidget(self.stitch_threshold, b, 3, 1, 1)
|
| 155 |
-
|
| 156 |
-
label = QLabel("flow3D\nsmooth:")
|
| 157 |
-
label.setToolTip(
|
| 158 |
-
"for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
|
| 159 |
-
)
|
| 160 |
-
label.setFont(self.medfont)
|
| 161 |
-
self.segBoxG.addWidget(label, b, 4, 1, 3)
|
| 162 |
-
self.flow3D_smooth = QLineEdit()
|
| 163 |
-
self.flow3D_smooth.setText("0.0")
|
| 164 |
-
self.flow3D_smooth.setFixedWidth(30)
|
| 165 |
-
self.flow3D_smooth.setFont(self.medfont)
|
| 166 |
-
self.flow3D_smooth.setToolTip(
|
| 167 |
-
"for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
|
| 168 |
-
)
|
| 169 |
-
self.segBoxG.addWidget(self.flow3D_smooth, b, 7, 1, 1)
|
| 170 |
-
|
| 171 |
-
b+=1
|
| 172 |
-
label = QLabel("anisotropy:")
|
| 173 |
-
label.setToolTip(
|
| 174 |
-
"for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
|
| 175 |
-
)
|
| 176 |
-
label.setFont(self.medfont)
|
| 177 |
-
self.segBoxG.addWidget(label, b, 0, 1, 3)
|
| 178 |
-
self.anisotropy = QLineEdit()
|
| 179 |
-
self.anisotropy.setText("1.0")
|
| 180 |
-
self.anisotropy.setFixedWidth(30)
|
| 181 |
-
self.anisotropy.setFont(self.medfont)
|
| 182 |
-
self.anisotropy.setToolTip(
|
| 183 |
-
"for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
|
| 184 |
-
)
|
| 185 |
-
self.segBoxG.addWidget(self.anisotropy, b, 3, 1, 1)
|
| 186 |
-
|
| 187 |
-
b+=1
|
| 188 |
-
label = QLabel("min\nsize:")
|
| 189 |
-
label.setToolTip(
|
| 190 |
-
"all masks less than this size in pixels (volume) will be removed"
|
| 191 |
-
)
|
| 192 |
-
label.setFont(self.medfont)
|
| 193 |
-
self.segBoxG.addWidget(label, b, 0, 1, 4)
|
| 194 |
-
self.min_size = QLineEdit()
|
| 195 |
-
self.min_size.setText("15")
|
| 196 |
-
self.min_size.setFixedWidth(50)
|
| 197 |
-
self.min_size.setFont(self.medfont)
|
| 198 |
-
self.min_size.setToolTip(
|
| 199 |
-
"all masks less than this size in pixels (volume) will be removed"
|
| 200 |
-
)
|
| 201 |
-
self.segBoxG.addWidget(self.min_size, b, 3, 1, 1)
|
| 202 |
-
|
| 203 |
-
b += 1
|
| 204 |
-
self.orthobtn = QCheckBox("ortho")
|
| 205 |
-
self.orthobtn.setToolTip("activate orthoviews with 3D image")
|
| 206 |
-
self.orthobtn.setFont(self.medfont)
|
| 207 |
-
self.orthobtn.setChecked(False)
|
| 208 |
-
self.l0.addWidget(self.orthobtn, b, 0, 1, 2)
|
| 209 |
-
self.orthobtn.toggled.connect(self.toggle_ortho)
|
| 210 |
-
|
| 211 |
-
label = QLabel("dz:")
|
| 212 |
-
label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 213 |
-
label.setFont(self.medfont)
|
| 214 |
-
self.l0.addWidget(label, b, 2, 1, 1)
|
| 215 |
-
self.dz = 10
|
| 216 |
-
self.dzedit = QLineEdit()
|
| 217 |
-
self.dzedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 218 |
-
self.dzedit.setText(str(self.dz))
|
| 219 |
-
self.dzedit.returnPressed.connect(self.update_ortho)
|
| 220 |
-
self.dzedit.setFixedWidth(40)
|
| 221 |
-
self.dzedit.setFont(self.medfont)
|
| 222 |
-
self.l0.addWidget(self.dzedit, b, 3, 1, 2)
|
| 223 |
-
|
| 224 |
-
label = QLabel("z-aspect:")
|
| 225 |
-
label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 226 |
-
label.setFont(self.medfont)
|
| 227 |
-
self.l0.addWidget(label, b, 5, 1, 2)
|
| 228 |
-
self.zaspect = 1.0
|
| 229 |
-
self.zaspectedit = QLineEdit()
|
| 230 |
-
self.zaspectedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 231 |
-
self.zaspectedit.setText(str(self.zaspect))
|
| 232 |
-
self.zaspectedit.returnPressed.connect(self.update_ortho)
|
| 233 |
-
self.zaspectedit.setFixedWidth(40)
|
| 234 |
-
self.zaspectedit.setFont(self.medfont)
|
| 235 |
-
self.l0.addWidget(self.zaspectedit, b, 7, 1, 2)
|
| 236 |
-
|
| 237 |
-
b += 1
|
| 238 |
-
# add z position underneath
|
| 239 |
-
self.currentZ = 0
|
| 240 |
-
label = QLabel("Z:")
|
| 241 |
-
label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 242 |
-
self.l0.addWidget(label, b, 5, 1, 2)
|
| 243 |
-
self.zpos = QLineEdit()
|
| 244 |
-
self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 245 |
-
self.zpos.setText(str(self.currentZ))
|
| 246 |
-
self.zpos.returnPressed.connect(self.update_ztext)
|
| 247 |
-
self.zpos.setFixedWidth(40)
|
| 248 |
-
self.zpos.setFont(self.medfont)
|
| 249 |
-
self.l0.addWidget(self.zpos, b, 7, 1, 2)
|
| 250 |
-
|
| 251 |
-
# if called with image, load it
|
| 252 |
-
if image is not None:
|
| 253 |
-
self.filename = image
|
| 254 |
-
io._load_image(self, self.filename, load_3D=True)
|
| 255 |
-
|
| 256 |
-
self.load_3D = True
|
| 257 |
-
|
| 258 |
-
def add_mask(self, points=None, color=(100, 200, 50), dense=True):
|
| 259 |
-
# points is list of strokes
|
| 260 |
-
|
| 261 |
-
points_all = np.concatenate(points, axis=0)
|
| 262 |
-
|
| 263 |
-
# loop over z values
|
| 264 |
-
median = []
|
| 265 |
-
zdraw = np.unique(points_all[:, 0])
|
| 266 |
-
zrange = np.arange(zdraw.min(), zdraw.max() + 1, 1, int)
|
| 267 |
-
zmin = zdraw.min()
|
| 268 |
-
pix = np.zeros((2, 0), "uint16")
|
| 269 |
-
mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool")
|
| 270 |
-
k = 0
|
| 271 |
-
for z in zdraw:
|
| 272 |
-
ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
|
| 273 |
-
0, "int"), np.zeros(0, "int")
|
| 274 |
-
for stroke in points:
|
| 275 |
-
stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
|
| 276 |
-
iz = stroke[:, 0] == z
|
| 277 |
-
vr = stroke[iz, 1]
|
| 278 |
-
vc = stroke[iz, 2]
|
| 279 |
-
if iz.sum() > 0:
|
| 280 |
-
# get points inside drawn points
|
| 281 |
-
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8")
|
| 282 |
-
pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
|
| 283 |
-
axis=-1)[:, np.newaxis, :]
|
| 284 |
-
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
|
| 285 |
-
ar, ac = np.nonzero(mask)
|
| 286 |
-
ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
|
| 287 |
-
# get dense outline
|
| 288 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 289 |
-
cv2.CHAIN_APPROX_NONE)
|
| 290 |
-
pvc, pvr = contours[-2][0].squeeze().T
|
| 291 |
-
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
|
| 292 |
-
# concatenate all points
|
| 293 |
-
ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
|
| 294 |
-
# if these pixels are overlapping with another cell, reassign them
|
| 295 |
-
ioverlap = self.cellpix[z][ar, ac] > 0
|
| 296 |
-
if (~ioverlap).sum() < 8:
|
| 297 |
-
print("ERROR: cell too small without overlaps, not drawn")
|
| 298 |
-
return None
|
| 299 |
-
elif ioverlap.sum() > 0:
|
| 300 |
-
ar, ac = ar[~ioverlap], ac[~ioverlap]
|
| 301 |
-
# compute outline of new mask
|
| 302 |
-
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8")
|
| 303 |
-
mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
|
| 304 |
-
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 305 |
-
cv2.CHAIN_APPROX_NONE)
|
| 306 |
-
pvc, pvr = contours[-2][0].squeeze().T
|
| 307 |
-
vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
|
| 308 |
-
ars = np.concatenate((ars, ar), axis=0)
|
| 309 |
-
acs = np.concatenate((acs, ac), axis=0)
|
| 310 |
-
vrs = np.concatenate((vrs, vr), axis=0)
|
| 311 |
-
vcs = np.concatenate((vcs, vc), axis=0)
|
| 312 |
-
self.draw_mask(z, ars, acs, vrs, vcs, color)
|
| 313 |
-
|
| 314 |
-
median.append(np.array([np.median(ars), np.median(acs)]))
|
| 315 |
-
mall[z - zmin, ars, acs] = True
|
| 316 |
-
pix = np.append(pix, np.vstack((ars, acs)), axis=-1)
|
| 317 |
-
|
| 318 |
-
mall = mall[:, pix[0].min():pix[0].max() + 1,
|
| 319 |
-
pix[1].min():pix[1].max() + 1].astype("float32")
|
| 320 |
-
ymin, xmin = pix[0].min(), pix[1].min()
|
| 321 |
-
if len(zdraw) > 1:
|
| 322 |
-
mall, zfill = interpZ(mall, zdraw - zmin)
|
| 323 |
-
for z in zfill:
|
| 324 |
-
mask = mall[z].copy()
|
| 325 |
-
ar, ac = np.nonzero(mask)
|
| 326 |
-
ioverlap = self.cellpix[z + zmin][ar + ymin, ac + xmin] > 0
|
| 327 |
-
if (~ioverlap).sum() < 5:
|
| 328 |
-
print("WARNING: stroke on plane %d not included due to overlaps" %
|
| 329 |
-
z)
|
| 330 |
-
elif ioverlap.sum() > 0:
|
| 331 |
-
mask[ar[ioverlap], ac[ioverlap]] = 0
|
| 332 |
-
ar, ac = ar[~ioverlap], ac[~ioverlap]
|
| 333 |
-
# compute outline of mask
|
| 334 |
-
outlines = masks_to_outlines(mask)
|
| 335 |
-
vr, vc = np.nonzero(outlines)
|
| 336 |
-
vr, vc = vr + ymin, vc + xmin
|
| 337 |
-
ar, ac = ar + ymin, ac + xmin
|
| 338 |
-
self.draw_mask(z + zmin, ar, ac, vr, vc, color)
|
| 339 |
-
|
| 340 |
-
self.zdraw.append(zdraw)
|
| 341 |
-
|
| 342 |
-
return median
|
| 343 |
-
|
| 344 |
-
def move_in_Z(self):
|
| 345 |
-
if self.loaded:
|
| 346 |
-
self.currentZ = min(self.NZ, max(0, int(self.scroll.value())))
|
| 347 |
-
self.zpos.setText(str(self.currentZ))
|
| 348 |
-
self.update_plot()
|
| 349 |
-
self.draw_layer()
|
| 350 |
-
self.update_layer()
|
| 351 |
-
|
| 352 |
-
def make_orthoviews(self):
|
| 353 |
-
self.pOrtho, self.imgOrtho, self.layerOrtho = [], [], []
|
| 354 |
-
for j in range(2):
|
| 355 |
-
self.pOrtho.append(
|
| 356 |
-
pg.ViewBox(lockAspect=True, name=f"plotOrtho{j}",
|
| 357 |
-
border=[100, 100, 100], invertY=True, enableMouse=False))
|
| 358 |
-
self.pOrtho[j].setMenuEnabled(False)
|
| 359 |
-
|
| 360 |
-
self.imgOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
|
| 361 |
-
self.imgOrtho[j].autoDownsample = False
|
| 362 |
-
|
| 363 |
-
self.layerOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
|
| 364 |
-
self.layerOrtho[j].setLevels([0., 255.])
|
| 365 |
-
|
| 366 |
-
#self.pOrtho[j].scene().contextMenuItem = self.pOrtho[j]
|
| 367 |
-
self.pOrtho[j].addItem(self.imgOrtho[j])
|
| 368 |
-
self.pOrtho[j].addItem(self.layerOrtho[j])
|
| 369 |
-
self.pOrtho[j].addItem(self.vLineOrtho[j], ignoreBounds=False)
|
| 370 |
-
self.pOrtho[j].addItem(self.hLineOrtho[j], ignoreBounds=False)
|
| 371 |
-
|
| 372 |
-
self.pOrtho[0].linkView(self.pOrtho[0].YAxis, self.p0)
|
| 373 |
-
self.pOrtho[1].linkView(self.pOrtho[1].XAxis, self.p0)
|
| 374 |
-
|
| 375 |
-
def add_orthoviews(self):
|
| 376 |
-
self.yortho = self.Ly // 2
|
| 377 |
-
self.xortho = self.Lx // 2
|
| 378 |
-
if self.NZ > 1:
|
| 379 |
-
self.update_ortho()
|
| 380 |
-
|
| 381 |
-
self.win.addItem(self.pOrtho[0], 0, 1, rowspan=1, colspan=1)
|
| 382 |
-
self.win.addItem(self.pOrtho[1], 1, 0, rowspan=1, colspan=1)
|
| 383 |
-
|
| 384 |
-
qGraphicsGridLayout = self.win.ci.layout
|
| 385 |
-
qGraphicsGridLayout.setColumnStretchFactor(0, 2)
|
| 386 |
-
qGraphicsGridLayout.setColumnStretchFactor(1, 1)
|
| 387 |
-
qGraphicsGridLayout.setRowStretchFactor(0, 2)
|
| 388 |
-
qGraphicsGridLayout.setRowStretchFactor(1, 1)
|
| 389 |
-
|
| 390 |
-
self.pOrtho[0].setYRange(0, self.Lx)
|
| 391 |
-
self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 392 |
-
self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 393 |
-
self.pOrtho[1].setXRange(0, self.Ly)
|
| 394 |
-
|
| 395 |
-
self.p0.addItem(self.vLine, ignoreBounds=False)
|
| 396 |
-
self.p0.addItem(self.hLine, ignoreBounds=False)
|
| 397 |
-
self.p0.setYRange(0, self.Lx)
|
| 398 |
-
self.p0.setXRange(0, self.Ly)
|
| 399 |
-
|
| 400 |
-
self.win.show()
|
| 401 |
-
self.show()
|
| 402 |
-
|
| 403 |
-
def remove_orthoviews(self):
|
| 404 |
-
self.win.removeItem(self.pOrtho[0])
|
| 405 |
-
self.win.removeItem(self.pOrtho[1])
|
| 406 |
-
self.p0.removeItem(self.vLine)
|
| 407 |
-
self.p0.removeItem(self.hLine)
|
| 408 |
-
self.win.show()
|
| 409 |
-
self.show()
|
| 410 |
-
|
| 411 |
-
def update_crosshairs(self):
|
| 412 |
-
self.yortho = min(self.Ly - 1, max(0, int(self.yortho)))
|
| 413 |
-
self.xortho = min(self.Lx - 1, max(0, int(self.xortho)))
|
| 414 |
-
self.vLine.setPos(self.xortho)
|
| 415 |
-
self.hLine.setPos(self.yortho)
|
| 416 |
-
self.vLineOrtho[1].setPos(self.xortho)
|
| 417 |
-
self.hLineOrtho[1].setPos(self.zc)
|
| 418 |
-
self.vLineOrtho[0].setPos(self.zc)
|
| 419 |
-
self.hLineOrtho[0].setPos(self.yortho)
|
| 420 |
-
|
| 421 |
-
def update_ortho(self):
|
| 422 |
-
if self.NZ > 1 and self.orthobtn.isChecked():
|
| 423 |
-
dzcurrent = self.dz
|
| 424 |
-
self.dz = min(100, max(3, int(self.dzedit.text())))
|
| 425 |
-
self.zaspect = max(0.01, min(100., float(self.zaspectedit.text())))
|
| 426 |
-
self.dzedit.setText(str(self.dz))
|
| 427 |
-
self.zaspectedit.setText(str(self.zaspect))
|
| 428 |
-
if self.dz != dzcurrent:
|
| 429 |
-
self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 430 |
-
self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 431 |
-
dztot = min(self.NZ, self.dz * 2)
|
| 432 |
-
y = self.yortho
|
| 433 |
-
x = self.xortho
|
| 434 |
-
z = self.currentZ
|
| 435 |
-
if dztot == self.NZ:
|
| 436 |
-
zmin, zmax = 0, self.NZ
|
| 437 |
-
else:
|
| 438 |
-
if z - self.dz < 0:
|
| 439 |
-
zmin = 0
|
| 440 |
-
zmax = zmin + self.dz * 2
|
| 441 |
-
elif z + self.dz >= self.NZ:
|
| 442 |
-
zmax = self.NZ
|
| 443 |
-
zmin = zmax - self.dz * 2
|
| 444 |
-
else:
|
| 445 |
-
zmin, zmax = z - self.dz, z + self.dz
|
| 446 |
-
self.zc = z - zmin
|
| 447 |
-
self.update_crosshairs()
|
| 448 |
-
if self.view == 0 or self.view == 4:
|
| 449 |
-
for j in range(2):
|
| 450 |
-
if j == 0:
|
| 451 |
-
if self.view == 0:
|
| 452 |
-
image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy()
|
| 453 |
-
else:
|
| 454 |
-
image = self.stack_filtered[zmin:zmax, :,
|
| 455 |
-
x].transpose(1, 0, 2).copy()
|
| 456 |
-
else:
|
| 457 |
-
image = self.stack[
|
| 458 |
-
zmin:zmax,
|
| 459 |
-
y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax,
|
| 460 |
-
y, :].copy()
|
| 461 |
-
if self.nchan == 1:
|
| 462 |
-
# show single channel
|
| 463 |
-
image = image[..., 0]
|
| 464 |
-
if self.color == 0:
|
| 465 |
-
self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
|
| 466 |
-
if self.nchan > 1:
|
| 467 |
-
levels = np.array([
|
| 468 |
-
self.saturation[0][self.currentZ],
|
| 469 |
-
self.saturation[1][self.currentZ],
|
| 470 |
-
self.saturation[2][self.currentZ]
|
| 471 |
-
])
|
| 472 |
-
self.imgOrtho[j].setLevels(levels)
|
| 473 |
-
else:
|
| 474 |
-
self.imgOrtho[j].setLevels(
|
| 475 |
-
self.saturation[0][self.currentZ])
|
| 476 |
-
elif self.color > 0 and self.color < 4:
|
| 477 |
-
if self.nchan > 1:
|
| 478 |
-
image = image[..., self.color - 1]
|
| 479 |
-
self.imgOrtho[j].setImage(image, autoLevels=False,
|
| 480 |
-
lut=self.cmap[self.color])
|
| 481 |
-
if self.nchan > 1:
|
| 482 |
-
self.imgOrtho[j].setLevels(
|
| 483 |
-
self.saturation[self.color - 1][self.currentZ])
|
| 484 |
-
else:
|
| 485 |
-
self.imgOrtho[j].setLevels(
|
| 486 |
-
self.saturation[0][self.currentZ])
|
| 487 |
-
elif self.color == 4:
|
| 488 |
-
if image.ndim > 2:
|
| 489 |
-
image = image.astype("float32").mean(axis=2).astype("uint8")
|
| 490 |
-
self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
|
| 491 |
-
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
|
| 492 |
-
elif self.color == 5:
|
| 493 |
-
if image.ndim > 2:
|
| 494 |
-
image = image.astype("float32").mean(axis=2).astype("uint8")
|
| 495 |
-
self.imgOrtho[j].setImage(image, autoLevels=False,
|
| 496 |
-
lut=self.cmap[0])
|
| 497 |
-
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
|
| 498 |
-
self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect)
|
| 499 |
-
self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect)
|
| 500 |
-
|
| 501 |
-
else:
|
| 502 |
-
image = np.zeros((10, 10), "uint8")
|
| 503 |
-
self.imgOrtho[0].setImage(image, autoLevels=False, lut=None)
|
| 504 |
-
self.imgOrtho[0].setLevels([0.0, 255.0])
|
| 505 |
-
self.imgOrtho[1].setImage(image, autoLevels=False, lut=None)
|
| 506 |
-
self.imgOrtho[1].setLevels([0.0, 255.0])
|
| 507 |
-
|
| 508 |
-
zrange = zmax - zmin
|
| 509 |
-
self.layer_ortho = [
|
| 510 |
-
np.zeros((self.Ly, zrange, 4), "uint8"),
|
| 511 |
-
np.zeros((zrange, self.Lx, 4), "uint8")
|
| 512 |
-
]
|
| 513 |
-
if self.masksOn:
|
| 514 |
-
for j in range(2):
|
| 515 |
-
if j == 0:
|
| 516 |
-
cp = self.cellpix[zmin:zmax, :, x].T
|
| 517 |
-
else:
|
| 518 |
-
cp = self.cellpix[zmin:zmax, y]
|
| 519 |
-
self.layer_ortho[j][..., :3] = self.cellcolors[cp, :]
|
| 520 |
-
self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8")
|
| 521 |
-
if self.selected > 0:
|
| 522 |
-
self.layer_ortho[j][cp == self.selected] = np.array(
|
| 523 |
-
[255, 255, 255, self.opacity])
|
| 524 |
-
|
| 525 |
-
if self.outlinesOn:
|
| 526 |
-
for j in range(2):
|
| 527 |
-
if j == 0:
|
| 528 |
-
op = self.outpix[zmin:zmax, :, x].T
|
| 529 |
-
else:
|
| 530 |
-
op = self.outpix[zmin:zmax, y]
|
| 531 |
-
self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8")
|
| 532 |
-
|
| 533 |
-
for j in range(2):
|
| 534 |
-
self.layerOrtho[j].setImage(self.layer_ortho[j])
|
| 535 |
-
self.win.show()
|
| 536 |
-
self.show()
|
| 537 |
-
|
| 538 |
-
def toggle_ortho(self):
|
| 539 |
-
if self.orthobtn.isChecked():
|
| 540 |
-
self.add_orthoviews()
|
| 541 |
-
else:
|
| 542 |
-
self.remove_orthoviews()
|
| 543 |
-
|
| 544 |
-
def plot_clicked(self, event):
|
| 545 |
-
if event.button()==QtCore.Qt.LeftButton \
|
| 546 |
-
and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
|
| 547 |
-
and not self.removing_region:
|
| 548 |
-
if event.double():
|
| 549 |
-
try:
|
| 550 |
-
self.p0.setYRange(0, self.Ly + self.pr)
|
| 551 |
-
except:
|
| 552 |
-
self.p0.setYRange(0, self.Ly)
|
| 553 |
-
self.p0.setXRange(0, self.Lx)
|
| 554 |
-
elif self.loaded and not self.in_stroke:
|
| 555 |
-
if self.orthobtn.isChecked():
|
| 556 |
-
items = self.win.scene().items(event.scenePos())
|
| 557 |
-
for x in items:
|
| 558 |
-
if x == self.p0:
|
| 559 |
-
pos = self.p0.mapSceneToView(event.scenePos())
|
| 560 |
-
x = int(pos.x())
|
| 561 |
-
y = int(pos.y())
|
| 562 |
-
if y >= 0 and y < self.Ly and x >= 0 and x < self.Lx:
|
| 563 |
-
self.yortho = y
|
| 564 |
-
self.xortho = x
|
| 565 |
-
self.update_ortho()
|
| 566 |
-
|
| 567 |
-
def update_plot(self):
|
| 568 |
-
super().update_plot()
|
| 569 |
-
if self.NZ > 1 and self.orthobtn.isChecked():
|
| 570 |
-
self.update_ortho()
|
| 571 |
-
self.win.show()
|
| 572 |
-
self.show()
|
| 573 |
-
|
| 574 |
-
def keyPressEvent(self, event):
|
| 575 |
-
if self.loaded:
|
| 576 |
-
if not (event.modifiers() &
|
| 577 |
-
(QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
|
| 578 |
-
QtCore.Qt.AltModifier) or self.in_stroke):
|
| 579 |
-
updated = False
|
| 580 |
-
if len(self.current_point_set) > 0:
|
| 581 |
-
if event.key() == QtCore.Qt.Key_Return:
|
| 582 |
-
self.add_set()
|
| 583 |
-
if self.NZ > 1:
|
| 584 |
-
if event.key() == QtCore.Qt.Key_Left:
|
| 585 |
-
self.currentZ = max(0, self.currentZ - 1)
|
| 586 |
-
self.scroll.setValue(self.currentZ)
|
| 587 |
-
updated = True
|
| 588 |
-
elif event.key() == QtCore.Qt.Key_Right:
|
| 589 |
-
self.currentZ = min(self.NZ - 1, self.currentZ + 1)
|
| 590 |
-
self.scroll.setValue(self.currentZ)
|
| 591 |
-
updated = True
|
| 592 |
-
else:
|
| 593 |
-
nviews = self.ViewDropDown.count() - 1
|
| 594 |
-
nviews += int(
|
| 595 |
-
self.ViewDropDown.model().item(self.ViewDropDown.count() -
|
| 596 |
-
1).isEnabled())
|
| 597 |
-
if event.key() == QtCore.Qt.Key_X:
|
| 598 |
-
self.MCheckBox.toggle()
|
| 599 |
-
if event.key() == QtCore.Qt.Key_Z:
|
| 600 |
-
self.OCheckBox.toggle()
|
| 601 |
-
if event.key() == QtCore.Qt.Key_Left or event.key(
|
| 602 |
-
) == QtCore.Qt.Key_A:
|
| 603 |
-
self.currentZ = max(0, self.currentZ - 1)
|
| 604 |
-
self.scroll.setValue(self.currentZ)
|
| 605 |
-
updated = True
|
| 606 |
-
elif event.key() == QtCore.Qt.Key_Right or event.key(
|
| 607 |
-
) == QtCore.Qt.Key_D:
|
| 608 |
-
self.currentZ = min(self.NZ - 1, self.currentZ + 1)
|
| 609 |
-
self.scroll.setValue(self.currentZ)
|
| 610 |
-
updated = True
|
| 611 |
-
elif event.key() == QtCore.Qt.Key_PageDown:
|
| 612 |
-
self.view = (self.view + 1) % (nviews)
|
| 613 |
-
self.ViewDropDown.setCurrentIndex(self.view)
|
| 614 |
-
elif event.key() == QtCore.Qt.Key_PageUp:
|
| 615 |
-
self.view = (self.view - 1) % (nviews)
|
| 616 |
-
self.ViewDropDown.setCurrentIndex(self.view)
|
| 617 |
-
|
| 618 |
-
# can change background or stroke size if cell not finished
|
| 619 |
-
if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
|
| 620 |
-
self.color = (self.color - 1) % (6)
|
| 621 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 622 |
-
elif event.key() == QtCore.Qt.Key_Down or event.key(
|
| 623 |
-
) == QtCore.Qt.Key_S:
|
| 624 |
-
self.color = (self.color + 1) % (6)
|
| 625 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 626 |
-
elif event.key() == QtCore.Qt.Key_R:
|
| 627 |
-
if self.color != 1:
|
| 628 |
-
self.color = 1
|
| 629 |
-
else:
|
| 630 |
-
self.color = 0
|
| 631 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 632 |
-
elif event.key() == QtCore.Qt.Key_G:
|
| 633 |
-
if self.color != 2:
|
| 634 |
-
self.color = 2
|
| 635 |
-
else:
|
| 636 |
-
self.color = 0
|
| 637 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 638 |
-
elif event.key() == QtCore.Qt.Key_B:
|
| 639 |
-
if self.color != 3:
|
| 640 |
-
self.color = 3
|
| 641 |
-
else:
|
| 642 |
-
self.color = 0
|
| 643 |
-
self.RGBDropDown.setCurrentIndex(self.color)
|
| 644 |
-
elif (event.key() == QtCore.Qt.Key_Comma or
|
| 645 |
-
event.key() == QtCore.Qt.Key_Period):
|
| 646 |
-
count = self.BrushChoose.count()
|
| 647 |
-
gci = self.BrushChoose.currentIndex()
|
| 648 |
-
if event.key() == QtCore.Qt.Key_Comma:
|
| 649 |
-
gci = max(0, gci - 1)
|
| 650 |
-
else:
|
| 651 |
-
gci = min(count - 1, gci + 1)
|
| 652 |
-
self.BrushChoose.setCurrentIndex(gci)
|
| 653 |
-
self.brush_choose()
|
| 654 |
-
if not updated:
|
| 655 |
-
self.update_plot()
|
| 656 |
-
if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
|
| 657 |
-
self.p0.keyPressEvent(event)
|
| 658 |
-
|
| 659 |
-
def update_ztext(self):
|
| 660 |
-
zpos = self.currentZ
|
| 661 |
-
try:
|
| 662 |
-
zpos = int(self.zpos.text())
|
| 663 |
-
except:
|
| 664 |
-
print("ERROR: zposition is not a number")
|
| 665 |
-
self.currentZ = max(0, min(self.NZ - 1, zpos))
|
| 666 |
-
self.zpos.setText(str(self.currentZ))
|
| 667 |
-
self.scroll.setValue(self.currentZ)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/guihelpwindowtext.html
DELETED
|
@@ -1,143 +0,0 @@
|
|
| 1 |
-
<qt>
|
| 2 |
-
<p class="has-line-data" data-line-start="5" data-line-end="6">
|
| 3 |
-
<b>Main GUI mouse controls:</b>
|
| 4 |
-
</p>
|
| 5 |
-
<ul>
|
| 6 |
-
<li class="has-line-data" data-line-start="7" data-line-end="8">Pan = left-click + drag</li>
|
| 7 |
-
<li class="has-line-data" data-line-start="8" data-line-end="9">Zoom = scroll wheel (or +/= and - buttons)</li>
|
| 8 |
-
<li class="has-line-data" data-line-start="9" data-line-end="10">Full view = double left-click</li>
|
| 9 |
-
<li class="has-line-data" data-line-start="10" data-line-end="11">Select mask = left-click on mask</li>
|
| 10 |
-
<li class="has-line-data" data-line-start="11" data-line-end="12">Delete mask = Ctrl (or COMMAND on Mac) +
|
| 11 |
-
left-click
|
| 12 |
-
</li>
|
| 13 |
-
<li class="has-line-data" data-line-start="11" data-line-end="12">Merge masks = Alt + left-click (will merge
|
| 14 |
-
last two)
|
| 15 |
-
</li>
|
| 16 |
-
<li class="has-line-data" data-line-start="12" data-line-end="13">Start draw mask = right-click</li>
|
| 17 |
-
<li class="has-line-data" data-line-start="13" data-line-end="15">End draw mask = right-click, or return to
|
| 18 |
-
circle at beginning
|
| 19 |
-
</li>
|
| 20 |
-
</ul>
|
| 21 |
-
<p class="has-line-data" data-line-start="15" data-line-end="16">Overlaps in masks are NOT allowed. If you
|
| 22 |
-
draw a mask on top of another mask, it is cropped so that it doesn’t overlap with the old mask. Masks in 2D
|
| 23 |
-
should be single strokes (single stroke is checked). If you want to draw masks in 3D (experimental), then
|
| 24 |
-
you can turn this option off and draw a stroke on each plane with the cell and then press ENTER. 3D
|
| 25 |
-
labelling will fill in planes that you have not labelled so that you do not have to as densely label.
|
| 26 |
-
</p>
|
| 27 |
-
<p class="has-line-data" data-line-start="17" data-line-end="18"> <b>!NOTE!:</b> The GUI automatically saves after
|
| 28 |
-
you draw a mask in 2D but NOT after 3D mask drawing and NOT after segmentation. Save in the file menu or
|
| 29 |
-
with Ctrl+S. The output file is in the same folder as the loaded image with <code>_seg.npy</code> appended.
|
| 30 |
-
</p>
|
| 31 |
-
|
| 32 |
-
<p class="has-line-data" data-line-start="19" data-line-end="20"> <b>Bulk Mask Deletion</b>
|
| 33 |
-
Clicking the 'delete multiple' button will allow you to select and delete multiple masks at once.
|
| 34 |
-
Masks can be deselected by clicking on them again. Once you have selected all the masks you want to delete,
|
| 35 |
-
click the 'done' button to delete them.
|
| 36 |
-
<br>
|
| 37 |
-
<br>
|
| 38 |
-
Alternatively, you can create a rectangular region to delete a regions of masks by clicking the
|
| 39 |
-
'delete multiple' button, and then moving and/or resizing the region to select the masks you want to delete.
|
| 40 |
-
Once you have selected the masks you want to delete, click the 'done' button to delete them.
|
| 41 |
-
<br>
|
| 42 |
-
<br>
|
| 43 |
-
At any point in the process, you can click the 'cancel' button to cancel the bulk deletion.
|
| 44 |
-
</p>
|
| 45 |
-
<hr>
|
| 46 |
-
<table class="table table-striped table-bordered">
|
| 47 |
-
<br>
|
| 48 |
-
<br>
|
| 49 |
-
FYI there are tooltips throughout the GUI (hover over text to see)
|
| 50 |
-
<br>
|
| 51 |
-
<thead>
|
| 52 |
-
<tr>
|
| 53 |
-
<th>Keyboard shortcuts</th>
|
| 54 |
-
<th>Description</th>
|
| 55 |
-
</tr>
|
| 56 |
-
</thead>
|
| 57 |
-
<tbody>
|
| 58 |
-
<tr>
|
| 59 |
-
<td>=/+ button // - button</td>
|
| 60 |
-
<td>zoom in // zoom out</td>
|
| 61 |
-
</tr>
|
| 62 |
-
<tr>
|
| 63 |
-
<td>CTRL+Z</td>
|
| 64 |
-
<td>undo previously drawn mask/stroke</td>
|
| 65 |
-
</tr>
|
| 66 |
-
<tr>
|
| 67 |
-
<td>CTRL+Y</td>
|
| 68 |
-
<td>undo remove mask</td>
|
| 69 |
-
</tr>
|
| 70 |
-
<tr>
|
| 71 |
-
<td>CTRL+0</td>
|
| 72 |
-
<td>clear all masks</td>
|
| 73 |
-
</tr>
|
| 74 |
-
<tr>
|
| 75 |
-
<td>CTRL+L</td>
|
| 76 |
-
<td>load image (can alternatively drag and drop image)</td>
|
| 77 |
-
</tr>
|
| 78 |
-
<tr>
|
| 79 |
-
<td>CTRL+S</td>
|
| 80 |
-
<td>SAVE MASKS IN IMAGE to <code>_seg.npy</code> file</td>
|
| 81 |
-
</tr>
|
| 82 |
-
<tr>
|
| 83 |
-
<td>CTRL+T</td>
|
| 84 |
-
<td>train model using _seg.npy files in folder
|
| 85 |
-
</tr>
|
| 86 |
-
<tr>
|
| 87 |
-
<td>CTRL+P</td>
|
| 88 |
-
<td>load <code>_seg.npy</code> file (note: it will load automatically with image if it exists)</td>
|
| 89 |
-
</tr>
|
| 90 |
-
<tr>
|
| 91 |
-
<td>CTRL+M</td>
|
| 92 |
-
<td>load masks file (must be same size as image with 0 for NO mask, and 1,2,3… for masks)</td>
|
| 93 |
-
</tr>
|
| 94 |
-
<tr>
|
| 95 |
-
<td>CTRL+N</td>
|
| 96 |
-
<td>save masks as PNG</td>
|
| 97 |
-
</tr>
|
| 98 |
-
<tr>
|
| 99 |
-
<td>CTRL+R</td>
|
| 100 |
-
<td>save ROIs to native ImageJ ROI format</td>
|
| 101 |
-
</tr>
|
| 102 |
-
<tr>
|
| 103 |
-
<td>CTRL+F</td>
|
| 104 |
-
<td>save flows to image file</td>
|
| 105 |
-
</tr>
|
| 106 |
-
<tr>
|
| 107 |
-
<td>A/D or LEFT/RIGHT</td>
|
| 108 |
-
<td>cycle through images in current directory</td>
|
| 109 |
-
</tr>
|
| 110 |
-
<tr>
|
| 111 |
-
<td>W/S or UP/DOWN</td>
|
| 112 |
-
<td>change color (RGB/gray/red/green/blue)</td>
|
| 113 |
-
</tr>
|
| 114 |
-
<tr>
|
| 115 |
-
<td>R / G / B</td>
|
| 116 |
-
<td>toggle between RGB and Red or Green or Blue</td>
|
| 117 |
-
</tr>
|
| 118 |
-
<tr>
|
| 119 |
-
<td>PAGE-UP / PAGE-DOWN</td>
|
| 120 |
-
<td>change to flows and cell prob views (if segmentation computed)</td>
|
| 121 |
-
</tr>
|
| 122 |
-
<tr>
|
| 123 |
-
<td>X</td>
|
| 124 |
-
<td>turn masks ON or OFF</td>
|
| 125 |
-
</tr>
|
| 126 |
-
<tr>
|
| 127 |
-
<td>Z</td>
|
| 128 |
-
<td>toggle outlines ON or OFF</td>
|
| 129 |
-
</tr>
|
| 130 |
-
<tr>
|
| 131 |
-
<td>, / .</td>
|
| 132 |
-
<td>increase / decrease brush size for drawing masks</td>
|
| 133 |
-
</tr>
|
| 134 |
-
</tbody>
|
| 135 |
-
</table>
|
| 136 |
-
<p class="has-line-data" data-line-start="36" data-line-end="37"><strong>Segmentation options
|
| 137 |
-
(2D only) </strong></p>
|
| 138 |
-
<p class="has-line-data" data-line-start="38" data-line-end="39">use GPU: if you have specially
|
| 139 |
-
installed the cuda version of torch, then you can activate this. Due to the size of the
|
| 140 |
-
transformer network, it will greatly speed up the processing time.</p>
|
| 141 |
-
<p class="has-line-data" data-line-start="40" data-line-end="41">There are no channel options
|
| 142 |
-
in v4.0.1+ since all 3 channels are used for segmentation. </p>
|
| 143 |
-
</qt>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/guiparts.py
DELETED
|
@@ -1,793 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
from qtpy import QtGui, QtCore
|
| 5 |
-
from qtpy.QtGui import QPixmap, QDoubleValidator
|
| 6 |
-
from qtpy.QtWidgets import QWidget, QDialog, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox, QVBoxLayout
|
| 7 |
-
import pyqtgraph as pg
|
| 8 |
-
import numpy as np
|
| 9 |
-
import pathlib, os
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def stylesheet():
|
| 13 |
-
return """
|
| 14 |
-
QToolTip {
|
| 15 |
-
background-color: black;
|
| 16 |
-
color: white;
|
| 17 |
-
border: black solid 1px
|
| 18 |
-
}
|
| 19 |
-
QComboBox {color: white;
|
| 20 |
-
background-color: rgb(40,40,40);}
|
| 21 |
-
QComboBox::item:enabled { color: white;
|
| 22 |
-
background-color: rgb(40,40,40);
|
| 23 |
-
selection-color: white;
|
| 24 |
-
selection-background-color: rgb(50,100,50);}
|
| 25 |
-
QComboBox::item:!enabled {
|
| 26 |
-
background-color: rgb(40,40,40);
|
| 27 |
-
color: rgb(100,100,100);
|
| 28 |
-
}
|
| 29 |
-
QScrollArea > QWidget > QWidget
|
| 30 |
-
{
|
| 31 |
-
background: transparent;
|
| 32 |
-
border: none;
|
| 33 |
-
margin: 0px 0px 0px 0px;
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
QGroupBox
|
| 37 |
-
{ border: 1px solid white; color: rgb(255,255,255);
|
| 38 |
-
border-radius: 6px;
|
| 39 |
-
margin-top: 8px;
|
| 40 |
-
padding: 0px 0px;}
|
| 41 |
-
|
| 42 |
-
QPushButton:pressed {Text-align: center;
|
| 43 |
-
background-color: rgb(150,50,150);
|
| 44 |
-
border-color: white;
|
| 45 |
-
color:white;}
|
| 46 |
-
QToolTip {
|
| 47 |
-
background-color: black;
|
| 48 |
-
color: white;
|
| 49 |
-
border: black solid 1px
|
| 50 |
-
}
|
| 51 |
-
QPushButton:!pressed {Text-align: center;
|
| 52 |
-
background-color: rgb(50,50,50);
|
| 53 |
-
border-color: white;
|
| 54 |
-
color:white;}
|
| 55 |
-
QToolTip {
|
| 56 |
-
background-color: black;
|
| 57 |
-
color: white;
|
| 58 |
-
border: black solid 1px
|
| 59 |
-
}
|
| 60 |
-
QPushButton:disabled {Text-align: center;
|
| 61 |
-
background-color: rgb(30,30,30);
|
| 62 |
-
border-color: white;
|
| 63 |
-
color:rgb(80,80,80);}
|
| 64 |
-
QToolTip {
|
| 65 |
-
background-color: black;
|
| 66 |
-
color: white;
|
| 67 |
-
border: black solid 1px
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class DarkPalette(QtGui.QPalette):
|
| 74 |
-
"""Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application.
|
| 75 |
-
(from pykilosort/kilosort4)
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
def __init__(self):
|
| 79 |
-
QtGui.QPalette.__init__(self)
|
| 80 |
-
self.setup()
|
| 81 |
-
|
| 82 |
-
def setup(self):
|
| 83 |
-
self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40))
|
| 84 |
-
self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255))
|
| 85 |
-
self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24))
|
| 86 |
-
self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47))
|
| 87 |
-
self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255))
|
| 88 |
-
self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255))
|
| 89 |
-
self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255))
|
| 90 |
-
self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47))
|
| 91 |
-
self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255))
|
| 92 |
-
self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0))
|
| 93 |
-
self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218))
|
| 94 |
-
self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218))
|
| 95 |
-
self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0))
|
| 96 |
-
self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text,
|
| 97 |
-
QtGui.QColor(128, 128, 128))
|
| 98 |
-
self.setColor(
|
| 99 |
-
QtGui.QPalette.Disabled,
|
| 100 |
-
QtGui.QPalette.ButtonText,
|
| 101 |
-
QtGui.QColor(128, 128, 128),
|
| 102 |
-
)
|
| 103 |
-
self.setColor(
|
| 104 |
-
QtGui.QPalette.Disabled,
|
| 105 |
-
QtGui.QPalette.WindowText,
|
| 106 |
-
QtGui.QColor(128, 128, 128),
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
# def create_channel_choose():
|
| 111 |
-
# # choose channel
|
| 112 |
-
# ChannelChoose = [QComboBox(), QComboBox()]
|
| 113 |
-
# ChannelLabels = []
|
| 114 |
-
# ChannelChoose[0].addItems(["gray", "red", "green", "blue"])
|
| 115 |
-
# ChannelChoose[1].addItems(["none", "red", "green", "blue"])
|
| 116 |
-
# cstr = ["chan to segment:", "chan2 (optional): "]
|
| 117 |
-
# for i in range(2):
|
| 118 |
-
# ChannelLabels.append(QLabel(cstr[i]))
|
| 119 |
-
# if i == 0:
|
| 120 |
-
# ChannelLabels[i].setToolTip(
|
| 121 |
-
# "this is the channel in which the cytoplasm or nuclei exist \
|
| 122 |
-
# that you want to segment")
|
| 123 |
-
# ChannelChoose[i].setToolTip(
|
| 124 |
-
# "this is the channel in which the cytoplasm or nuclei exist \
|
| 125 |
-
# that you want to segment")
|
| 126 |
-
# else:
|
| 127 |
-
# ChannelLabels[i].setToolTip(
|
| 128 |
-
# "if <em>cytoplasm</em> model is chosen, and you also have a \
|
| 129 |
-
# nuclear channel, then choose the nuclear channel for this option")
|
| 130 |
-
# ChannelChoose[i].setToolTip(
|
| 131 |
-
# "if <em>cytoplasm</em> model is chosen, and you also have a \
|
| 132 |
-
# nuclear channel, then choose the nuclear channel for this option")
|
| 133 |
-
|
| 134 |
-
# return ChannelChoose, ChannelLabels
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class ModelButton(QPushButton):
|
| 138 |
-
|
| 139 |
-
def __init__(self, parent, model_name, text):
|
| 140 |
-
super().__init__()
|
| 141 |
-
self.setEnabled(False)
|
| 142 |
-
self.setText(text)
|
| 143 |
-
self.setFont(parent.boldfont)
|
| 144 |
-
self.clicked.connect(lambda: self.press(parent))
|
| 145 |
-
self.model_name = "cpsam"
|
| 146 |
-
|
| 147 |
-
def press(self, parent):
|
| 148 |
-
parent.compute_segmentation(model_name="cpsam")
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
class FilterButton(QPushButton):
|
| 152 |
-
|
| 153 |
-
def __init__(self, parent, text):
|
| 154 |
-
super().__init__()
|
| 155 |
-
self.setEnabled(False)
|
| 156 |
-
self.model_type = text
|
| 157 |
-
self.setText(text)
|
| 158 |
-
self.setFont(parent.medfont)
|
| 159 |
-
self.clicked.connect(lambda: self.press(parent))
|
| 160 |
-
|
| 161 |
-
def press(self, parent):
|
| 162 |
-
if self.model_type == "filter":
|
| 163 |
-
parent.restore = "filter"
|
| 164 |
-
normalize_params = parent.get_normalize_params()
|
| 165 |
-
if (normalize_params["sharpen_radius"] == 0 and
|
| 166 |
-
normalize_params["smooth_radius"] == 0 and
|
| 167 |
-
normalize_params["tile_norm_blocksize"] == 0):
|
| 168 |
-
print(
|
| 169 |
-
"GUI_ERROR: no filtering settings on (use custom filter settings)")
|
| 170 |
-
parent.restore = None
|
| 171 |
-
return
|
| 172 |
-
parent.restore = self.model_type
|
| 173 |
-
parent.compute_saturation()
|
| 174 |
-
# elif self.model_type != "none":
|
| 175 |
-
# parent.compute_denoise_model(model_type=self.model_type)
|
| 176 |
-
else:
|
| 177 |
-
parent.clear_restore()
|
| 178 |
-
# parent.set_restore_button()
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
class ObservableVariable(QtCore.QObject):
|
| 182 |
-
valueChanged = QtCore.Signal(object)
|
| 183 |
-
|
| 184 |
-
def __init__(self, initial=None):
|
| 185 |
-
super().__init__()
|
| 186 |
-
self._value = initial
|
| 187 |
-
|
| 188 |
-
def set(self, new_value):
|
| 189 |
-
""" Use this method to get emit the value changing and update the ROI count"""
|
| 190 |
-
if new_value != self._value:
|
| 191 |
-
self._value = new_value
|
| 192 |
-
self.valueChanged.emit(new_value)
|
| 193 |
-
|
| 194 |
-
def get(self):
|
| 195 |
-
return self._value
|
| 196 |
-
|
| 197 |
-
def __call__(self):
|
| 198 |
-
return self._value
|
| 199 |
-
|
| 200 |
-
def reset(self):
|
| 201 |
-
self.set(0)
|
| 202 |
-
|
| 203 |
-
def __iadd__(self, amount):
|
| 204 |
-
if not isinstance(amount, (int, float)):
|
| 205 |
-
raise TypeError("Value must be numeric.")
|
| 206 |
-
self.set(self._value + amount)
|
| 207 |
-
return self
|
| 208 |
-
|
| 209 |
-
def __radd__(self, other):
|
| 210 |
-
return other + self._value
|
| 211 |
-
|
| 212 |
-
def __add__(self, other):
|
| 213 |
-
return other + self._value
|
| 214 |
-
|
| 215 |
-
def __isub__(self, amount):
|
| 216 |
-
if not isinstance(amount, (int, float)):
|
| 217 |
-
raise TypeError("Value must be numeric.")
|
| 218 |
-
self.set(self._value - amount)
|
| 219 |
-
return self
|
| 220 |
-
|
| 221 |
-
def __str__(self):
|
| 222 |
-
return str(self._value)
|
| 223 |
-
|
| 224 |
-
def __lt__(self, x):
|
| 225 |
-
return self._value < x
|
| 226 |
-
|
| 227 |
-
def __gt__(self, x):
|
| 228 |
-
return self._value > x
|
| 229 |
-
|
| 230 |
-
def __eq__(self, x):
|
| 231 |
-
return self._value == x
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
class NormalizationSettings(QWidget):
|
| 235 |
-
# TODO
|
| 236 |
-
pass
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class SegmentationSettings(QWidget):
|
| 240 |
-
""" Container for gui settings. Validation is done automatically so any attributes can
|
| 241 |
-
be acessed without concern.
|
| 242 |
-
"""
|
| 243 |
-
def __init__(self, font):
|
| 244 |
-
super().__init__()
|
| 245 |
-
|
| 246 |
-
# Put everything in a grid layout:
|
| 247 |
-
grid_layout = QGridLayout()
|
| 248 |
-
widget_container = QWidget()
|
| 249 |
-
widget_container.setLayout(grid_layout)
|
| 250 |
-
row = 0
|
| 251 |
-
|
| 252 |
-
########################### Diameter ###########################
|
| 253 |
-
# TODO: Validate inputs
|
| 254 |
-
diam_qlabel = QLabel("diameter:")
|
| 255 |
-
diam_qlabel.setToolTip("diameter of cells in pixels. If not 30, image will be resized to this")
|
| 256 |
-
diam_qlabel.setFont(font)
|
| 257 |
-
grid_layout.addWidget(diam_qlabel, row, 0, 1, 2)
|
| 258 |
-
self.diameter_box = QLineEdit()
|
| 259 |
-
self.diameter_box.setToolTip("diameter of cells in pixels. If not blank, image will be resized relative to 30 pixel cell diameters")
|
| 260 |
-
self.diameter_box.setFont(font)
|
| 261 |
-
self.diameter_box.setFixedWidth(40)
|
| 262 |
-
self.diameter_box.setText(' ')
|
| 263 |
-
grid_layout.addWidget(self.diameter_box, row, 2, 1, 2)
|
| 264 |
-
|
| 265 |
-
row += 1
|
| 266 |
-
|
| 267 |
-
########################### Flow threshold ###########################
|
| 268 |
-
# TODO: Validate inputs
|
| 269 |
-
flow_threshold_qlabel = QLabel("flow\nthreshold:")
|
| 270 |
-
flow_threshold_qlabel.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
|
| 271 |
-
flow_threshold_qlabel.setFont(font)
|
| 272 |
-
grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2)
|
| 273 |
-
self.flow_threshold_box = QLineEdit()
|
| 274 |
-
self.flow_threshold_box.setText("0.4")
|
| 275 |
-
self.flow_threshold_box.setFixedWidth(40)
|
| 276 |
-
self.flow_threshold_box.setFont(font)
|
| 277 |
-
grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2)
|
| 278 |
-
self.flow_threshold_box.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
|
| 279 |
-
|
| 280 |
-
########################### Cellprob threshold ###########################
|
| 281 |
-
# TODO: Validate inputs
|
| 282 |
-
cellprob_qlabel = QLabel("cellprob\nthreshold:")
|
| 283 |
-
cellprob_qlabel.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
|
| 284 |
-
cellprob_qlabel.setFont(font)
|
| 285 |
-
grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2)
|
| 286 |
-
self.cellprob_threshold_box = QLineEdit()
|
| 287 |
-
self.cellprob_threshold_box.setText("0.0")
|
| 288 |
-
self.cellprob_threshold_box.setFixedWidth(40)
|
| 289 |
-
self.cellprob_threshold_box.setFont(font)
|
| 290 |
-
self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
|
| 291 |
-
grid_layout.addWidget(self.cellprob_threshold_box, row, 6, 1, 2)
|
| 292 |
-
|
| 293 |
-
row += 1
|
| 294 |
-
|
| 295 |
-
########################### Norm percentiles ###########################
|
| 296 |
-
norm_percentiles_qlabel = QLabel("norm percentiles:")
|
| 297 |
-
norm_percentiles_qlabel.setToolTip("sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)")
|
| 298 |
-
norm_percentiles_qlabel.setFont(font)
|
| 299 |
-
grid_layout.addWidget(norm_percentiles_qlabel, row, 0, 1, 8)
|
| 300 |
-
|
| 301 |
-
row += 1
|
| 302 |
-
validator = QDoubleValidator(0.0, 100.0, 2)
|
| 303 |
-
validator.setNotation(QDoubleValidator.StandardNotation)
|
| 304 |
-
|
| 305 |
-
low_norm_qlabel = QLabel('lower:')
|
| 306 |
-
low_norm_qlabel.setToolTip("pixels at this percentile set to 0 (default 1.0)")
|
| 307 |
-
low_norm_qlabel.setFont(font)
|
| 308 |
-
grid_layout.addWidget(low_norm_qlabel, row, 0, 1, 2)
|
| 309 |
-
self.norm_percentile_low_box = QLineEdit()
|
| 310 |
-
self.norm_percentile_low_box.setText("1.0")
|
| 311 |
-
self.norm_percentile_low_box.setFont(font)
|
| 312 |
-
self.norm_percentile_low_box.setFixedWidth(40)
|
| 313 |
-
self.norm_percentile_low_box.setToolTip("pixels at this percentile set to 0 (default 1.0)")
|
| 314 |
-
self.norm_percentile_low_box.setValidator(validator)
|
| 315 |
-
self.norm_percentile_low_box.editingFinished.connect(self.validate_normalization_range)
|
| 316 |
-
grid_layout.addWidget(self.norm_percentile_low_box, row, 2, 1, 1)
|
| 317 |
-
|
| 318 |
-
high_norm_qlabel = QLabel('upper:')
|
| 319 |
-
high_norm_qlabel.setToolTip("pixels at this percentile set to 1 (default 99.0)")
|
| 320 |
-
high_norm_qlabel.setFont(font)
|
| 321 |
-
grid_layout.addWidget(high_norm_qlabel, row, 4, 1, 2)
|
| 322 |
-
self.norm_percentile_high_box = QLineEdit()
|
| 323 |
-
self.norm_percentile_high_box.setText("99.0")
|
| 324 |
-
self.norm_percentile_high_box.setFont(font)
|
| 325 |
-
self.norm_percentile_high_box.setFixedWidth(40)
|
| 326 |
-
self.norm_percentile_high_box.setToolTip("pixels at this percentile set to 1 (default 99.0)")
|
| 327 |
-
self.norm_percentile_high_box.setValidator(validator)
|
| 328 |
-
self.norm_percentile_high_box.editingFinished.connect(self.validate_normalization_range)
|
| 329 |
-
grid_layout.addWidget(self.norm_percentile_high_box, row, 6, 1, 2)
|
| 330 |
-
|
| 331 |
-
row += 1
|
| 332 |
-
|
| 333 |
-
########################### niter ###########################
|
| 334 |
-
# TODO: change this to follow the same default logic as 'diameter' above
|
| 335 |
-
# TODO: input validation
|
| 336 |
-
niter_qlabel = QLabel("niter dynamics:")
|
| 337 |
-
niter_qlabel.setFont(font)
|
| 338 |
-
niter_qlabel.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
|
| 339 |
-
grid_layout.addWidget(niter_qlabel, row, 0, 1, 4)
|
| 340 |
-
self.niter_box = QLineEdit()
|
| 341 |
-
self.niter_box.setText("0")
|
| 342 |
-
self.niter_box.setFixedWidth(40)
|
| 343 |
-
self.niter_box.setFont(font)
|
| 344 |
-
self.niter_box.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
|
| 345 |
-
grid_layout.addWidget(self.niter_box, row, 4, 1, 2)
|
| 346 |
-
|
| 347 |
-
self.setLayout(grid_layout)
|
| 348 |
-
|
| 349 |
-
def validate_normalization_range(self):
|
| 350 |
-
low_text = self.norm_percentile_low_box.text()
|
| 351 |
-
high_text = self.norm_percentile_high_box.text()
|
| 352 |
-
|
| 353 |
-
if not low_text or low_text.isspace():
|
| 354 |
-
self.norm_percentile_low_box.setText('1.0')
|
| 355 |
-
low_text = '1.0'
|
| 356 |
-
elif not high_text or high_text.isspace():
|
| 357 |
-
self.norm_percentile_high_box.setText('1.0')
|
| 358 |
-
high_text = '99.0'
|
| 359 |
-
|
| 360 |
-
low = float(low_text)
|
| 361 |
-
high = float(high_text)
|
| 362 |
-
|
| 363 |
-
if low >= high:
|
| 364 |
-
# Invalid: show error and mark fields
|
| 365 |
-
self.norm_percentile_low_box.setStyleSheet("border: 1px solid red;")
|
| 366 |
-
self.norm_percentile_high_box.setStyleSheet("border: 1px solid red;")
|
| 367 |
-
else:
|
| 368 |
-
# Valid: clear style
|
| 369 |
-
self.norm_percentile_low_box.setStyleSheet("")
|
| 370 |
-
self.norm_percentile_high_box.setStyleSheet("")
|
| 371 |
-
|
| 372 |
-
@property
|
| 373 |
-
def low_percentile(self):
|
| 374 |
-
""" Also validate the low input by returning 1.0 if text doesn't work """
|
| 375 |
-
low_text = self.norm_percentile_low_box.text()
|
| 376 |
-
if not low_text or low_text.isspace():
|
| 377 |
-
self.norm_percentile_low_box.setText('1.0')
|
| 378 |
-
low_text = '1.0'
|
| 379 |
-
return float(self.norm_percentile_low_box.text())
|
| 380 |
-
|
| 381 |
-
@property
|
| 382 |
-
def high_percentile(self):
|
| 383 |
-
""" Also validate the high input by returning 99.0 if text doesn't work """
|
| 384 |
-
high_text = self.norm_percentile_high_box.text()
|
| 385 |
-
if not high_text or high_text.isspace():
|
| 386 |
-
self.norm_percentile_high_box.setText('99.0')
|
| 387 |
-
high_text = '99.0'
|
| 388 |
-
return float(self.norm_percentile_high_box.text())
|
| 389 |
-
|
| 390 |
-
@property
|
| 391 |
-
def diameter(self):
|
| 392 |
-
""" Get the diameter from the diameter box, if box isn't a number return None"""
|
| 393 |
-
try:
|
| 394 |
-
d = float(self.diameter_box.text())
|
| 395 |
-
except ValueError:
|
| 396 |
-
d = None
|
| 397 |
-
return d
|
| 398 |
-
|
| 399 |
-
@property
|
| 400 |
-
def flow_threshold(self):
|
| 401 |
-
return float(self.flow_threshold_box.text())
|
| 402 |
-
|
| 403 |
-
@property
|
| 404 |
-
def cellprob_threshold(self):
|
| 405 |
-
return float(self.cellprob_threshold_box.text())
|
| 406 |
-
|
| 407 |
-
@property
|
| 408 |
-
def niter(self):
|
| 409 |
-
num = int(self.niter_box.text())
|
| 410 |
-
if num < 1:
|
| 411 |
-
self.niter_box.setText('200')
|
| 412 |
-
return 200
|
| 413 |
-
else:
|
| 414 |
-
return num
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
class TrainWindow(QDialog):
|
| 419 |
-
|
| 420 |
-
def __init__(self, parent, model_strings):
|
| 421 |
-
super().__init__(parent)
|
| 422 |
-
self.setGeometry(100, 100, 900, 550)
|
| 423 |
-
self.setWindowTitle("train settings")
|
| 424 |
-
self.win = QWidget(self)
|
| 425 |
-
self.l0 = QGridLayout()
|
| 426 |
-
self.win.setLayout(self.l0)
|
| 427 |
-
|
| 428 |
-
yoff = 0
|
| 429 |
-
qlabel = QLabel("train model w/ images + _seg.npy in current folder >>")
|
| 430 |
-
qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
|
| 431 |
-
|
| 432 |
-
qlabel.setAlignment(QtCore.Qt.AlignVCenter)
|
| 433 |
-
self.l0.addWidget(qlabel, yoff, 0, 1, 2)
|
| 434 |
-
|
| 435 |
-
# choose initial model
|
| 436 |
-
yoff += 1
|
| 437 |
-
self.ModelChoose = QComboBox()
|
| 438 |
-
self.ModelChoose.addItems(model_strings)
|
| 439 |
-
self.ModelChoose.setFixedWidth(150)
|
| 440 |
-
self.ModelChoose.setCurrentIndex(parent.training_params["model_index"])
|
| 441 |
-
self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1)
|
| 442 |
-
qlabel = QLabel("initial model: ")
|
| 443 |
-
qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 444 |
-
self.l0.addWidget(qlabel, yoff, 0, 1, 1)
|
| 445 |
-
|
| 446 |
-
# choose parameters
|
| 447 |
-
labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"]
|
| 448 |
-
self.edits = []
|
| 449 |
-
yoff += 1
|
| 450 |
-
for i, label in enumerate(labels):
|
| 451 |
-
qlabel = QLabel(label)
|
| 452 |
-
qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 453 |
-
self.l0.addWidget(qlabel, i + yoff, 0, 1, 1)
|
| 454 |
-
self.edits.append(QLineEdit())
|
| 455 |
-
self.edits[-1].setText(str(parent.training_params[label]))
|
| 456 |
-
self.edits[-1].setFixedWidth(200)
|
| 457 |
-
self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1)
|
| 458 |
-
|
| 459 |
-
yoff += len(labels)
|
| 460 |
-
|
| 461 |
-
yoff += 1
|
| 462 |
-
self.use_norm = QCheckBox(f"use restored/filtered image")
|
| 463 |
-
self.use_norm.setChecked(True)
|
| 464 |
-
|
| 465 |
-
yoff += 2
|
| 466 |
-
qlabel = QLabel(
|
| 467 |
-
"(to remove files, click cancel then remove \nfrom folder and reopen train window)"
|
| 468 |
-
)
|
| 469 |
-
self.l0.addWidget(qlabel, yoff, 0, 2, 4)
|
| 470 |
-
|
| 471 |
-
# click button
|
| 472 |
-
yoff += 3
|
| 473 |
-
QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
|
| 474 |
-
self.buttonBox = QDialogButtonBox(QBtn)
|
| 475 |
-
self.buttonBox.accepted.connect(lambda: self.accept(parent))
|
| 476 |
-
self.buttonBox.rejected.connect(self.reject)
|
| 477 |
-
self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4)
|
| 478 |
-
|
| 479 |
-
# list files in folder
|
| 480 |
-
qlabel = QLabel("filenames")
|
| 481 |
-
qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
|
| 482 |
-
self.l0.addWidget(qlabel, 0, 4, 1, 1)
|
| 483 |
-
qlabel = QLabel("# of masks")
|
| 484 |
-
qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
|
| 485 |
-
self.l0.addWidget(qlabel, 0, 5, 1, 1)
|
| 486 |
-
|
| 487 |
-
for i in range(10):
|
| 488 |
-
if i > len(parent.train_files) - 1:
|
| 489 |
-
break
|
| 490 |
-
elif i == 9 and len(parent.train_files) > 10:
|
| 491 |
-
label = "..."
|
| 492 |
-
nmasks = "..."
|
| 493 |
-
else:
|
| 494 |
-
label = os.path.split(parent.train_files[i])[-1]
|
| 495 |
-
nmasks = str(parent.train_labels[i].max())
|
| 496 |
-
qlabel = QLabel(label)
|
| 497 |
-
self.l0.addWidget(qlabel, i + 1, 4, 1, 1)
|
| 498 |
-
qlabel = QLabel(nmasks)
|
| 499 |
-
qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 500 |
-
self.l0.addWidget(qlabel, i + 1, 5, 1, 1)
|
| 501 |
-
|
| 502 |
-
def accept(self, parent):
|
| 503 |
-
# set training params
|
| 504 |
-
parent.training_params = {
|
| 505 |
-
"model_index": self.ModelChoose.currentIndex(),
|
| 506 |
-
"learning_rate": float(self.edits[0].text()),
|
| 507 |
-
"weight_decay": float(self.edits[1].text()),
|
| 508 |
-
"n_epochs": int(self.edits[2].text()),
|
| 509 |
-
"model_name": self.edits[3].text(),
|
| 510 |
-
#"use_norm": True if self.use_norm.isChecked() else False,
|
| 511 |
-
}
|
| 512 |
-
self.done(1)
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
class ExampleGUI(QDialog):
|
| 516 |
-
|
| 517 |
-
def __init__(self, parent=None):
|
| 518 |
-
super(ExampleGUI, self).__init__(parent)
|
| 519 |
-
self.setGeometry(100, 100, 1300, 900)
|
| 520 |
-
self.setWindowTitle("GUI layout")
|
| 521 |
-
self.win = QWidget(self)
|
| 522 |
-
layout = QGridLayout()
|
| 523 |
-
self.win.setLayout(layout)
|
| 524 |
-
guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
|
| 525 |
-
guip_path = str(guip_path.resolve())
|
| 526 |
-
pixmap = QPixmap(guip_path)
|
| 527 |
-
label = QLabel(self)
|
| 528 |
-
label.setPixmap(pixmap)
|
| 529 |
-
pixmap.scaled
|
| 530 |
-
layout.addWidget(label, 0, 0, 1, 1)
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
class HelpWindow(QDialog):
|
| 534 |
-
|
| 535 |
-
def __init__(self, parent=None):
|
| 536 |
-
super(HelpWindow, self).__init__(parent)
|
| 537 |
-
self.setGeometry(100, 50, 700, 1000)
|
| 538 |
-
self.setWindowTitle("cellpose help")
|
| 539 |
-
self.win = QWidget(self)
|
| 540 |
-
layout = QGridLayout()
|
| 541 |
-
self.win.setLayout(layout)
|
| 542 |
-
|
| 543 |
-
text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html")
|
| 544 |
-
with open(str(text_file.resolve()), "r") as f:
|
| 545 |
-
text = f.read()
|
| 546 |
-
|
| 547 |
-
label = QLabel(text)
|
| 548 |
-
label.setFont(QtGui.QFont("Arial", 8))
|
| 549 |
-
label.setWordWrap(True)
|
| 550 |
-
layout.addWidget(label, 0, 0, 1, 1)
|
| 551 |
-
self.show()
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
class TrainHelpWindow(QDialog):
|
| 555 |
-
|
| 556 |
-
def __init__(self, parent=None):
|
| 557 |
-
super(TrainHelpWindow, self).__init__(parent)
|
| 558 |
-
self.setGeometry(100, 50, 700, 300)
|
| 559 |
-
self.setWindowTitle("training instructions")
|
| 560 |
-
self.win = QWidget(self)
|
| 561 |
-
layout = QGridLayout()
|
| 562 |
-
self.win.setLayout(layout)
|
| 563 |
-
|
| 564 |
-
text_file = pathlib.Path(__file__).parent.joinpath(
|
| 565 |
-
"guitrainhelpwindowtext.html")
|
| 566 |
-
with open(str(text_file.resolve()), "r") as f:
|
| 567 |
-
text = f.read()
|
| 568 |
-
|
| 569 |
-
label = QLabel(text)
|
| 570 |
-
label.setFont(QtGui.QFont("Arial", 8))
|
| 571 |
-
label.setWordWrap(True)
|
| 572 |
-
layout.addWidget(label, 0, 0, 1, 1)
|
| 573 |
-
self.show()
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
class ViewBoxNoRightDrag(pg.ViewBox):
|
| 577 |
-
|
| 578 |
-
def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True,
|
| 579 |
-
invertY=False, enableMenu=True, name=None, invertX=False):
|
| 580 |
-
pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY,
|
| 581 |
-
enableMenu, name, invertX)
|
| 582 |
-
self.parent = parent
|
| 583 |
-
self.axHistoryPointer = -1
|
| 584 |
-
|
| 585 |
-
def keyPressEvent(self, ev):
|
| 586 |
-
"""
|
| 587 |
-
This routine should capture key presses in the current view box.
|
| 588 |
-
The following events are implemented:
|
| 589 |
-
+/= : moves forward in the zooming stack (if it exists)
|
| 590 |
-
- : moves backward in the zooming stack (if it exists)
|
| 591 |
-
|
| 592 |
-
"""
|
| 593 |
-
ev.accept()
|
| 594 |
-
if ev.text() == "-":
|
| 595 |
-
self.scaleBy([1.1, 1.1])
|
| 596 |
-
elif ev.text() in ["+", "="]:
|
| 597 |
-
self.scaleBy([0.9, 0.9])
|
| 598 |
-
else:
|
| 599 |
-
ev.ignore()
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
class ImageDraw(pg.ImageItem):
|
| 603 |
-
"""
|
| 604 |
-
**Bases:** :class:`GraphicsObject <pyqtgraph.GraphicsObject>`
|
| 605 |
-
GraphicsObject displaying an image. Optimized for rapid update (ie video display).
|
| 606 |
-
This item displays either a 2D numpy array (height, width) or
|
| 607 |
-
a 3D array (height, width, RGBa). This array is optionally scaled (see
|
| 608 |
-
:func:`setLevels <pyqtgraph.ImageItem.setLevels>`) and/or colored
|
| 609 |
-
with a lookup table (see :func:`setLookupTable <pyqtgraph.ImageItem.setLookupTable>`)
|
| 610 |
-
before being displayed.
|
| 611 |
-
ImageItem is frequently used in conjunction with
|
| 612 |
-
:class:`HistogramLUTItem <pyqtgraph.HistogramLUTItem>` or
|
| 613 |
-
:class:`HistogramLUTWidget <pyqtgraph.HistogramLUTWidget>` to provide a GUI
|
| 614 |
-
for controlling the levels and lookup table used to display the image.
|
| 615 |
-
"""
|
| 616 |
-
|
| 617 |
-
sigImageChanged = QtCore.Signal()
|
| 618 |
-
|
| 619 |
-
def __init__(self, image=None, viewbox=None, parent=None, **kargs):
|
| 620 |
-
super(ImageDraw, self).__init__()
|
| 621 |
-
self.levels = np.array([0, 255])
|
| 622 |
-
self.lut = None
|
| 623 |
-
self.autoDownsample = False
|
| 624 |
-
self.axisOrder = "row-major"
|
| 625 |
-
self.removable = False
|
| 626 |
-
|
| 627 |
-
self.parent = parent
|
| 628 |
-
self.setDrawKernel(kernel_size=self.parent.brush_size)
|
| 629 |
-
self.parent.current_stroke = []
|
| 630 |
-
self.parent.in_stroke = False
|
| 631 |
-
|
| 632 |
-
def mouseClickEvent(self, ev):
|
| 633 |
-
if (self.parent.masksOn or
|
| 634 |
-
self.parent.outlinesOn) and not self.parent.removing_region:
|
| 635 |
-
is_right_click = ev.button() == QtCore.Qt.RightButton
|
| 636 |
-
if self.parent.loaded \
|
| 637 |
-
and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\
|
| 638 |
-
and not self.parent.deleting_multiple:
|
| 639 |
-
if not self.parent.in_stroke:
|
| 640 |
-
ev.accept()
|
| 641 |
-
self.create_start(ev.pos())
|
| 642 |
-
self.parent.stroke_appended = False
|
| 643 |
-
self.parent.in_stroke = True
|
| 644 |
-
self.drawAt(ev.pos(), ev)
|
| 645 |
-
else:
|
| 646 |
-
ev.accept()
|
| 647 |
-
self.end_stroke()
|
| 648 |
-
self.parent.in_stroke = False
|
| 649 |
-
elif not self.parent.in_stroke:
|
| 650 |
-
y, x = int(ev.pos().y()), int(ev.pos().x())
|
| 651 |
-
if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx:
|
| 652 |
-
if ev.button() == QtCore.Qt.LeftButton and not ev.double():
|
| 653 |
-
idx = self.parent.cellpix[self.parent.currentZ][y, x]
|
| 654 |
-
if idx > 0:
|
| 655 |
-
if ev.modifiers() & QtCore.Qt.ControlModifier:
|
| 656 |
-
# delete mask selected
|
| 657 |
-
self.parent.remove_cell(idx)
|
| 658 |
-
elif ev.modifiers() & QtCore.Qt.AltModifier:
|
| 659 |
-
self.parent.merge_cells(idx)
|
| 660 |
-
elif self.parent.masksOn and not self.parent.deleting_multiple:
|
| 661 |
-
self.parent.unselect_cell()
|
| 662 |
-
self.parent.select_cell(idx)
|
| 663 |
-
elif self.parent.deleting_multiple:
|
| 664 |
-
if idx in self.parent.removing_cells_list:
|
| 665 |
-
self.parent.unselect_cell_multi(idx)
|
| 666 |
-
self.parent.removing_cells_list.remove(idx)
|
| 667 |
-
else:
|
| 668 |
-
self.parent.select_cell_multi(idx)
|
| 669 |
-
self.parent.removing_cells_list.append(idx)
|
| 670 |
-
|
| 671 |
-
elif self.parent.masksOn and not self.parent.deleting_multiple:
|
| 672 |
-
self.parent.unselect_cell()
|
| 673 |
-
|
| 674 |
-
def mouseDragEvent(self, ev):
|
| 675 |
-
ev.ignore()
|
| 676 |
-
return
|
| 677 |
-
|
| 678 |
-
def hoverEvent(self, ev):
|
| 679 |
-
if self.parent.in_stroke:
|
| 680 |
-
if self.parent.in_stroke:
|
| 681 |
-
# continue stroke if not at start
|
| 682 |
-
self.drawAt(ev.pos())
|
| 683 |
-
if self.is_at_start(ev.pos()):
|
| 684 |
-
self.end_stroke()
|
| 685 |
-
else:
|
| 686 |
-
ev.acceptClicks(QtCore.Qt.RightButton)
|
| 687 |
-
|
| 688 |
-
def create_start(self, pos):
|
| 689 |
-
self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False,
|
| 690 |
-
pen=pg.mkPen(color=(255, 0, 0),
|
| 691 |
-
width=self.parent.brush_size),
|
| 692 |
-
size=max(3 * 2,
|
| 693 |
-
self.parent.brush_size * 1.8 * 2),
|
| 694 |
-
brush=None)
|
| 695 |
-
self.parent.p0.addItem(self.scatter)
|
| 696 |
-
|
| 697 |
-
def is_at_start(self, pos):
|
| 698 |
-
thresh_out = max(6, self.parent.brush_size * 3)
|
| 699 |
-
thresh_in = max(3, self.parent.brush_size * 1.8)
|
| 700 |
-
# first check if you ever left the start
|
| 701 |
-
if len(self.parent.current_stroke) > 3:
|
| 702 |
-
stroke = np.array(self.parent.current_stroke)
|
| 703 |
-
dist = (((stroke[1:, 1:] -
|
| 704 |
-
stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5
|
| 705 |
-
dist = dist.flatten()
|
| 706 |
-
has_left = (dist > thresh_out).nonzero()[0]
|
| 707 |
-
if len(has_left) > 0:
|
| 708 |
-
first_left = np.sort(has_left)[0]
|
| 709 |
-
has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum()
|
| 710 |
-
if has_returned > 0:
|
| 711 |
-
return True
|
| 712 |
-
else:
|
| 713 |
-
return False
|
| 714 |
-
else:
|
| 715 |
-
return False
|
| 716 |
-
|
| 717 |
-
def end_stroke(self):
|
| 718 |
-
self.parent.p0.removeItem(self.scatter)
|
| 719 |
-
if not self.parent.stroke_appended:
|
| 720 |
-
self.parent.strokes.append(self.parent.current_stroke)
|
| 721 |
-
self.parent.stroke_appended = True
|
| 722 |
-
self.parent.current_stroke = np.array(self.parent.current_stroke)
|
| 723 |
-
ioutline = self.parent.current_stroke[:, 3] == 1
|
| 724 |
-
self.parent.current_point_set.append(
|
| 725 |
-
list(self.parent.current_stroke[ioutline]))
|
| 726 |
-
self.parent.current_stroke = []
|
| 727 |
-
if self.parent.autosave:
|
| 728 |
-
self.parent.add_set()
|
| 729 |
-
if len(self.parent.current_point_set) and len(
|
| 730 |
-
self.parent.current_point_set[0]) > 0 and self.parent.autosave:
|
| 731 |
-
self.parent.add_set()
|
| 732 |
-
self.parent.in_stroke = False
|
| 733 |
-
|
| 734 |
-
def tabletEvent(self, ev):
|
| 735 |
-
pass
|
| 736 |
-
|
| 737 |
-
def drawAt(self, pos, ev=None):
|
| 738 |
-
mask = self.strokemask
|
| 739 |
-
stroke = self.parent.current_stroke
|
| 740 |
-
pos = [int(pos.y()), int(pos.x())]
|
| 741 |
-
dk = self.drawKernel
|
| 742 |
-
kc = self.drawKernelCenter
|
| 743 |
-
sx = [0, dk.shape[0]]
|
| 744 |
-
sy = [0, dk.shape[1]]
|
| 745 |
-
tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]]
|
| 746 |
-
ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]]
|
| 747 |
-
kcent = kc.copy()
|
| 748 |
-
if tx[0] <= 0:
|
| 749 |
-
sx[0] = 0
|
| 750 |
-
sx[1] = kc[0] + 1
|
| 751 |
-
tx = sx
|
| 752 |
-
kcent[0] = 0
|
| 753 |
-
if ty[0] <= 0:
|
| 754 |
-
sy[0] = 0
|
| 755 |
-
sy[1] = kc[1] + 1
|
| 756 |
-
ty = sy
|
| 757 |
-
kcent[1] = 0
|
| 758 |
-
if tx[1] >= self.parent.Ly - 1:
|
| 759 |
-
sx[0] = dk.shape[0] - kc[0] - 1
|
| 760 |
-
sx[1] = dk.shape[0]
|
| 761 |
-
tx[0] = self.parent.Ly - kc[0] - 1
|
| 762 |
-
tx[1] = self.parent.Ly
|
| 763 |
-
kcent[0] = tx[1] - tx[0] - 1
|
| 764 |
-
if ty[1] >= self.parent.Lx - 1:
|
| 765 |
-
sy[0] = dk.shape[1] - kc[1] - 1
|
| 766 |
-
sy[1] = dk.shape[1]
|
| 767 |
-
ty[0] = self.parent.Lx - kc[1] - 1
|
| 768 |
-
ty[1] = self.parent.Lx
|
| 769 |
-
kcent[1] = ty[1] - ty[0] - 1
|
| 770 |
-
|
| 771 |
-
ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1]))
|
| 772 |
-
ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1]))
|
| 773 |
-
self.image[ts] = mask[ss]
|
| 774 |
-
|
| 775 |
-
for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)):
|
| 776 |
-
for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)):
|
| 777 |
-
iscent = np.logical_and(kx == kcent[0], ky == kcent[1])
|
| 778 |
-
stroke.append([self.parent.currentZ, x, y, iscent])
|
| 779 |
-
self.updateImage()
|
| 780 |
-
|
| 781 |
-
def setDrawKernel(self, kernel_size=3):
|
| 782 |
-
bs = kernel_size
|
| 783 |
-
kernel = np.ones((bs, bs), np.uint8)
|
| 784 |
-
self.drawKernel = kernel
|
| 785 |
-
self.drawKernelCenter = [
|
| 786 |
-
int(np.floor(kernel.shape[0] / 2)),
|
| 787 |
-
int(np.floor(kernel.shape[1] / 2))
|
| 788 |
-
]
|
| 789 |
-
onmask = 255 * kernel[:, :, np.newaxis]
|
| 790 |
-
offmask = np.zeros((bs, bs, 1))
|
| 791 |
-
opamask = 100 * kernel[:, :, np.newaxis]
|
| 792 |
-
self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1)
|
| 793 |
-
self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
<qt>
|
| 2 |
-
Check out this <a href="https://youtu.be/3Y1VKcxjNy4">video</a> to learn the process.
|
| 3 |
-
<ol>
|
| 4 |
-
<li>Drag and drop an image from a folder of images with a similar style (like similar cell types).</li>
|
| 5 |
-
<li>Run the built-in models on one of the images using the "model zoo" and find the one that works best for your
|
| 6 |
-
data. Make sure that if you have a nuclear channel you have selected it for CHAN2.
|
| 7 |
-
</li>
|
| 8 |
-
<li>Fix the labelling by drawing new ROIs (right-click) and deleting incorrect ones (CTRL+click). The GUI
|
| 9 |
-
autosaves any manual changes (but does not autosave after running the model, for that click CTRL+S). The
|
| 10 |
-
segmentation is saved in a "_seg.npy" file.
|
| 11 |
-
</li>
|
| 12 |
-
<li> Go to the "Models" menu in the File bar at the top and click "Train new model..." or use shortcut CTRL+T.
|
| 13 |
-
</li>
|
| 14 |
-
<li> Choose the pretrained model to start the training from (the model you used in #2), and type in the model
|
| 15 |
-
name that you want to use. The other parameters should work well in general for most data types. Then click
|
| 16 |
-
OK.
|
| 17 |
-
</li>
|
| 18 |
-
<li> The model will train (much faster if you have a GPU) and then auto-run on the next image in the folder.
|
| 19 |
-
Next you can repeat #3-#5 as many times as is necessary.
|
| 20 |
-
</li>
|
| 21 |
-
<li> The trained model is available to use in the future in the GUI in the "custom model" section and is saved
|
| 22 |
-
in your image folder.
|
| 23 |
-
</li>
|
| 24 |
-
</ol>
|
| 25 |
-
</qt>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/io.py
DELETED
|
@@ -1,634 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
import os, gc
|
| 5 |
-
import numpy as np
|
| 6 |
-
import cv2
|
| 7 |
-
import fastremap
|
| 8 |
-
|
| 9 |
-
from ..io import imread, imread_2D, imread_3D, imsave, outlines_to_text, add_model, remove_model, save_rois
|
| 10 |
-
from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models
|
| 11 |
-
from ..utils import masks_to_outlines, outlines_list
|
| 12 |
-
|
| 13 |
-
try:
|
| 14 |
-
import qtpy
|
| 15 |
-
from qtpy.QtWidgets import QFileDialog
|
| 16 |
-
GUI = True
|
| 17 |
-
except:
|
| 18 |
-
GUI = False
|
| 19 |
-
|
| 20 |
-
try:
|
| 21 |
-
import matplotlib.pyplot as plt
|
| 22 |
-
MATPLOTLIB = True
|
| 23 |
-
except:
|
| 24 |
-
MATPLOTLIB = False
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def _init_model_list(parent):
|
| 28 |
-
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 29 |
-
parent.model_list_path = MODEL_LIST_PATH
|
| 30 |
-
parent.model_strings = get_user_models()
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _add_model(parent, filename=None, load_model=True):
|
| 34 |
-
if filename is None:
|
| 35 |
-
name = QFileDialog.getOpenFileName(parent, "Add model to GUI")
|
| 36 |
-
filename = name[0]
|
| 37 |
-
add_model(filename)
|
| 38 |
-
fname = os.path.split(filename)[-1]
|
| 39 |
-
parent.ModelChooseC.addItems([fname])
|
| 40 |
-
parent.model_strings.append(fname)
|
| 41 |
-
|
| 42 |
-
for ind, model_string in enumerate(parent.model_strings[:-1]):
|
| 43 |
-
if model_string == fname:
|
| 44 |
-
_remove_model(parent, ind=ind + 1, verbose=False)
|
| 45 |
-
|
| 46 |
-
parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
|
| 47 |
-
if load_model:
|
| 48 |
-
parent.model_choose(custom=True)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def _remove_model(parent, ind=None, verbose=True):
|
| 52 |
-
if ind is None:
|
| 53 |
-
ind = parent.ModelChooseC.currentIndex()
|
| 54 |
-
if ind > 0:
|
| 55 |
-
ind -= 1
|
| 56 |
-
parent.ModelChooseC.removeItem(ind + 1)
|
| 57 |
-
del parent.model_strings[ind]
|
| 58 |
-
# remove model from txt path
|
| 59 |
-
modelstr = parent.ModelChooseC.currentText()
|
| 60 |
-
remove_model(modelstr)
|
| 61 |
-
if len(parent.model_strings) > 0:
|
| 62 |
-
parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
|
| 63 |
-
else:
|
| 64 |
-
parent.ModelChooseC.setCurrentIndex(0)
|
| 65 |
-
else:
|
| 66 |
-
print("ERROR: no model selected to delete")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def _get_train_set(image_names):
|
| 70 |
-
""" get training data and labels for images in current folder image_names"""
|
| 71 |
-
train_data, train_labels, train_files = [], [], []
|
| 72 |
-
restore = None
|
| 73 |
-
normalize_params = normalize_default
|
| 74 |
-
for image_name_full in image_names:
|
| 75 |
-
image_name = os.path.splitext(image_name_full)[0]
|
| 76 |
-
label_name = None
|
| 77 |
-
if os.path.exists(image_name + "_seg.npy"):
|
| 78 |
-
dat = np.load(image_name + "_seg.npy", allow_pickle=True).item()
|
| 79 |
-
masks = dat["masks"].squeeze()
|
| 80 |
-
if masks.ndim == 2:
|
| 81 |
-
fastremap.renumber(masks, in_place=True)
|
| 82 |
-
label_name = image_name + "_seg.npy"
|
| 83 |
-
else:
|
| 84 |
-
print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2")
|
| 85 |
-
if "img_restore" in dat:
|
| 86 |
-
data = dat["img_restore"].squeeze()
|
| 87 |
-
restore = dat["restore"]
|
| 88 |
-
else:
|
| 89 |
-
data = imread(image_name_full)
|
| 90 |
-
normalize_params = dat[
|
| 91 |
-
"normalize_params"] if "normalize_params" in dat else normalize_default
|
| 92 |
-
if label_name is not None:
|
| 93 |
-
train_files.append(image_name_full)
|
| 94 |
-
train_data.append(data)
|
| 95 |
-
train_labels.append(masks)
|
| 96 |
-
if restore:
|
| 97 |
-
print(f"GUI_INFO: using {restore} images (dat['img_restore'])")
|
| 98 |
-
return train_data, train_labels, train_files, restore, normalize_params
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def _load_image(parent, filename=None, load_seg=True, load_3D=False):
|
| 102 |
-
""" load image with filename; if None, open QFileDialog
|
| 103 |
-
if image is grey change view to default to grey scale
|
| 104 |
-
"""
|
| 105 |
-
|
| 106 |
-
if parent.load_3D:
|
| 107 |
-
load_3D = True
|
| 108 |
-
|
| 109 |
-
if filename is None:
|
| 110 |
-
name = QFileDialog.getOpenFileName(parent, "Load image")
|
| 111 |
-
filename = name[0]
|
| 112 |
-
if filename == "":
|
| 113 |
-
return
|
| 114 |
-
manual_file = os.path.splitext(filename)[0] + "_seg.npy"
|
| 115 |
-
load_mask = False
|
| 116 |
-
if load_seg:
|
| 117 |
-
if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked():
|
| 118 |
-
if filename is not None:
|
| 119 |
-
image = (imread_2D(filename) if not load_3D else
|
| 120 |
-
imread_3D(filename))
|
| 121 |
-
else:
|
| 122 |
-
image = None
|
| 123 |
-
_load_seg(parent, manual_file, image=image, image_file=filename,
|
| 124 |
-
load_3D=load_3D)
|
| 125 |
-
return
|
| 126 |
-
elif parent.autoloadMasks.isChecked():
|
| 127 |
-
mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext(
|
| 128 |
-
filename)[-1]
|
| 129 |
-
mask_file = os.path.splitext(filename)[
|
| 130 |
-
0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file
|
| 131 |
-
load_mask = True if os.path.isfile(mask_file) else False
|
| 132 |
-
try:
|
| 133 |
-
print(f"GUI_INFO: loading image: {filename}")
|
| 134 |
-
if not load_3D:
|
| 135 |
-
image = imread_2D(filename)
|
| 136 |
-
else:
|
| 137 |
-
image = imread_3D(filename)
|
| 138 |
-
parent.loaded = True
|
| 139 |
-
except Exception as e:
|
| 140 |
-
print("ERROR: images not compatible")
|
| 141 |
-
print(f"ERROR: {e}")
|
| 142 |
-
|
| 143 |
-
if parent.loaded:
|
| 144 |
-
parent.reset()
|
| 145 |
-
parent.filename = filename
|
| 146 |
-
filename = os.path.split(parent.filename)[-1]
|
| 147 |
-
_initialize_images(parent, image, load_3D=load_3D)
|
| 148 |
-
parent.loaded = True
|
| 149 |
-
parent.enable_buttons()
|
| 150 |
-
if load_mask:
|
| 151 |
-
_load_masks(parent, filename=mask_file)
|
| 152 |
-
|
| 153 |
-
# check if gray and adjust viewer:
|
| 154 |
-
if len(np.unique(image[..., 1:])) == 1:
|
| 155 |
-
parent.color = 4
|
| 156 |
-
parent.RGBDropDown.setCurrentIndex(4) # gray
|
| 157 |
-
parent.update_plot()
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def _initialize_images(parent, image, load_3D=False):
|
| 161 |
-
""" format image for GUI
|
| 162 |
-
|
| 163 |
-
assumes image is Z x W x H x C
|
| 164 |
-
|
| 165 |
-
"""
|
| 166 |
-
load_3D = parent.load_3D if load_3D is False else load_3D
|
| 167 |
-
|
| 168 |
-
parent.stack = image
|
| 169 |
-
print(f"GUI_INFO: image shape: {image.shape}")
|
| 170 |
-
if load_3D:
|
| 171 |
-
parent.NZ = len(parent.stack)
|
| 172 |
-
parent.scroll.setMaximum(parent.NZ - 1)
|
| 173 |
-
else:
|
| 174 |
-
parent.NZ = 1
|
| 175 |
-
parent.stack = parent.stack[np.newaxis, ...]
|
| 176 |
-
|
| 177 |
-
img_min = image.min()
|
| 178 |
-
img_max = image.max()
|
| 179 |
-
parent.stack = parent.stack.astype(np.float32)
|
| 180 |
-
parent.stack -= img_min
|
| 181 |
-
if img_max > img_min + 1e-3:
|
| 182 |
-
parent.stack /= (img_max - img_min)
|
| 183 |
-
parent.stack *= 255
|
| 184 |
-
|
| 185 |
-
if load_3D:
|
| 186 |
-
print("GUI_INFO: converted to float and normalized values to 0.0->255.0")
|
| 187 |
-
|
| 188 |
-
del image
|
| 189 |
-
gc.collect()
|
| 190 |
-
|
| 191 |
-
parent.imask = 0
|
| 192 |
-
parent.Ly, parent.Lx = parent.stack.shape[-3:-1]
|
| 193 |
-
parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1]
|
| 194 |
-
parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8")
|
| 195 |
-
if hasattr(parent, "stack_filtered"):
|
| 196 |
-
parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1]
|
| 197 |
-
elif parent.restore and "upsample" in parent.restore:
|
| 198 |
-
parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx *
|
| 199 |
-
parent.ratio)
|
| 200 |
-
else:
|
| 201 |
-
parent.Lyr, parent.Lxr = parent.Ly, parent.Lx
|
| 202 |
-
parent.clear_all()
|
| 203 |
-
|
| 204 |
-
if not hasattr(parent, "stack_filtered") and parent.restore:
|
| 205 |
-
print("GUI_INFO: no 'img_restore' found, applying current settings")
|
| 206 |
-
parent.compute_restore()
|
| 207 |
-
|
| 208 |
-
if parent.autobtn.isChecked():
|
| 209 |
-
if parent.restore is None or parent.restore != "filter":
|
| 210 |
-
print(
|
| 211 |
-
"GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
|
| 212 |
-
)
|
| 213 |
-
parent.compute_saturation()
|
| 214 |
-
# elif len(parent.saturation) != parent.NZ:
|
| 215 |
-
# parent.saturation = []
|
| 216 |
-
# for r in range(3):
|
| 217 |
-
# parent.saturation.append([])
|
| 218 |
-
# for n in range(parent.NZ):
|
| 219 |
-
# parent.saturation[-1].append([0, 255])
|
| 220 |
-
# parent.sliders[r].setValue([0, 255])
|
| 221 |
-
parent.compute_scale()
|
| 222 |
-
parent.track_changes = []
|
| 223 |
-
|
| 224 |
-
if load_3D:
|
| 225 |
-
parent.currentZ = int(np.floor(parent.NZ / 2))
|
| 226 |
-
parent.scroll.setValue(parent.currentZ)
|
| 227 |
-
parent.zpos.setText(str(parent.currentZ))
|
| 228 |
-
else:
|
| 229 |
-
parent.currentZ = 0
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False):
|
| 233 |
-
""" load *_seg.npy with filename; if None, open QFileDialog """
|
| 234 |
-
if filename is None:
|
| 235 |
-
name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy")
|
| 236 |
-
filename = name[0]
|
| 237 |
-
try:
|
| 238 |
-
dat = np.load(filename, allow_pickle=True).item()
|
| 239 |
-
# check if there are keys in filename
|
| 240 |
-
dat["outlines"]
|
| 241 |
-
parent.loaded = True
|
| 242 |
-
except:
|
| 243 |
-
parent.loaded = False
|
| 244 |
-
print("ERROR: not NPY")
|
| 245 |
-
return
|
| 246 |
-
|
| 247 |
-
parent.reset()
|
| 248 |
-
if image is None:
|
| 249 |
-
found_image = False
|
| 250 |
-
if "filename" in dat:
|
| 251 |
-
parent.filename = dat["filename"]
|
| 252 |
-
if os.path.isfile(parent.filename):
|
| 253 |
-
parent.filename = dat["filename"]
|
| 254 |
-
found_image = True
|
| 255 |
-
else:
|
| 256 |
-
imgname = os.path.split(parent.filename)[1]
|
| 257 |
-
root = os.path.split(filename)[0]
|
| 258 |
-
parent.filename = root + "/" + imgname
|
| 259 |
-
if os.path.isfile(parent.filename):
|
| 260 |
-
found_image = True
|
| 261 |
-
if found_image:
|
| 262 |
-
try:
|
| 263 |
-
print(parent.filename)
|
| 264 |
-
image = (imread_2D(parent.filename) if not load_3D else
|
| 265 |
-
imread_3D(parent.filename))
|
| 266 |
-
except:
|
| 267 |
-
parent.loaded = False
|
| 268 |
-
found_image = False
|
| 269 |
-
print("ERROR: cannot find image file, loading from npy")
|
| 270 |
-
if not found_image:
|
| 271 |
-
parent.filename = filename[:-8]
|
| 272 |
-
print(parent.filename)
|
| 273 |
-
if "img" in dat:
|
| 274 |
-
image = dat["img"]
|
| 275 |
-
else:
|
| 276 |
-
print("ERROR: no image file found and no image in npy")
|
| 277 |
-
return
|
| 278 |
-
else:
|
| 279 |
-
parent.filename = image_file
|
| 280 |
-
|
| 281 |
-
parent.restore = None
|
| 282 |
-
parent.ratio = 1.
|
| 283 |
-
|
| 284 |
-
if "normalize_params" in dat:
|
| 285 |
-
parent.set_normalize_params(dat["normalize_params"])
|
| 286 |
-
|
| 287 |
-
_initialize_images(parent, image, load_3D=load_3D)
|
| 288 |
-
print(parent.stack.shape)
|
| 289 |
-
|
| 290 |
-
if "outlines" in dat:
|
| 291 |
-
if isinstance(dat["outlines"], list):
|
| 292 |
-
# old way of saving files
|
| 293 |
-
dat["outlines"] = dat["outlines"][::-1]
|
| 294 |
-
for k, outline in enumerate(dat["outlines"]):
|
| 295 |
-
if "colors" in dat:
|
| 296 |
-
color = dat["colors"][k]
|
| 297 |
-
else:
|
| 298 |
-
col_rand = np.random.randint(1000)
|
| 299 |
-
color = parent.colormap[col_rand, :3]
|
| 300 |
-
median = parent.add_mask(points=outline, color=color)
|
| 301 |
-
if median is not None:
|
| 302 |
-
parent.cellcolors = np.append(parent.cellcolors,
|
| 303 |
-
color[np.newaxis, :], axis=0)
|
| 304 |
-
parent.ncells += 1
|
| 305 |
-
else:
|
| 306 |
-
if dat["masks"].min() == -1:
|
| 307 |
-
dat["masks"] += 1
|
| 308 |
-
dat["outlines"] += 1
|
| 309 |
-
parent.ncells.set(dat["masks"].max())
|
| 310 |
-
if "colors" in dat and len(dat["colors"]) == dat["masks"].max():
|
| 311 |
-
colors = dat["colors"]
|
| 312 |
-
else:
|
| 313 |
-
colors = parent.colormap[:parent.ncells.get(), :3]
|
| 314 |
-
|
| 315 |
-
_masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors)
|
| 316 |
-
|
| 317 |
-
parent.draw_layer()
|
| 318 |
-
|
| 319 |
-
if "manual_changes" in dat:
|
| 320 |
-
parent.track_changes = dat["manual_changes"]
|
| 321 |
-
print("GUI_INFO: loaded in previous changes")
|
| 322 |
-
if "zdraw" in dat:
|
| 323 |
-
parent.zdraw = dat["zdraw"]
|
| 324 |
-
else:
|
| 325 |
-
parent.zdraw = [None for n in range(parent.ncells.get())]
|
| 326 |
-
parent.loaded = True
|
| 327 |
-
else:
|
| 328 |
-
parent.clear_all()
|
| 329 |
-
|
| 330 |
-
parent.ismanual = np.zeros(parent.ncells.get(), bool)
|
| 331 |
-
if "ismanual" in dat:
|
| 332 |
-
if len(dat["ismanual"]) == parent.ncells:
|
| 333 |
-
parent.ismanual = dat["ismanual"]
|
| 334 |
-
|
| 335 |
-
if "current_channel" in dat:
|
| 336 |
-
parent.color = (dat["current_channel"] + 2) % 5
|
| 337 |
-
parent.RGBDropDown.setCurrentIndex(parent.color)
|
| 338 |
-
|
| 339 |
-
if "flows" in dat:
|
| 340 |
-
parent.flows = dat["flows"]
|
| 341 |
-
try:
|
| 342 |
-
if parent.flows[0].shape[-3] != dat["masks"].shape[-2]:
|
| 343 |
-
Ly, Lx = dat["masks"].shape[-2:]
|
| 344 |
-
for i in range(len(parent.flows)):
|
| 345 |
-
parent.flows[i] = cv2.resize(
|
| 346 |
-
parent.flows[i].squeeze(), (Lx, Ly),
|
| 347 |
-
interpolation=cv2.INTER_NEAREST)[np.newaxis, ...]
|
| 348 |
-
if parent.NZ == 1:
|
| 349 |
-
parent.recompute_masks = True
|
| 350 |
-
else:
|
| 351 |
-
parent.recompute_masks = False
|
| 352 |
-
|
| 353 |
-
except:
|
| 354 |
-
try:
|
| 355 |
-
if len(parent.flows[0]) > 0:
|
| 356 |
-
parent.flows = parent.flows[0]
|
| 357 |
-
except:
|
| 358 |
-
parent.flows = [[], [], [], [], [[]]]
|
| 359 |
-
parent.recompute_masks = False
|
| 360 |
-
|
| 361 |
-
parent.enable_buttons()
|
| 362 |
-
parent.update_layer()
|
| 363 |
-
del dat
|
| 364 |
-
gc.collect()
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
def _load_masks(parent, filename=None):
|
| 368 |
-
""" load zeros-based masks (0=no cell, 1=cell 1, ...) """
|
| 369 |
-
if filename is None:
|
| 370 |
-
name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)")
|
| 371 |
-
filename = name[0]
|
| 372 |
-
print(f"GUI_INFO: loading masks: {filename}")
|
| 373 |
-
masks = imread(filename)
|
| 374 |
-
outlines = None
|
| 375 |
-
if masks.ndim > 3:
|
| 376 |
-
# Z x nchannels x Ly x Lx
|
| 377 |
-
if masks.shape[-1] > 5:
|
| 378 |
-
parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2)))
|
| 379 |
-
outlines = masks[..., 1]
|
| 380 |
-
masks = masks[..., 0]
|
| 381 |
-
else:
|
| 382 |
-
parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2)))
|
| 383 |
-
masks = masks[..., 0]
|
| 384 |
-
elif masks.ndim == 3:
|
| 385 |
-
if masks.shape[-1] < 5:
|
| 386 |
-
masks = masks[np.newaxis, :, :, 0]
|
| 387 |
-
elif masks.ndim < 3:
|
| 388 |
-
masks = masks[np.newaxis, :, :]
|
| 389 |
-
# masks should be Z x Ly x Lx
|
| 390 |
-
if masks.shape[0] != parent.NZ:
|
| 391 |
-
print("ERROR: masks are not same depth (number of planes) as image stack")
|
| 392 |
-
return
|
| 393 |
-
|
| 394 |
-
_masks_to_gui(parent, masks, outlines)
|
| 395 |
-
if parent.ncells > 0:
|
| 396 |
-
parent.draw_layer()
|
| 397 |
-
parent.toggle_mask_ops()
|
| 398 |
-
del masks
|
| 399 |
-
gc.collect()
|
| 400 |
-
parent.update_layer()
|
| 401 |
-
parent.update_plot()
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
def _masks_to_gui(parent, masks, outlines=None, colors=None):
|
| 405 |
-
""" masks loaded into GUI """
|
| 406 |
-
# get unique values
|
| 407 |
-
shape = masks.shape
|
| 408 |
-
if len(fastremap.unique(masks)) != masks.max() + 1:
|
| 409 |
-
print("GUI_INFO: renumbering masks")
|
| 410 |
-
fastremap.renumber(masks, in_place=True)
|
| 411 |
-
outlines = None
|
| 412 |
-
masks = masks.reshape(shape)
|
| 413 |
-
if masks.ndim == 2:
|
| 414 |
-
outlines = None
|
| 415 |
-
masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype(
|
| 416 |
-
np.uint32)
|
| 417 |
-
if parent.restore and "upsample" in parent.restore:
|
| 418 |
-
parent.cellpix_resize = masks.copy()
|
| 419 |
-
parent.cellpix = parent.cellpix_resize.copy()
|
| 420 |
-
parent.cellpix_orig = cv2.resize(
|
| 421 |
-
masks.squeeze(), (parent.Lx0, parent.Ly0),
|
| 422 |
-
interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
|
| 423 |
-
parent.resize = True
|
| 424 |
-
else:
|
| 425 |
-
parent.cellpix = masks
|
| 426 |
-
if parent.cellpix.ndim == 2:
|
| 427 |
-
parent.cellpix = parent.cellpix[np.newaxis, :, :]
|
| 428 |
-
if parent.restore and "upsample" in parent.restore:
|
| 429 |
-
if parent.cellpix_resize.ndim == 2:
|
| 430 |
-
parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :]
|
| 431 |
-
if parent.cellpix_orig.ndim == 2:
|
| 432 |
-
parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :]
|
| 433 |
-
|
| 434 |
-
print(f"GUI_INFO: {masks.max()} masks found")
|
| 435 |
-
|
| 436 |
-
# get outlines
|
| 437 |
-
if outlines is None: # parent.outlinesOn
|
| 438 |
-
parent.outpix = np.zeros_like(parent.cellpix)
|
| 439 |
-
if parent.restore and "upsample" in parent.restore:
|
| 440 |
-
parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
|
| 441 |
-
for z in range(parent.NZ):
|
| 442 |
-
outlines = masks_to_outlines(parent.cellpix[z])
|
| 443 |
-
parent.outpix[z] = outlines * parent.cellpix[z]
|
| 444 |
-
if parent.restore and "upsample" in parent.restore:
|
| 445 |
-
outlines = masks_to_outlines(parent.cellpix_orig[z])
|
| 446 |
-
parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
|
| 447 |
-
if z % 50 == 0 and parent.NZ > 1:
|
| 448 |
-
print("GUI_INFO: plane %d outlines processed" % z)
|
| 449 |
-
if parent.restore and "upsample" in parent.restore:
|
| 450 |
-
parent.outpix_resize = parent.outpix.copy()
|
| 451 |
-
else:
|
| 452 |
-
parent.outpix = outlines
|
| 453 |
-
if parent.restore and "upsample" in parent.restore:
|
| 454 |
-
parent.outpix_resize = parent.outpix.copy()
|
| 455 |
-
parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
|
| 456 |
-
for z in range(parent.NZ):
|
| 457 |
-
outlines = masks_to_outlines(parent.cellpix_orig[z])
|
| 458 |
-
parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
|
| 459 |
-
if z % 50 == 0 and parent.NZ > 1:
|
| 460 |
-
print("GUI_INFO: plane %d outlines processed" % z)
|
| 461 |
-
|
| 462 |
-
if parent.outpix.ndim == 2:
|
| 463 |
-
parent.outpix = parent.outpix[np.newaxis, :, :]
|
| 464 |
-
if parent.restore and "upsample" in parent.restore:
|
| 465 |
-
if parent.outpix_resize.ndim == 2:
|
| 466 |
-
parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :]
|
| 467 |
-
if parent.outpix_orig.ndim == 2:
|
| 468 |
-
parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :]
|
| 469 |
-
|
| 470 |
-
parent.ncells.set(parent.cellpix.max())
|
| 471 |
-
colors = parent.colormap[:parent.ncells.get(), :3] if colors is None else colors
|
| 472 |
-
print("GUI_INFO: creating cellcolors and drawing masks")
|
| 473 |
-
parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors),
|
| 474 |
-
axis=0).astype(np.uint8)
|
| 475 |
-
if parent.ncells > 0:
|
| 476 |
-
parent.draw_layer()
|
| 477 |
-
parent.toggle_mask_ops()
|
| 478 |
-
parent.ismanual = np.zeros(parent.ncells.get(), bool)
|
| 479 |
-
parent.zdraw = list(-1 * np.ones(parent.ncells.get(), np.int16))
|
| 480 |
-
|
| 481 |
-
if hasattr(parent, "stack_filtered"):
|
| 482 |
-
parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1)
|
| 483 |
-
print("set denoised/filtered view")
|
| 484 |
-
else:
|
| 485 |
-
parent.ViewDropDown.setCurrentIndex(0)
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
def _save_png(parent):
|
| 489 |
-
""" save masks to png or tiff (if 3D) """
|
| 490 |
-
filename = parent.filename
|
| 491 |
-
base = os.path.splitext(filename)[0]
|
| 492 |
-
if parent.NZ == 1:
|
| 493 |
-
if parent.cellpix[0].max() > 65534:
|
| 494 |
-
print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)")
|
| 495 |
-
imsave(base + "_cp_masks.tif", parent.cellpix[0])
|
| 496 |
-
else:
|
| 497 |
-
print("GUI_INFO: saving 2D masks to png")
|
| 498 |
-
imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16))
|
| 499 |
-
else:
|
| 500 |
-
print("GUI_INFO: saving 3D masks to tiff")
|
| 501 |
-
imsave(base + "_cp_masks.tif", parent.cellpix)
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
def _save_flows(parent):
|
| 505 |
-
""" save flows and cellprob to tiff """
|
| 506 |
-
filename = parent.filename
|
| 507 |
-
base = os.path.splitext(filename)[0]
|
| 508 |
-
print("GUI_INFO: saving flows and cellprob to tiff")
|
| 509 |
-
if len(parent.flows) > 0:
|
| 510 |
-
imsave(base + "_cp_cellprob.tif", parent.flows[1])
|
| 511 |
-
for i in range(3):
|
| 512 |
-
imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i])
|
| 513 |
-
if len(parent.flows) > 2:
|
| 514 |
-
imsave(base + "_cp_flows.tif", parent.flows[2])
|
| 515 |
-
print("GUI_INFO: saved flows and cellprob")
|
| 516 |
-
else:
|
| 517 |
-
print("ERROR: no flows or cellprob found")
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
def _save_rois(parent):
|
| 521 |
-
""" save masks as rois in .zip file for ImageJ """
|
| 522 |
-
filename = parent.filename
|
| 523 |
-
if parent.NZ == 1:
|
| 524 |
-
print(
|
| 525 |
-
f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.")
|
| 526 |
-
save_rois(parent.cellpix[0], parent.filename)
|
| 527 |
-
else:
|
| 528 |
-
print("ERROR: cannot save 3D outlines")
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
def _save_outlines(parent):
|
| 532 |
-
filename = parent.filename
|
| 533 |
-
base = os.path.splitext(filename)[0]
|
| 534 |
-
if parent.NZ == 1:
|
| 535 |
-
print(
|
| 536 |
-
"GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ"
|
| 537 |
-
)
|
| 538 |
-
outlines = outlines_list(parent.cellpix[0])
|
| 539 |
-
outlines_to_text(base, outlines)
|
| 540 |
-
else:
|
| 541 |
-
print("ERROR: cannot save 3D outlines")
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
def _save_sets_with_check(parent):
|
| 545 |
-
""" Save masks and update *_seg.npy file. Use this function when saving should be optional
|
| 546 |
-
based on the disableAutosave checkbox. Otherwise, use _save_sets """
|
| 547 |
-
if not parent.disableAutosave.isChecked():
|
| 548 |
-
_save_sets(parent)
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
def _save_sets(parent):
|
| 552 |
-
""" save masks to *_seg.npy. This function should be used when saving
|
| 553 |
-
is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check
|
| 554 |
-
"""
|
| 555 |
-
filename = parent.filename
|
| 556 |
-
base = os.path.splitext(filename)[0]
|
| 557 |
-
flow_threshold = parent.segmentation_settings.flow_threshold
|
| 558 |
-
cellprob_threshold = parent.segmentation_settings.cellprob_threshold
|
| 559 |
-
|
| 560 |
-
if parent.NZ > 1:
|
| 561 |
-
dat = {
|
| 562 |
-
"outlines":
|
| 563 |
-
parent.outpix,
|
| 564 |
-
"colors":
|
| 565 |
-
parent.cellcolors[1:],
|
| 566 |
-
"masks":
|
| 567 |
-
parent.cellpix,
|
| 568 |
-
"current_channel": (parent.color - 2) % 5,
|
| 569 |
-
"filename":
|
| 570 |
-
parent.filename,
|
| 571 |
-
"flows":
|
| 572 |
-
parent.flows,
|
| 573 |
-
"zdraw":
|
| 574 |
-
parent.zdraw,
|
| 575 |
-
"model_path":
|
| 576 |
-
parent.current_model_path
|
| 577 |
-
if hasattr(parent, "current_model_path") else 0,
|
| 578 |
-
"flow_threshold":
|
| 579 |
-
flow_threshold,
|
| 580 |
-
"cellprob_threshold":
|
| 581 |
-
cellprob_threshold,
|
| 582 |
-
"normalize_params":
|
| 583 |
-
parent.get_normalize_params(),
|
| 584 |
-
"restore":
|
| 585 |
-
parent.restore,
|
| 586 |
-
"ratio":
|
| 587 |
-
parent.ratio,
|
| 588 |
-
"diameter":
|
| 589 |
-
parent.segmentation_settings.diameter
|
| 590 |
-
}
|
| 591 |
-
if parent.restore is not None:
|
| 592 |
-
dat["img_restore"] = parent.stack_filtered
|
| 593 |
-
else:
|
| 594 |
-
dat = {
|
| 595 |
-
"outlines":
|
| 596 |
-
parent.outpix.squeeze() if parent.restore is None or
|
| 597 |
-
not "upsample" in parent.restore else parent.outpix_resize.squeeze(),
|
| 598 |
-
"colors":
|
| 599 |
-
parent.cellcolors[1:],
|
| 600 |
-
"masks":
|
| 601 |
-
parent.cellpix.squeeze() if parent.restore is None or
|
| 602 |
-
not "upsample" in parent.restore else parent.cellpix_resize.squeeze(),
|
| 603 |
-
"filename":
|
| 604 |
-
parent.filename,
|
| 605 |
-
"flows":
|
| 606 |
-
parent.flows,
|
| 607 |
-
"ismanual":
|
| 608 |
-
parent.ismanual,
|
| 609 |
-
"manual_changes":
|
| 610 |
-
parent.track_changes,
|
| 611 |
-
"model_path":
|
| 612 |
-
parent.current_model_path
|
| 613 |
-
if hasattr(parent, "current_model_path") else 0,
|
| 614 |
-
"flow_threshold":
|
| 615 |
-
flow_threshold,
|
| 616 |
-
"cellprob_threshold":
|
| 617 |
-
cellprob_threshold,
|
| 618 |
-
"normalize_params":
|
| 619 |
-
parent.get_normalize_params(),
|
| 620 |
-
"restore":
|
| 621 |
-
parent.restore,
|
| 622 |
-
"ratio":
|
| 623 |
-
parent.ratio,
|
| 624 |
-
"diameter":
|
| 625 |
-
parent.segmentation_settings.diameter
|
| 626 |
-
}
|
| 627 |
-
if parent.restore is not None:
|
| 628 |
-
dat["img_restore"] = parent.stack_filtered
|
| 629 |
-
try:
|
| 630 |
-
np.save(base + "_seg.npy", dat)
|
| 631 |
-
print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells.get(), base + "_seg.npy"))
|
| 632 |
-
except Exception as e:
|
| 633 |
-
print(f"ERROR: {e}")
|
| 634 |
-
del dat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/make_train.py
DELETED
|
@@ -1,107 +0,0 @@
|
|
| 1 |
-
import os, argparse
|
| 2 |
-
import numpy as np
|
| 3 |
-
from cellpose import io, transforms
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def main():
|
| 7 |
-
parser = argparse.ArgumentParser(description='Make slices of XYZ image data for training. Assumes image is ZXYC unless specified otherwise using --channel_axis and --z_axis')
|
| 8 |
-
|
| 9 |
-
input_img_args = parser.add_argument_group("input image arguments")
|
| 10 |
-
input_img_args.add_argument('--dir', default=[], type=str,
|
| 11 |
-
help='folder containing data to run or train on.')
|
| 12 |
-
input_img_args.add_argument(
|
| 13 |
-
'--image_path', default=[], type=str, help=
|
| 14 |
-
'if given and --dir not given, run on single image instead of folder (cannot train with this option)'
|
| 15 |
-
)
|
| 16 |
-
input_img_args.add_argument(
|
| 17 |
-
'--look_one_level_down', action='store_true',
|
| 18 |
-
help='run processing on all subdirectories of current folder')
|
| 19 |
-
input_img_args.add_argument('--img_filter', default=[], type=str,
|
| 20 |
-
help='end string for images to run on')
|
| 21 |
-
input_img_args.add_argument(
|
| 22 |
-
'--channel_axis', default=-1, type=int,
|
| 23 |
-
help='axis of image which corresponds to image channels')
|
| 24 |
-
input_img_args.add_argument('--z_axis', default=0, type=int,
|
| 25 |
-
help='axis of image which corresponds to Z dimension')
|
| 26 |
-
input_img_args.add_argument(
|
| 27 |
-
'--chan', default=0, type=int, help=
|
| 28 |
-
'Deprecated')
|
| 29 |
-
input_img_args.add_argument(
|
| 30 |
-
'--chan2', default=0, type=int, help=
|
| 31 |
-
'Deprecated'
|
| 32 |
-
)
|
| 33 |
-
input_img_args.add_argument('--invert', action='store_true',
|
| 34 |
-
help='invert grayscale channel')
|
| 35 |
-
input_img_args.add_argument(
|
| 36 |
-
'--all_channels', action='store_true', help=
|
| 37 |
-
'deprecated')
|
| 38 |
-
input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
|
| 39 |
-
help="anisotropy of volume in 3D")
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
# algorithm settings
|
| 43 |
-
algorithm_args = parser.add_argument_group("algorithm arguments")
|
| 44 |
-
algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0,
|
| 45 |
-
type=float, help='high-pass filtering radius. Default: %(default)s')
|
| 46 |
-
algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int,
|
| 47 |
-
help='tile normalization block size. Default: %(default)s')
|
| 48 |
-
algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int,
|
| 49 |
-
help='number of crops in XY to save per tiff. Default: %(default)s')
|
| 50 |
-
algorithm_args.add_argument('--crop_size', required=False, default=512, type=int,
|
| 51 |
-
help='size of random crop to save. Default: %(default)s')
|
| 52 |
-
|
| 53 |
-
args = parser.parse_args()
|
| 54 |
-
|
| 55 |
-
# find images
|
| 56 |
-
if len(args.img_filter) > 0:
|
| 57 |
-
imf = args.img_filter
|
| 58 |
-
else:
|
| 59 |
-
imf = None
|
| 60 |
-
|
| 61 |
-
if len(args.dir) > 0:
|
| 62 |
-
image_names = io.get_image_files(args.dir, "_masks", imf=imf,
|
| 63 |
-
look_one_level_down=args.look_one_level_down)
|
| 64 |
-
dirname = args.dir
|
| 65 |
-
else:
|
| 66 |
-
if os.path.exists(args.image_path):
|
| 67 |
-
image_names = [args.image_path]
|
| 68 |
-
dirname = os.path.split(args.image_path)[0]
|
| 69 |
-
else:
|
| 70 |
-
raise ValueError(f"ERROR: no file found at {args.image_path}")
|
| 71 |
-
|
| 72 |
-
np.random.seed(0)
|
| 73 |
-
nimg_per_tif = args.nimg_per_tif
|
| 74 |
-
crop_size = args.crop_size
|
| 75 |
-
os.makedirs(os.path.join(dirname, 'train/'), exist_ok=True)
|
| 76 |
-
pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)]
|
| 77 |
-
npm = ["YX", "ZY", "ZX"]
|
| 78 |
-
for name in image_names:
|
| 79 |
-
name0 = os.path.splitext(os.path.split(name)[-1])[0]
|
| 80 |
-
img0 = io.imread_3D(name)
|
| 81 |
-
try:
|
| 82 |
-
img0 = transforms.convert_image(img0, channel_axis=args.channel_axis,
|
| 83 |
-
z_axis=args.z_axis, do_3D=True)
|
| 84 |
-
except ValueError:
|
| 85 |
-
print('Error converting image. Did you provide the correct --channel_axis and --z_axis ?')
|
| 86 |
-
|
| 87 |
-
for p in range(3):
|
| 88 |
-
img = img0.transpose(pm[p]).copy()
|
| 89 |
-
print(npm[p], img[0].shape)
|
| 90 |
-
Ly, Lx = img.shape[1:3]
|
| 91 |
-
imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]]
|
| 92 |
-
if args.anisotropy > 1.0 and p > 0:
|
| 93 |
-
imgs = transforms.resize_image(imgs, Ly=int(args.anisotropy * Ly), Lx=Lx)
|
| 94 |
-
for k, img in enumerate(imgs):
|
| 95 |
-
if args.tile_norm:
|
| 96 |
-
img = transforms.normalize99_tile(img, blocksize=args.tile_norm)
|
| 97 |
-
if args.sharpen_radius:
|
| 98 |
-
img = transforms.smooth_sharpen_img(img,
|
| 99 |
-
sharpen_radius=args.sharpen_radius)
|
| 100 |
-
ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size)
|
| 101 |
-
lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size)
|
| 102 |
-
io.imsave(os.path.join(dirname, f'train/{name0}_{npm[p]}_{k}.tif'),
|
| 103 |
-
img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze())
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
if __name__ == '__main__':
|
| 107 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/gui/menus.py
DELETED
|
@@ -1,145 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
from qtpy.QtWidgets import QAction
|
| 5 |
-
from . import io
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def mainmenu(parent):
|
| 9 |
-
main_menu = parent.menuBar()
|
| 10 |
-
file_menu = main_menu.addMenu("&File")
|
| 11 |
-
# load processed data
|
| 12 |
-
loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", parent)
|
| 13 |
-
loadImg.setShortcut("Ctrl+L")
|
| 14 |
-
loadImg.triggered.connect(lambda: io._load_image(parent))
|
| 15 |
-
file_menu.addAction(loadImg)
|
| 16 |
-
|
| 17 |
-
parent.autoloadMasks = QAction("Autoload masks from _masks.tif file", parent,
|
| 18 |
-
checkable=True)
|
| 19 |
-
parent.autoloadMasks.setChecked(False)
|
| 20 |
-
file_menu.addAction(parent.autoloadMasks)
|
| 21 |
-
|
| 22 |
-
parent.disableAutosave = QAction("Disable autosave _seg.npy file", parent,
|
| 23 |
-
checkable=True)
|
| 24 |
-
parent.disableAutosave.setChecked(False)
|
| 25 |
-
file_menu.addAction(parent.disableAutosave)
|
| 26 |
-
|
| 27 |
-
parent.loadMasks = QAction("Load &masks (*.tif, *.png, *.jpg)", parent)
|
| 28 |
-
parent.loadMasks.setShortcut("Ctrl+M")
|
| 29 |
-
parent.loadMasks.triggered.connect(lambda: io._load_masks(parent))
|
| 30 |
-
file_menu.addAction(parent.loadMasks)
|
| 31 |
-
parent.loadMasks.setEnabled(False)
|
| 32 |
-
|
| 33 |
-
loadManual = QAction("Load &processed/labelled image (*_seg.npy)", parent)
|
| 34 |
-
loadManual.setShortcut("Ctrl+P")
|
| 35 |
-
loadManual.triggered.connect(lambda: io._load_seg(parent))
|
| 36 |
-
file_menu.addAction(loadManual)
|
| 37 |
-
|
| 38 |
-
parent.saveSet = QAction("&Save masks and image (as *_seg.npy)", parent)
|
| 39 |
-
parent.saveSet.setShortcut("Ctrl+S")
|
| 40 |
-
parent.saveSet.triggered.connect(lambda: io._save_sets(parent))
|
| 41 |
-
file_menu.addAction(parent.saveSet)
|
| 42 |
-
parent.saveSet.setEnabled(False)
|
| 43 |
-
|
| 44 |
-
parent.savePNG = QAction("Save masks as P&NG/tif", parent)
|
| 45 |
-
parent.savePNG.setShortcut("Ctrl+N")
|
| 46 |
-
parent.savePNG.triggered.connect(lambda: io._save_png(parent))
|
| 47 |
-
file_menu.addAction(parent.savePNG)
|
| 48 |
-
parent.savePNG.setEnabled(False)
|
| 49 |
-
|
| 50 |
-
parent.saveOutlines = QAction("Save &Outlines as text for imageJ", parent)
|
| 51 |
-
parent.saveOutlines.setShortcut("Ctrl+O")
|
| 52 |
-
parent.saveOutlines.triggered.connect(lambda: io._save_outlines(parent))
|
| 53 |
-
file_menu.addAction(parent.saveOutlines)
|
| 54 |
-
parent.saveOutlines.setEnabled(False)
|
| 55 |
-
|
| 56 |
-
parent.saveROIs = QAction("Save outlines as .zip archive of &ROI files for ImageJ",
|
| 57 |
-
parent)
|
| 58 |
-
parent.saveROIs.setShortcut("Ctrl+R")
|
| 59 |
-
parent.saveROIs.triggered.connect(lambda: io._save_rois(parent))
|
| 60 |
-
file_menu.addAction(parent.saveROIs)
|
| 61 |
-
parent.saveROIs.setEnabled(False)
|
| 62 |
-
|
| 63 |
-
parent.saveFlows = QAction("Save &Flows and cellprob as tif", parent)
|
| 64 |
-
parent.saveFlows.setShortcut("Ctrl+F")
|
| 65 |
-
parent.saveFlows.triggered.connect(lambda: io._save_flows(parent))
|
| 66 |
-
file_menu.addAction(parent.saveFlows)
|
| 67 |
-
parent.saveFlows.setEnabled(False)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def editmenu(parent):
|
| 71 |
-
main_menu = parent.menuBar()
|
| 72 |
-
edit_menu = main_menu.addMenu("&Edit")
|
| 73 |
-
parent.undo = QAction("Undo previous mask/trace", parent)
|
| 74 |
-
parent.undo.setShortcut("Ctrl+Z")
|
| 75 |
-
parent.undo.triggered.connect(parent.undo_action)
|
| 76 |
-
parent.undo.setEnabled(False)
|
| 77 |
-
edit_menu.addAction(parent.undo)
|
| 78 |
-
|
| 79 |
-
parent.redo = QAction("Undo remove mask", parent)
|
| 80 |
-
parent.redo.setShortcut("Ctrl+Y")
|
| 81 |
-
parent.redo.triggered.connect(parent.undo_remove_action)
|
| 82 |
-
parent.redo.setEnabled(False)
|
| 83 |
-
edit_menu.addAction(parent.redo)
|
| 84 |
-
|
| 85 |
-
parent.ClearButton = QAction("Clear all masks", parent)
|
| 86 |
-
parent.ClearButton.setShortcut("Ctrl+0")
|
| 87 |
-
parent.ClearButton.triggered.connect(parent.clear_all)
|
| 88 |
-
parent.ClearButton.setEnabled(False)
|
| 89 |
-
edit_menu.addAction(parent.ClearButton)
|
| 90 |
-
|
| 91 |
-
parent.remcell = QAction("Remove selected cell (Ctrl+CLICK)", parent)
|
| 92 |
-
parent.remcell.setShortcut("Ctrl+Click")
|
| 93 |
-
parent.remcell.triggered.connect(parent.remove_action)
|
| 94 |
-
parent.remcell.setEnabled(False)
|
| 95 |
-
edit_menu.addAction(parent.remcell)
|
| 96 |
-
|
| 97 |
-
parent.mergecell = QAction("FYI: Merge cells by Alt+Click", parent)
|
| 98 |
-
parent.mergecell.setEnabled(False)
|
| 99 |
-
edit_menu.addAction(parent.mergecell)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def modelmenu(parent):
|
| 103 |
-
main_menu = parent.menuBar()
|
| 104 |
-
io._init_model_list(parent)
|
| 105 |
-
model_menu = main_menu.addMenu("&Models")
|
| 106 |
-
parent.addmodel = QAction("Add custom torch model to GUI", parent)
|
| 107 |
-
#parent.addmodel.setShortcut("Ctrl+A")
|
| 108 |
-
parent.addmodel.triggered.connect(parent.add_model)
|
| 109 |
-
parent.addmodel.setEnabled(True)
|
| 110 |
-
model_menu.addAction(parent.addmodel)
|
| 111 |
-
|
| 112 |
-
parent.removemodel = QAction("Remove selected custom model from GUI", parent)
|
| 113 |
-
#parent.removemodel.setShortcut("Ctrl+R")
|
| 114 |
-
parent.removemodel.triggered.connect(parent.remove_model)
|
| 115 |
-
parent.removemodel.setEnabled(True)
|
| 116 |
-
model_menu.addAction(parent.removemodel)
|
| 117 |
-
|
| 118 |
-
parent.newmodel = QAction("&Train new model with image+masks in folder", parent)
|
| 119 |
-
parent.newmodel.setShortcut("Ctrl+T")
|
| 120 |
-
parent.newmodel.triggered.connect(parent.new_model)
|
| 121 |
-
parent.newmodel.setEnabled(False)
|
| 122 |
-
model_menu.addAction(parent.newmodel)
|
| 123 |
-
|
| 124 |
-
openTrainHelp = QAction("Training instructions", parent)
|
| 125 |
-
openTrainHelp.triggered.connect(parent.train_help_window)
|
| 126 |
-
model_menu.addAction(openTrainHelp)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def helpmenu(parent):
|
| 130 |
-
main_menu = parent.menuBar()
|
| 131 |
-
help_menu = main_menu.addMenu("&Help")
|
| 132 |
-
|
| 133 |
-
openHelp = QAction("&Help with GUI", parent)
|
| 134 |
-
openHelp.setShortcut("Ctrl+H")
|
| 135 |
-
openHelp.triggered.connect(parent.help_window)
|
| 136 |
-
help_menu.addAction(openHelp)
|
| 137 |
-
|
| 138 |
-
openGUI = QAction("&GUI layout", parent)
|
| 139 |
-
openGUI.setShortcut("Ctrl+G")
|
| 140 |
-
openGUI.triggered.connect(parent.gui_window)
|
| 141 |
-
help_menu.addAction(openGUI)
|
| 142 |
-
|
| 143 |
-
openTrainHelp = QAction("Training instructions", parent)
|
| 144 |
-
openTrainHelp.triggered.connect(parent.train_help_window)
|
| 145 |
-
help_menu.addAction(openTrainHelp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/seg_post_model/cellpose/vit_sam_new.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from segment_anything import sam_model_registry
|
| 7 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 8 |
-
from torch import nn
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
|
| 11 |
-
class Transformer(nn.Module):
|
| 12 |
-
def __init__(self, backbone="vit_l", ps=16, nout=3, bsize=256, rdrop=0.4,
|
| 13 |
-
checkpoint=None, dtype=torch.float32):
|
| 14 |
-
super(Transformer, self).__init__()
|
| 15 |
-
"""
|
| 16 |
-
print(self.encoder.patch_embed)
|
| 17 |
-
PatchEmbed(
|
| 18 |
-
(proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
|
| 19 |
-
)
|
| 20 |
-
print(self.encoder.neck)
|
| 21 |
-
Sequential(
|
| 22 |
-
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
| 23 |
-
(1): LayerNorm2d()
|
| 24 |
-
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
| 25 |
-
(3): LayerNorm2d()
|
| 26 |
-
)
|
| 27 |
-
"""
|
| 28 |
-
# instantiate the vit model, default to not loading SAM
|
| 29 |
-
# checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
|
| 30 |
-
if checkpoint is None:
|
| 31 |
-
checkpoint = "sam_vit_l_0b3195.pth"
|
| 32 |
-
self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
|
| 33 |
-
w = self.encoder.patch_embed.proj.weight.detach()
|
| 34 |
-
nchan = w.shape[0]
|
| 35 |
-
|
| 36 |
-
# change token size to ps x ps
|
| 37 |
-
self.ps = ps
|
| 38 |
-
# self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
|
| 39 |
-
# self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
|
| 40 |
-
|
| 41 |
-
# adjust position embeddings for new bsize and new token size
|
| 42 |
-
ds = (1024 // 16) // (bsize // ps)
|
| 43 |
-
self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
|
| 44 |
-
|
| 45 |
-
# readout weights for nout output channels
|
| 46 |
-
# if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
|
| 47 |
-
self.nout = nout
|
| 48 |
-
self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
|
| 49 |
-
|
| 50 |
-
# W2 reshapes token space to pixel space, not trainable
|
| 51 |
-
self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
|
| 52 |
-
requires_grad=False)
|
| 53 |
-
|
| 54 |
-
# fraction of layers to drop at random during training
|
| 55 |
-
self.rdrop = rdrop
|
| 56 |
-
|
| 57 |
-
# average diameter of ROIs from training images from fine-tuning
|
| 58 |
-
self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
|
| 59 |
-
# average diameter of ROIs during main training
|
| 60 |
-
self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
|
| 61 |
-
|
| 62 |
-
# set attention to global in every layer
|
| 63 |
-
for blk in self.encoder.blocks:
|
| 64 |
-
blk.window_size = 0
|
| 65 |
-
|
| 66 |
-
self.dtype = dtype
|
| 67 |
-
|
| 68 |
-
def forward(self, x, feat=None):
|
| 69 |
-
# same progression as SAM until readout
|
| 70 |
-
x = self.encoder.patch_embed(x)
|
| 71 |
-
if feat is not None:
|
| 72 |
-
feat = self.encoder.patch_embed(feat)
|
| 73 |
-
x = x + x * feat * 0.5
|
| 74 |
-
|
| 75 |
-
if self.encoder.pos_embed is not None:
|
| 76 |
-
x = x + self.encoder.pos_embed
|
| 77 |
-
|
| 78 |
-
if self.training and self.rdrop > 0:
|
| 79 |
-
nlay = len(self.encoder.blocks)
|
| 80 |
-
rdrop = (torch.rand((len(x), nlay), device=x.device) <
|
| 81 |
-
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
|
| 82 |
-
for i, blk in enumerate(self.encoder.blocks):
|
| 83 |
-
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 84 |
-
x = x * mask + blk(x) * (1-mask)
|
| 85 |
-
else:
|
| 86 |
-
for blk in self.encoder.blocks:
|
| 87 |
-
x = blk(x)
|
| 88 |
-
|
| 89 |
-
x = self.encoder.neck(x.permute(0, 3, 1, 2))
|
| 90 |
-
|
| 91 |
-
# readout is changed here
|
| 92 |
-
x1 = self.out(x)
|
| 93 |
-
x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
|
| 94 |
-
|
| 95 |
-
# maintain the second output of feature size 256 for backwards compatibility
|
| 96 |
-
|
| 97 |
-
return x1, torch.randn((x.shape[0], 256), device=x.device)
|
| 98 |
-
|
| 99 |
-
def load_model(self, PATH, device, strict = False):
|
| 100 |
-
state_dict = torch.load(PATH, map_location = device, weights_only=True)
|
| 101 |
-
keys = [k for k in state_dict.keys()]
|
| 102 |
-
if keys[0][:7] == "module.":
|
| 103 |
-
from collections import OrderedDict
|
| 104 |
-
new_state_dict = OrderedDict()
|
| 105 |
-
for k, v in state_dict.items():
|
| 106 |
-
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
|
| 107 |
-
new_state_dict[name] = v
|
| 108 |
-
self.load_state_dict(new_state_dict, strict = strict)
|
| 109 |
-
else:
|
| 110 |
-
self.load_state_dict(state_dict, strict = strict)
|
| 111 |
-
|
| 112 |
-
if self.dtype != torch.float32:
|
| 113 |
-
self = self.to(self.dtype)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
@property
|
| 117 |
-
def device(self):
|
| 118 |
-
"""
|
| 119 |
-
Get the device of the model.
|
| 120 |
-
|
| 121 |
-
Returns:
|
| 122 |
-
torch.device: The device of the model.
|
| 123 |
-
"""
|
| 124 |
-
return next(self.parameters()).device
|
| 125 |
-
|
| 126 |
-
def save_model(self, filename):
|
| 127 |
-
"""
|
| 128 |
-
Save the model to a file.
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
filename (str): The path to the file where the model will be saved.
|
| 132 |
-
"""
|
| 133 |
-
torch.save(self.state_dict(), filename)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class CPnetBioImageIO(Transformer):
|
| 138 |
-
"""
|
| 139 |
-
A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
|
| 140 |
-
|
| 141 |
-
This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
|
| 142 |
-
allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
|
| 143 |
-
"""
|
| 144 |
-
|
| 145 |
-
def forward(self, x):
|
| 146 |
-
"""
|
| 147 |
-
Perform a forward pass of the CPnet model and return unpacked tensors.
|
| 148 |
-
|
| 149 |
-
Args:
|
| 150 |
-
x (torch.Tensor): Input tensor.
|
| 151 |
-
|
| 152 |
-
Returns:
|
| 153 |
-
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
|
| 154 |
-
"""
|
| 155 |
-
output_tensor, style_tensor, downsampled_tensors = super().forward(x)
|
| 156 |
-
return output_tensor, style_tensor, *downsampled_tensors
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def load_model(self, filename, device=None):
|
| 160 |
-
"""
|
| 161 |
-
Load the model from a file.
|
| 162 |
-
|
| 163 |
-
Args:
|
| 164 |
-
filename (str): The path to the file where the model is saved.
|
| 165 |
-
device (torch.device, optional): The device to load the model on. Defaults to None.
|
| 166 |
-
"""
|
| 167 |
-
if (device is not None) and (device.type != "cpu"):
|
| 168 |
-
state_dict = torch.load(filename, map_location=device, weights_only=True)
|
| 169 |
-
else:
|
| 170 |
-
self.__init__(self.nout)
|
| 171 |
-
state_dict = torch.load(filename, map_location=torch.device("cpu"),
|
| 172 |
-
weights_only=True)
|
| 173 |
-
|
| 174 |
-
self.load_state_dict(state_dict)
|
| 175 |
-
|
| 176 |
-
def load_state_dict(self, state_dict):
|
| 177 |
-
"""
|
| 178 |
-
Load the state dictionary into the model.
|
| 179 |
-
|
| 180 |
-
This method overrides the default `load_state_dict` to handle Cellpose's custom
|
| 181 |
-
loading mechanism and ensures compatibility with BioImage.IO Core.
|
| 182 |
-
|
| 183 |
-
Args:
|
| 184 |
-
state_dict (Mapping[str, Any]): A state dictionary to load into the model
|
| 185 |
-
"""
|
| 186 |
-
if state_dict["output.2.weight"].shape[0] != self.nout:
|
| 187 |
-
for name in self.state_dict():
|
| 188 |
-
if "output" not in name:
|
| 189 |
-
self.state_dict()[name].copy_(state_dict[name])
|
| 190 |
-
else:
|
| 191 |
-
super().load_state_dict(
|
| 192 |
-
{name: param for name, param in state_dict.items()},
|
| 193 |
-
strict=False)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|