Skip to content

Statistics

The statistics module performs cluster-based permutation testing on MNI-space NIfTI volumes produced by the simulation pipeline. It supports two analysis types: group comparison (responders vs non-responders) and correlation (voxelwise correlation with a continuous outcome measure, ACES-style).

graph LR
    NIFTI[MNI NIfTI Volumes] --> LOAD[Load Group Data]
    CSV[Subject CSV] --> CFG[Config]
    CFG --> TEST[Voxelwise Test]
    LOAD --> TEST
    TEST --> PERM[Permutation Engine]
    PERM --> SIG[Significant Clusters]
    SIG --> REPORT[Summary Report]
    SIG --> PLOTS[Null Distribution Plots]
    SIG --> MAPS[NIfTI Output Maps]
    style PERM fill:#2d5a27,stroke:#4a8,color:#fff

Group Comparison

Compare two groups (e.g., responders vs non-responders) using cluster-based permutation testing with voxelwise t-tests. Supports both unpaired (independent samples) and paired designs.

from tit.stats import GroupComparisonConfig, run_group_comparison

# Load subjects from CSV (columns: subject_id, simulation_name, response)
subjects = GroupComparisonConfig.load_subjects("/data/my_project/subjects.csv")

config = GroupComparisonConfig(
    project_dir="/data/my_project",
    analysis_name="responder_comparison",
    subjects=subjects,
    test_type=GroupComparisonConfig.TestType.UNPAIRED,
    alternative=GroupComparisonConfig.Alternative.TWO_SIDED,
    n_permutations=5000,
    alpha=0.05,
    cluster_threshold=0.05,
    cluster_stat=GroupComparisonConfig.ClusterStat.MASS,
    tissue_type=GroupComparisonConfig.TissueType.GREY,
    group1_name="Responders",
    group2_name="Non-Responders",
)

result = run_group_comparison(config)
print(f"Significant clusters: {result.n_significant_clusters}")
print(f"Significant voxels:   {result.n_significant_voxels}")
print(f"Analysis time:        {result.analysis_time:.1f}s")
print(f"Output directory:     {result.output_dir}")

Correlation

Test for voxelwise correlation between electric field magnitude and a continuous outcome measure (e.g., clinical effect size). Supports Pearson and Spearman correlation with optional subject-level weights.

from tit.stats import CorrelationConfig, run_correlation

# Load subjects from CSV (columns: subject_id, simulation_name, effect_size; optional: weight)
subjects = CorrelationConfig.load_subjects("/data/my_project/correlation_subjects.csv")

config = CorrelationConfig(
    project_dir="/data/my_project",
    analysis_name="efield_outcome_correlation",
    subjects=subjects,
    correlation_type=CorrelationConfig.CorrelationType.PEARSON,
    n_permutations=5000,
    alpha=0.05,
    cluster_threshold=0.05,
    cluster_stat=CorrelationConfig.ClusterStat.MASS,
    tissue_type=CorrelationConfig.TissueType.GREY,
    use_weights=True,
    effect_metric="Clinical Improvement",
)

result = run_correlation(config)
print(f"Significant clusters: {result.n_significant_clusters}")
print(f"Significant voxels:   {result.n_significant_voxels}")

Configuration

Subject Definitions

Both analysis types use a nested Subject dataclass to define per-subject metadata. Subjects are typically loaded from CSV files via the load_subjects class method.

CSV columns: subject_id, simulation_name, response (0 or 1).

subjects = GroupComparisonConfig.load_subjects("/data/my_project/subjects.csv")
Field Type Description
subject_id str Subject identifier (e.g., "001")
simulation_name str Name of the simulation to load
response int Group assignment: 1 = group 1, 0 = group 2

CSV columns: subject_id, simulation_name, effect_size; optional: weight.

subjects = CorrelationConfig.load_subjects("/data/my_project/subjects.csv")
Field Type Default Description
subject_id str Subject identifier
simulation_name str Name of the simulation to load
effect_size float Continuous outcome measure
weight float 1.0 Subject-level weight

Enums

Enum Values Used By
GroupComparisonConfig.TestType UNPAIRED, PAIRED Group comparison only
GroupComparisonConfig.Alternative TWO_SIDED, GREATER, LESS Group comparison only
CorrelationConfig.CorrelationType PEARSON, SPEARMAN Correlation only
ClusterStat MASS, SIZE Both (nested as GroupComparisonConfig.ClusterStat / CorrelationConfig.ClusterStat)
TissueType GREY, WHITE, ALL Both (nested as GroupComparisonConfig.TissueType / CorrelationConfig.TissueType)

Statistical Parameters

These parameters are shared by both GroupComparisonConfig and CorrelationConfig:

Parameter Type Default Description
project_dir str Path to the BIDS project root
analysis_name str Name for this analysis run
subjects list[Subject] List of subject definitions
cluster_threshold float 0.05 Uncorrected p-value threshold for cluster formation
cluster_stat ClusterStat MASS Cluster statistic: "mass" (sum of t-values) or "size" (voxel count)
n_permutations int 1000 Number of permutations for null distribution
alpha float 0.05 Family-wise error rate for cluster significance
n_jobs int -1 Number of parallel jobs (-1 = all CPUs)
tissue_type TissueType GREY Tissue mask: "grey", "white", or "all"
nifti_file_pattern str \| None None Custom NIfTI filename pattern (auto-resolved from tissue_type if None)
atlas_files list[str] [] Atlas filenames for overlap analysis

NIfTI File Patterns

The tissue_type field determines which NIfTI files are loaded from each subject's simulation output:

TissueType Resolved Pattern
GREY grey_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz
WHITE white_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz
ALL {simulation_name}_TI_MNI_MNI_TI_max.nii.gz

Set nifti_file_pattern to override the auto-resolved pattern.

Output

Results are saved to derivatives/ti-toolbox/stats/<analysis_type>/<analysis_name>/ within the project directory. Both analysis types produce the following:

File Description
significant_voxels_mask.nii.gz Binary mask of significant voxels
pvalues_map.nii.gz Negative log10 p-value map
permutation_null_distribution.pdf Null distribution with observed clusters
cluster_size_mass_correlation.pdf Cluster size vs mass scatter plot
analysis_summary.txt Text summary of results
permutation_details.txt Per-permutation log
*_analysis_*.log Timestamped run log

Group comparison additionally produces:

File Description
average_responders.nii.gz Mean field map for group 1
average_non_responders.nii.gz Mean field map for group 2
difference_map.nii.gz Group 1 minus group 2 difference map

Correlation additionally produces:

File Description
correlation_map.nii.gz Full voxelwise correlation map
correlation_map_thresholded.nii.gz Correlation map masked to significant voxels
t_statistics_map.nii.gz Voxelwise t-statistic map
average_efield.nii.gz Mean electric field across all subjects

Result Dataclasses

GroupComparisonResult

Field Type Description
success bool Whether the analysis completed
output_dir str Path to the output directory
n_responders int Number of group 1 subjects
n_non_responders int Number of group 2 subjects
n_significant_voxels int Count of significant voxels
n_significant_clusters int Count of significant clusters
cluster_threshold float Cluster statistic threshold from null distribution
analysis_time float Total runtime in seconds
clusters list Cluster details (size, MNI center)
log_file str Path to the analysis log

CorrelationResult

Field Type Description
success bool Whether the analysis completed
output_dir str Path to the output directory
n_subjects int Number of subjects
n_significant_voxels int Count of significant voxels
n_significant_clusters int Count of significant clusters
cluster_threshold float Cluster statistic threshold from null distribution
analysis_time float Total runtime in seconds
clusters list Cluster details (size, MNI center, mean/peak r)
log_file str Path to the analysis log

API Reference

tit.stats.config.GroupComparisonConfig dataclass

GroupComparisonConfig(project_dir: str, analysis_name: str, subjects: list[Subject], test_type: TestType = UNPAIRED, alternative: Alternative = TWO_SIDED, cluster_threshold: float = 0.05, cluster_stat: ClusterStat = MASS, n_permutations: int = 1000, alpha: float = 0.05, n_jobs: int = -1, tissue_type: TissueType = GREY, nifti_file_pattern: str | None = None, group1_name: str = 'Responders', group2_name: str = 'Non-Responders', value_metric: str = 'Current Intensity', atlas_files: list[str] = list())

Configuration for group comparison permutation testing.

Subject dataclass

Subject(subject_id: str, simulation_name: str, response: int)

A single subject in a group comparison analysis.

load_subjects classmethod

load_subjects(csv_path: str) -> list[Subject]

Load group comparison subjects from a CSV file.

Expected columns: subject_id, simulation_name, response (0 or 1).

Source code in tit/stats/config.py
@classmethod
def load_subjects(cls, csv_path: str) -> list["GroupComparisonConfig.Subject"]:
    """Load group comparison subjects from a CSV file.

    Expected columns: subject_id, simulation_name, response (0 or 1).
    """
    import pandas as pd

    df = pd.read_csv(csv_path)
    required = {"subject_id", "simulation_name", "response"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"CSV missing required columns: {missing}")

    subjects = []
    for _, row in df.iterrows():
        sid = str(row["subject_id"]).replace("sub-", "")
        if sid.endswith(".0"):
            sid = sid[:-2]
        subjects.append(
            cls.Subject(
                subject_id=sid,
                simulation_name=str(row["simulation_name"]),
                response=int(row["response"]),
            )
        )
    return subjects

tit.stats.config.GroupComparisonResult dataclass

GroupComparisonResult(success: bool, output_dir: str, n_responders: int, n_non_responders: int, n_significant_voxels: int, n_significant_clusters: int, cluster_threshold: float, analysis_time: float, clusters: list, log_file: str)

Result of a group comparison permutation test.

tit.stats.config.CorrelationConfig dataclass

CorrelationConfig(project_dir: str, analysis_name: str, subjects: list[Subject], correlation_type: CorrelationType = PEARSON, cluster_threshold: float = 0.05, cluster_stat: ClusterStat = MASS, n_permutations: int = 1000, alpha: float = 0.05, n_jobs: int = -1, use_weights: bool = True, tissue_type: TissueType = GREY, nifti_file_pattern: str | None = None, effect_metric: str = 'Effect Size', field_metric: str = 'Electric Field Magnitude', atlas_files: list[str] = list())

Configuration for correlation-based permutation testing.

Subject dataclass

Subject(subject_id: str, simulation_name: str, effect_size: float, weight: float = 1.0)

A single subject in a correlation analysis.

load_subjects classmethod

load_subjects(csv_path: str) -> list[Subject]

Load correlation subjects from a CSV file.

Expected columns: subject_id, simulation_name, effect_size. Optional column: weight.

Source code in tit/stats/config.py
@classmethod
def load_subjects(cls, csv_path: str) -> list["CorrelationConfig.Subject"]:
    """Load correlation subjects from a CSV file.

    Expected columns: subject_id, simulation_name, effect_size.
    Optional column: weight.
    """
    import pandas as pd

    df = pd.read_csv(csv_path)
    required = {"subject_id", "simulation_name", "effect_size"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"CSV missing required columns: {missing}")

    has_weights = "weight" in df.columns
    subjects = []
    for _, row in df.iterrows():
        if pd.isna(row["subject_id"]) or pd.isna(row["effect_size"]):
            continue

        sid = row["subject_id"]
        if isinstance(sid, float):
            sid = str(int(sid))
        else:
            sid = str(sid).replace("sub-", "")
            if sid.endswith(".0"):
                sid = sid[:-2]

        weight = (
            float(row["weight"])
            if has_weights and pd.notna(row.get("weight"))
            else 1.0
        )
        subjects.append(
            cls.Subject(
                subject_id=sid,
                simulation_name=str(row["simulation_name"]),
                effect_size=float(row["effect_size"]),
                weight=weight,
            )
        )

    if not subjects:
        raise ValueError("No valid subjects found in CSV")
    return subjects

tit.stats.config.CorrelationResult dataclass

CorrelationResult(success: bool, output_dir: str, n_subjects: int, n_significant_voxels: int, n_significant_clusters: int, cluster_threshold: float, analysis_time: float, clusters: list, log_file: str)

Result of a correlation permutation test.

tit.stats.permutation.run_group_comparison

run_group_comparison(config: GroupComparisonConfig, callback_handler=None, stop_callback=None) -> GroupComparisonResult

Run cluster-based permutation testing for group comparison.

Parameters

config : GroupComparisonConfig Fully specified configuration. callback_handler : logging.Handler, optional GUI log handler. stop_callback : callable, optional Returns True to abort.

Source code in tit/stats/permutation.py
def run_group_comparison(
    config: GroupComparisonConfig,
    callback_handler=None,
    stop_callback=None,
) -> GroupComparisonResult:
    """Run cluster-based permutation testing for group comparison.

    Parameters
    ----------
    config : GroupComparisonConfig
        Fully specified configuration.
    callback_handler : logging.Handler, optional
        GUI log handler.
    stop_callback : callable, optional
        Returns True to abort.
    """
    t0 = time.time()
    output_dir = _resolve_output_dir(
        config.project_dir,
        "group_comparison",
        config.analysis_name,
    )
    log, log_file = _setup_logger(output_dir, "group_comparison", callback_handler)

    log.info("=" * 70)
    log.info("CLUSTER-BASED PERMUTATION TESTING — GROUP COMPARISON")
    log.info("=" * 70)
    log.info("Analysis: %s", config.analysis_name)
    log.info("Output:   %s", output_dir)
    log.info(
        "Config:   test=%s  alt=%s  stat=%s  threshold=%.3f  perms=%d  alpha=%.3f  jobs=%d",
        config.test_type.value,
        config.alternative.value,
        config.cluster_stat.value,
        config.cluster_threshold,
        config.n_permutations,
        config.alpha,
        config.n_jobs,
    )

    # ── 1. Load data ─────────────────────────────────────────────────────
    log.info("[1/8] Loading subject data")
    step = time.time()

    resp_configs = [
        {"subject_id": s.subject_id, "simulation_name": s.simulation_name}
        for s in config.subjects
        if s.response == 1
    ]
    non_resp_configs = [
        {"subject_id": s.subject_id, "simulation_name": s.simulation_name}
        for s in config.subjects
        if s.response == 0
    ]

    responders, template_img, resp_ids = load_group_data_ti_toolbox(
        resp_configs,
        nifti_file_pattern=config.nifti_file_pattern,
        dtype=np.float32,
    )
    non_responders, _, non_resp_ids = load_group_data_ti_toolbox(
        non_resp_configs,
        nifti_file_pattern=config.nifti_file_pattern,
        dtype=np.float32,
    )

    log.info(
        "Loaded %d %s: %s",
        len(resp_ids),
        config.group1_name,
        resp_ids,
    )
    log.info(
        "Loaded %d %s: %s",
        len(non_resp_ids),
        config.group2_name,
        non_resp_ids,
    )
    log.info("Image shape: %s  (%.1fs)", responders.shape[:3], time.time() - step)

    if stop_callback and stop_callback():
        raise KeyboardInterrupt("Stopped by user")

    # ── 2. Voxelwise t-test ──────────────────────────────────────────────
    log.info("[2/8] Voxelwise statistical tests")
    step = time.time()

    p_values, t_statistics, valid_mask = ttest_voxelwise(
        responders,
        non_responders,
        test_type=config.test_type.value,
        alternative=config.alternative.value,
        log=log,
    )

    log.info(
        "Min p=%.2e, p<0.05: %d  (%.1fs)",
        np.min(p_values[valid_mask]),
        np.sum((p_values < 0.05) & valid_mask),
        time.time() - step,
    )

    if stop_callback and stop_callback():
        raise KeyboardInterrupt("Stopped by user")

    # ── 3. Permutation correction ────────────────────────────────────────
    log.info(
        "[3/8] Cluster-based permutation correction (%d perms)", config.n_permutations
    )
    step = time.time()

    perm_log_file = os.path.join(output_dir, "permutation_details.txt")

    engine = PermutationEngine(
        cluster_threshold=config.cluster_threshold,
        n_permutations=config.n_permutations,
        alpha=config.alpha,
        cluster_stat=config.cluster_stat.value,
        alternative=config.alternative.value,
        n_jobs=config.n_jobs,
        log=log,
    )
    sig_mask, cluster_threshold, sig_clusters, null_dist, all_clusters, corr_data = (
        engine.correct_groups(
            responders,
            non_responders,
            p_values=p_values,
            t_statistics=t_statistics,
            valid_mask=valid_mask,
            test_type=config.test_type.value,
            perm_log_file=perm_log_file,
            subject_ids_resp=resp_ids,
            subject_ids_non_resp=non_resp_ids,
        )
    )

    log.info(
        "Significant clusters: %d, voxels: %d  (%.1fs)",
        len(sig_clusters),
        np.sum(sig_mask),
        time.time() - step,
    )

    # ── 4. Cluster analysis ──────────────────────────────────────────────
    log.info("[4/8] Cluster analysis")
    clusters = cluster_analysis(sig_mask, template_img.affine, log=log)

    # ── 5. Plots ─────────────────────────────────────────────────────────
    log.info("[5/8] Generating plots")
    perm_plot = os.path.join(output_dir, "permutation_null_distribution.pdf")
    plot_permutation_null_distribution(
        null_dist,
        cluster_threshold,
        all_clusters,
        perm_plot,
        alpha=config.alpha,
        cluster_stat=config.cluster_stat.value,
    )
    corr_plot = os.path.join(output_dir, "cluster_size_mass_correlation.pdf")
    plot_cluster_size_mass_correlation(
        corr_data["sizes"],
        corr_data["masses"],
        corr_plot,
    )

    # ── 6. Average maps ─────────────────────────────────────────────────
    log.info("[6/8] Average intensity maps")
    avg_resp = np.mean(responders, axis=-1).astype(np.float32)
    _save_nifti(
        avg_resp, template_img, os.path.join(output_dir, "average_responders.nii.gz")
    )

    avg_non = np.mean(non_responders, axis=-1).astype(np.float32)
    _save_nifti(
        avg_non, template_img, os.path.join(output_dir, "average_non_responders.nii.gz")
    )

    diff = (avg_resp - avg_non).astype(np.float32)
    _save_nifti(diff, template_img, os.path.join(output_dir, "difference_map.nii.gz"))

    # ── 7. Atlas overlap ─────────────────────────────────────────────────
    log.info("[7/8] Atlas overlap")
    atlas_results = {}
    if config.atlas_files:
        if _ATLAS_DIR.exists():
            atlas_results = atlas_overlap_analysis(
                sig_mask,
                config.atlas_files,
                str(_ATLAS_DIR),
                reference_img=template_img,
            )

    # ── 8. Save outputs ─────────────────────────────────────────────────
    log.info("[8/8] Saving results")
    _save_nifti(
        sig_mask.astype(np.uint8),
        template_img,
        os.path.join(output_dir, "significant_voxels_mask.nii.gz"),
    )

    log_p = -np.log10(p_values + 1e-10)
    log_p[~valid_mask] = 0
    _save_nifti(log_p, template_img, os.path.join(output_dir, "pvalues_map.nii.gz"))

    summary_path = os.path.join(output_dir, "analysis_summary.txt")
    generate_summary(
        config,
        responders,
        non_responders,
        sig_mask,
        cluster_threshold,
        clusters,
        atlas_results,
        summary_path,
    )

    total = time.time() - t0
    log.info(
        "COMPLETE in %.1fs — %d sig clusters, %d sig voxels",
        total,
        len(sig_clusters),
        np.sum(sig_mask),
    )

    # Cleanup
    del responders, non_responders, p_values, t_statistics
    gc.collect()
    for h in log.handlers[:]:
        h.close()
        log.removeHandler(h)

    return GroupComparisonResult(
        success=True,
        output_dir=output_dir,
        n_responders=len(resp_ids),
        n_non_responders=len(non_resp_ids),
        n_significant_voxels=int(np.sum(sig_mask)),
        n_significant_clusters=len(sig_clusters),
        cluster_threshold=float(cluster_threshold),
        analysis_time=total,
        clusters=clusters,
        log_file=log_file,
    )

tit.stats.permutation.run_correlation

run_correlation(config: CorrelationConfig, callback_handler=None, stop_callback=None) -> CorrelationResult

Run cluster-based permutation testing for correlation (ACES-style).

Parameters

config : CorrelationConfig Fully specified configuration. callback_handler : logging.Handler, optional GUI log handler. stop_callback : callable, optional Returns True to abort.

Source code in tit/stats/permutation.py
def run_correlation(
    config: CorrelationConfig,
    callback_handler=None,
    stop_callback=None,
) -> CorrelationResult:
    """Run cluster-based permutation testing for correlation (ACES-style).

    Parameters
    ----------
    config : CorrelationConfig
        Fully specified configuration.
    callback_handler : logging.Handler, optional
        GUI log handler.
    stop_callback : callable, optional
        Returns True to abort.
    """
    t0 = time.time()
    output_dir = _resolve_output_dir(
        config.project_dir,
        "correlation",
        config.analysis_name,
    )
    log, log_file = _setup_logger(output_dir, "correlation", callback_handler)

    log.info("=" * 70)
    log.info("CORRELATION-BASED CLUSTER PERMUTATION TESTING (ACES-style)")
    log.info("=" * 70)
    log.info("Analysis: %s", config.analysis_name)
    log.info("Output:   %s", output_dir)
    log.info(
        "Config:   corr=%s  stat=%s  threshold=%.3f  perms=%d  alpha=%.3f  jobs=%d",
        config.correlation_type.value,
        config.cluster_stat.value,
        config.cluster_threshold,
        config.n_permutations,
        config.alpha,
        config.n_jobs,
    )

    # ── 1. Load data ─────────────────────────────────────────────────────
    log.info("[1/7] Loading subject data")
    step = time.time()

    subject_dicts = [
        {"subject_id": s.subject_id, "simulation_name": s.simulation_name}
        for s in config.subjects
    ]
    subject_data, template_img, subject_ids = load_group_data_ti_toolbox(
        subject_dicts,
        nifti_file_pattern=config.nifti_file_pattern,
        dtype=np.float32,
    )

    # Build effect sizes / weights aligned with loaded subjects
    config_lookup = {s.subject_id: s for s in config.subjects}
    effect_sizes = np.array(
        [config_lookup[sid].effect_size for sid in subject_ids],
        dtype=np.float64,
    )
    weights = None
    if config.use_weights:
        weights = np.array(
            [config_lookup[sid].weight for sid in subject_ids],
            dtype=np.float64,
        )

    n_subjects = len(subject_ids)
    log.info("Loaded %d subjects: %s", n_subjects, subject_ids)
    log.info(
        "Effect sizes: mean=%.3f, std=%.3f, range=[%.3f, %.3f]",
        np.mean(effect_sizes),
        np.std(effect_sizes),
        np.min(effect_sizes),
        np.max(effect_sizes),
    )
    log.info("Data shape: %s  (%.1fs)", subject_data.shape[:3], time.time() - step)

    if stop_callback and stop_callback():
        raise KeyboardInterrupt("Stopped by user")

    # ── 2. Voxelwise correlation ─────────────────────────────────────────
    log.info("[2/7] Voxelwise correlation")
    step = time.time()

    r_values, t_statistics, p_values, valid_mask = correlation_voxelwise(
        subject_data,
        effect_sizes,
        weights=weights,
        correlation_type=config.correlation_type.value,
        log=log,
    )

    log.info("Correlation computed in %.1fs", time.time() - step)

    if stop_callback and stop_callback():
        raise KeyboardInterrupt("Stopped by user")

    # ── 3. Permutation correction ────────────────────────────────────────
    log.info(
        "[3/7] Cluster-based permutation correction (%d perms)", config.n_permutations
    )
    step = time.time()

    perm_log_file = os.path.join(output_dir, "permutation_details.txt")

    engine = PermutationEngine(
        cluster_threshold=config.cluster_threshold,
        n_permutations=config.n_permutations,
        alpha=config.alpha,
        cluster_stat=config.cluster_stat.value,
        alternative="two-sided",
        n_jobs=config.n_jobs,
        log=log,
    )
    sig_mask, cluster_threshold, sig_clusters, null_dist, all_clusters, corr_data = (
        engine.correct_correlation(
            subject_data,
            effect_sizes,
            r_values=r_values,
            t_statistics=t_statistics,
            p_values=p_values,
            valid_mask=valid_mask,
            correlation_type=config.correlation_type.value,
            weights=weights,
            perm_log_file=perm_log_file,
            subject_ids=subject_ids,
        )
    )

    log.info(
        "Significant clusters: %d, voxels: %d  (%.1fs)",
        len(sig_clusters),
        np.sum(sig_mask),
        time.time() - step,
    )

    # ── 4. Cluster analysis ──────────────────────────────────────────────
    log.info("[4/7] Cluster analysis")
    clusters = cluster_analysis(sig_mask, template_img.affine, log=log)

    # Annotate with correlation stats
    from scipy.ndimage import label as scipy_label

    labeled, _ = scipy_label(sig_mask)
    for c in clusters:
        c_mask = labeled == c["cluster_id"]
        c["mean_r"] = float(np.mean(r_values[c_mask]))
        c["peak_r"] = float(np.max(r_values[c_mask]))

    # ── 5. Plots ─────────────────────────────────────────────────────────
    log.info("[5/7] Generating plots")
    perm_plot = os.path.join(output_dir, "permutation_null_distribution.pdf")
    plot_permutation_null_distribution(
        null_dist,
        cluster_threshold,
        all_clusters,
        perm_plot,
        alpha=config.alpha,
        cluster_stat=config.cluster_stat.value,
    )
    if len(corr_data["sizes"]) > 0:
        corr_plot = os.path.join(output_dir, "cluster_size_mass_correlation.pdf")
        plot_cluster_size_mass_correlation(
            corr_data["sizes"],
            corr_data["masses"],
            corr_plot,
        )

    # ── 6. Atlas overlap ─────────────────────────────────────────────────
    log.info("[6/7] Atlas overlap")
    atlas_results = {}
    if config.atlas_files:
        if _ATLAS_DIR.exists():
            atlas_results = atlas_overlap_analysis(
                sig_mask,
                config.atlas_files,
                str(_ATLAS_DIR),
                reference_img=template_img,
            )

    # ── 7. Save outputs ──────────────────────────────────────────────────
    log.info("[7/7] Saving results")

    _save_nifti(
        sig_mask.astype(np.uint8),
        template_img,
        os.path.join(output_dir, "significant_voxels_mask.nii.gz"),
    )
    _save_nifti(
        r_values.astype(np.float32),
        template_img,
        os.path.join(output_dir, "correlation_map.nii.gz"),
    )
    _save_nifti(
        t_statistics.astype(np.float32),
        template_img,
        os.path.join(output_dir, "t_statistics_map.nii.gz"),
    )

    log_p = -np.log10(p_values + 1e-10)
    log_p[~valid_mask] = 0
    _save_nifti(log_p, template_img, os.path.join(output_dir, "pvalues_map.nii.gz"))

    r_thresh = r_values.copy()
    r_thresh[sig_mask == 0] = 0
    _save_nifti(
        r_thresh.astype(np.float32),
        template_img,
        os.path.join(output_dir, "correlation_map_thresholded.nii.gz"),
    )

    avg = np.mean(subject_data, axis=-1).astype(np.float32)
    _save_nifti(avg, template_img, os.path.join(output_dir, "average_efield.nii.gz"))

    summary_path = os.path.join(output_dir, "analysis_summary.txt")
    generate_correlation_summary(
        config,
        subject_data,
        effect_sizes,
        r_values,
        sig_mask,
        cluster_threshold,
        clusters,
        atlas_results,
        summary_path,
        subject_ids=subject_ids,
        weights=weights,
    )

    total = time.time() - t0
    log.info(
        "COMPLETE in %.1fs — %d sig clusters, %d sig voxels",
        total,
        len(sig_clusters),
        np.sum(sig_mask),
    )

    # Cleanup
    del subject_data, effect_sizes, weights, t_statistics, p_values
    gc.collect()
    for h in log.handlers[:]:
        h.close()
        log.removeHandler(h)

    return CorrelationResult(
        success=True,
        output_dir=output_dir,
        n_subjects=n_subjects,
        n_significant_voxels=int(np.sum(sig_mask)),
        n_significant_clusters=len(sig_clusters),
        cluster_threshold=float(cluster_threshold),
        analysis_time=total,
        clusters=clusters,
        log_file=log_file,
    )

tit.stats.engine.PermutationEngine

PermutationEngine(*, cluster_threshold: float = 0.05, n_permutations: int = 1000, alpha: float = 0.05, cluster_stat: str = 'mass', alternative: str = 'two-sided', n_jobs: int = -1, log: Logger | None = None)

Master class for cluster-based permutation testing.

Stores all control parameters as self.* so methods only need data arrays.

Source code in tit/stats/engine.py
def __init__(
    self,
    *,
    cluster_threshold: float = 0.05,
    n_permutations: int = 1000,
    alpha: float = 0.05,
    cluster_stat: str = "mass",
    alternative: str = "two-sided",
    n_jobs: int = -1,
    log: logging.Logger | None = None,
):
    self.cluster_threshold = cluster_threshold
    self.n_permutations = n_permutations
    self.alpha = alpha
    self.cluster_stat = cluster_stat
    self.alternative = alternative
    self.n_jobs = n_jobs
    self._log = log or logger

correct_groups

correct_groups(responders, non_responders, *, p_values, t_statistics, valid_mask, test_type: str = 'unpaired', perm_log_file: str | None = None, subject_ids_resp: list | None = None, subject_ids_non_resp: list | None = None) -> tuple

Cluster-based permutation correction for group comparison.

Returns (sig_mask, threshold, sig_clusters, null_dist, observed_clusters, correlation_data).

Source code in tit/stats/engine.py
def correct_groups(
    self,
    responders,
    non_responders,
    *,
    p_values,
    t_statistics,
    valid_mask,
    test_type: str = "unpaired",
    perm_log_file: str | None = None,
    subject_ids_resp: list | None = None,
    subject_ids_non_resp: list | None = None,
) -> tuple:
    """Cluster-based permutation correction for group comparison.

    Returns ``(sig_mask, threshold, sig_clusters, null_dist, observed_clusters,
    correlation_data)``.
    """
    from .io_utils import save_permutation_details

    if self.cluster_stat == "mass" and t_statistics is None:
        raise ValueError("t_statistics required when cluster_stat='mass'")

    self._log.info(
        "Cluster-based permutation correction (%s, %s)",
        self.cluster_stat,
        self.alternative,
    )

    initial_mask = (p_values < self.cluster_threshold) & valid_mask
    labeled_array, n_clusters = label(initial_mask)
    self._log.info(
        "Clusters at p<%.3f (uncorrected): %d",
        self.cluster_threshold,
        n_clusters,
    )

    empty = {"sizes": np.array([]), "masses": np.array([])}
    if n_clusters == 0:
        self._log.warning(
            "No clusters found. Consider increasing cluster_threshold."
        )
        return (
            np.zeros_like(p_values, dtype=int),
            0,
            [],
            np.array([]),
            [],
            empty,
        )

    # Pre-extract voxel data
    all_data = np.concatenate([responders, non_responders], axis=-1)
    n_resp = responders.shape[-1]
    n_total = n_resp + non_responders.shape[-1]

    test_coords = np.argwhere(valid_mask)
    n_test = len(test_coords)
    self._log.info("Pre-extracting %d voxels, %d subjects", n_test, n_total)

    test_data = np.zeros((n_test, n_total), dtype=np.float32)
    for idx, (i, j, k) in enumerate(test_coords):
        test_data[idx, :] = all_data[i, j, k, :].astype(np.float32)
    del all_data
    gc.collect()

    self._log.info("Test data: %.1f MB", test_data.nbytes / (1024**2))

    actual_jobs = multiprocessing.cpu_count() if self.n_jobs == -1 else self.n_jobs
    self._log.info(
        "Running %d permutations on %d cores",
        self.n_permutations,
        actual_jobs,
    )

    seeds = np.random.randint(0, 2**31, size=self.n_permutations)
    track = (
        perm_log_file is not None
        and subject_ids_resp is not None
        and subject_ids_non_resp is not None
    )

    if actual_jobs == 1:
        results = [
            _run_single_permutation(
                test_data,
                test_coords,
                n_resp,
                n_total,
                self.cluster_threshold,
                valid_mask,
                p_values.shape,
                test_type=test_type,
                alternative=self.alternative,
                cluster_stat=self.cluster_stat,
                seed=seeds[i],
                return_indices=track,
            )
            for i in range(self.n_permutations)
        ]
    else:
        results = Parallel(n_jobs=actual_jobs, verbose=0)(
            delayed(_run_single_permutation)(
                test_data,
                test_coords,
                n_resp,
                n_total,
                self.cluster_threshold,
                valid_mask,
                p_values.shape,
                test_type=test_type,
                alternative=self.alternative,
                cluster_stat=self.cluster_stat,
                seed=seeds[i],
                return_indices=track,
            )
            for i in range(self.n_permutations)
        )
        gc.collect()

    del test_data
    gc.collect()

    if track:
        null_stats = np.array([r[0] for r in results])
        perm_indices = [r[1] for r in results]
        null_sizes = [r[2] for r in results]
        null_masses = [r[3] for r in results]
    else:
        null_stats = np.array([r[0] for r in results])
        perm_indices = None
        null_sizes = [r[1] for r in results]
        null_masses = [r[2] for r in results]

    # Determine threshold
    sorted_null = np.sort(null_stats)[::-1]
    ti = max(1, min(int(self.alpha * self.n_permutations), len(sorted_null)))
    cluster_stat_threshold = sorted_null[ti - 1]

    stat_unit = "voxels" if self.cluster_stat == "size" else "mass units"
    self._log.info(
        "Threshold (p<%.3f): %.2f %s  " "(null min=%.2f, mean=%.2f, max=%.2f)",
        self.alpha,
        cluster_stat_threshold,
        stat_unit,
        np.min(null_stats),
        np.mean(null_stats),
        np.max(null_stats),
    )

    if perm_log_file is not None and perm_indices is not None:
        info = [
            {
                "perm_num": i,
                "perm_idx": perm_indices[i],
                "max_cluster_size": null_stats[i],
            }
            for i in range(self.n_permutations)
        ]
        save_permutation_details(
            info,
            perm_log_file,
            subject_ids_resp,
            subject_ids_non_resp,
        )
        self._log.info("Permutation log saved: %s", perm_log_file)

    # Identify significant clusters (MNE-style per-cluster p-values)
    sig_mask, sig_clusters, all_observed = _identify_significant_clusters(
        labeled_array,
        n_clusters,
        t_statistics,
        null_stats,
        self.cluster_stat,
        self.alpha,
        self.alternative,
        self._log,
    )

    self._log.info(
        "Significant: %d clusters, %d voxels",
        len(sig_clusters),
        np.sum(sig_mask),
    )

    correlation_data = {
        "sizes": np.array(null_sizes),
        "masses": np.array(null_masses),
    }
    return (
        sig_mask,
        cluster_stat_threshold,
        sig_clusters,
        null_stats,
        all_observed,
        correlation_data,
    )

correct_correlation

correct_correlation(subject_data, effect_sizes, *, r_values, t_statistics, p_values, valid_mask, correlation_type: str = 'pearson', weights=None, perm_log_file: str | None = None, subject_ids: list | None = None) -> tuple

Cluster-based permutation correction for correlation analysis.

Returns (sig_mask, threshold, sig_clusters, null_dist, observed_clusters, correlation_data).

Source code in tit/stats/engine.py
def correct_correlation(
    self,
    subject_data,
    effect_sizes,
    *,
    r_values,
    t_statistics,
    p_values,
    valid_mask,
    correlation_type: str = "pearson",
    weights=None,
    perm_log_file: str | None = None,
    subject_ids: list | None = None,
) -> tuple:
    """Cluster-based permutation correction for correlation analysis.

    Returns ``(sig_mask, threshold, sig_clusters, null_dist, observed_clusters,
    correlation_data)``.
    """
    from .io_utils import save_permutation_details

    effect_sizes = np.asarray(effect_sizes, dtype=np.float64)
    n_subjects = len(effect_sizes)

    self._log.info(
        "Correlation cluster correction (%s, %s, %s)",
        correlation_type,
        self.cluster_stat,
        self.alternative,
    )

    # Form initial clusters based on alternative
    match self.alternative:
        case "greater":
            initial_mask = (
                (p_values < self.cluster_threshold)
                & valid_mask
                & (t_statistics > 0)
            )
        case "less":
            initial_mask = (
                (p_values < self.cluster_threshold)
                & valid_mask
                & (t_statistics < 0)
            )
        case _:
            initial_mask = (p_values < self.cluster_threshold) & valid_mask

    labeled_array, n_clusters = label(initial_mask)
    self._log.info("Clusters at p<%.3f: %d", self.cluster_threshold, n_clusters)

    empty = {"sizes": np.array([]), "masses": np.array([])}
    if n_clusters == 0:
        self._log.warning("No clusters found.")
        return (
            np.zeros_like(p_values, dtype=int),
            0,
            [],
            np.array([]),
            [],
            empty,
        )

    # Pre-extract voxel data
    valid_coords = np.argwhere(valid_mask)
    n_valid = len(valid_coords)
    voxel_data = np.zeros((n_valid, n_subjects), dtype=np.float64)
    for idx, (i, j, k) in enumerate(valid_coords):
        voxel_data[idx, :] = subject_data[i, j, k, :]

    # Pre-rank for Spearman
    preranked = False
    if correlation_type == "spearman":
        self._log.info("Pre-ranking voxel data for Spearman (%d voxels)", n_valid)
        voxel_data = np.apply_along_axis(rankdata, 1, voxel_data)
        preranked = True

    actual_jobs = multiprocessing.cpu_count() if self.n_jobs == -1 else self.n_jobs
    self._log.info(
        "Running %d permutations on %d cores",
        self.n_permutations,
        actual_jobs,
    )

    seeds = np.random.randint(0, 2**31, size=self.n_permutations)
    track = perm_log_file is not None and subject_ids is not None

    if actual_jobs == 1:
        results = [
            _run_single_correlation_permutation(
                voxel_data,
                effect_sizes,
                valid_coords,
                self.cluster_threshold,
                valid_mask,
                p_values.shape,
                correlation_type=correlation_type,
                weights=weights,
                cluster_stat=self.cluster_stat,
                alternative=self.alternative,
                seed=seeds[i],
                return_indices=track,
                voxel_data_preranked=preranked,
            )
            for i in range(self.n_permutations)
        ]
    else:
        results = Parallel(n_jobs=actual_jobs, verbose=0)(
            delayed(_run_single_correlation_permutation)(
                voxel_data,
                effect_sizes,
                valid_coords,
                self.cluster_threshold,
                valid_mask,
                p_values.shape,
                correlation_type=correlation_type,
                weights=weights,
                cluster_stat=self.cluster_stat,
                alternative=self.alternative,
                seed=seeds[i],
                return_indices=track,
                voxel_data_preranked=preranked,
            )
            for i in range(self.n_permutations)
        )
        gc.collect()

    del voxel_data
    gc.collect()

    if track:
        null_stats = np.array([r[0] for r in results])
        null_sizes = [r[2] for r in results]
        null_masses = [r[3] for r in results]
    else:
        null_stats = np.array([r[0] for r in results])
        null_sizes = [r[1] for r in results]
        null_masses = [r[2] for r in results]

    sorted_null = np.sort(null_stats)[::-1]
    ti = max(1, min(int(self.alpha * self.n_permutations), len(sorted_null)))
    cluster_stat_threshold = sorted_null[ti - 1]

    stat_unit = "voxels" if self.cluster_stat == "size" else "mass units"
    self._log.info(
        "Threshold (p<%.3f): %.2f %s  " "(null min=%.2f, mean=%.2f, max=%.2f)",
        self.alpha,
        cluster_stat_threshold,
        stat_unit,
        np.min(null_stats),
        np.mean(null_stats),
        np.max(null_stats),
    )

    # Identify significant clusters
    sig_mask, sig_clusters, all_observed = _identify_significant_clusters(
        labeled_array,
        n_clusters,
        t_statistics,
        null_stats,
        self.cluster_stat,
        self.alpha,
        self.alternative,
        self._log,
        r_values=r_values,
    )

    self._log.info(
        "Significant: %d clusters, %d voxels",
        len(sig_clusters),
        np.sum(sig_mask),
    )

    correlation_data = {
        "sizes": np.array(null_sizes),
        "masses": np.array(null_masses),
    }
    return (
        sig_mask,
        cluster_stat_threshold,
        sig_clusters,
        null_stats,
        all_observed,
        correlation_data,
    )

tit.stats.engine.cluster_analysis

cluster_analysis(sig_mask, affine, log=None)

Connected-component analysis with MNI coordinate mapping.

Returns a list of cluster dicts sorted by size (descending).

Source code in tit/stats/engine.py
def cluster_analysis(sig_mask, affine, log=None):
    """Connected-component analysis with MNI coordinate mapping.

    Returns a list of cluster dicts sorted by size (descending).
    """
    import nibabel as nib

    _log = log or logger
    labeled_array, num_clusters = label(sig_mask)

    if num_clusters == 0:
        _log.info("No clusters found in significance mask")
        return []

    clusters = []
    for cid in range(1, num_clusters + 1):
        coords = np.argwhere(labeled_array == cid)
        com_voxel = np.mean(coords, axis=0)
        com_mni = nib.affines.apply_affine(affine, com_voxel)
        clusters.append(
            {
                "cluster_id": cid,
                "size": len(coords),
                "center_voxel": com_voxel,
                "center_mni": com_mni,
            }
        )

    clusters.sort(key=lambda c: c["size"], reverse=True)

    _log.info("Found %d clusters", num_clusters)
    for c in clusters[:10]:
        _log.info(
            "  Cluster %d: %d voxels, MNI (%.1f, %.1f, %.1f)",
            c["cluster_id"],
            c["size"],
            c["center_mni"][0],
            c["center_mni"][1],
            c["center_mni"][2],
        )

    return clusters

tit.stats.nifti.load_subject_nifti_ti_toolbox

load_subject_nifti_ti_toolbox(subject_id: str, simulation_name: str, nifti_file_pattern: str = 'grey_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz', dtype=float32) -> tuple[ndarray, Nifti1Image, str]

Load a NIfTI file from TI-Toolbox BIDS structure

Parameters:

subject_id : str Subject ID (e.g., '070') simulation_name : str Simulation name (e.g., 'ICP_RHIPPO') nifti_file_pattern : str, optional Pattern for NIfTI files. Default: 'grey_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz' Available variables: {subject_id}, {simulation_name} dtype : numpy dtype, optional Data type to load (default: float32)

Returns:

data : ndarray NIfTI data img : nibabel Nifti1Image NIfTI image object filepath : str Full path to the loaded file

Source code in tit/stats/nifti.py
def load_subject_nifti_ti_toolbox(
    subject_id: str,
    simulation_name: str,
    nifti_file_pattern: str = "grey_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz",
    dtype=np.float32,
) -> tuple[np.ndarray, nib.Nifti1Image, str]:
    """
    Load a NIfTI file from TI-Toolbox BIDS structure

    Parameters:
    -----------
    subject_id : str
        Subject ID (e.g., '070')
    simulation_name : str
        Simulation name (e.g., 'ICP_RHIPPO')
    nifti_file_pattern : str, optional
        Pattern for NIfTI files. Default: 'grey_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz'
        Available variables: {subject_id}, {simulation_name}
    dtype : numpy dtype, optional
        Data type to load (default: float32)

    Returns:
    --------
    data : ndarray
        NIfTI data
    img : nibabel Nifti1Image
        NIfTI image object
    filepath : str
        Full path to the loaded file
    """
    pm = get_path_manager()

    nifti_dir = os.path.join(
        pm.simulation(subject_id, simulation_name),
        "TI",
        "niftis",
    )

    # Format the filename pattern
    filename = nifti_file_pattern.format(
        subject_id=subject_id, simulation_name=simulation_name
    )
    filepath = os.path.join(nifti_dir, filename)

    # Load the file (inline basic loading)
    if not os.path.exists(filepath):
        # Provide extra context to make debugging path/layout issues easier
        if os.path.isdir(nifti_dir):
            try:
                existing = sorted(os.listdir(nifti_dir))
            except OSError:
                existing = []
            preview = existing[:20]
            suffix = ""
            if len(existing) > len(preview):
                suffix = f" (showing first {len(preview)} of {len(existing)})"
            raise FileNotFoundError(
                f"NIfTI file not found: {filepath}. "
                f"Directory exists: {nifti_dir}. "
                f"Files in directory: {preview}{suffix}"
            )
        raise FileNotFoundError(f"NIfTI file not found: {filepath}")

    img = nib.load(filepath)
    data = img.get_fdata(dtype=dtype)

    # Ensure 3D data (squeeze out extra dimensions if present)
    while data.ndim > 3:
        data = np.squeeze(data, axis=-1)

    return data, img, filepath

tit.stats.nifti.load_group_data_ti_toolbox

load_group_data_ti_toolbox(subject_configs: list[dict], nifti_file_pattern: str = 'grey_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz', dtype=float32) -> tuple[ndarray, Nifti1Image, list[str]]

Load multiple subjects from TI-Toolbox BIDS structure

Parameters:

subject_configs : list of dict List of subject configurations with keys: - 'subject_id': Subject ID (e.g., '070') - 'simulation_name': Simulation name (e.g., 'ICP_RHIPPO') nifti_file_pattern : str, optional Pattern for NIfTI files dtype : numpy dtype, optional Data type to load (default: float32)

Returns:

data_4d : ndarray (x, y, z, n_subjects) 4D array with all loaded data template_img : nibabel Nifti1Image Template image from first subject subject_ids : list of str List of successfully loaded subject IDs

Source code in tit/stats/nifti.py
def load_group_data_ti_toolbox(
    subject_configs: list[dict],
    nifti_file_pattern: str = "grey_{simulation_name}_TI_MNI_MNI_TI_max.nii.gz",
    dtype=np.float32,
) -> tuple[np.ndarray, nib.Nifti1Image, list[str]]:
    """
    Load multiple subjects from TI-Toolbox BIDS structure

    Parameters:
    -----------
    subject_configs : list of dict
        List of subject configurations with keys:
        - 'subject_id': Subject ID (e.g., '070')
        - 'simulation_name': Simulation name (e.g., 'ICP_RHIPPO')
    nifti_file_pattern : str, optional
        Pattern for NIfTI files
    dtype : numpy dtype, optional
        Data type to load (default: float32)

    Returns:
    --------
    data_4d : ndarray (x, y, z, n_subjects)
        4D array with all loaded data
    template_img : nibabel Nifti1Image
        Template image from first subject
    subject_ids : list of str
        List of successfully loaded subject IDs
    """
    data_list = []
    subject_ids = []
    template_img = None
    template_affine = None
    template_header = None

    for config in subject_configs:
        subject_id = config["subject_id"]
        simulation_name = config["simulation_name"]

        data, img, filepath = load_subject_nifti_ti_toolbox(
            subject_id, simulation_name, nifti_file_pattern, dtype=dtype
        )

        # Store template image from first subject
        if template_img is None:
            template_img = img
            template_affine = img.affine.copy()
            template_header = img.header.copy()

        data_list.append(data)
        subject_ids.append(subject_id)

        # Clear the image object to free memory
        del img

    if len(data_list) == 0:
        raise ValueError("No subjects could be loaded successfully")

    # Stack into 4D array
    data_4d = np.stack(data_list, axis=-1).astype(dtype)

    # Recreate minimal template image
    template_img = nib.Nifti1Image(data_4d[..., 0], template_affine, template_header)

    # Clean up
    del data_list
    gc.collect()

    return data_4d, template_img, subject_ids