Source code for cyto.postprocessing.graph
from typing import Any
import pandas as pd
from tqdm import tqdm
import os
import pyclesperanto_prototype as cle
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from skimage import exposure
import networkx as nx
import SimpleITK as sitk
import time
from ..utils.utils import *
[docs]
class CellTriangulation(object):
def __init__(self, base_image=True, verbose=True) -> None:
"""
Initialize the CellTriangulation class.
This class performs triangulation of cell centroids by searching for the minimum distance required to generate a single connected network for each frame.
Args:
base_image (bool): If True, plots the cell networks over the base image.
verbose (bool): If True, enables detailed processing output.
"""
self.base_image = base_image
self.verbose = verbose
def __call__(self, data) -> Any:
image = data["image"]
feature = data["feature"]
START_FRAME = feature.frame.min()
END_FRAME = feature.frame.max()
graph = []
network_image = None
n_frames = END_FRAME - START_FRAME + 1
for frame_idx, CUR_FRAME in enumerate(tqdm(range(START_FRAME,END_FRAME+1))):
print("Frame: ", CUR_FRAME)
frame = image[:,:,CUR_FRAME]
centroids = feature[feature.frame==CUR_FRAME][['i','j']].T.to_numpy()
# generate distance matrix
distance_matrix = cle.generate_distance_matrix(centroids, centroids)
# grid search for the smallest distance to generate one connected network
# Release GPU buffers each iteration to prevent accumulation over up to 1000 steps.
pbar = tqdm(range(50,10001,10))
for i in pbar:
pbar.set_description(str(i))
connection_matrix_se = cle.smaller_or_equal_constant(distance_matrix, constant=i)
connection_matrix_lq = cle.greater_or_equal_constant(distance_matrix, constant=1)
connection_matrix = cle.multiply_images(connection_matrix_se, connection_matrix_lq)
mesh = cle.create_like(frame)
cle.touch_matrix_to_mesh(centroids, connection_matrix, mesh)
networkx_graph = cle.to_networkx(connection_matrix, centroids)
connected = nx.is_connected(networkx_graph)
# Release iteration buffers before deciding to break or continue
connection_matrix_se.data.release()
connection_matrix_lq.data.release()
mesh.data.release()
del connection_matrix_se, connection_matrix_lq, mesh
if connected:
print("Cell network is connected with minimum distance {}".format(i))
graph.append(networkx_graph)
break
connection_matrix.data.release()
del connection_matrix
mesh_device = cle.create_like(frame)
cle.touch_matrix_to_mesh(centroids, connection_matrix, mesh_device)
mesh_host = cle.pull(mesh_device)
# memory clean up
distance_matrix.data.release()
connection_matrix.data.release()
mesh_device.data.release()
del distance_matrix, connection_matrix, mesh_device
# Pre-allocate on first frame to avoid O(T²) concatenation overhead
if network_image is None:
network_image = np.empty((*mesh_host.shape, n_frames), dtype=mesh_host.dtype)
network_image[..., frame_idx] = mesh_host
return {"image": network_image, "network": graph}
[docs]
def label_centroids_to_pointlist_sitk(label, cell_type=""):
labels_centroid = {
"label": [],
"x": [],
"y": [],
"bbox_xstart": [],
"bbox_ystart": [],
"bbox_xsize": [],
"bbox_ysize": [],
"cell_type": []
}
labelStat = sitk.LabelShapeStatisticsImageFilter()
labelStat.Execute(label)
for l in tqdm(labelStat.GetLabels()):
labels_centroid["label"].append(l)
labels_centroid["x"].append(label.TransformPhysicalPointToContinuousIndex(labelStat.GetCentroid(l))[0])
labels_centroid["y"].append(label.TransformPhysicalPointToContinuousIndex(labelStat.GetCentroid(l))[1])
bbox_start = label.TransformPhysicalPointToContinuousIndex((labelStat.GetBoundingBox(l)[0],labelStat.GetBoundingBox(l)[1]))
bbox_end = label.TransformPhysicalPointToContinuousIndex((
labelStat.GetBoundingBox(l)[0] + labelStat.GetBoundingBox(l)[2],
labelStat.GetBoundingBox(l)[1] + labelStat.GetBoundingBox(l)[3],
))
labels_centroid["bbox_xstart"].append(bbox_start[0])
labels_centroid["bbox_ystart"].append(bbox_start[1])
labels_centroid["bbox_xsize"].append(bbox_end[0] - bbox_start[0])
labels_centroid["bbox_ysize"].append(bbox_end[1] - bbox_start[1])
labels_centroid["cell_type"].append(cell_type)
return pd.DataFrame.from_dict(labels_centroid)
[docs]
class CrossCellContactMeasures(object):
def __init__(self, verbose=True) -> None:
"""
Perform cross cell type contact measurements using GPU-accelerated OpenCL
via pyclesperanto. Parallelism across time frames is handled at the SLURM
array job level — each job receives a disjoint frame range.
Args:
verbose (bool): Turn on or off the processing printout
"""
self.verbose = verbose
[docs]
def run_single_frame(self,label_0, label_1, centroids_0,centroids_1, features_0, features_1,frame):
"""
Single frame process run
Args:
label_0 (arr): Numpy array label of first cell type
label_1 (arr): Numpy array label of second cell type
centroids_0 (arr): Numpy array of centroid coordinates corresponding to label_0, in pixel space ij.
centroids_1 (arr): Numpy array of centroid coordinates corresponding to label_1, in pixel space ij.
features_0 (Dataframe): Dataframe of spares cell info of first cell type
features_1 (Dataframe): Dataframe of spares cell info of second cell type
frame (int): Frame number to async parallel processing positioning.
"""
if self.verbose:
tqdm.write("frame {} thread started".format(frame))
start_time_0 = time.time()
# get number of centroids in each cell type
c_count = [centroids_0.shape[1],centroids_1.shape[1]]
if self.verbose:
tqdm.write("Distance matrix scale: {} x {}".format(c_count[0], c_count[1]))
# generate distance matrix on gpu
start_time = time.time()
if self.verbose:
print("Generating distance matrix: {}".format(frame))
distance_matrix_device = cle.generate_distance_matrix(centroids_0, centroids_1)
if self.verbose:
check_gpu_memory()
if self.verbose:
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Print the elapsed time
tqdm.write("@frame {}: distance elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time))
# relabel the input
start_time = time.time()
relabelFilter = sitk.RelabelComponentImageFilter()
label_sitk_0 = relabelFilter.Execute(sitk.GetImageFromArray(label_0))
label_0 = sitk.GetArrayFromImage(label_sitk_0)
label_sitk_1 = relabelFilter.Execute(sitk.GetImageFromArray(label_1))
label_1 = sitk.GetArrayFromImage(label_sitk_1)
if self.verbose:
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Print the elapsed time
tqdm.write("@frame {}: relabel elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time))
# relabel the input
start_time = time.time()
# cell shape measurement by SITK
statFilter_0 = sitk.LabelShapeStatisticsImageFilter()
statFilter_0.Execute(label_sitk_0)
statFilter_1 = sitk.LabelShapeStatisticsImageFilter()
statFilter_1.Execute(label_sitk_1)
if self.verbose:
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Print the elapsed time
tqdm.write("@frame {}: labelShapeStat elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time))
start_time = time.time()
if self.verbose:
print("Generating binary overlap matrix: {}".format(frame))
overlap_matrix_device = cle.generate_binary_overlap_matrix(label_0, label_1)
if self.verbose:
check_gpu_memory()
if self.verbose:
print("Generating masked distance matrix: {}".format(frame))
masked_distance_matrix = cle.multiply_images(overlap_matrix_device,distance_matrix_device)
if self.verbose:
check_gpu_memory()
pointlist = np.concatenate([centroids_0,centroids_1],axis=1)
masked_host = cle.pull(masked_distance_matrix)[1:, 1:] # pull once
distance_matrix_pivot = np.zeros((c_count[0]+c_count[1]+1,c_count[0]+c_count[1]+1))
distance_matrix_pivot[(centroids_0.shape[1]+1):,1:(centroids_0.shape[1]+1)] = masked_host
distance_matrix_pivot[1:(centroids_0.shape[1]+1),(centroids_0.shape[1]+1):] = masked_host.T
distance_matrix_pivot = cle.push(distance_matrix_pivot)
distance_mesh_device = cle.create_labels_like(label_0)
cle.touch_matrix_to_mesh(pointlist, distance_matrix_pivot,distance_mesh_device)
if self.verbose:
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Print the elapsed time
tqdm.write("@frame {}: contact analysis elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time))
start_time = time.time()
networkx_graph_two_cell_types_overlap = cle.to_networkx(distance_matrix_pivot, pointlist)
graph = networkx_graph_two_cell_types_overlap # networkx graph
if self.verbose:
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Print the elapsed time
tqdm.write("@frame {}: network export elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time))
start_time = time.time()
# pulling data from device to host
if self.verbose:
print("Cleaning up GPU: {}".format(frame))
distance_matrix_host = cle.pull(distance_matrix_device)
overlap_matrix_host = cle.pull(overlap_matrix_device)
distance_mesh_host = cle.pull(distance_mesh_device)
# release GPU memory explicitly
distance_matrix_device.data.release()
overlap_matrix_device.data.release()
masked_distance_matrix.data.release()
distance_mesh_device.data.release()
del distance_matrix_device, overlap_matrix_device, masked_distance_matrix, distance_mesh_device
if self.verbose:
check_gpu_memory()
if self.verbose:
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time
# Print the elapsed time
tqdm.write("@frame {}: device to host elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time))
end_time = time.time()
# Calculate elapsed time
elapsed_time = end_time - start_time_0
# Print the elapsed time
tqdm.write("@frame {}: threaded loop elapsed time for thread count = {}: {:.4f}s".format(frame,str(self.threads),elapsed_time))
# use the masked_distance_matrix to combine with features table
f_0 = features_0[features_0.frame==frame]
contact = []
contact_label = []
closest_cell_dist = []
for i in range(1,len(f_0.index)+1):
overlap = overlap_matrix_host[1:,i]
dist = distance_matrix_host[1:,i]
cell_label_offset = len(features_1[features_1.frame<frame].index)
contact.append(True) if np.sum(overlap)>0 else contact.append(False)
contact_label.append(np.where(overlap == 1)[0]+cell_label_offset+1) # cell label starts from 1 so need to offset extra 1
closest_cell_dist.append(np.min(dist))
return {
"graph":graph,
"network_image": distance_mesh_host,
"frame": frame,
"contact": contact,
"contact_label": contact_label,
"closest_cell_dist": closest_cell_dist,
}
def __call__(self, data) -> Any:
labels = data["label"]
features = data["feature"]
assert len(features) == 2, "Input features must be 2"
START_FRAME = features[0].frame.min()
END_FRAME = features[0].frame.max()
features_out = features[0].copy()
graph = []
network_image = None
contact = []
contact_label = []
closest_cell_dist = []
# Precompute cumulative cell-type-1 counts per frame to avoid per-frame DataFrame scans
frames_sorted = sorted(features[1].frame.unique())
_f1_counts = features[1].groupby("frame").size()
_cumulative_offset = {f: int(_f1_counts[_f1_counts.index < f].sum()) for f in features[0].frame.unique()}
pbar = tqdm(range(START_FRAME, END_FRAME + 1), desc="Cross cell contact measurements")
for frame_idx, CUR_FRAME in enumerate(pbar):
c = []
c_count = []
for i, label in enumerate(labels):
centroids = features[i][features[i].frame == CUR_FRAME][["i", "j"]].to_numpy().T
c.append(centroids)
c_count.append(centroids.shape[1])
# GPU: distance matrix between all centroid pairs (OpenCL via pyclesperanto)
distance_matrix_device = cle.generate_distance_matrix(c[0], c[1])
# Relabel only if label IDs are non-consecutive (pyclesperanto indexes by label value)
# Cellpose always produces consecutive 1..N labels so this is typically a no-op.
for li in range(2):
lslice = labels[li][:, :, CUR_FRAME]
if hasattr(lslice, "compute"):
lslice = lslice.compute()
lslice = np.asarray(lslice)
present = np.unique(lslice)
n_cells = len(present) - 1 # exclude background 0
if n_cells > 0 and present[-1] != n_cells:
relabelFilter = sitk.RelabelComponentImageFilter()
relabeled = relabelFilter.Execute(sitk.GetImageFromArray(lslice))
labels[li][:, :, CUR_FRAME] = sitk.GetArrayFromImage(relabeled)
# GPU: binary overlap matrix (pixel-level label contact detection)
overlap_matrix_device = cle.generate_binary_overlap_matrix(labels[0][:, :, CUR_FRAME], labels[1][:, :, CUR_FRAME])
# GPU: mask distances to only contacting cell pairs
masked_distance_matrix = cle.multiply_images(overlap_matrix_device, distance_matrix_device)
pointlist = np.concatenate(c, axis=1)
masked_host = cle.pull(masked_distance_matrix)[1:, 1:] # pull once
distance_matrix_pivot = np.zeros((c_count[0] + c_count[1] + 1, c_count[0] + c_count[1] + 1))
distance_matrix_pivot[(c[0].shape[1] + 1):, 1:(c[0].shape[1] + 1)] = masked_host
distance_matrix_pivot[1:(c[0].shape[1] + 1), (c[0].shape[1] + 1):] = masked_host.T
distance_matrix_pivot = cle.push(distance_matrix_pivot)
distance_mesh_device = cle.create_labels_like(labels[0][:, :, CUR_FRAME])
cle.touch_matrix_to_mesh(pointlist, distance_matrix_pivot, distance_mesh_device)
networkx_graph_two_cell_types_overlap = cle.to_networkx(distance_matrix_pivot, pointlist)
# note: graph index may have offset if START_FRAME != 0
graph.append(networkx_graph_two_cell_types_overlap)
# pull results to host
distance_matrix_host = cle.pull(distance_matrix_device)
overlap_matrix_host = cle.pull(overlap_matrix_device)
distance_mesh_host = cle.pull(distance_mesh_device)
# release GPU memory explicitly
distance_matrix_device.data.release()
overlap_matrix_device.data.release()
masked_distance_matrix.data.release()
distance_mesh_device.data.release()
del distance_matrix_device, overlap_matrix_device, masked_distance_matrix, distance_mesh_device
# Pre-allocate output image on first frame to avoid O(T²) concatenation
if network_image is None:
n_frames = END_FRAME - START_FRAME + 1
network_image = np.empty((*distance_mesh_host.shape, n_frames), dtype=distance_mesh_host.dtype)
network_image[..., frame_idx] = distance_mesh_host
# Vectorized per-cell contact extraction (replaces Python loop over ~4000 cells)
f_0 = features[0][features[0].frame == CUR_FRAME]
n0 = len(f_0.index)
cell_label_offset = _cumulative_offset.get(CUR_FRAME, 0)
overlap_sub = overlap_matrix_host[1:, 1:n0 + 1] # shape (n1, n0)
dist_sub = distance_matrix_host[1:, 1:n0 + 1] # shape (n1, n0)
contact.extend((overlap_sub.sum(axis=0) > 0).tolist())
closest_cell_dist.extend(dist_sub.min(axis=0).tolist())
for col in range(n0):
contact_label.append(np.where(overlap_sub[:, col] == 1)[0] + cell_label_offset + 1)
features_out["contact"] = contact
features_out["contacting cell labels"] = contact_label
features_out["closest cell dist"] = closest_cell_dist
return {"image": network_image, "feature": features_out, "network": graph}