Skip to content

Plotting & Visualization

The tit.plotting module provides matplotlib-based visualization functions used by the analysis, optimization, and reporting pipelines. All functions are headless-safe and work in Docker/CI environments without a display server.

graph LR
    ANALYZER[Analyzer] --> PLOTS[tit.plotting]
    STATS[Stats Module] --> PLOTS
    REPORTING[Report Generators] --> PLOTS
    PLOTS --> PDF[PDF Figures]
    PLOTS --> PNG[PNG / base64 Images]
    style PLOTS fill:#2d5a27,stroke:#4a8,color:#fff

Focality Histograms

plot_whole_head_roi_histogram

Generates a whole-head field distribution histogram with per-bin ROI contribution color coding. Includes focality cutoff lines, an optional mean ROI field marker, and a summary statistics box.

from tit.plotting import plot_whole_head_roi_histogram

output_path = plot_whole_head_roi_histogram(
    output_dir="/data/project/derivatives/ti-toolbox/analysis/sub-001",
    whole_head_field_data=whole_head_values,   # np.ndarray
    roi_field_data=roi_values,                 # np.ndarray
    whole_head_element_sizes=wh_sizes,         # optional, np.ndarray
    roi_element_sizes=roi_sizes,               # optional, np.ndarray
    filename="TI_max.nii.gz",                  # optional, used for title and output name
    region_name="M1",                          # optional, ROI label
    roi_field_value=0.152,                     # optional, draws vertical marker
    data_type="element",                       # "element" or "voxel"
    voxel_dims=(1.0, 1.0, 1.0),               # optional, for voxel volume weighting
    n_bins=100,
    dpi=600,
)

Returns the path to the saved PDF, or None if input data is empty.

TI Metric Distributions

plot_montage_distributions

Creates three side-by-side histograms showing TImax, TImean, and Focality distributions across montages.

from tit.plotting import plot_montage_distributions

output_path = plot_montage_distributions(
    timax_values=[0.21, 0.34, 0.18],
    timean_values=[0.12, 0.19, 0.09],
    focality_values=[0.85, 0.72, 0.91],
    output_file="/data/output/montage_distributions.png",
    dpi=300,
)

plot_intensity_vs_focality

Scatter plot of intensity versus focality, optionally colored by a composite index.

from tit.plotting import plot_intensity_vs_focality

output_path = plot_intensity_vs_focality(
    intensity=[0.12, 0.19, 0.09, 0.25],
    focality=[0.85, 0.72, 0.91, 0.68],
    composite=[0.48, 0.55, 0.41, 0.60],  # or None
    output_file="/data/output/intensity_vs_focality.png",
    dpi=300,
)

Statistical Plots

plot_permutation_null_distribution

Plots a permutation null distribution histogram with a significance threshold line and markers for observed clusters. Uses seaborn for styling.

from tit.plotting import plot_permutation_null_distribution

output_path = plot_permutation_null_distribution(
    null_distribution=null_dist_array,       # np.ndarray
    threshold=42.0,                          # significance threshold
    observed_clusters=[                      # list of dicts
        {"stat_value": 55.0, "p_value": 0.01},
        {"stat_value": 30.0, "p_value": 0.12},
    ],
    output_file="/data/output/null_distribution.pdf",
    alpha=0.05,
    cluster_stat="size",                     # "size" or "mass"
    dpi=300,
)

plot_cluster_size_mass_correlation

Scatter plot with regression line showing the correlation between cluster size and cluster mass from permutation testing. Annotates Pearson r and p-value.

from tit.plotting import plot_cluster_size_mass_correlation

output_path = plot_cluster_size_mass_correlation(
    cluster_sizes=sizes_array,     # np.ndarray
    cluster_masses=masses_array,   # np.ndarray
    output_file="/data/output/size_mass_correlation.pdf",
    dpi=300,
)

Returns None if fewer than 2 non-zero data points are available.

Static Overlay Images

generate_static_overlay_images

Generates base64-encoded PNG slice images by overlaying a NIfTI field map on a T1 anatomical image. Produces 7 slices per orientation (axial, sagittal, coronal) with neurological convention labels.

from tit.plotting import generate_static_overlay_images

images = generate_static_overlay_images(
    t1_file="/data/project/sub-001/anat/sub-001_T1w.nii.gz",
    overlay_file="/data/project/derivatives/SimNIBS/sub-001/Simulations/TI_max.nii.gz",
    subject_id="001",          # optional
    montage_name="motor",      # optional
    output_dir=None,           # optional, not used for file output
)

# images is a dict with keys: "axial", "sagittal", "coronal"
# Each value is a list of dicts with: "base64", "slice_num", "overlay_voxels"
for entry in images["axial"]:
    print(f"Slice {entry['slice_num']}: {entry['overlay_voxels']} overlay voxels")

Helpers

The tit.plotting._common module provides shared utilities used by all plotting functions.

SaveFigOptions

Frozen dataclass controlling figure save parameters.

from tit.plotting import SaveFigOptions

opts = SaveFigOptions(
    dpi=600,                # default: 600
    bbox_inches="tight",    # default: "tight"
    facecolor="white",      # default: "white"
    edgecolor="none",       # default: "none"
)

ensure_headless_matplotlib_backend

Sets the matplotlib backend to "Agg" (or a specified backend) for headless environments. Should be called before importing matplotlib.pyplot. No-ops if a backend is already active.

from tit.plotting import ensure_headless_matplotlib_backend

ensure_headless_matplotlib_backend()          # defaults to "Agg"
ensure_headless_matplotlib_backend("Cairo")   # or specify another backend

savefig_close

Saves a matplotlib Figure to disk and closes it. Uses fig.savefig (not plt.savefig) to avoid global pyplot state issues.

from tit.plotting import savefig_close, SaveFigOptions

path = savefig_close(
    fig,
    "/data/output/figure.pdf",
    fmt="pdf",                           # optional explicit format
    opts=SaveFigOptions(dpi=300),        # optional overrides
)

Lazy Imports

The tit.plotting package uses lazy imports throughout. Importing tit.plotting does not pull in matplotlib, nibabel, seaborn, or scipy. These dependencies are only loaded when a plot function is actually called.

API Reference

Focality

tit.plotting.focality.plot_whole_head_roi_histogram

plot_whole_head_roi_histogram(*, output_dir: str, whole_head_field_data: ndarray, roi_field_data: ndarray, whole_head_element_sizes: ndarray | None = None, roi_element_sizes: ndarray | None = None, filename: str | None = None, region_name: str | None = None, roi_field_value: float | None = None, data_type: str = 'element', voxel_dims: tuple | None = None, n_bins: int = 100, dpi: int = 600) -> str | None

Generate a whole-head histogram with ROI contribution color coding.

Efficient implementation: ROI contribution per bin is computed via vectorized division (no Python loops).

Source code in tit/plotting/focality.py
def plot_whole_head_roi_histogram(
    *,
    output_dir: str,
    whole_head_field_data: np.ndarray,
    roi_field_data: np.ndarray,
    whole_head_element_sizes: np.ndarray | None = None,
    roi_element_sizes: np.ndarray | None = None,
    filename: str | None = None,
    region_name: str | None = None,
    roi_field_value: float | None = None,
    data_type: str = "element",
    voxel_dims: tuple | None = None,
    n_bins: int = 100,
    dpi: int = 600,
) -> str | None:
    """
    Generate a whole-head histogram with ROI contribution color coding.

    Efficient implementation: ROI contribution per bin is computed via vectorized
    division (no Python loops).
    """
    if whole_head_field_data is None or roi_field_data is None:
        return None

    whole_head_field_data = np.asarray(whole_head_field_data)
    roi_field_data = np.asarray(roi_field_data)

    if whole_head_field_data.size == 0 or roi_field_data.size == 0:
        return None

    # Remove NaN values
    wh_mask = ~np.isnan(whole_head_field_data)
    roi_mask = ~np.isnan(roi_field_data)
    whole_head_field_data = whole_head_field_data[wh_mask]
    roi_field_data = roi_field_data[roi_mask]

    if whole_head_field_data.size == 0 or roi_field_data.size == 0:
        return None

    # Optional volume weighting (only if we can do it consistently for both datasets)
    weights_wh = None
    weights_roi = None
    if data_type == "voxel" and voxel_dims is not None:
        voxel_volume = float(np.prod(voxel_dims[:3]))
        weights_wh = np.full(whole_head_field_data.shape, voxel_volume, dtype=float)
        weights_roi = np.full(roi_field_data.shape, voxel_volume, dtype=float)
    elif whole_head_element_sizes is not None and roi_element_sizes is not None:
        # Robust handling: some callers may pass scalar (0-d) "element sizes" in edge
        # cases (e.g., tiny ROIs). In that case, treat it as a uniform weight.
        wh_sizes = np.asarray(whole_head_element_sizes)
        roi_sizes = np.asarray(roi_element_sizes)

        # Broadcast scalars to match data, otherwise apply NaN masks.
        if wh_sizes.ndim == 0:
            wh_sizes = np.full(
                whole_head_field_data.shape, wh_sizes.item(), dtype=float
            )
        else:
            wh_sizes = wh_sizes[wh_mask]

        if roi_sizes.ndim == 0:
            roi_sizes = np.full(roi_field_data.shape, roi_sizes.item(), dtype=float)
        else:
            roi_sizes = roi_sizes[roi_mask]

        if (
            wh_sizes.shape == whole_head_field_data.shape
            and roi_sizes.shape == roi_field_data.shape
        ):
            weights_wh = wh_sizes
            weights_roi = roi_sizes

    ensure_headless_matplotlib_backend()
    import matplotlib.pyplot as plt

    # Keep these local to the plotting call (avoid global side effects).
    #
    # Note: In minimal Docker/SimNIBS environments, matplotlib can emit very noisy
    # `findfont:` messages when fonts are missing. We suppress that noise in
    # `ensure_headless_matplotlib_backend()`; here we avoid forcing Helvetica (which
    # may not be installed) and provide a reasonable preference order.
    rc = {
        "pdf.fonttype": 42,  # Embed fonts as text (not paths)
        "pdf.use14corefonts": True,
        "font.family": "sans-serif",
        "font.sans-serif": [
            "DejaVu Sans",
            "Liberation Sans",
            "Bitstream Vera Sans",
            "sans-serif",
        ],
        "text.usetex": False,
        "svg.fonttype": "none",
    }

    with plt.rc_context(rc):
        fig, ax = plt.subplots(figsize=(14, 10))

        # Histogram bins based on whole head data
        hist, bin_edges = np.histogram(
            whole_head_field_data, bins=n_bins, weights=weights_wh
        )
        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        bin_width = float(bin_edges[1] - bin_edges[0])

        roi_hist, _ = np.histogram(roi_field_data, bins=bin_edges, weights=weights_roi)

        # Vectorized ROI contribution
        roi_contribution = np.divide(
            roi_hist, hist, out=np.zeros_like(hist, dtype=float), where=hist > 0
        )

        non_zero = roi_contribution[roi_contribution > 0]
        if non_zero.size > 0:
            max_contribution = float(max(np.percentile(non_zero, 95), 0.01))
        else:
            max_contribution = 0.01

        normalized = np.clip(roi_contribution / max_contribution, 0, 1)

        rainbow_cmap = plt.cm.get_cmap("rainbow")
        colors = rainbow_cmap(normalized)
        colors[:, 3] = 0.7

        ax.bar(bin_centers, hist, width=bin_width, color=colors, edgecolor="black")

        # Focality cutoffs based on 99.9 percentile
        focality_cutoffs = np.array([50, 75, 90, 95], dtype=float)
        percentile_99_9 = float(np.percentile(whole_head_field_data, 99.9))
        thresholds = (focality_cutoffs / 100.0) * percentile_99_9
        counts = [int(np.count_nonzero(whole_head_field_data >= t)) for t in thresholds]

        colors_lines = ["red", "darkred", "crimson", "maroon"]
        for i, (threshold, cutoff, count) in enumerate(
            zip(thresholds, focality_cutoffs, counts)
        ):
            if (
                float(np.min(whole_head_field_data))
                <= threshold
                <= float(np.max(whole_head_field_data))
            ):
                ax.axvline(
                    x=threshold,
                    color=colors_lines[i % len(colors_lines)],
                    linestyle="--",
                    linewidth=2,
                    label=f"{int(cutoff)}% of 99.9%ile\n({threshold:.2f} V/m)\nCount: {count:,} {data_type}s",
                )

        if roi_field_value is not None and float(
            np.min(whole_head_field_data)
        ) <= float(roi_field_value) <= float(np.max(whole_head_field_data)):
            ax.axvline(
                x=float(roi_field_value),
                color="green",
                linestyle="-",
                linewidth=3,
                label=f"Mean ROI Field\n({float(roi_field_value):.2f} V/m)",
            )

        if ax.get_legend_handles_labels()[0]:
            ax.legend(
                loc="upper left", bbox_to_anchor=(0.02, 0.98), frameon=True, fontsize=11
            )

        ax.set_xlabel("Field Strength (V/m)", fontsize=14)
        ax.set_ylabel(f"{data_type.capitalize()}s", fontsize=14)
        ax.tick_params(axis="both", which="major", labelsize=12)

        title_parts = ["Whole-Head Field Distribution with ROI Contribution"]
        if region_name:
            title_parts.append(f"ROI: {region_name}")
        if filename:
            title_parts.append(f"File: {filename}")
        ax.set_title("\n".join(title_parts), fontsize=14)
        ax.grid(True, alpha=0.3)

        # Colorbar for ROI contribution
        sm = plt.cm.ScalarMappable(cmap=rainbow_cmap, norm=plt.Normalize(0, 1))
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax, shrink=0.7, pad=0.02, aspect=25)
        # Avoid non-ASCII arrows to keep PDF core fonts (Helvetica) warning-free in minimal containers.
        cbar.set_label(
            f"ROI Contribution Fraction\n(Blue->Green->Red, max={max_contribution:.3f})",
            fontsize=12,
        )

        # Stats box
        stats_text = (
            "Whole Head:\n"
            f"Max: {float(np.max(whole_head_field_data)):.2f} V/m\n"
            f"Mean: {float(np.mean(whole_head_field_data)):.2f} V/m\n"
            f"99.9%ile: {float(np.percentile(whole_head_field_data, 99.9)):.2f} V/m\n"
            f"{data_type.capitalize()}s: {whole_head_field_data.size:,}\n\n"
            "ROI:\n"
            f"Mean: {float(np.mean(roi_field_data)):.2f} V/m\n"
            f"Max: {float(np.max(roi_field_data)):.2f} V/m\n"
            f"{data_type.capitalize()}s: {roi_field_data.size:,}"
        )
        ax.text(
            0.98,
            0.98,
            stats_text,
            transform=ax.transAxes,
            fontsize=11,
            verticalalignment="top",
            horizontalalignment="right",
            bbox=dict(boxstyle="square", facecolor="lightyellow"),
        )

        if filename:
            base_name = _stem_no_nii_gz(filename)
        elif region_name:
            base_name = f"{region_name}_whole_head_roi"
        else:
            base_name = "whole_head_roi_histogram"

        os.makedirs(output_dir, exist_ok=True)
        hist_file = os.path.join(output_dir, f"{base_name}_histogram.pdf")
        fig.tight_layout()
        return savefig_close(fig, hist_file, fmt="pdf", opts=SaveFigOptions(dpi=dpi))

TI Metrics

tit.plotting.ti_metrics.plot_montage_distributions

plot_montage_distributions(*, timax_values: Sequence[float], timean_values: Sequence[float], focality_values: Sequence[float], output_file: str, dpi: int = 300) -> str | None

Create 3 side-by-side histograms for TImax, TImean and Focality distributions.

Source code in tit/plotting/ti_metrics.py
def plot_montage_distributions(
    *,
    timax_values: Sequence[float],
    timean_values: Sequence[float],
    focality_values: Sequence[float],
    output_file: str,
    dpi: int = 300,
) -> str | None:
    """
    Create 3 side-by-side histograms for TImax, TImean and Focality distributions.
    """
    if (not timax_values) and (not timean_values) and (not focality_values):
        return None

    ensure_headless_matplotlib_backend()
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    configs = [
        (timax_values, axes[0], "TImax (V/m)", "TImax Distribution", "#2196F3"),
        (timean_values, axes[1], "TImean (V/m)", "TImean Distribution", "#4CAF50"),
        (focality_values, axes[2], "Focality", "Focality Distribution", "#FF9800"),
    ]

    for values, ax, xlabel, title, color in configs:
        if values:
            ax.hist(values, bins=20, color=color, edgecolor="black", alpha=0.7)
            ax.set_xlabel(xlabel, fontsize=12)
            ax.set_ylabel("Frequency", fontsize=12)
            ax.set_title(title, fontsize=14, fontweight="bold")
            ax.grid(axis="y", alpha=0.3)

    fig.tight_layout()
    return savefig_close(fig, output_file, opts=SaveFigOptions(dpi=dpi))

tit.plotting.ti_metrics.plot_intensity_vs_focality

plot_intensity_vs_focality(*, intensity: Sequence[float], focality: Sequence[float], composite: Sequence[float] | None, output_file: str, dpi: int = 300) -> str | None

Scatter plot of intensity vs focality, optionally colored by composite index.

Source code in tit/plotting/ti_metrics.py
def plot_intensity_vs_focality(
    *,
    intensity: Sequence[float],
    focality: Sequence[float],
    composite: Sequence[float] | None,
    output_file: str,
    dpi: int = 300,
) -> str | None:
    """
    Scatter plot of intensity vs focality, optionally colored by composite index.
    """
    if (not intensity) or (not focality):
        return None

    ensure_headless_matplotlib_backend()
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(figsize=(6, 5))
    if composite and any(c is not None for c in composite):
        sc = ax.scatter(
            intensity,
            focality,
            c=composite,
            cmap="viridis",
            s=40,
            edgecolor="black",
            alpha=0.7,
        )
        fig.colorbar(sc, ax=ax).set_label("Composite Index", fontsize=12)
    else:
        ax.scatter(intensity, focality, s=40, edgecolor="black", alpha=0.7)

    ax.set_xlabel("TImean_ROI (V/m)", fontsize=12)
    ax.set_ylabel("Focality", fontsize=12)
    ax.set_title("Intensity vs Focality", fontsize=14, fontweight="bold")
    ax.grid(alpha=0.3)
    fig.tight_layout()
    return savefig_close(fig, output_file, opts=SaveFigOptions(dpi=dpi))

Statistical Plots

tit.plotting.stats.plot_permutation_null_distribution

plot_permutation_null_distribution(null_distribution: ndarray, threshold: float, observed_clusters: Sequence[Mapping[str, float]], output_file: str, *, alpha: float = 0.05, cluster_stat: str = 'size', dpi: int = 300) -> str

Plot permutation null distribution with threshold and observed clusters.

Source code in tit/plotting/stats.py
def plot_permutation_null_distribution(
    null_distribution: np.ndarray,
    threshold: float,
    observed_clusters: Sequence[Mapping[str, float]],
    output_file: str,
    *,
    alpha: float = 0.05,
    cluster_stat: str = "size",
    dpi: int = 300,
) -> str:
    """
    Plot permutation null distribution with threshold and observed clusters.
    """
    ensure_headless_matplotlib_backend()
    import matplotlib.pyplot as plt
    import seaborn as sns

    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.0)

    fig, ax = plt.subplots(figsize=(10, 6))

    # Labels based on cluster statistic
    if cluster_stat == "size":
        x_label = "Maximum Cluster Size (voxels)"
        title = "Permutation Null Distribution of Maximum Cluster Sizes"
        threshold_label = f"Discrete Threshold (p<{alpha}): {threshold:.1f} voxels"
    else:
        x_label = "Maximum Cluster Mass (sum of t-statistics)"
        title = "Permutation Null Distribution of Maximum Cluster Mass"
        threshold_label = f"Discrete Threshold (p<{alpha}): {threshold:.2f}"

    # Histogram
    if sns is not None:
        sns.histplot(
            null_distribution,
            bins=200,
            alpha=0.7,
            color="gray",
            edgecolor="black",
            label="Null Distribution",
            ax=ax,
        )
    else:
        ax.hist(
            null_distribution,
            bins=200,
            alpha=0.7,
            color="gray",
            edgecolor="black",
            label="Null Distribution",
        )

    # Threshold line
    ax.axvline(
        threshold, color="red", linestyle="--", linewidth=2, label=threshold_label
    )

    # Observed clusters
    sig_plotted = False
    nonsig_plotted = False
    for cluster in observed_clusters:
        stat_value = float(cluster["stat_value"])
        p_value = cluster.get("p_value", None)
        if p_value is not None:
            is_significant = float(p_value) < 0.05
        else:
            is_significant = stat_value > threshold

        color = "green" if is_significant else "orange"
        label = None
        if is_significant and not sig_plotted:
            label = "Significant Clusters (p<0.05)"
            sig_plotted = True
        elif (not is_significant) and (not nonsig_plotted):
            label = "Non-significant Clusters (p≥0.05)"
            nonsig_plotted = True

        ax.axvline(
            stat_value, color=color, linestyle="-", linewidth=2, alpha=0.7, label=label
        )

    ax.set_xlabel(x_label, fontsize=12)
    ax.set_ylabel("Frequency", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.legend(loc="upper right", fontsize=10)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()

    return savefig_close(fig, output_file, fmt="pdf", opts=SaveFigOptions(dpi=dpi))

tit.plotting.stats.plot_cluster_size_mass_correlation

plot_cluster_size_mass_correlation(cluster_sizes: ndarray, cluster_masses: ndarray, output_file: str, *, dpi: int = 300) -> str | None

Plot correlation between cluster size and cluster mass from permutation null distribution.

Source code in tit/plotting/stats.py
def plot_cluster_size_mass_correlation(
    cluster_sizes: np.ndarray,
    cluster_masses: np.ndarray,
    output_file: str,
    *,
    dpi: int = 300,
) -> str | None:
    """
    Plot correlation between cluster size and cluster mass from permutation null distribution.
    """
    from scipy.stats import pearsonr

    ensure_headless_matplotlib_backend()
    import matplotlib.pyplot as plt

    import seaborn as sns

    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.0)

    # Remove zeros
    mask = (cluster_sizes > 0) & (cluster_masses > 0)
    sizes_nonzero = cluster_sizes[mask]
    masses_nonzero = cluster_masses[mask]
    if len(sizes_nonzero) < 2:
        return None

    r_value, p_value = pearsonr(sizes_nonzero, masses_nonzero)

    fig, ax = plt.subplots(figsize=(10, 8))

    if sns is not None:
        sns.regplot(
            x=sizes_nonzero,
            y=masses_nonzero,
            ax=ax,
            scatter_kws={
                "alpha": 0.6,
                "s": 50,
                "color": "steelblue",
                "edgecolors": "black",
                "linewidths": 0.5,
            },
            line_kws={"color": "red", "linewidth": 2},
        )
    else:
        ax.scatter(
            sizes_nonzero,
            masses_nonzero,
            alpha=0.6,
            s=50,
            c="steelblue",
            edgecolors="black",
            linewidths=0.5,
        )
        z = np.polyfit(sizes_nonzero, masses_nonzero, 1)
        xs = np.linspace(
            float(np.min(sizes_nonzero)), float(np.max(sizes_nonzero)), 100
        )
        ax.plot(xs, z[0] * xs + z[1], color="red", linewidth=2)

    z = np.polyfit(sizes_nonzero, masses_nonzero, 1)
    ax.set_xlabel("Maximum Cluster Size (voxels)", fontsize=12, fontweight="bold")
    ax.set_ylabel(
        "Maximum Cluster Mass (sum of t-statistics)", fontsize=12, fontweight="bold"
    )
    ax.set_title(
        f"Cluster Size vs Cluster Mass Correlation\nPearson r = {r_value:.3f} (p = {p_value:.2e})",
        fontsize=14,
        fontweight="bold",
    )

    textstr = (
        f"n = {len(sizes_nonzero)} permutations\n"
        f"r = {r_value:.3f}\n"
        f"p = {p_value:.2e}\n"
        f"Linear fit: y = {z[0]:.2f}x + {z[1]:.2f}"
    )
    props = dict(boxstyle="round", facecolor="wheat", alpha=0.8)
    ax.text(
        0.05,
        0.95,
        textstr,
        transform=ax.transAxes,
        fontsize=11,
        verticalalignment="top",
        bbox=props,
    )

    ax.grid(True, alpha=0.3)
    fig.tight_layout()

    return savefig_close(fig, output_file, fmt="pdf", opts=SaveFigOptions(dpi=dpi))

Static Overlays

tit.plotting.static_overlay.generate_static_overlay_images

generate_static_overlay_images(*, t1_file: str, overlay_file: str, subject_id: str | None = None, montage_name: str | None = None, output_dir: str | None = None) -> dict[str, list[dict[str, Any]]]

Generate static overlay images for axial, sagittal, and coronal views.

'axial', 'sagittal', 'coronal'. Each value is a list of dicts:
  • base64: base64-encoded PNG
  • slice_num: 1-based slice index within that orientation
  • overlay_voxels: number of non-zero overlay voxels in that slice
Source code in tit/plotting/static_overlay.py
def generate_static_overlay_images(
    *,
    t1_file: str,
    overlay_file: str,
    subject_id: str | None = None,
    montage_name: str | None = None,
    output_dir: str | None = None,
) -> dict[str, list[dict[str, Any]]]:
    """Generate static overlay images for axial, sagittal, and coronal views.

    Returns a dict with keys: 'axial', 'sagittal', 'coronal'. Each value is a list of dicts:
      - base64: base64-encoded PNG
      - slice_num: 1-based slice index within that orientation
      - overlay_voxels: number of non-zero overlay voxels in that slice
    """
    # Kept as a local import so importing tit.plotting doesn't require these deps.
    import base64
    import io

    import nibabel as nib
    import numpy as np
    import matplotlib.pyplot as plt
    from scipy.ndimage import zoom

    # Load NIfTI files
    t1_img = nib.load(t1_file)
    overlay_img = nib.load(overlay_file)

    # Get data arrays
    t1_data = t1_img.get_fdata()
    overlay_data = overlay_img.get_fdata()

    # Handle 4D arrays (take first volume)
    if len(overlay_data.shape) == 4:
        overlay_data = overlay_data[..., 0]

    # Check if dimensions match and adjust if needed
    if t1_data.shape != overlay_data.shape:
        # Resample overlay to match T1 dimensions
        zoom_factors = [t1_data.shape[i] / overlay_data.shape[i] for i in range(3)]
        overlay_data = zoom(overlay_data, zoom_factors, order=1)

    # Get voxel dimensions (spacing) from header
    voxel_sizes = t1_img.header.get_zooms()[:3]  # x, y, z dimensions in mm

    # Normalize T1 data for display (robust normalization)
    nonzero = t1_data[t1_data > 0]
    if nonzero.size == 0:
        # Degenerate case; keep as-is
        t1_min, t1_max = float(np.min(t1_data)), float(np.max(t1_data))
    else:
        t1_min, t1_max = np.percentile(nonzero, [2, 98])
    denom = (t1_max - t1_min) if (t1_max - t1_min) != 0 else 1.0
    t1_normalized = np.clip((t1_data - t1_min) / denom, 0, 1)

    # Normalize overlay data
    overlay_max = float(np.max(overlay_data))
    if overlay_max > 0:
        overlay_normalized = overlay_data / overlay_max
    else:
        overlay_normalized = overlay_data

    # Create mask for non-zero overlay values
    overlay_mask = overlay_data > (overlay_max * 0.1)  # Show values above 10% of max

    # Get dimensions for slice planning
    dims = t1_data.shape

    # Define slice positions for each orientation (create 7 slices each)
    num_slices = 7

    def safe_slices(dim_size: int, n: int) -> np.ndarray:
        start = dim_size // 4
        end = min((dim_size * 3) // 4, dim_size - 1)
        return np.linspace(start, end, n).astype(int)

    slice_positions = {
        "axial": safe_slices(dims[2], num_slices),
        "sagittal": safe_slices(dims[0], num_slices),
        "coronal": safe_slices(dims[1], num_slices),
    }

    # Create colormap for overlay (hot colormap)
    cmap = plt.cm.hot
    cmap.set_bad(color=(0, 0, 0, 0))  # transparent for masked values

    # Calculate aspect ratios for each view based on voxel dimensions
    aspects = {
        "axial": voxel_sizes[1] / voxel_sizes[0],  # y/x ratio
        "sagittal": voxel_sizes[2] / voxel_sizes[1],  # z/y ratio
        "coronal": voxel_sizes[2] / voxel_sizes[0],  # z/x ratio
    }

    generated_images: dict[str, list[dict[str, Any]]] = {
        "axial": [],
        "sagittal": [],
        "coronal": [],
    }

    orientations = [
        ("axial", 2, aspects["axial"]),  # slice along z-axis
        ("sagittal", 0, aspects["sagittal"]),  # slice along x-axis
        ("coronal", 1, aspects["coronal"]),  # slice along y-axis
    ]

    for orientation, axis, aspect_ratio in orientations:
        positions = slice_positions[orientation]

        for i, slice_pos in enumerate(positions):
            # Extract slice data based on orientation
            if orientation == "axial":
                t1_slice = t1_normalized[:, :, slice_pos]
                overlay_slice = overlay_normalized[:, :, slice_pos]
                mask_slice = overlay_mask[:, :, slice_pos]
            elif orientation == "sagittal":
                t1_slice = t1_normalized[slice_pos, :, :]
                overlay_slice = overlay_normalized[slice_pos, :, :]
                mask_slice = overlay_mask[slice_pos, :, :]
            else:  # coronal
                t1_slice = t1_normalized[:, slice_pos, :]
                overlay_slice = overlay_normalized[:, slice_pos, :]
                mask_slice = overlay_mask[:, slice_pos, :]

            # Orientation corrections
            t1_slice = np.rot90(t1_slice, k=1)
            overlay_slice = np.rot90(overlay_slice, k=1)
            mask_slice = np.rot90(mask_slice, k=1)
            if orientation == "coronal":
                # Flip for neurological convention
                t1_slice = np.fliplr(t1_slice)
                overlay_slice = np.fliplr(overlay_slice)
                mask_slice = np.fliplr(mask_slice)

            overlay_masked = np.ma.masked_where(~mask_slice, overlay_slice)

            fig, ax = plt.subplots(1, 1, figsize=(4, 4 * aspect_ratio), dpi=100)
            try:
                ax.imshow(
                    t1_slice,
                    cmap="gray",
                    alpha=1.0,
                    aspect=aspect_ratio,
                    vmin=0,
                    vmax=1,
                )

                overlay_voxels = int(np.sum(mask_slice))
                if overlay_voxels > 0:
                    ax.imshow(
                        overlay_masked,
                        cmap=cmap,
                        alpha=0.6,
                        aspect=aspect_ratio,
                        vmin=0,
                        vmax=1,
                    )

                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_title(
                    f"{orientation.title()} {i+1}",
                    fontsize=12,
                    fontweight="bold",
                    pad=10,
                )

                # Compact orientation labels
                if orientation == "axial":
                    ax.text(
                        0.05,
                        0.95,
                        "L",
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight="bold",
                        color="white",
                        va="top",
                        ha="left",
                    )
                    ax.text(
                        0.95,
                        0.95,
                        "R",
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight="bold",
                        color="white",
                        va="top",
                        ha="right",
                    )
                elif orientation == "sagittal":
                    ax.text(
                        0.05,
                        0.95,
                        "A",
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight="bold",
                        color="white",
                        va="top",
                        ha="left",
                    )
                    ax.text(
                        0.95,
                        0.95,
                        "P",
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight="bold",
                        color="white",
                        va="top",
                        ha="right",
                    )
                else:  # coronal
                    ax.text(
                        0.05,
                        0.95,
                        "R",
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight="bold",
                        color="white",
                        va="top",
                        ha="left",
                    )
                    ax.text(
                        0.95,
                        0.95,
                        "L",
                        transform=ax.transAxes,
                        fontsize=10,
                        fontweight="bold",
                        color="white",
                        va="top",
                        ha="right",
                    )

                buf = io.BytesIO()
                plt.savefig(
                    buf,
                    dpi=100,
                    bbox_inches="tight",
                    facecolor="white",
                    edgecolor="none",
                    format="png",
                )
                buf.seek(0)
                image_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
            finally:
                plt.close(fig)

            generated_images[orientation].append(
                {
                    "base64": image_base64,
                    "slice_num": i + 1,
                    "overlay_voxels": overlay_voxels,
                }
            )

    return generated_images

Common Utilities

tit.plotting._common.SaveFigOptions dataclass

SaveFigOptions(dpi: int = 600, bbox_inches: str = 'tight', facecolor: str = 'white', edgecolor: str = 'none')

tit.plotting._common.ensure_headless_matplotlib_backend

ensure_headless_matplotlib_backend(backend: str = 'Agg') -> None

Best-effort backend setup for headless environments.

Important: - This should be called BEFORE importing matplotlib.pyplot. - If a backend is already active, we do not force-change it.

Source code in tit/plotting/_common.py
def ensure_headless_matplotlib_backend(backend: str = "Agg") -> None:
    """
    Best-effort backend setup for headless environments.

    Important:
    - This should be called BEFORE importing matplotlib.pyplot.
    - If a backend is already active, we do not force-change it.
    """
    import os
    import matplotlib

    os.environ.setdefault("MPLBACKEND", backend)

    # Silence noisy `findfont:` chatter (safe even if pyplot was already imported).
    suppress_matplotlib_findfont_noise()

    current = str(matplotlib.get_backend() or "")
    if current and current.lower() != backend.lower():
        # Backend already selected; don't override.
        return

    matplotlib.use(backend)  # type: ignore[attr-defined]

tit.plotting._common.savefig_close

savefig_close(fig: Any, output_file: str, *, fmt: str | None = None, opts: SaveFigOptions = SaveFigOptions()) -> str

Save a matplotlib Figure and close it.

Uses fig.savefig (not plt.savefig) to avoid relying on global pyplot state.

Source code in tit/plotting/_common.py
def savefig_close(
    fig: Any,
    output_file: str,
    *,
    fmt: str | None = None,
    opts: SaveFigOptions = SaveFigOptions(),
) -> str:
    """
    Save a matplotlib Figure and close it.

    Uses fig.savefig (not plt.savefig) to avoid relying on global pyplot state.
    """
    fig.savefig(
        output_file,
        dpi=opts.dpi,
        bbox_inches=opts.bbox_inches,
        facecolor=opts.facecolor,
        edgecolor=opts.edgecolor,
        format=fmt,
    )
    import matplotlib.pyplot as plt

    plt.close(fig)

    return output_file