"""
cyto.postprocessing.plots.timeseries
=====================================
Multi-frame statistical plots for cell network analysis.
All functions consume the ``all_results`` list produced by the spatiomics
batch script (one dict per frame) and write figures to an output directory.
Time windows (early / mid / late) are defined by explicit absolute hour
thresholds rather than implicit frame-count thirds, making analyses
reproducible across datasets of different lengths.
No GPU dependencies — importable without pyclesperanto.
"""
import os
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy import stats
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _period(frame_idx, frame_interval_seconds, early_cutoff_h, mid_cutoff_h):
"""Return the time-period label for a frame index.
Parameters
----------
frame_idx : int
Raw frame index (0-based).
frame_interval_seconds : float
Physical duration of one frame in seconds.
early_cutoff_h : float
Frames with ``time_h < early_cutoff_h`` are labelled 'early'.
mid_cutoff_h : float
Frames with ``early_cutoff_h <= time_h < mid_cutoff_h`` are 'mid';
frames at or beyond ``mid_cutoff_h`` are 'late'.
Returns
-------
str
One of ``'early'``, ``'mid'``, ``'late'``.
"""
time_h = frame_idx * frame_interval_seconds / 3600
if time_h < early_cutoff_h:
return 'early'
elif time_h < mid_cutoff_h:
return 'mid'
return 'late'
def _period_label(period, early_cutoff_h, mid_cutoff_h):
"""Return a display label for *period* showing only absolute hour boundaries."""
if period == 'early':
return f'0–{early_cutoff_h:.0f}h'
elif period == 'mid':
return f'{early_cutoff_h:.0f}–{mid_cutoff_h:.0f}h'
return f'≥{mid_cutoff_h:.0f}h'
def _add_significance_bars(ax, positions, data_groups, y_max):
"""Overlay Mann-Whitney U significance bars on a box-plot axes.
Tests are run pairwise between (early, mid), (mid, late), and
(early, late). Stars are drawn above the data at staggered heights.
Parameters
----------
ax : matplotlib.axes.Axes
positions : list of float
x-positions of the three box-plot groups.
data_groups : list of list
Raw data for each group, in the same order as *positions*.
y_max : float
Upper data limit used to anchor the bar heights.
"""
def _stars(p):
if p < 0.001:
return '***'
elif p < 0.01:
return '**'
elif p < 0.05:
return '*'
return 'ns'
bar_height = y_max * 0.05
comparisons = [(0, 1, 'Early vs Mid'), (1, 2, 'Mid vs Late'), (0, 2, 'Early vs Late')]
for i, (idx1, idx2, _label) in enumerate(comparisons):
if idx1 >= len(data_groups) or idx2 >= len(data_groups):
continue
if not data_groups[idx1] or not data_groups[idx2]:
continue
try:
_, p_value = stats.mannwhitneyu(
data_groups[idx1], data_groups[idx2], alternative='two-sided'
)
stars = _stars(p_value)
y_pos = y_max + bar_height * (i + 1)
x1, x2 = positions[idx1], positions[idx2]
ax.plot(
[x1, x1, x2, x2],
[y_pos, y_pos + bar_height * 0.1, y_pos + bar_height * 0.1, y_pos],
'k-', linewidth=1,
)
ax.text(
(x1 + x2) / 2, y_pos + bar_height * 0.2,
stars, ha='center', va='bottom', fontsize=10, fontweight='bold',
)
except Exception:
pass
# ---------------------------------------------------------------------------
# Public plotting functions
# ---------------------------------------------------------------------------
[docs]
def plot_centrality_histograms(
all_results, output_dir, frame_interval_seconds,
early_cutoff_h, mid_cutoff_h, n_bins=15,
):
"""Save per-cell-type centrality histograms, stratified by time period.
One PNG and one SVG are written per cell type
(``<cell_type>_centrality_histograms.{png,svg}``).
Parameters
----------
all_results : list of dict
Frame-level results with keys ``'frame'`` and ``'centrality'``.
output_dir : str or Path
Directory to save figures.
frame_interval_seconds : float
Physical duration of one frame in seconds (used for time conversion).
early_cutoff_h, mid_cutoff_h : float
Absolute hour thresholds separating early / mid / late periods.
n_bins : int, optional
Number of histogram bins (default 15).
"""
colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'}
centrality_types = ['degree', 'closeness', 'betweenness', 'eigenvector']
data_by_period = {
p: {'cancer': {ct: [] for ct in centrality_types},
'tcell': {ct: [] for ct in centrality_types}}
for p in ('early', 'mid', 'late')
}
for result in all_results:
period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h)
for cell_type in ('cancer', 'tcell'):
for ct in centrality_types:
vals = result['centrality'][cell_type].get(ct, {})
if vals:
data_by_period[period][cell_type][ct].extend(list(vals.values()))
for cell_type in ('cancer', 'tcell'):
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
global_x_min, global_x_max = float('inf'), float('-inf')
for ct in centrality_types:
for p in ('early', 'mid', 'late'):
d = data_by_period[p][cell_type][ct]
if d:
global_x_min = min(global_x_min, min(d))
global_x_max = max(global_x_max, max(d))
for i, ct in enumerate(centrality_types):
ax = axes[i]
all_data = [v for p in ('early', 'mid', 'late')
for v in data_by_period[p][cell_type][ct]]
if all_data:
bins = np.linspace(min(all_data), max(all_data), n_bins + 1)
for period in ('early', 'mid', 'late'):
d = data_by_period[period][cell_type][ct]
if d:
counts, _ = np.histogram(d, bins=bins)
pcts = counts / len(d) * 100
centers = (bins[:-1] + bins[1:]) / 2
mean_val = np.mean(d)
plabel = _period_label(period, early_cutoff_h, mid_cutoff_h)
ax.bar(centers, pcts, width=bins[1] - bins[0], alpha=0.6,
label=f'{plabel} (μ={mean_val:.3f})',
color=colors[period])
ax.axvline(mean_val, color=colors[period], linestyle='--',
alpha=0.8, linewidth=2)
ax.set_title(f'{ct.title()} Centrality')
ax.set_xlabel('Centrality Value')
ax.set_ylabel('Percentage (%)')
ax.set_ylim(0, 100)
if global_x_min != float('inf'):
ax.set_xlim(global_x_min, global_x_max)
ax.legend()
ax.grid(True, alpha=0.3)
plt.suptitle(f'{cell_type.title()} Cell Centrality Histograms '
f'(early<{early_cutoff_h}h, mid<{mid_cutoff_h}h)')
plt.tight_layout()
for ext in ('png', 'svg'):
fig.savefig(os.path.join(output_dir, f'{cell_type}_centrality_histograms.{ext}'),
format=ext, bbox_inches='tight', dpi=300)
plt.close(fig)
[docs]
def plot_clustering_histograms(
all_results, output_dir, frame_interval_seconds,
early_cutoff_h, mid_cutoff_h, n_bins=15,
):
"""Save per-cell-type clustering-coefficient histograms by time period.
Parameters
----------
all_results : list of dict
output_dir : str or Path
frame_interval_seconds : float
early_cutoff_h, mid_cutoff_h : float
n_bins : int, optional
"""
colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'}
data_by_period = {p: {'cancer': [], 'tcell': []} for p in ('early', 'mid', 'late')}
for result in all_results:
period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h)
for cell_type in ('cancer', 'tcell'):
vals = result['clustering'].get(cell_type, {})
if vals:
data_by_period[period][cell_type].extend(list(vals.values()))
for cell_type in ('cancer', 'tcell'):
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
all_data = [v for p in ('early', 'mid', 'late')
for v in data_by_period[p][cell_type]]
if all_data:
bins = np.linspace(min(all_data), max(all_data), n_bins + 1)
for period in ('early', 'mid', 'late'):
d = data_by_period[period][cell_type]
if d:
counts, _ = np.histogram(d, bins=bins)
pcts = counts / len(d) * 100
centers = (bins[:-1] + bins[1:]) / 2
mean_val = np.mean(d)
plabel = _period_label(period, early_cutoff_h, mid_cutoff_h)
ax.bar(centers, pcts, width=bins[1] - bins[0], alpha=0.6,
label=f'{plabel} (μ={mean_val:.3f})',
color=colors[period])
ax.axvline(mean_val, color=colors[period], linestyle='--',
alpha=0.8, linewidth=2)
ax.set_title(f'{cell_type.title()} Cell Clustering Coefficient')
ax.set_xlabel('Clustering Coefficient')
ax.set_ylabel('Percentage (%)')
ax.set_ylim(0, 100)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
for ext in ('png', 'svg'):
fig.savefig(os.path.join(output_dir, f'{cell_type}_clustering_histograms.{ext}'),
format=ext, bbox_inches='tight', dpi=300)
plt.close(fig)
[docs]
def plot_centrality_boxplots_by_timepoint(
all_results, output_dir, frame_interval_seconds,
early_cutoff_h, mid_cutoff_h,
):
"""Save box-and-whisker plots of centrality measures grouped by time period.
Mann-Whitney U significance bars are added between groups. One PNG and
one SVG are written per cell type.
Parameters
----------
all_results : list of dict
output_dir : str or Path
frame_interval_seconds : float
early_cutoff_h, mid_cutoff_h : float
"""
colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'}
centrality_types = ['degree', 'closeness', 'betweenness', 'eigenvector']
period_names = ['early', 'mid', 'late']
data_by_period = {
p: {'cancer': {ct: [] for ct in centrality_types},
'tcell': {ct: [] for ct in centrality_types}}
for p in period_names
}
for result in all_results:
period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h)
for cell_type in ('cancer', 'tcell'):
for ct in centrality_types:
vals = result['centrality'][cell_type].get(ct, {})
if vals:
data_by_period[period][cell_type][ct].extend(list(vals.values()))
for cell_type in ('cancer', 'tcell'):
fig, axes = plt.subplots(1, 4, figsize=(8, 4))
all_flat = [v for p in period_names for ct in centrality_types
for v in data_by_period[p][cell_type][ct]]
if all_flat:
g_min, g_max = min(all_flat), max(all_flat)
y_range = g_max - g_min
y_lim = (g_min - y_range * 0.05, g_max + y_range * 0.4)
for i, ct in enumerate(centrality_types):
ax = axes[i]
data_groups, labels, positions = [], [], []
for j, period in enumerate(period_names):
d = data_by_period[period][cell_type][ct]
if d:
data_groups.append(d)
labels.append(_period_label(period, early_cutoff_h, mid_cutoff_h))
positions.append(j)
if data_groups:
bp = ax.boxplot(data_groups, positions=positions, widths=0.6, patch_artist=True)
for patch, period in zip(bp['boxes'], period_names[:len(bp['boxes'])]):
patch.set_facecolor(colors[period])
patch.set_alpha(0.7)
for j, (d, pos) in enumerate(zip(data_groups, positions)):
period = period_names[j] if j < len(period_names) else 'unknown'
jitter = np.random.uniform(-0.15, 0.15, size=len(d))
ax.scatter(np.full_like(d, pos, dtype=float) + jitter, d,
color=colors.get(period, 'gray'), alpha=0.6, s=20)
if all_flat:
ax.set_ylim(*y_lim)
local_flat = [v for g in data_groups for v in g]
_add_significance_bars(ax, positions, data_groups, max(local_flat))
ax.set_title(f'{ct.title()} Centrality')
ax.set_xlabel('Time Period')
ax.set_ylabel('Centrality Value')
ax.set_xticks(positions)
ax.set_xticklabels(labels)
ax.grid(True, alpha=0.3)
plt.suptitle(f'{cell_type.title()} Cell Centrality – Time Point Comparison '
f'(early<{early_cutoff_h}h, mid<{mid_cutoff_h}h)')
plt.tight_layout()
for ext in ('png', 'svg'):
fig.savefig(
os.path.join(output_dir, f'{cell_type}_centrality_boxplots_timepoints.{ext}'),
format=ext, bbox_inches='tight', dpi=300,
)
plt.close(fig)
[docs]
def plot_clustering_boxplots_by_timepoint(
all_results, output_dir, frame_interval_seconds,
early_cutoff_h, mid_cutoff_h,
):
"""Save box-and-whisker plots of clustering coefficients grouped by time period.
Parameters
----------
all_results : list of dict
output_dir : str or Path
frame_interval_seconds : float
early_cutoff_h, mid_cutoff_h : float
"""
colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'}
period_names = ['early', 'mid', 'late']
data_by_period = {p: {'cancer': [], 'tcell': []} for p in period_names}
for result in all_results:
period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h)
for cell_type in ('cancer', 'tcell'):
vals = result['clustering'].get(cell_type, {})
if vals:
data_by_period[period][cell_type].extend(list(vals.values()))
for cell_type in ('cancer', 'tcell'):
fig, ax = plt.subplots(1, 1, figsize=(2, 4))
data_groups, labels, positions = [], [], []
for j, period in enumerate(period_names):
d = data_by_period[period][cell_type]
if d:
data_groups.append(d)
labels.append(_period_label(period, early_cutoff_h, mid_cutoff_h))
positions.append(j)
if data_groups:
bp = ax.boxplot(data_groups, positions=positions, widths=0.6, patch_artist=True)
for patch, period in zip(bp['boxes'], period_names[:len(bp['boxes'])]):
patch.set_facecolor(colors[period])
patch.set_alpha(0.7)
for j, (d, pos) in enumerate(zip(data_groups, positions)):
period = period_names[j] if j < len(period_names) else 'unknown'
jitter = np.random.uniform(-0.15, 0.15, size=len(d))
ax.scatter(np.full_like(d, pos, dtype=float) + jitter, d,
color=colors.get(period, 'gray'), alpha=0.6, s=20)
flat = [v for g in data_groups for v in g]
if flat:
y_max, y_min = max(flat), min(flat)
y_range = y_max - y_min
ax.set_ylim(y_min - y_range * 0.05, y_max + y_range * 0.4)
_add_significance_bars(ax, positions, data_groups, y_max)
ax.set_title(f'{cell_type.title()} Clustering – Time Point Comparison')
ax.set_xlabel('Time Period')
ax.set_ylabel('Clustering Coefficient')
ax.set_xticks(positions)
ax.set_xticklabels(labels)
ax.grid(True, alpha=0.3)
plt.tight_layout()
for ext in ('png', 'svg'):
fig.savefig(
os.path.join(output_dir, f'{cell_type}_clustering_boxplots_timepoints.{ext}'),
format=ext, bbox_inches='tight', dpi=300,
)
plt.close(fig)