Source code for cyto.segmentation.base

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import numpy as np
import dask.array as da
from tqdm import tqdm

from cyto.tasks.base import PipelineTask


[docs] class SegmentationBase(PipelineTask): def __init__( self, name: str, verbose: bool = True, execution_config: Optional[Dict[str, Any]] = None, ) -> None: """ Base class for segmentation models Args: name (str): Name of the segmentation method verbose (bool): Turn on or off the processing printout execution_config (Optional[Dict[str, Any]]): Configuration for execution """ super().__init__(execution_config) self.name = name self.verbose = verbose def _log_message(self, message: str) -> None: """Log message if verbose is enabled""" if self.verbose: tqdm.write(message) def _prepare_image(self, data: Dict[str, Any]) -> np.ndarray: """Extract and prepare image from data dict, handling dask arrays""" image = data["image"] if isinstance(image, da.Array): image = image.compute() return image def _process_frames(self, image: np.ndarray) -> np.ndarray: """Process image frames using the specific segmentation method""" label = np.zeros_like(image) for t in tqdm(range(image.shape[2]), disable=not self.verbose): frame = image[:, :, t] if isinstance(frame, da.Array): frame = frame.compute() # Process individual frame using the specific method frame_result = self._segment_frame(frame, t) label[:, :, t] = frame_result return label.astype(np.uint16) @abstractmethod def _segment_frame(self, frame: np.ndarray, frame_idx: int) -> np.ndarray: """ Segment a single frame - must be implemented by subclasses Args: frame (np.ndarray): Single 2D image frame frame_idx (int): Frame index Returns: np.ndarray: Segmented labels for the frame """ pass @abstractmethod def _setup_model(self) -> Any: """Setup and return the segmentation model - must be implemented by subclasses""" pass @abstractmethod def _cleanup_resources(self) -> None: """Cleanup resources after segmentation - must be implemented by subclasses""" pass
[docs] def run_container(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Run segmentation in a container using the configured runner. This method handles the container execution by: 1. Getting the appropriate runner (Docker/Singularity) from execution_config. 2. Passing the task (self) and data to the runner. 3. The runner serializes the task and data, runs them in a container, and deserializes the result. 4. Inside the container, the container_worker calls run_baremetal(). Args: data (Dict[str, Any]): Dictionary containing 'image' key with image data Returns: Dict[str, Any]: Dictionary with 'image' and 'label' keys from container execution """ # Log container execution start self._log_message(f"Starting {self.name} segmentation in container") # Get the configured runner (Docker or Singularity) runner = self._get_runner() # Execute the task in the container # The runner will: # 1. Serialize this task object and the input data # 2. Run a container with the specified image # 3. Execute container_worker.py inside the container # 4. container_worker calls self.run_baremetal(data) # 5. Deserialize and return the results results = runner.run(self, data) # Log completion self._log_message(f"{self.name} container segmentation complete") return results
[docs] def run_baremetal(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Main segmentation pipeline for baremetal execution. Args: data (Dict): Dictionary containing 'image' key with image data Returns: Dict: Dictionary with 'image' and 'label' keys """ # Prepare image image = self._prepare_image(data) # Setup model model = self._setup_model() # Log start self._log_message(f"{self.name} Segmentation 2D started") # Process frames label = self._process_frames(image) # Cleanup self._cleanup_resources() # Log completion self._log_message(f"{self.name} segmentation complete") return {"image": image, "label": label}