Source code for cyto.postprocessing.plots.timeseries

"""
cyto.postprocessing.plots.timeseries
=====================================
Multi-frame statistical plots for cell network analysis.

All functions consume the ``all_results`` list produced by the spatiomics
batch script (one dict per frame) and write figures to an output directory.

Time windows (early / mid / late) are defined by explicit absolute hour
thresholds rather than implicit frame-count thirds, making analyses
reproducible across datasets of different lengths.

No GPU dependencies — importable without pyclesperanto.
"""

import os
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy import stats


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

def _period(frame_idx, frame_interval_seconds, early_cutoff_h, mid_cutoff_h):
    """Return the time-period label for a frame index.

    Parameters
    ----------
    frame_idx : int
        Raw frame index (0-based).
    frame_interval_seconds : float
        Physical duration of one frame in seconds.
    early_cutoff_h : float
        Frames with ``time_h < early_cutoff_h`` are labelled 'early'.
    mid_cutoff_h : float
        Frames with ``early_cutoff_h <= time_h < mid_cutoff_h`` are 'mid';
        frames at or beyond ``mid_cutoff_h`` are 'late'.

    Returns
    -------
    str
        One of ``'early'``, ``'mid'``, ``'late'``.
    """
    time_h = frame_idx * frame_interval_seconds / 3600
    if time_h < early_cutoff_h:
        return 'early'
    elif time_h < mid_cutoff_h:
        return 'mid'
    return 'late'


def _period_label(period, early_cutoff_h, mid_cutoff_h):
    """Return a display label for *period* showing only absolute hour boundaries."""
    if period == 'early':
        return f'0–{early_cutoff_h:.0f}h'
    elif period == 'mid':
        return f'{early_cutoff_h:.0f}{mid_cutoff_h:.0f}h'
    return f'≥{mid_cutoff_h:.0f}h'


def _add_significance_bars(ax, positions, data_groups, y_max):
    """Overlay Mann-Whitney U significance bars on a box-plot axes.

    Tests are run pairwise between (early, mid), (mid, late), and
    (early, late). Stars are drawn above the data at staggered heights.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
    positions : list of float
        x-positions of the three box-plot groups.
    data_groups : list of list
        Raw data for each group, in the same order as *positions*.
    y_max : float
        Upper data limit used to anchor the bar heights.
    """
    def _stars(p):
        if p < 0.001:
            return '***'
        elif p < 0.01:
            return '**'
        elif p < 0.05:
            return '*'
        return 'ns'

    bar_height = y_max * 0.05
    comparisons = [(0, 1, 'Early vs Mid'), (1, 2, 'Mid vs Late'), (0, 2, 'Early vs Late')]

    for i, (idx1, idx2, _label) in enumerate(comparisons):
        if idx1 >= len(data_groups) or idx2 >= len(data_groups):
            continue
        if not data_groups[idx1] or not data_groups[idx2]:
            continue
        try:
            _, p_value = stats.mannwhitneyu(
                data_groups[idx1], data_groups[idx2], alternative='two-sided'
            )
            stars = _stars(p_value)
            y_pos = y_max + bar_height * (i + 1)
            x1, x2 = positions[idx1], positions[idx2]
            ax.plot(
                [x1, x1, x2, x2],
                [y_pos, y_pos + bar_height * 0.1, y_pos + bar_height * 0.1, y_pos],
                'k-', linewidth=1,
            )
            ax.text(
                (x1 + x2) / 2, y_pos + bar_height * 0.2,
                stars, ha='center', va='bottom', fontsize=10, fontweight='bold',
            )
        except Exception:
            pass


# ---------------------------------------------------------------------------
# Public plotting functions
# ---------------------------------------------------------------------------

[docs] def plot_centrality_histograms( all_results, output_dir, frame_interval_seconds, early_cutoff_h, mid_cutoff_h, n_bins=15, ): """Save per-cell-type centrality histograms, stratified by time period. One PNG and one SVG are written per cell type (``<cell_type>_centrality_histograms.{png,svg}``). Parameters ---------- all_results : list of dict Frame-level results with keys ``'frame'`` and ``'centrality'``. output_dir : str or Path Directory to save figures. frame_interval_seconds : float Physical duration of one frame in seconds (used for time conversion). early_cutoff_h, mid_cutoff_h : float Absolute hour thresholds separating early / mid / late periods. n_bins : int, optional Number of histogram bins (default 15). """ colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'} centrality_types = ['degree', 'closeness', 'betweenness', 'eigenvector'] data_by_period = { p: {'cancer': {ct: [] for ct in centrality_types}, 'tcell': {ct: [] for ct in centrality_types}} for p in ('early', 'mid', 'late') } for result in all_results: period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h) for cell_type in ('cancer', 'tcell'): for ct in centrality_types: vals = result['centrality'][cell_type].get(ct, {}) if vals: data_by_period[period][cell_type][ct].extend(list(vals.values())) for cell_type in ('cancer', 'tcell'): fig, axes = plt.subplots(1, 4, figsize=(16, 4)) global_x_min, global_x_max = float('inf'), float('-inf') for ct in centrality_types: for p in ('early', 'mid', 'late'): d = data_by_period[p][cell_type][ct] if d: global_x_min = min(global_x_min, min(d)) global_x_max = max(global_x_max, max(d)) for i, ct in enumerate(centrality_types): ax = axes[i] all_data = [v for p in ('early', 'mid', 'late') for v in data_by_period[p][cell_type][ct]] if all_data: bins = np.linspace(min(all_data), max(all_data), n_bins + 1) for period in ('early', 'mid', 'late'): d = data_by_period[period][cell_type][ct] if d: counts, _ = np.histogram(d, bins=bins) pcts = counts / len(d) * 100 centers = (bins[:-1] + bins[1:]) / 2 mean_val = np.mean(d) plabel = _period_label(period, early_cutoff_h, mid_cutoff_h) ax.bar(centers, pcts, width=bins[1] - bins[0], alpha=0.6, label=f'{plabel} (μ={mean_val:.3f})', color=colors[period]) ax.axvline(mean_val, color=colors[period], linestyle='--', alpha=0.8, linewidth=2) ax.set_title(f'{ct.title()} Centrality') ax.set_xlabel('Centrality Value') ax.set_ylabel('Percentage (%)') ax.set_ylim(0, 100) if global_x_min != float('inf'): ax.set_xlim(global_x_min, global_x_max) ax.legend() ax.grid(True, alpha=0.3) plt.suptitle(f'{cell_type.title()} Cell Centrality Histograms ' f'(early<{early_cutoff_h}h, mid<{mid_cutoff_h}h)') plt.tight_layout() for ext in ('png', 'svg'): fig.savefig(os.path.join(output_dir, f'{cell_type}_centrality_histograms.{ext}'), format=ext, bbox_inches='tight', dpi=300) plt.close(fig)
[docs] def plot_clustering_histograms( all_results, output_dir, frame_interval_seconds, early_cutoff_h, mid_cutoff_h, n_bins=15, ): """Save per-cell-type clustering-coefficient histograms by time period. Parameters ---------- all_results : list of dict output_dir : str or Path frame_interval_seconds : float early_cutoff_h, mid_cutoff_h : float n_bins : int, optional """ colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'} data_by_period = {p: {'cancer': [], 'tcell': []} for p in ('early', 'mid', 'late')} for result in all_results: period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h) for cell_type in ('cancer', 'tcell'): vals = result['clustering'].get(cell_type, {}) if vals: data_by_period[period][cell_type].extend(list(vals.values())) for cell_type in ('cancer', 'tcell'): fig, ax = plt.subplots(1, 1, figsize=(4, 4)) all_data = [v for p in ('early', 'mid', 'late') for v in data_by_period[p][cell_type]] if all_data: bins = np.linspace(min(all_data), max(all_data), n_bins + 1) for period in ('early', 'mid', 'late'): d = data_by_period[period][cell_type] if d: counts, _ = np.histogram(d, bins=bins) pcts = counts / len(d) * 100 centers = (bins[:-1] + bins[1:]) / 2 mean_val = np.mean(d) plabel = _period_label(period, early_cutoff_h, mid_cutoff_h) ax.bar(centers, pcts, width=bins[1] - bins[0], alpha=0.6, label=f'{plabel} (μ={mean_val:.3f})', color=colors[period]) ax.axvline(mean_val, color=colors[period], linestyle='--', alpha=0.8, linewidth=2) ax.set_title(f'{cell_type.title()} Cell Clustering Coefficient') ax.set_xlabel('Clustering Coefficient') ax.set_ylabel('Percentage (%)') ax.set_ylim(0, 100) ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() for ext in ('png', 'svg'): fig.savefig(os.path.join(output_dir, f'{cell_type}_clustering_histograms.{ext}'), format=ext, bbox_inches='tight', dpi=300) plt.close(fig)
[docs] def plot_contact_vs_time(all_results, output_dir, frame_interval_seconds, skip_frames=1): """Save a two-panel figure: cross-cell contacts and cell counts vs time. Parameters ---------- all_results : list of dict output_dir : str or Path frame_interval_seconds : float Physical duration of one frame before any subsampling. skip_frames : int, optional Subsampling step used when processing frames (default 1). """ frames = [r['frame'] for r in all_results] cross_contacts = [r['contacts']['cross_contacts'] for r in all_results] cancer_counts = [r['contacts']['cancer_count'] for r in all_results] tcell_counts = [r['contacts']['tcell_count'] for r in all_results] time_hours = [f * frame_interval_seconds / 3600 for f in frames] fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(4, 8)) ax1.plot(time_hours, cross_contacts, 'o-', color='red', alpha=0.7, linewidth=2, markersize=4) ax1.set_xlabel('Time (hours)') ax1.set_ylabel('T Cell – Cancer Cell Contacts') ax1.set_title('Instantaneous Cross-Cell Contacts vs Time') ax1.grid(True, alpha=0.3) ax2.plot(time_hours, cancer_counts, 'o-', color='green', alpha=0.7, linewidth=2, markersize=4, label='Cancer Cells') ax2.plot(time_hours, tcell_counts, 'o-', color='magenta', alpha=0.7, linewidth=2, markersize=4, label='T Cells') ax2.set_xlabel('Time (hours)') ax2.set_ylabel('Number of Cells') ax2.set_title('Cell Counts vs Time') ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() for ext in ('png', 'svg'): fig.savefig(os.path.join(output_dir, f'contact_vs_time.{ext}'), format=ext, bbox_inches='tight', dpi=300) plt.close(fig)
[docs] def plot_centrality_boxplots_by_timepoint( all_results, output_dir, frame_interval_seconds, early_cutoff_h, mid_cutoff_h, ): """Save box-and-whisker plots of centrality measures grouped by time period. Mann-Whitney U significance bars are added between groups. One PNG and one SVG are written per cell type. Parameters ---------- all_results : list of dict output_dir : str or Path frame_interval_seconds : float early_cutoff_h, mid_cutoff_h : float """ colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'} centrality_types = ['degree', 'closeness', 'betweenness', 'eigenvector'] period_names = ['early', 'mid', 'late'] data_by_period = { p: {'cancer': {ct: [] for ct in centrality_types}, 'tcell': {ct: [] for ct in centrality_types}} for p in period_names } for result in all_results: period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h) for cell_type in ('cancer', 'tcell'): for ct in centrality_types: vals = result['centrality'][cell_type].get(ct, {}) if vals: data_by_period[period][cell_type][ct].extend(list(vals.values())) for cell_type in ('cancer', 'tcell'): fig, axes = plt.subplots(1, 4, figsize=(8, 4)) all_flat = [v for p in period_names for ct in centrality_types for v in data_by_period[p][cell_type][ct]] if all_flat: g_min, g_max = min(all_flat), max(all_flat) y_range = g_max - g_min y_lim = (g_min - y_range * 0.05, g_max + y_range * 0.4) for i, ct in enumerate(centrality_types): ax = axes[i] data_groups, labels, positions = [], [], [] for j, period in enumerate(period_names): d = data_by_period[period][cell_type][ct] if d: data_groups.append(d) labels.append(_period_label(period, early_cutoff_h, mid_cutoff_h)) positions.append(j) if data_groups: bp = ax.boxplot(data_groups, positions=positions, widths=0.6, patch_artist=True) for patch, period in zip(bp['boxes'], period_names[:len(bp['boxes'])]): patch.set_facecolor(colors[period]) patch.set_alpha(0.7) for j, (d, pos) in enumerate(zip(data_groups, positions)): period = period_names[j] if j < len(period_names) else 'unknown' jitter = np.random.uniform(-0.15, 0.15, size=len(d)) ax.scatter(np.full_like(d, pos, dtype=float) + jitter, d, color=colors.get(period, 'gray'), alpha=0.6, s=20) if all_flat: ax.set_ylim(*y_lim) local_flat = [v for g in data_groups for v in g] _add_significance_bars(ax, positions, data_groups, max(local_flat)) ax.set_title(f'{ct.title()} Centrality') ax.set_xlabel('Time Period') ax.set_ylabel('Centrality Value') ax.set_xticks(positions) ax.set_xticklabels(labels) ax.grid(True, alpha=0.3) plt.suptitle(f'{cell_type.title()} Cell Centrality – Time Point Comparison ' f'(early<{early_cutoff_h}h, mid<{mid_cutoff_h}h)') plt.tight_layout() for ext in ('png', 'svg'): fig.savefig( os.path.join(output_dir, f'{cell_type}_centrality_boxplots_timepoints.{ext}'), format=ext, bbox_inches='tight', dpi=300, ) plt.close(fig)
[docs] def plot_clustering_boxplots_by_timepoint( all_results, output_dir, frame_interval_seconds, early_cutoff_h, mid_cutoff_h, ): """Save box-and-whisker plots of clustering coefficients grouped by time period. Parameters ---------- all_results : list of dict output_dir : str or Path frame_interval_seconds : float early_cutoff_h, mid_cutoff_h : float """ colors = {'early': '#1f77b4', 'mid': '#ff7f0e', 'late': '#2ca02c'} period_names = ['early', 'mid', 'late'] data_by_period = {p: {'cancer': [], 'tcell': []} for p in period_names} for result in all_results: period = _period(result['frame'], frame_interval_seconds, early_cutoff_h, mid_cutoff_h) for cell_type in ('cancer', 'tcell'): vals = result['clustering'].get(cell_type, {}) if vals: data_by_period[period][cell_type].extend(list(vals.values())) for cell_type in ('cancer', 'tcell'): fig, ax = plt.subplots(1, 1, figsize=(2, 4)) data_groups, labels, positions = [], [], [] for j, period in enumerate(period_names): d = data_by_period[period][cell_type] if d: data_groups.append(d) labels.append(_period_label(period, early_cutoff_h, mid_cutoff_h)) positions.append(j) if data_groups: bp = ax.boxplot(data_groups, positions=positions, widths=0.6, patch_artist=True) for patch, period in zip(bp['boxes'], period_names[:len(bp['boxes'])]): patch.set_facecolor(colors[period]) patch.set_alpha(0.7) for j, (d, pos) in enumerate(zip(data_groups, positions)): period = period_names[j] if j < len(period_names) else 'unknown' jitter = np.random.uniform(-0.15, 0.15, size=len(d)) ax.scatter(np.full_like(d, pos, dtype=float) + jitter, d, color=colors.get(period, 'gray'), alpha=0.6, s=20) flat = [v for g in data_groups for v in g] if flat: y_max, y_min = max(flat), min(flat) y_range = y_max - y_min ax.set_ylim(y_min - y_range * 0.05, y_max + y_range * 0.4) _add_significance_bars(ax, positions, data_groups, y_max) ax.set_title(f'{cell_type.title()} Clustering – Time Point Comparison') ax.set_xlabel('Time Period') ax.set_ylabel('Clustering Coefficient') ax.set_xticks(positions) ax.set_xticklabels(labels) ax.grid(True, alpha=0.3) plt.tight_layout() for ext in ('png', 'svg'): fig.savefig( os.path.join(output_dir, f'{cell_type}_clustering_boxplots_timepoints.{ext}'), format=ext, bbox_inches='tight', dpi=300, ) plt.close(fig)