Source code for cyto.segmentation.stardist
from stardist.models import StarDist2D
from typing import Any
from tqdm import tqdm
import dask.array as da
import tensorflow as tf
import numpy as np
from csbdeep.utils import normalize
import gc
from .base import SegmentationBase
[docs]
class StarDist(SegmentationBase):
def __init__(self, model_name="2D_versatile_fluo", prob_thresh=0.479071, nms_thresh=0.3, verbose=True) -> None:
"""
Perform StarDist segmentation model
Args:
model_name (str): Registered models for StarDist2D
prob_thresh (float): Probability threshold between 0 and 1
nms_thresh (float): Non-maximum suppression threshold between 0 and 1
verbose (bool): Turn on or off the processing printout
"""
super().__init__("StarDist", verbose)
self.model_name = model_name
self.prob_thresh = prob_thresh
self.nms_thresh = nms_thresh
self._model = None
def _setup_model(self) -> Any:
"""Setup StarDist model and configure GPU"""
# Configure GPU memory growth
physical_devices = tf.config.list_physical_devices('GPU')
try:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
self._model = StarDist2D.from_pretrained(self.model_name)
self._log_message("StarDist Segmentation 2D: {}".format(self.model_name))
return self._model
def _segment_frame(self, frame: np.ndarray, frame_idx: int) -> np.ndarray:
"""Segment a single frame using StarDist"""
# Normalize the image
img = normalize(frame, pmin=0, pmax=100, axis=(0,1)) # convert from range [0,max] to [0,1]
# Predict instances
label_, _ = self._model.predict_instances(
img,
prob_thresh=self.prob_thresh,
nms_thresh=self.nms_thresh
)
return label_
def _cleanup_resources(self) -> None:
"""Clean up TensorFlow session"""
tf.keras.backend.clear_session()
self._model = None