Source code for cyto.tracking.trackmate_in_mem

#!/usr/bin/env python3
"""
TrackMate In-Memory Processing Example

This module demonstrates how to:
1. Load pre-segmented label data and corresponding images
2. Extract cell data to CSV format from labels
3. Use TrackMate for particle tracking via pyimagej scripting

Requirements:
- pyimagej
- imagej with TrackMate plugin
- numpy
- pandas
- scikit-image
- tifffile (for loading TIFF files)
"""

import numpy as np
import pandas as pd
import imagej
import scyjava
from pathlib import Path
import logging
from typing import Tuple, List, Dict, Any, Optional
from skimage import measure
import tifffile
import tempfile
import os
from tqdm import tqdm
from cyto.utils.label_to_table import label_to_sparse

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


[docs] class TrackMateInMemoryProcessor: """ Handles tracking using pyimagej and TrackMate with pre-segmented data """ def __init__(self, imagej_path: Optional[str] = None, ij=None): """ Initialize the processor with ImageJ/Fiji Args: imagej_path: Path to ImageJ/Fiji installation. If None, downloads automatically. ij: Existing ImageJ instance (optional). """ logger.info("Initializing ImageJ...") if ij is not None: self.ij = ij logger.info("Using provided ImageJ instance.") elif imagej_path: self.ij = imagej.init(imagej_path, mode="headless") logger.info(f"Initialized ImageJ from path: {imagej_path}") else: self.ij = imagej.init('sc.fiji:fiji') logger.info("Initialized ImageJ using Fiji (sc.fiji:fiji).") logger.info(f"ImageJ version: {self.ij.getVersion()}") # Import necessary Java classes self._import_java_classes() def _import_java_classes(self): """Import required Java classes for TrackMate""" try: # TrackMate classes self.Model = scyjava.jimport('fiji.plugin.trackmate.Model') self.Settings = scyjava.jimport('fiji.plugin.trackmate.Settings') self.TrackMate = scyjava.jimport('fiji.plugin.trackmate.TrackMate') self.SelectionModel = scyjava.jimport('fiji.plugin.trackmate.SelectionModel') # Detectors self.ManualDetectorFactory = scyjava.jimport('fiji.plugin.trackmate.detection.ManualDetectorFactory') # Trackers self.SparseLAPTrackerFactory = scyjava.jimport('fiji.plugin.trackmate.tracking.jaqaman.SparseLAPTrackerFactory') # Ensure class is loaded # self.SparseLAPTrackerFactory = scyjava.jimport('fiji.plugin.trackmate.tracking.sparselap.SparseLAPTrackerFactory') # self.LAPUtils = scyjava.jimport('fiji.plugin.trackmate.tracking.LAPUtils') # Analyzers self.SpotAnalyzerProvider = scyjava.jimport('fiji.plugin.trackmate.providers.SpotAnalyzerProvider') self.EdgeAnalyzerProvider = scyjava.jimport('fiji.plugin.trackmate.providers.EdgeAnalyzerProvider') self.TrackAnalyzerProvider = scyjava.jimport('fiji.plugin.trackmate.providers.TrackAnalyzerProvider') # Spot and Track classes self.Spot = scyjava.jimport('fiji.plugin.trackmate.Spot') self.SpotCollection = scyjava.jimport('fiji.plugin.trackmate.SpotCollection') # ImageJ classes self.ImagePlus = scyjava.jimport('ij.ImagePlus') self.Calibration = scyjava.jimport('ij.measure.Calibration') logger.info("Successfully imported TrackMate Java classes") except Exception as e: logger.error(f"Failed to import Java classes: {e}") raise
[docs] def load_tiff_data(self, image_path: Path, label_path: Path) -> Tuple[np.ndarray, np.ndarray]: """ Load image and label data from TIFF files Args: image_path: Path to the original image TIFF file label_path: Path to the segmented label TIFF file Returns: Tuple of (image_array, label_array) """ logger.info(f"Loading image from: {image_path}") logger.info(f"Loading labels from: {label_path}") # Load image data image_data = tifffile.imread(str(image_path)) logger.info(f"Image shape: {image_data.shape}, dtype: {image_data.dtype}") # only take the first channel image_data = image_data[:,0,:,:] # Load label data label_data = tifffile.imread(str(label_path)) logger.info(f"Label shape: {label_data.shape}, dtype: {label_data.dtype}") # Ensure both arrays have compatible shapes if image_data.shape != label_data.shape: logger.warning(f"Shape mismatch: image {image_data.shape} vs labels {label_data.shape}") # Try to align shapes if possible if len(image_data.shape) == len(label_data.shape): min_shape = tuple(min(a, b) for a, b in zip(image_data.shape, label_data.shape)) image_data = image_data[:min_shape[0], :min_shape[1]] if len(min_shape) == 2 else image_data[:min_shape[0], :min_shape[1], :min_shape[2]] label_data = label_data[:min_shape[0], :min_shape[1]] if len(min_shape) == 2 else label_data[:min_shape[0], :min_shape[1], :min_shape[2]] logger.info(f"Aligned to shape: {min_shape}") return image_data, label_data
[docs] def labels_to_spots_csv(self, labels: np.ndarray, image: np.ndarray, time_points: Optional[List[int]] = None) -> pd.DataFrame: """ Convert label image to CSV format suitable for TrackMate using cyto.utils.label_to_table.label_to_sparse Args: labels: Label image from segmentation (XYT format) image: Original intensity image (XYT format) time_points: List of time points (for time series) Returns: DataFrame with spot information in TrackMate format """ logger.info("Converting labels to spots CSV format using label_to_sparse...") # Reshape arrays for label_to_sparse which expects format (H, W, T) if labels.ndim == 2: # Single frame: add time dimension labels_hwt = labels[:, :, np.newaxis] image_hwt = image[:, :, np.newaxis] if image.ndim == 2 else image[:, :, np.newaxis] elif labels.ndim == 3: # Time series: transpose from XYT to HWT (assuming input is XYT) labels_hwt = labels.transpose(1, 2, 0) # XYT -> YXT -> HWT image_hwt = image.transpose(1, 2, 0) if image.ndim == 3 else np.broadcast_to(image[np.newaxis, :, :], (labels.shape[1], labels.shape[2], labels.shape[0])).transpose(1, 2, 0) # Use label_to_sparse to extract comprehensive features features_df = label_to_sparse(labels_hwt, image_hwt, channel_name="", processes=1) if features_df.empty: logger.warning("No spots found in the data") return pd.DataFrame() # Convert to TrackMate spots format spots_data = [] for _, row in features_df.iterrows(): spot_data = { 'ID': int(row['label']), 'FRAME': int(row['frame']), 'POSITION_X': float(row['x']), # physical coordinates 'POSITION_Y': float(row['y']), # physical coordinates 'POSITION_Z': 0.0, # 2D case 'RADIUS': float(row['feret_radius']), 'QUALITY': float(row['mean']), 'MEAN_INTENSITY': float(row['mean']), 'TOTAL_INTENSITY': float(row['mass']), # SimpleITK mass is total intensity 'AREA': float(row['size']), 'PERIMETER': float(row['perimeter']), 'ECCENTRICITY': float(row['elongation']), # Use elongation as eccentricity proxy 'SOLIDITY': float(row['roundness']) # Use roundness as solidity proxy } # Add optional features if available if 'median' in row: spot_data['MEDIAN_INTENSITY'] = float(row['median']) if 'sd' in row: spot_data['STD_INTENSITY'] = float(row['sd']) if 'flatness' in row: spot_data['FLATNESS'] = float(row['flatness']) spots_data.append(spot_data) # Create DataFrame df = pd.DataFrame(spots_data) logger.info(f"Extracted {len(df)} spots across {len(df['FRAME'].unique()) if 'FRAME' in df.columns else 1} frames using label_to_sparse") return df
def _extract_spots_from_frame(self, labels: np.ndarray, image: np.ndarray, frame: int) -> List[Dict[str, Any]]: """ Extract spot information from a single frame """ spots = [] # Remove background (label 0) unique_labels = np.unique(labels) unique_labels = unique_labels[unique_labels != 0] if len(unique_labels) == 0: logger.warning(f"No objects found in frame {frame}") return spots props = measure.regionprops(labels, intensity_image=image) for prop in props: # Calculate spot properties y_center, x_center = prop.centroid # Note: regionprops returns (row, col) area = prop.area mean_intensity = prop.mean_intensity max_intensity = prop.max_intensity # Estimate radius from area (assuming circular objects) radius = np.sqrt(area / np.pi) # Quality score - can be mean intensity, max intensity, or custom metric quality = mean_intensity spot_data = { 'ID': prop.label, 'FRAME': frame, 'POSITION_X': x_center, 'POSITION_Y': y_center, 'POSITION_Z': 0.0, # 2D case 'RADIUS': radius, 'QUALITY': quality, 'MEAN_INTENSITY': mean_intensity, 'MAX_INTENSITY': max_intensity, 'TOTAL_INTENSITY': mean_intensity * area, 'AREA': area, 'PERIMETER': prop.perimeter, 'ECCENTRICITY': prop.eccentricity, 'SOLIDITY': prop.solidity } spots.append(spot_data) return spots
[docs] def create_trackmate_model(self, spots_df: pd.DataFrame, image_shape: Tuple[int, ...], pixel_size: float = 1.0, time_interval: float = 1.0 ) -> Tuple[Any, Any]: """ Create TrackMate model from spots DataFrame Args: spots_df: DataFrame containing spot information image_shape: Shape of the original image (height, width) or (frames, height, width) pixel_size: Physical pixel size in micrometers time_interval: Time interval between frames in minutes Returns: Tuple of (model, settings) """ logger.info("Creating TrackMate model...") # Create model model = self.Model() # Set image properties if len(image_shape) == 2: height, width = image_shape nframes = 1 else: nframes, height, width = image_shape # Create dummy ImagePlus for calibration dummy_image = np.zeros((height, width), dtype=np.uint16) imp = self.ij.py.to_imageplus(dummy_image) # Set calibration cal = self.Calibration() cal.pixelWidth = pixel_size cal.pixelHeight = pixel_size cal.pixelDepth = pixel_size cal.frameInterval = time_interval cal.setUnit("ยตm") imp.setCalibration(cal) # Create settings settings = self.Settings(imp) # settings.setFrom(imp) # Configure detector (Manual detector since we already have spots) detector_factory = self.ManualDetectorFactory() settings.detectorFactory = detector_factory settings.detectorSettings = detector_factory.getDefaultSettings() # Configure tracker tracker_factory = self.SparseLAPTrackerFactory() settings.trackerFactory = tracker_factory # Set tracker settings # tracker_settings = self.LAPUtils.getDefaultLAPSettingsMap() tracker_settings = settings.trackerFactory.getDefaultSettings() tracker_settings.put('LINKING_MAX_DISTANCE', 15.0) # placeholder; overridden by TrackMate.__call__ tracker_settings.put('GAP_CLOSING_MAX_DISTANCE', 15.0) tracker_settings.put('MAX_FRAME_GAP', self.ij.py.to_java(2)) settings.trackerSettings = tracker_settings # Add analyzers # self._configure_analyzers(settings) # Create spots from DataFrame self._add_spots_to_model(model, spots_df) return model, settings
def _configure_analyzers(self, settings: Any): """Configure spot, edge, and track analyzers""" # Add spot analyzers spot_analyzer_provider = self.SpotAnalyzerProvider() spot_analyzers = spot_analyzer_provider.getKeys() for analyzer_key in spot_analyzers: settings.addSpotAnalyzerFactory(spot_analyzer_provider.getFactory(analyzer_key)) # Add edge analyzers edge_analyzer_provider = self.EdgeAnalyzerProvider() edge_analyzers = edge_analyzer_provider.getKeys() for analyzer_key in edge_analyzers: settings.addEdgeAnalyzer(edge_analyzer_provider.getFactory(analyzer_key)) # Add track analyzers track_analyzer_provider = self.TrackAnalyzerProvider() track_analyzers = track_analyzer_provider.getKeys() for analyzer_key in track_analyzers: settings.addTrackAnalyzer(track_analyzer_provider.getFactory(analyzer_key)) def _add_spots_to_model(self, model: Any, spots_df: pd.DataFrame): """ Add spots from DataFrame to TrackMate model """ logger.info("Adding spots to TrackMate model...") spots_collection = self.SpotCollection() for _, row in tqdm(spots_df.iterrows(), total=len(spots_df), desc="Adding spots"): # Create spot spot = self.Spot( float(row['POSITION_X']), float(row['POSITION_Y']), float(row['POSITION_Z']), float(row['RADIUS']), float(row['QUALITY']) ) # Store original label ID as a custom feature to preserve the mapping spot.putFeature('ORIGINAL_LABEL_ID', float(row['ID'])) # Set additional features spot.putFeature('MEAN_INTENSITY', float(row['MEAN_INTENSITY'])) if 'MAX_INTENSITY' in row: spot.putFeature('MAX_INTENSITY', float(row['MAX_INTENSITY'])) spot.putFeature('TOTAL_INTENSITY', float(row['TOTAL_INTENSITY'])) spot.putFeature('AREA', float(row['AREA'])) if 'PERIMETER' in row: spot.putFeature('PERIMETER', float(row['PERIMETER'])) if 'ECCENTRICITY' in row: spot.putFeature('ECCENTRICITY', float(row['ECCENTRICITY'])) if 'SOLIDITY' in row: spot.putFeature('SOLIDITY', float(row['SOLIDITY'])) # Add to collection spots_collection.add(spot, self.ij.py.to_java(int(row['FRAME']))) model.setSpots(spots_collection, False) logger.info(f"Added {spots_collection.getNSpots(False)} spots to model")
[docs] def run_tracking(self, model: Any, settings: Any) -> Dict[str, Any]: """ Run TrackMate tracking algorithm Args: model: TrackMate model settings: TrackMate settings Returns: Dictionary with tracking results """ logger.info("Running TrackMate tracking...") # Create TrackMate instance trackmate = self.TrackMate(model, settings) # Check input ok = trackmate.checkInput() if not ok: error_msg = trackmate.getErrorMessage() logger.error(f"TrackMate input check failed: {error_msg}") return {"success": False, "error": error_msg} # Process detection step (skip since we have manual spots) logger.info("Processing detection step...") ok = trackmate.execDetection() if not ok: error_msg = trackmate.getErrorMessage() logger.error(f"TrackMate detection failed: {error_msg}") return {"success": False, "error": error_msg} # Initial filtering (optional) logger.info("Processing initial filtering...") ok = trackmate.execInitialSpotFiltering() if not ok: error_msg = trackmate.getErrorMessage() logger.error(f"TrackMate initial filtering failed: {error_msg}") return {"success": False, "error": error_msg} # Compute spot features logger.info("Computing spot features...") ok = trackmate.execSpotFiltering(True) if not ok: error_msg = trackmate.getErrorMessage() logger.error(f"TrackMate spot feature computation failed: {error_msg}") return {"success": False, "error": error_msg} # Track linking logger.info("Performing track linking...") ok = trackmate.execTracking() if not ok: error_msg = trackmate.getErrorMessage() logger.error(f"TrackMate tracking failed: {error_msg}") return {"success": False, "error": error_msg} # Compute track features logger.info("Computing track features...") ok = trackmate.execTrackFiltering(True) if not ok: error_msg = trackmate.getErrorMessage() logger.error(f"TrackMate track feature computation failed: {error_msg}") return {"success": False, "error": error_msg} # Extract results tracks = model.getTrackModel().trackIDs(True) n_tracks = len(list(tracks)) logger.info(f"Tracking completed successfully. Generated {n_tracks} tracks") return { "success": True, "model": model, "n_tracks": n_tracks, "n_spots": model.getSpots().getNSpots(True) }
[docs] def export_tracking_results(self, model: Any, output_dir: Path) -> Dict[str, Path]: """ Export tracking results to CSV format compatible with TrackMate CSV importer Args: model: TrackMate model with tracking results output_dir: Directory to save results Returns: Dictionary with paths to exported files """ logger.info("Exporting tracking results in TrackMate CSV importer format...") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) exported_files = {} # Create mapping from spot to track ID spot_to_track = {} for track_id in model.getTrackModel().trackIDs(True): track_spots = model.getTrackModel().trackSpots(track_id) for spot in track_spots: spot_to_track[spot.ID()] = track_id # Export spots with track information for TrackMate CSV importer spots_file = output_dir / "spots.csv" spots_data = [] for spot in model.getSpots().iterable(True): # Get track ID for this spot (if it belongs to a track) track_id = spot_to_track.get(spot.ID(), -1) # -1 for untracked spots spot_data = { 'ID': int(spot.getFeature('ORIGINAL_LABEL_ID')), # idCol - original label IDs 'TRACK_ID': int(track_id), # trackCol - track indices 'POSITION_X': float(spot.getFeature('POSITION_X')), # xCol - X positions 'POSITION_Y': float(spot.getFeature('POSITION_Y')), # yCol - Y positions 'FRAME': int(spot.getFeature('FRAME')), # frameCol - frame numbers } spots_data.append(spot_data) spots_df = pd.DataFrame(spots_data) # Sort by track ID and frame for better organization spots_df = spots_df.sort_values(['TRACK_ID', 'FRAME']) spots_df.to_csv(spots_file, index=False) exported_files['spots'] = spots_file logger.info(f"Exported tracking results in TrackMate CSV importer format to {output_dir}") logger.info(f"CSV columns: ID (col 0), TRACK_ID (col 1), POSITION_X (col 2), POSITION_Y (col 3), FRAME (col 4)") return exported_files
[docs] def run_trackmate_on_data(): """ Main function to run TrackMate tracking on your specific data files """ # Define paths to your data image_path = Path("/Users/jacky/Projects/Cytotoxicity/Cytotoxicity-Pipeline/data/confocal_vesicle_paper/crop.tif") label_path = Path("/Users/jacky/Projects/Cytotoxicity/Cytotoxicity-Pipeline/data/confocal_vesicle_paper/filtered_stack_crop.tif") # Check if files exist if not image_path.exists(): logger.error(f"Image file not found: {image_path}") return {"success": False, "error": f"Image file not found: {image_path}"} if not label_path.exists(): logger.error(f"Label file not found: {label_path}") return {"success": False, "error": f"Label file not found: {label_path}"} try: # Initialize processor logger.info("Initializing TrackMate processor...") processor = TrackMateInMemoryProcessor( imagej_path="/Applications/Fiji.app" # Adjust path to your Fiji installation if needed ) # Load data logger.info("Loading image and label data...") image_data, label_data = processor.load_tiff_data(image_path, label_path) # Convert labels to spots logger.info("Converting labels to spots...") spots_df = processor.labels_to_spots_csv(label_data, image_data) if len(spots_df) == 0: logger.warning("No spots found in the data") return {"success": False, "error": "No spots found in the data"} logger.info(f"Found {len(spots_df)} spots") print("\nSpot statistics:") print(spots_df.describe()) print(f"\nFrames with spots: {len(spots_df['FRAME'].unique())} / {len(spots_df['FRAME'])}") # Create TrackMate model logger.info("Creating TrackMate model...") model, settings = processor.create_trackmate_model( spots_df, image_data.shape, pixel_size=0.1, # Adjust based on your microscopy setup time_interval=1.0 # Adjust based on your time interval ) # Run tracking logger.info("Running TrackMate tracking...") tracking_results = processor.run_tracking(model, settings) if tracking_results["success"]: logger.info(f"Tracking successful!") logger.info(f"Generated {tracking_results['n_tracks']} tracks from {tracking_results['n_spots']} spots") # Export results output_dir = Path("./tracking_results") exported_files = processor.export_tracking_results(model, output_dir) print("\nโœ“ Tracking completed successfully!") print(f"Results exported to: {output_dir.absolute()}") print("\nExported files:") for file_type, file_path in exported_files.items(): print(f" {file_type}: {file_path.name}") return { "success": True, "model": model, "spots_df": spots_df, "exported_files": exported_files, "n_tracks": tracking_results['n_tracks'], "n_spots": tracking_results['n_spots'] } else: logger.error(f"Tracking failed: {tracking_results.get('error', 'Unknown error')}") return {"success": False, "error": tracking_results.get('error')} except Exception as e: logger.error(f"Error during processing: {str(e)}") import traceback traceback.print_exc() return {"success": False, "error": str(e)}
if __name__ == "__main__": # Run the tracking on your data results = run_trackmate_on_data() if results["success"]: print(f"\n๐ŸŽ‰ Successfully processed {results['n_spots']} spots into {results['n_tracks']} tracks!") else: print(f"\nโŒ Processing failed: {results.get('error', 'Unknown error')}")