Source code for cyto.postprocessing.graph

from typing import Any
import pandas as pd
from tqdm import tqdm
import os
import pyclesperanto_prototype as cle
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from skimage import exposure
import networkx as nx
import SimpleITK as sitk
import time
from ..utils.utils import *

[docs] class CellTriangulation(object): def __init__(self, base_image=True, verbose=True) -> None: """ Initialize the CellTriangulation class. This class performs triangulation of cell centroids by searching for the minimum distance required to generate a single connected network for each frame. Args: base_image (bool): If True, plots the cell networks over the base image. verbose (bool): If True, enables detailed processing output. """ self.base_image = base_image self.verbose = verbose def __call__(self, data) -> Any: image = data["image"] feature = data["feature"] START_FRAME = feature.frame.min() END_FRAME = feature.frame.max() graph = [] network_image = None n_frames = END_FRAME - START_FRAME + 1 for frame_idx, CUR_FRAME in enumerate(tqdm(range(START_FRAME,END_FRAME+1))): print("Frame: ", CUR_FRAME) frame = image[:,:,CUR_FRAME] centroids = feature[feature.frame==CUR_FRAME][['i','j']].T.to_numpy() # generate distance matrix distance_matrix = cle.generate_distance_matrix(centroids, centroids) # grid search for the smallest distance to generate one connected network # Release GPU buffers each iteration to prevent accumulation over up to 1000 steps. pbar = tqdm(range(50,10001,10)) for i in pbar: pbar.set_description(str(i)) connection_matrix_se = cle.smaller_or_equal_constant(distance_matrix, constant=i) connection_matrix_lq = cle.greater_or_equal_constant(distance_matrix, constant=1) connection_matrix = cle.multiply_images(connection_matrix_se, connection_matrix_lq) mesh = cle.create_like(frame) cle.touch_matrix_to_mesh(centroids, connection_matrix, mesh) networkx_graph = cle.to_networkx(connection_matrix, centroids) connected = nx.is_connected(networkx_graph) # Release iteration buffers before deciding to break or continue connection_matrix_se.data.release() connection_matrix_lq.data.release() mesh.data.release() del connection_matrix_se, connection_matrix_lq, mesh if connected: print("Cell network is connected with minimum distance {}".format(i)) graph.append(networkx_graph) break connection_matrix.data.release() del connection_matrix mesh_device = cle.create_like(frame) cle.touch_matrix_to_mesh(centroids, connection_matrix, mesh_device) mesh_host = cle.pull(mesh_device) # memory clean up distance_matrix.data.release() connection_matrix.data.release() mesh_device.data.release() del distance_matrix, connection_matrix, mesh_device # Pre-allocate on first frame to avoid O(T²) concatenation overhead if network_image is None: network_image = np.empty((*mesh_host.shape, n_frames), dtype=mesh_host.dtype) network_image[..., frame_idx] = mesh_host return {"image": network_image, "network": graph}
[docs] def label_centroids_to_pointlist_sitk(label, cell_type=""): labels_centroid = { "label": [], "x": [], "y": [], "bbox_xstart": [], "bbox_ystart": [], "bbox_xsize": [], "bbox_ysize": [], "cell_type": [] } labelStat = sitk.LabelShapeStatisticsImageFilter() labelStat.Execute(label) for l in tqdm(labelStat.GetLabels()): labels_centroid["label"].append(l) labels_centroid["x"].append(label.TransformPhysicalPointToContinuousIndex(labelStat.GetCentroid(l))[0]) labels_centroid["y"].append(label.TransformPhysicalPointToContinuousIndex(labelStat.GetCentroid(l))[1]) bbox_start = label.TransformPhysicalPointToContinuousIndex((labelStat.GetBoundingBox(l)[0],labelStat.GetBoundingBox(l)[1])) bbox_end = label.TransformPhysicalPointToContinuousIndex(( labelStat.GetBoundingBox(l)[0] + labelStat.GetBoundingBox(l)[2], labelStat.GetBoundingBox(l)[1] + labelStat.GetBoundingBox(l)[3], )) labels_centroid["bbox_xstart"].append(bbox_start[0]) labels_centroid["bbox_ystart"].append(bbox_start[1]) labels_centroid["bbox_xsize"].append(bbox_end[0] - bbox_start[0]) labels_centroid["bbox_ysize"].append(bbox_end[1] - bbox_start[1]) labels_centroid["cell_type"].append(cell_type) return pd.DataFrame.from_dict(labels_centroid)
[docs] class CrossCellContactMeasures(object): def __init__(self, verbose=True) -> None: """ Perform cross cell type contact measurements using GPU-accelerated OpenCL via pyclesperanto. Parallelism across time frames is handled at the SLURM array job level — each job receives a disjoint frame range. Args: verbose (bool): Turn on or off the processing printout """ self.verbose = verbose
[docs] def run_single_frame(self,label_0, label_1, centroids_0,centroids_1, features_0, features_1,frame): """ Single frame process run Args: label_0 (arr): Numpy array label of first cell type label_1 (arr): Numpy array label of second cell type centroids_0 (arr): Numpy array of centroid coordinates corresponding to label_0, in pixel space ij. centroids_1 (arr): Numpy array of centroid coordinates corresponding to label_1, in pixel space ij. features_0 (Dataframe): Dataframe of spares cell info of first cell type features_1 (Dataframe): Dataframe of spares cell info of second cell type frame (int): Frame number to async parallel processing positioning. """ if self.verbose: tqdm.write("frame {} thread started".format(frame)) start_time_0 = time.time() # get number of centroids in each cell type c_count = [centroids_0.shape[1],centroids_1.shape[1]] if self.verbose: tqdm.write("Distance matrix scale: {} x {}".format(c_count[0], c_count[1])) # generate distance matrix on gpu start_time = time.time() if self.verbose: print("Generating distance matrix: {}".format(frame)) distance_matrix_device = cle.generate_distance_matrix(centroids_0, centroids_1) if self.verbose: check_gpu_memory() if self.verbose: end_time = time.time() # Calculate elapsed time elapsed_time = end_time - start_time # Print the elapsed time tqdm.write("@frame {}: distance elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time)) # relabel the input start_time = time.time() relabelFilter = sitk.RelabelComponentImageFilter() label_sitk_0 = relabelFilter.Execute(sitk.GetImageFromArray(label_0)) label_0 = sitk.GetArrayFromImage(label_sitk_0) label_sitk_1 = relabelFilter.Execute(sitk.GetImageFromArray(label_1)) label_1 = sitk.GetArrayFromImage(label_sitk_1) if self.verbose: end_time = time.time() # Calculate elapsed time elapsed_time = end_time - start_time # Print the elapsed time tqdm.write("@frame {}: relabel elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time)) # relabel the input start_time = time.time() # cell shape measurement by SITK statFilter_0 = sitk.LabelShapeStatisticsImageFilter() statFilter_0.Execute(label_sitk_0) statFilter_1 = sitk.LabelShapeStatisticsImageFilter() statFilter_1.Execute(label_sitk_1) if self.verbose: end_time = time.time() # Calculate elapsed time elapsed_time = end_time - start_time # Print the elapsed time tqdm.write("@frame {}: labelShapeStat elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time)) start_time = time.time() if self.verbose: print("Generating binary overlap matrix: {}".format(frame)) overlap_matrix_device = cle.generate_binary_overlap_matrix(label_0, label_1) if self.verbose: check_gpu_memory() if self.verbose: print("Generating masked distance matrix: {}".format(frame)) masked_distance_matrix = cle.multiply_images(overlap_matrix_device,distance_matrix_device) if self.verbose: check_gpu_memory() pointlist = np.concatenate([centroids_0,centroids_1],axis=1) masked_host = cle.pull(masked_distance_matrix)[1:, 1:] # pull once distance_matrix_pivot = np.zeros((c_count[0]+c_count[1]+1,c_count[0]+c_count[1]+1)) distance_matrix_pivot[(centroids_0.shape[1]+1):,1:(centroids_0.shape[1]+1)] = masked_host distance_matrix_pivot[1:(centroids_0.shape[1]+1),(centroids_0.shape[1]+1):] = masked_host.T distance_matrix_pivot = cle.push(distance_matrix_pivot) distance_mesh_device = cle.create_labels_like(label_0) cle.touch_matrix_to_mesh(pointlist, distance_matrix_pivot,distance_mesh_device) if self.verbose: end_time = time.time() # Calculate elapsed time elapsed_time = end_time - start_time # Print the elapsed time tqdm.write("@frame {}: contact analysis elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time)) start_time = time.time() networkx_graph_two_cell_types_overlap = cle.to_networkx(distance_matrix_pivot, pointlist) graph = networkx_graph_two_cell_types_overlap # networkx graph if self.verbose: end_time = time.time() # Calculate elapsed time elapsed_time = end_time - start_time # Print the elapsed time tqdm.write("@frame {}: network export elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time)) start_time = time.time() # pulling data from device to host if self.verbose: print("Cleaning up GPU: {}".format(frame)) distance_matrix_host = cle.pull(distance_matrix_device) overlap_matrix_host = cle.pull(overlap_matrix_device) distance_mesh_host = cle.pull(distance_mesh_device) # release GPU memory explicitly distance_matrix_device.data.release() overlap_matrix_device.data.release() masked_distance_matrix.data.release() distance_mesh_device.data.release() del distance_matrix_device, overlap_matrix_device, masked_distance_matrix, distance_mesh_device if self.verbose: check_gpu_memory() if self.verbose: end_time = time.time() # Calculate elapsed time elapsed_time = end_time - start_time # Print the elapsed time tqdm.write("@frame {}: device to host elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time)) end_time = time.time() # Calculate elapsed time elapsed_time = end_time - start_time_0 # Print the elapsed time tqdm.write("@frame {}: threaded loop elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time)) # use the masked_distance_matrix to combine with features table f_0 = features_0[features_0.frame==frame] contact = [] contact_label = [] closest_cell_dist = [] for i in range(1,len(f_0.index)+1): overlap = overlap_matrix_host[1:,i] dist = distance_matrix_host[1:,i] cell_label_offset = len(features_1[features_1.frame<frame].index) contact.append(True) if np.sum(overlap)>0 else contact.append(False) contact_label.append(np.where(overlap == 1)[0]+cell_label_offset+1) # cell label starts from 1 so need to offset extra 1 closest_cell_dist.append(np.min(dist)) return { "graph":graph, "network_image": distance_mesh_host, "frame": frame, "contact": contact, "contact_label": contact_label, "closest_cell_dist": closest_cell_dist, }
def __call__(self, data) -> Any: labels = data["label"] features = data["feature"] assert len(features) == 2, "Input features must be 2" START_FRAME = features[0].frame.min() END_FRAME = features[0].frame.max() features_out = features[0].copy() graph = [] network_image = None contact = [] contact_label = [] closest_cell_dist = [] # Precompute cumulative cell-type-1 counts per frame to avoid per-frame DataFrame scans frames_sorted = sorted(features[1].frame.unique()) _f1_counts = features[1].groupby("frame").size() _cumulative_offset = {f: int(_f1_counts[_f1_counts.index < f].sum()) for f in features[0].frame.unique()} pbar = tqdm(range(START_FRAME, END_FRAME + 1), desc="Cross cell contact measurements") for frame_idx, CUR_FRAME in enumerate(pbar): c = [] c_count = [] for i, label in enumerate(labels): centroids = features[i][features[i].frame == CUR_FRAME][["i", "j"]].to_numpy().T c.append(centroids) c_count.append(centroids.shape[1]) # GPU: distance matrix between all centroid pairs (OpenCL via pyclesperanto) distance_matrix_device = cle.generate_distance_matrix(c[0], c[1]) # Relabel only if label IDs are non-consecutive (pyclesperanto indexes by label value) # Cellpose always produces consecutive 1..N labels so this is typically a no-op. for li in range(2): lslice = labels[li][:, :, CUR_FRAME] if hasattr(lslice, "compute"): lslice = lslice.compute() lslice = np.asarray(lslice) present = np.unique(lslice) n_cells = len(present) - 1 # exclude background 0 if n_cells > 0 and present[-1] != n_cells: relabelFilter = sitk.RelabelComponentImageFilter() relabeled = relabelFilter.Execute(sitk.GetImageFromArray(lslice)) labels[li][:, :, CUR_FRAME] = sitk.GetArrayFromImage(relabeled) # GPU: binary overlap matrix (pixel-level label contact detection) overlap_matrix_device = cle.generate_binary_overlap_matrix(labels[0][:, :, CUR_FRAME], labels[1][:, :, CUR_FRAME]) # GPU: mask distances to only contacting cell pairs masked_distance_matrix = cle.multiply_images(overlap_matrix_device, distance_matrix_device) pointlist = np.concatenate(c, axis=1) masked_host = cle.pull(masked_distance_matrix)[1:, 1:] # pull once distance_matrix_pivot = np.zeros((c_count[0] + c_count[1] + 1, c_count[0] + c_count[1] + 1)) distance_matrix_pivot[(c[0].shape[1] + 1):, 1:(c[0].shape[1] + 1)] = masked_host distance_matrix_pivot[1:(c[0].shape[1] + 1), (c[0].shape[1] + 1):] = masked_host.T distance_matrix_pivot = cle.push(distance_matrix_pivot) distance_mesh_device = cle.create_labels_like(labels[0][:, :, CUR_FRAME]) cle.touch_matrix_to_mesh(pointlist, distance_matrix_pivot, distance_mesh_device) networkx_graph_two_cell_types_overlap = cle.to_networkx(distance_matrix_pivot, pointlist) # note: graph index may have offset if START_FRAME != 0 graph.append(networkx_graph_two_cell_types_overlap) # pull results to host distance_matrix_host = cle.pull(distance_matrix_device) overlap_matrix_host = cle.pull(overlap_matrix_device) distance_mesh_host = cle.pull(distance_mesh_device) # release GPU memory explicitly distance_matrix_device.data.release() overlap_matrix_device.data.release() masked_distance_matrix.data.release() distance_mesh_device.data.release() del distance_matrix_device, overlap_matrix_device, masked_distance_matrix, distance_mesh_device # Pre-allocate output image on first frame to avoid O(T²) concatenation if network_image is None: n_frames = END_FRAME - START_FRAME + 1 network_image = np.empty((*distance_mesh_host.shape, n_frames), dtype=distance_mesh_host.dtype) network_image[..., frame_idx] = distance_mesh_host # Vectorized per-cell contact extraction (replaces Python loop over ~4000 cells) f_0 = features[0][features[0].frame == CUR_FRAME] n0 = len(f_0.index) cell_label_offset = _cumulative_offset.get(CUR_FRAME, 0) overlap_sub = overlap_matrix_host[1:, 1:n0 + 1] # shape (n1, n0) dist_sub = distance_matrix_host[1:, 1:n0 + 1] # shape (n1, n0) contact.extend((overlap_sub.sum(axis=0) > 0).tolist()) closest_cell_dist.extend(dist_sub.min(axis=0).tolist()) for col in range(n0): contact_label.append(np.where(overlap_sub[:, col] == 1)[0] + cell_label_offset + 1) features_out["contact"] = contact features_out["contacting cell labels"] = contact_label features_out["closest cell dist"] = closest_cell_dist return {"image": network_image, "feature": features_out, "network": graph}