def run_correlation(
config: CorrelationConfig,
callback_handler=None,
stop_callback=None,
) -> CorrelationResult:
"""Run cluster-based permutation testing for correlation (ACES-style).
Parameters
----------
config : CorrelationConfig
Fully specified configuration.
callback_handler : logging.Handler, optional
GUI log handler.
stop_callback : callable, optional
Returns True to abort.
"""
t0 = time.time()
output_dir = _resolve_output_dir(
config.project_dir,
"correlation",
config.analysis_name,
)
log, log_file = _setup_logger(output_dir, "correlation", callback_handler)
log.info("=" * 70)
log.info("CORRELATION-BASED CLUSTER PERMUTATION TESTING (ACES-style)")
log.info("=" * 70)
log.info("Analysis: %s", config.analysis_name)
log.info("Output: %s", output_dir)
log.info(
"Config: corr=%s stat=%s threshold=%.3f perms=%d alpha=%.3f jobs=%d",
config.correlation_type.value,
config.cluster_stat.value,
config.cluster_threshold,
config.n_permutations,
config.alpha,
config.n_jobs,
)
# ── 1. Load data ─────────────────────────────────────────────────────
log.info("[1/7] Loading subject data")
step = time.time()
subject_dicts = [
{"subject_id": s.subject_id, "simulation_name": s.simulation_name}
for s in config.subjects
]
subject_data, template_img, subject_ids = load_group_data_ti_toolbox(
subject_dicts,
nifti_file_pattern=config.nifti_file_pattern,
dtype=np.float32,
)
# Build effect sizes / weights aligned with loaded subjects
config_lookup = {s.subject_id: s for s in config.subjects}
effect_sizes = np.array(
[config_lookup[sid].effect_size for sid in subject_ids],
dtype=np.float64,
)
weights = None
if config.use_weights:
weights = np.array(
[config_lookup[sid].weight for sid in subject_ids],
dtype=np.float64,
)
n_subjects = len(subject_ids)
log.info("Loaded %d subjects: %s", n_subjects, subject_ids)
log.info(
"Effect sizes: mean=%.3f, std=%.3f, range=[%.3f, %.3f]",
np.mean(effect_sizes),
np.std(effect_sizes),
np.min(effect_sizes),
np.max(effect_sizes),
)
log.info("Data shape: %s (%.1fs)", subject_data.shape[:3], time.time() - step)
if stop_callback and stop_callback():
raise KeyboardInterrupt("Stopped by user")
# ── 2. Voxelwise correlation ─────────────────────────────────────────
log.info("[2/7] Voxelwise correlation")
step = time.time()
r_values, t_statistics, p_values, valid_mask = correlation_voxelwise(
subject_data,
effect_sizes,
weights=weights,
correlation_type=config.correlation_type.value,
log=log,
)
log.info("Correlation computed in %.1fs", time.time() - step)
if stop_callback and stop_callback():
raise KeyboardInterrupt("Stopped by user")
# ── 3. Permutation correction ────────────────────────────────────────
log.info(
"[3/7] Cluster-based permutation correction (%d perms)", config.n_permutations
)
step = time.time()
perm_log_file = os.path.join(output_dir, "permutation_details.txt")
engine = PermutationEngine(
cluster_threshold=config.cluster_threshold,
n_permutations=config.n_permutations,
alpha=config.alpha,
cluster_stat=config.cluster_stat.value,
alternative="two-sided",
n_jobs=config.n_jobs,
log=log,
)
sig_mask, cluster_threshold, sig_clusters, null_dist, all_clusters, corr_data = (
engine.correct_correlation(
subject_data,
effect_sizes,
r_values=r_values,
t_statistics=t_statistics,
p_values=p_values,
valid_mask=valid_mask,
correlation_type=config.correlation_type.value,
weights=weights,
perm_log_file=perm_log_file,
subject_ids=subject_ids,
)
)
log.info(
"Significant clusters: %d, voxels: %d (%.1fs)",
len(sig_clusters),
np.sum(sig_mask),
time.time() - step,
)
# ── 4. Cluster analysis ──────────────────────────────────────────────
log.info("[4/7] Cluster analysis")
clusters = cluster_analysis(sig_mask, template_img.affine, log=log)
# Annotate with correlation stats
from scipy.ndimage import label as scipy_label
labeled, _ = scipy_label(sig_mask)
for c in clusters:
c_mask = labeled == c["cluster_id"]
c["mean_r"] = float(np.mean(r_values[c_mask]))
c["peak_r"] = float(np.max(r_values[c_mask]))
# ── 5. Plots ─────────────────────────────────────────────────────────
log.info("[5/7] Generating plots")
perm_plot = os.path.join(output_dir, "permutation_null_distribution.pdf")
plot_permutation_null_distribution(
null_dist,
cluster_threshold,
all_clusters,
perm_plot,
alpha=config.alpha,
cluster_stat=config.cluster_stat.value,
)
if len(corr_data["sizes"]) > 0:
corr_plot = os.path.join(output_dir, "cluster_size_mass_correlation.pdf")
plot_cluster_size_mass_correlation(
corr_data["sizes"],
corr_data["masses"],
corr_plot,
)
# ── 6. Atlas overlap ─────────────────────────────────────────────────
log.info("[6/7] Atlas overlap")
atlas_results = {}
if config.atlas_files:
if _ATLAS_DIR.exists():
atlas_results = atlas_overlap_analysis(
sig_mask,
config.atlas_files,
str(_ATLAS_DIR),
reference_img=template_img,
)
# ── 7. Save outputs ──────────────────────────────────────────────────
log.info("[7/7] Saving results")
_save_nifti(
sig_mask.astype(np.uint8),
template_img,
os.path.join(output_dir, "significant_voxels_mask.nii.gz"),
)
_save_nifti(
r_values.astype(np.float32),
template_img,
os.path.join(output_dir, "correlation_map.nii.gz"),
)
_save_nifti(
t_statistics.astype(np.float32),
template_img,
os.path.join(output_dir, "t_statistics_map.nii.gz"),
)
log_p = -np.log10(p_values + 1e-10)
log_p[~valid_mask] = 0
_save_nifti(log_p, template_img, os.path.join(output_dir, "pvalues_map.nii.gz"))
r_thresh = r_values.copy()
r_thresh[sig_mask == 0] = 0
_save_nifti(
r_thresh.astype(np.float32),
template_img,
os.path.join(output_dir, "correlation_map_thresholded.nii.gz"),
)
avg = np.mean(subject_data, axis=-1).astype(np.float32)
_save_nifti(avg, template_img, os.path.join(output_dir, "average_efield.nii.gz"))
summary_path = os.path.join(output_dir, "analysis_summary.txt")
generate_correlation_summary(
config,
subject_data,
effect_sizes,
r_values,
sig_mask,
cluster_threshold,
clusters,
atlas_results,
summary_path,
subject_ids=subject_ids,
weights=weights,
)
total = time.time() - t0
log.info(
"COMPLETE in %.1fs — %d sig clusters, %d sig voxels",
total,
len(sig_clusters),
np.sum(sig_mask),
)
# Cleanup
del subject_data, effect_sizes, weights, t_statistics, p_values
gc.collect()
for h in log.handlers[:]:
h.close()
log.removeHandler(h)
return CorrelationResult(
success=True,
output_dir=output_dir,
n_subjects=n_subjects,
n_significant_voxels=int(np.sum(sig_mask)),
n_significant_clusters=len(sig_clusters),
cluster_threshold=float(cluster_threshold),
analysis_time=total,
clusters=clusters,
log_file=log_file,
)