Source code for cyto.segmentation.base
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import numpy as np
import dask.array as da
from tqdm import tqdm
from cyto.tasks.base import PipelineTask
[docs]
class SegmentationBase(PipelineTask):
def __init__(
self,
name: str,
verbose: bool = True,
execution_config: Optional[Dict[str, Any]] = None,
) -> None:
"""
Base class for segmentation models
Args:
name (str): Name of the segmentation method
verbose (bool): Turn on or off the processing printout
execution_config (Optional[Dict[str, Any]]): Configuration for execution
"""
super().__init__(execution_config)
self.name = name
self.verbose = verbose
def _log_message(self, message: str) -> None:
"""Log message if verbose is enabled"""
if self.verbose:
tqdm.write(message)
def _prepare_image(self, data: Dict[str, Any]) -> np.ndarray:
"""Extract and prepare image from data dict, handling dask arrays"""
image = data["image"]
if isinstance(image, da.Array):
image = image.compute()
return image
def _process_frames(self, image: np.ndarray) -> np.ndarray:
"""Process image frames using the specific segmentation method"""
label = np.zeros_like(image)
for t in tqdm(range(image.shape[2]), disable=not self.verbose):
frame = image[:, :, t]
if isinstance(frame, da.Array):
frame = frame.compute()
# Process individual frame using the specific method
frame_result = self._segment_frame(frame, t)
label[:, :, t] = frame_result
return label.astype(np.uint16)
@abstractmethod
def _segment_frame(self, frame: np.ndarray, frame_idx: int) -> np.ndarray:
"""
Segment a single frame - must be implemented by subclasses
Args:
frame (np.ndarray): Single 2D image frame
frame_idx (int): Frame index
Returns:
np.ndarray: Segmented labels for the frame
"""
pass
@abstractmethod
def _setup_model(self) -> Any:
"""Setup and return the segmentation model - must be implemented by subclasses"""
pass
@abstractmethod
def _cleanup_resources(self) -> None:
"""Cleanup resources after segmentation - must be implemented by subclasses"""
pass
[docs]
def run_container(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Run segmentation in a container using the configured runner.
This method handles the container execution by:
1. Getting the appropriate runner (Docker/Singularity) from execution_config.
2. Passing the task (self) and data to the runner.
3. The runner serializes the task and data, runs them in a container,
and deserializes the result.
4. Inside the container, the container_worker calls run_baremetal().
Args:
data (Dict[str, Any]): Dictionary containing 'image' key with image data
Returns:
Dict[str, Any]: Dictionary with 'image' and 'label' keys from container execution
"""
# Log container execution start
self._log_message(f"Starting {self.name} segmentation in container")
# Get the configured runner (Docker or Singularity)
runner = self._get_runner()
# Execute the task in the container
# The runner will:
# 1. Serialize this task object and the input data
# 2. Run a container with the specified image
# 3. Execute container_worker.py inside the container
# 4. container_worker calls self.run_baremetal(data)
# 5. Deserialize and return the results
results = runner.run(self, data)
# Log completion
self._log_message(f"{self.name} container segmentation complete")
return results