Source code for cyto.segmentation.cellpose

# %%
import os
import numpy as np
from cellpose import models, io, core

from typing import Any, Dict
from tqdm import tqdm
import dask.array as da
from bioio import BioImage
try:
    from bioio.writers import OmeTiffWriter
except ImportError:
    OmeTiffWriter = None  # removed in bioio ≥2; use bioio.Writer instead
import torch

from .base import SegmentationBase

[docs] class Cellpose(SegmentationBase): def __init__(self, model_type='cyto', cellprob_thresh=-3, model_matching_thresh=10.0, gpu=True, channels =[0,0], batch_size = 16, diameter = 16.18, verbose=True, execution_config=None) -> None: """ Perform CellPose segmentation 2D Args: model_type (str): Registered models for CellPose cellprob_thresh (float): Probability threshold between -8 and 8 model_matching_thresh (float): Non-maximum suppression threshold between 0 and 30 gpu (bool): use GPU channels (list): [cytoplasm, nucleus] if NUCLEUS channel does not exist, set the second channel to 0, channels = [0,0] # IF YOU HAVE GRAYSCALE batch_size (int): used for GPU diameter (float): cell diameter verbose (bool): Turn on or off the processing printout execution_config (dict, optional): Configuration for task execution (baremetal vs container) """ super().__init__("CellPose", verbose, execution_config) self.model_type = model_type self.cellprob_thresh = cellprob_thresh self.model_matching_thresh = model_matching_thresh self.gpu = gpu self.channels = channels self.batch_size = batch_size self.diameter = diameter self.flow_threshold = 0 # (31.0 - self.model_matching_thresh)/10.0 self._model = None def _setup_model(self) -> Any: """Setup CellPose model""" use_GPU = core.use_gpu() self._model = models.CellposeModel(gpu=self.gpu, model_type=self.model_type) self._log_message("CellPose Segmentation 2D: {}\nGPU activated? {}".format(self.model_type, use_GPU)) return self._model def _process_frames(self, image: np.ndarray) -> np.ndarray: """Override to use CellPose batch processing""" images_stack = [] for t in range(image.shape[2]): images_stack.append(image[:,:,t]) self._log_message("Running Cellpose segmentation...") # Handle different return values based on Cellpose version # Newer versions (3.0+) return (masks, flows, styles) # Older versions return (masks, flows, styles, diams) eval_result = self._model.eval( images_stack, batch_size=self.batch_size, diameter=self.diameter, # in pixel cellprob_threshold=self.cellprob_thresh, flow_threshold=self.flow_threshold, channels=self.channels, stitch_threshold=0.0, do_3D=False, progress=True) if len(eval_result) == 4: # Older Cellpose version masks, flows, styles, diams = eval_result elif len(eval_result) == 3: # Newer Cellpose version (3.0+) masks, flows, styles = eval_result diams = None # Not returned in newer versions else: raise ValueError(f"Unexpected number of return values from Cellpose eval(): {len(eval_result)}") masks = np.asarray(masks) label = np.transpose(masks, axes=[1,2,0]) return label.astype(np.uint16) def _segment_frame(self, frame: np.ndarray, frame_idx: int) -> np.ndarray: """Not used in CellPose as it processes all frames at once""" raise NotImplementedError("CellPose uses batch processing, this method should not be called")
[docs] def run_container(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Run CellPose segmentation in a container. This method uses the base class container execution logic which handles the container orchestration and calls run_baremetal inside the container. Args: data (Dict[str, Any]): Dictionary containing 'image' key with image data Returns: Dict[str, Any]: Dictionary with 'image' and 'label' keys from container execution """ return super().run_container(data)
def _cleanup_resources(self) -> None: """Clean up CUDA cache""" torch.cuda.empty_cache() self._model = None