# %%
import os
import numpy as np
from cellpose import models, io, core
from typing import Any, Dict
from tqdm import tqdm
import dask.array as da
from bioio import BioImage
try:
from bioio.writers import OmeTiffWriter
except ImportError:
OmeTiffWriter = None # removed in bioio ≥2; use bioio.Writer instead
import torch
from .base import SegmentationBase
[docs]
class Cellpose(SegmentationBase):
def __init__(self, model_type='cyto', cellprob_thresh=-3, model_matching_thresh=10.0, gpu=True, channels =[0,0], batch_size = 16, diameter = 16.18, verbose=True, execution_config=None) -> None:
"""
Perform CellPose segmentation 2D
Args:
model_type (str): Registered models for CellPose
cellprob_thresh (float): Probability threshold between -8 and 8
model_matching_thresh (float): Non-maximum suppression threshold between 0 and 30
gpu (bool): use GPU
channels (list): [cytoplasm, nucleus] if NUCLEUS channel does not exist, set the second channel to 0, channels = [0,0] # IF YOU HAVE GRAYSCALE
batch_size (int): used for GPU
diameter (float): cell diameter
verbose (bool): Turn on or off the processing printout
execution_config (dict, optional): Configuration for task execution (baremetal vs container)
"""
super().__init__("CellPose", verbose, execution_config)
self.model_type = model_type
self.cellprob_thresh = cellprob_thresh
self.model_matching_thresh = model_matching_thresh
self.gpu = gpu
self.channels = channels
self.batch_size = batch_size
self.diameter = diameter
self.flow_threshold = 0 # (31.0 - self.model_matching_thresh)/10.0
self._model = None
def _setup_model(self) -> Any:
"""Setup CellPose model"""
use_GPU = core.use_gpu()
self._model = models.CellposeModel(gpu=self.gpu, model_type=self.model_type)
self._log_message("CellPose Segmentation 2D: {}\nGPU activated? {}".format(self.model_type, use_GPU))
return self._model
def _process_frames(self, image: np.ndarray) -> np.ndarray:
"""Override to use CellPose batch processing"""
images_stack = []
for t in range(image.shape[2]):
images_stack.append(image[:,:,t])
self._log_message("Running Cellpose segmentation...")
# Handle different return values based on Cellpose version
# Newer versions (3.0+) return (masks, flows, styles)
# Older versions return (masks, flows, styles, diams)
eval_result = self._model.eval(
images_stack,
batch_size=self.batch_size,
diameter=self.diameter, # in pixel
cellprob_threshold=self.cellprob_thresh,
flow_threshold=self.flow_threshold,
channels=self.channels,
stitch_threshold=0.0,
do_3D=False,
progress=True)
if len(eval_result) == 4:
# Older Cellpose version
masks, flows, styles, diams = eval_result
elif len(eval_result) == 3:
# Newer Cellpose version (3.0+)
masks, flows, styles = eval_result
diams = None # Not returned in newer versions
else:
raise ValueError(f"Unexpected number of return values from Cellpose eval(): {len(eval_result)}")
masks = np.asarray(masks)
label = np.transpose(masks, axes=[1,2,0])
return label.astype(np.uint16)
def _segment_frame(self, frame: np.ndarray, frame_idx: int) -> np.ndarray:
"""Not used in CellPose as it processes all frames at once"""
raise NotImplementedError("CellPose uses batch processing, this method should not be called")
[docs]
def run_container(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Run CellPose segmentation in a container.
This method uses the base class container execution logic which handles
the container orchestration and calls run_baremetal inside the container.
Args:
data (Dict[str, Any]): Dictionary containing 'image' key with image data
Returns:
Dict[str, Any]: Dictionary with 'image' and 'label' keys from container execution
"""
return super().run_container(data)
def _cleanup_resources(self) -> None:
"""Clean up CUDA cache"""
torch.cuda.empty_cache()
self._model = None