VisionLanguageGroup's picture
clean up
86072ea
"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
"""
import os, time
# from pathlib import Path
import numpy as np
from tqdm import trange
import torch
from scipy.ndimage import gaussian_filter
# import gc
import cv2
import logging
models_logger = logging.getLogger(__name__)
from . import transforms, dynamics, utils
from .vit import Transformer
from .core import assign_device, run_net
# _MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
# _MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models")
# MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
# MODEL_NAMES = ["cpsam"]
# MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
normalize_default = {
"lowhigh": None,
"percentile": None,
"normalize": True,
"norm3D": True,
"sharpen_radius": 0,
"smooth_radius": 0,
"tile_norm_blocksize": 0,
"tile_norm_smooth3D": 1,
"invert": False
}
# def get_user_models():
# model_strings = []
# if os.path.exists(MODEL_LIST_PATH):
# with open(MODEL_LIST_PATH, "r") as textfile:
# lines = [line.rstrip() for line in textfile]
# if len(lines) > 0:
# model_strings.extend(lines)
# return model_strings
class SegModel():
"""
Class representing a Cellpose model.
Attributes:
diam_mean (float): Mean "diameter" value for the model.
builtin (bool): Whether the model is a built-in model or not.
device (torch device): Device used for model running / training.
nclasses (int): Number of classes in the model.
nbase (list): List of base values for the model.
net (CPnet): Cellpose network.
pretrained_model (str): Path to pretrained cellpose model.
pretrained_model_ortho (str): Path or model_name for pretrained cellpose model for ortho views in 3D.
backbone (str): Type of network ("default" is the standard res-unet, "transformer" for the segformer).
Methods:
__init__(self, gpu=False, pretrained_model=False, model_type=None, diam_mean=30., device=None):
Initialize the CellposeModel.
eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, progress=None):
Segment list of images x, or 4D array - Z x C x Y x X.
"""
def __init__(self, gpu=False, pretrained_model="", model_type=None,
diam_mean=None, device=None, nchan=None, use_bfloat16=True, vit_checkpoint=None):
"""
Initialize the CellposeModel.
Parameters:
gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available.
pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded.
model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo).
diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value.
device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
"""
### assign model device
self.device = assign_device(gpu=gpu)[0] if device is None else device
if torch.cuda.is_available():
device_gpu = self.device.type == "cuda"
elif torch.backends.mps.is_available():
device_gpu = self.device.type == "mps"
else:
device_gpu = False
self.gpu = device_gpu
if pretrained_model is None:
# raise ValueError("Must specify a pretrained model, training from scratch is not implemented")
pretrained_model = ""
self.pretrained_model = pretrained_model
dtype = torch.bfloat16 if use_bfloat16 else torch.float32
self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device)
def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None,
z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
flow3D_smooth=0, stitch_threshold=0.0,
min_size=15, max_size_fraction=0.4, niter=None,
augment=False, tile_overlap=0.1, bsize=256,
compute_masks=True, progress=None):
if isinstance(x, list) or x.squeeze().ndim == 5:
self.timing = []
masks, styles, flows = [], [], []
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
nimg = len(x)
iterator = trange(nimg, file=tqdm_out,
mininterval=30) if nimg > 1 else range(nimg)
for i in iterator:
tic = time.time()
maski, flowi, stylei = self.eval(
x[i],
feat=None if feat is None else feat[i],
batch_size=batch_size,
channel_axis=channel_axis,
z_axis=z_axis,
normalize=normalize,
invert=invert,
diameter=diameter[i] if isinstance(diameter, list) or
isinstance(diameter, np.ndarray) else diameter,
do_3D=do_3D,
anisotropy=anisotropy,
augment=augment,
tile_overlap=tile_overlap,
bsize=bsize,
resample=resample,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold,
compute_masks=compute_masks,
min_size=min_size,
max_size_fraction=max_size_fraction,
stitch_threshold=stitch_threshold,
flow3D_smooth=flow3D_smooth,
progress=progress,
niter=niter)
masks.append(maski)
flows.append(flowi)
styles.append(stylei)
self.timing.append(time.time() - tic)
return masks, flows, styles
############# actual eval code ############
# reshape image
x = transforms.convert_image(x, channel_axis=channel_axis,
z_axis=z_axis,
do_3D=(do_3D or stitch_threshold > 0))
# Add batch dimension if not present
if x.ndim < 4:
x = x[np.newaxis, ...]
if feat is not None:
if feat.ndim < 4:
feat = feat[np.newaxis, ...]
nimg = x.shape[0]
image_scaling = None
Ly_0 = x.shape[1]
Lx_0 = x.shape[2]
Lz_0 = None
if stitch_threshold > 0:
Lz_0 = x.shape[0]
if diameter is not None:
image_scaling = 30. / diameter
x = transforms.resize_image(x,
Ly=int(x.shape[1] * image_scaling),
Lx=int(x.shape[2] * image_scaling))
if feat is not None:
feat = transforms.resize_image(feat,
Ly=int(feat.shape[1] * image_scaling),
Lx=int(feat.shape[2] * image_scaling))
# normalize image
normalize_params = normalize_default
if isinstance(normalize, dict):
normalize_params = {**normalize_params, **normalize}
elif not isinstance(normalize, bool):
raise ValueError("normalize parameter must be a bool or a dict")
else:
normalize_params["normalize"] = normalize
normalize_params["invert"] = invert
# pre-normalize if 3D stack for stitching or do_3D
do_normalization = True if normalize_params["normalize"] else False
if nimg > 1 and do_normalization and (stitch_threshold or do_3D):
normalize_params["norm3D"] = True if do_3D else normalize_params["norm3D"]
x = transforms.normalize_img(x, **normalize_params)
do_normalization = False # do not normalize again
else:
if normalize_params["norm3D"] and nimg > 1 and do_normalization:
models_logger.warning(
"normalize_params['norm3D'] is True but do_3D is False and stitch_threshold=0, so setting to False"
)
normalize_params["norm3D"] = False
if do_normalization:
x = transforms.normalize_img(x, **normalize_params)
if feat is not None:
if feat.shape[-1] > feat.shape[1]:
# transpose feat to have channels last
feat = np.moveaxis(feat, 1, -1)
# adjust the anisotropy when diameter is specified and images are resized:
if isinstance(anisotropy, (float, int)) and image_scaling:
anisotropy = image_scaling * anisotropy
dP, cellprob, styles = self._run_net(
x,
feat=feat,
augment=augment,
batch_size=batch_size,
tile_overlap=tile_overlap,
bsize=bsize,
do_3D=do_3D,
anisotropy=anisotropy)
if resample:
# upsample flows before computing them:
dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
if compute_masks:
niter0 = 200
niter = niter0 if niter is None or niter == 0 else niter
masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, min_size=min_size,
max_size_fraction=max_size_fraction, niter=niter,
stitch_threshold=stitch_threshold, do_3D=do_3D)
else:
masks = np.zeros(0) #pass back zeros if not compute_masks
masks = masks.squeeze()
# undo resizing:
if image_scaling is not None or anisotropy is not None:
if compute_masks:
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
return masks
def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
"""
Resize cellprob array to specified dimensions for either 2D or 3D.
Parameters:
prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
to_y_size (int): The target size along the Y-axis.
to_x_size (int): The target size along the X-axis.
to_z_size (int, optional): The target size along the Z-axis. Required
for 3D cellprobs.
Returns:
numpy.ndarray: The resized cellprobs array with the same number of dimensions
as the input.
Raises:
ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
"""
prob_shape = prob.shape
prob = prob.squeeze()
squeeze_happened = prob.shape != prob_shape
prob_shape = np.array(prob_shape)
if prob.ndim == 2:
# 2D case:
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
if squeeze_happened:
prob = np.expand_dims(prob, int(np.argwhere(prob_shape == 1))) # add back empty axis for compatibility
elif prob.ndim == 3:
# 3D case:
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
prob = prob.transpose(1, 0, 2)
prob = transforms.resize_image(prob, Ly=to_z_size, Lx=to_x_size, no_channels=True)
prob = prob.transpose(1, 0, 2)
else:
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: {prob.shape}')
return prob
def _resize_gradients(self, grads: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
"""
Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
Parameters:
grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
to_y_size (int): The target size along the Y-axis.
to_x_size (int): The target size along the X-axis.
to_z_size (int, optional): The target size along the Z-axis. Required
for 3D gradients.
Returns:
numpy.ndarray: The resized gradient array with the same number of dimensions
as the input.
Raises:
ValueError: If the input gradient array does not have 3 or 4 dimensions.
"""
grads_shape = grads.shape
grads = grads.squeeze()
squeeze_happened = grads.shape != grads_shape
grads_shape = np.array(grads_shape)
if grads.ndim == 3:
# 2D case, with XY flows in 2 channels:
grads = np.moveaxis(grads, 0, -1) # Put gradients last
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
grads = np.moveaxis(grads, -1, 0) # Put gradients first
if squeeze_happened:
grads = np.expand_dims(grads, int(np.argwhere(grads_shape == 1))) # add back empty axis for compatibility
elif grads.ndim == 4:
# dP has gradients that can be treated as channels:
grads = grads.transpose(1, 2, 3, 0) # move gradients last:
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
grads = grads.transpose(1, 0, 2, 3) # switch axes to resize again
grads = transforms.resize_image(grads, Ly=to_z_size, Lx=to_x_size, no_channels=False)
grads = grads.transpose(3, 1, 0, 2) # undo transposition
else:
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: {grads.shape}')
return grads
def _run_net(self, x, feat=None,
augment=False,
batch_size=8, tile_overlap=0.1,
bsize=224, anisotropy=1.0, do_3D=False):
""" run network on image x """
tic = time.time()
shape = x.shape
nimg = shape[0]
yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment,
batch_size=batch_size,
tile_overlap=tile_overlap,
)
cellprob = yf[..., -1]
dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
if yf.shape[-1] > 3:
styles = yf[..., :-3]
styles = styles.squeeze()
net_time = time.time() - tic
if nimg > 1:
models_logger.info("network run in %2.2fs" % (net_time))
return dP, cellprob, styles
def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_threshold=0.0,
min_size=15, max_size_fraction=0.4, niter=None,
do_3D=False, stitch_threshold=0.0):
""" compute masks from flows and cell probability """
changed_device_from = None
if self.device.type == "mps" and do_3D:
models_logger.warning("MPS does not support 3D post-processing, switching to CPU")
self.device = torch.device("cpu")
changed_device_from = "mps"
Lz, Ly, Lx = shape[:3]
tic = time.time()
# if do_3D:
# masks = dynamics.resize_and_compute_masks(
# dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
# flow_threshold=flow_threshold, do_3D=do_3D,
# min_size=min_size, max_size_fraction=max_size_fraction,
# resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum()
# else None,
# device=self.device)
# else:
nimg = shape[0]
Ly0, Lx0 = cellprob[0].shape
resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx]
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
iterator = trange(nimg, file=tqdm_out,
mininterval=30) if nimg > 1 else range(nimg)
for i in iterator:
# turn off min_size for 3D stitching
min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1
outputs = dynamics.resize_and_compute_masks(
dP[:, i], cellprob[i],
niter=niter, cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, resize=resize,
min_size=min_size0, max_size_fraction=max_size_fraction,
device=self.device)
if i==0 and nimg > 1:
masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype)
if nimg > 1:
masks[i] = outputs
else:
masks = outputs
if stitch_threshold > 0 and nimg > 1:
models_logger.info(
f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
)
masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
masks = utils.fill_holes_and_remove_small_masks(
masks, min_size=min_size)
elif nimg > 1:
models_logger.warning(
"3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
)
flow_time = time.time() - tic
if shape[0] > 1:
models_logger.info("masks created in %2.2fs" % (flow_time))
if changed_device_from is not None:
models_logger.info("switching back to device %s" % self.device)
self.device = torch.device(changed_device_from)
return masks