Source code for cyto.segmentation.stardist

from stardist.models import StarDist2D
from typing import Any
from tqdm import tqdm
import dask.array as da
import tensorflow as tf
import numpy as np
from csbdeep.utils import normalize
import gc

from .base import SegmentationBase

[docs] class StarDist(SegmentationBase): def __init__(self, model_name="2D_versatile_fluo", prob_thresh=0.479071, nms_thresh=0.3, verbose=True) -> None: """ Perform StarDist segmentation model Args: model_name (str): Registered models for StarDist2D prob_thresh (float): Probability threshold between 0 and 1 nms_thresh (float): Non-maximum suppression threshold between 0 and 1 verbose (bool): Turn on or off the processing printout """ super().__init__("StarDist", verbose) self.model_name = model_name self.prob_thresh = prob_thresh self.nms_thresh = nms_thresh self._model = None def _setup_model(self) -> Any: """Setup StarDist model and configure GPU""" # Configure GPU memory growth physical_devices = tf.config.list_physical_devices('GPU') try: tf.config.experimental.set_memory_growth(physical_devices[0], True) except: # Invalid device or cannot modify virtual devices once initialized. pass self._model = StarDist2D.from_pretrained(self.model_name) self._log_message("StarDist Segmentation 2D: {}".format(self.model_name)) return self._model def _segment_frame(self, frame: np.ndarray, frame_idx: int) -> np.ndarray: """Segment a single frame using StarDist""" # Normalize the image img = normalize(frame, pmin=0, pmax=100, axis=(0,1)) # convert from range [0,max] to [0,1] # Predict instances label_, _ = self._model.predict_instances( img, prob_thresh=self.prob_thresh, nms_thresh=self.nms_thresh ) return label_ def _cleanup_resources(self) -> None: """Clean up TensorFlow session""" tf.keras.backend.clear_session() self._model = None