def main():
"""CLI entry point for DTI registration debugging diagnostics."""
if len(sys.argv) < 3:
print(f"Usage: python -m tit.pre.qsi.debug_dti <project_dir> <subject_id>")
sys.exit(1)
project_dir = Path(sys.argv[1])
subject_id = sys.argv[2]
m2m_dir = (
project_dir
/ "derivatives"
/ "SimNIBS"
/ f"sub-{subject_id}"
/ f"m2m_{subject_id}"
)
t1_path = m2m_dir / "T1.nii.gz"
registered_path = m2m_dir / "DTI_coregT1_tensor.nii.gz"
intermediate_path = m2m_dir / "DTI_ACPC_tensor.nii.gz"
dsi_dir = (
project_dir
/ "derivatives"
/ "qsirecon"
/ "derivatives"
/ "qsirecon-DSIStudio"
/ f"sub-{subject_id}"
/ "dwi"
)
acpc_t1_path = (
project_dir
/ "derivatives"
/ "qsiprep"
/ f"sub-{subject_id}"
/ "anat"
/ f"sub-{subject_id}_space-ACPC_desc-preproc_T1w.nii.gz"
)
print("=" * 60)
print("DTI REGISTRATION DIAGNOSTIC")
print("=" * 60)
# 1. Print all affines
if t1_path.exists():
t1_img = nib.load(str(t1_path))
_print_affine("SimNIBS T1", t1_img)
else:
print(f"\nSimNIBS T1 NOT FOUND: {t1_path}")
sys.exit(1)
if intermediate_path.exists():
acpc_img = nib.load(str(intermediate_path))
_print_affine("ACPC tensor (intermediate)", acpc_img)
else:
print(f"\nACPC tensor NOT FOUND: {intermediate_path}")
if registered_path.exists():
reg_img = nib.load(str(registered_path))
_print_affine("Registered tensor (final)", reg_img)
else:
print(f"\nRegistered tensor NOT FOUND: {registered_path}")
if acpc_t1_path.exists():
acpc_t1_img = nib.load(str(acpc_t1_path))
_print_affine("QSIPrep ACPC T1", acpc_t1_img)
# Load a single source component for reference
src_files = list(
dsi_dir.glob(f"sub-{subject_id}_space-ACPC_model-tensor_param-txx_dwimap.nii*")
)
if src_files:
src_img = nib.load(str(src_files[0]))
_print_affine("Source tensor component (txx)", src_img)
# 2. Centers of mass
print(f"\n{'=' * 60}")
print("CENTERS OF MASS (world coordinates)")
t1_com = _brain_com(t1_img.get_fdata(dtype=np.float32), t1_img.affine, "SimNIBS T1")
if intermediate_path.exists():
acpc_com = _brain_com(
acpc_img.get_fdata(dtype=np.float32), acpc_img.affine, "ACPC tensor"
)
if t1_com is not None and acpc_com is not None:
diff = t1_com - acpc_com
print(f" Shift needed: [{diff[0]:.1f}, {diff[1]:.1f}, {diff[2]:.1f}] mm")
print(f" Distance: {np.linalg.norm(diff):.1f} mm")
# 3. Overlap check
if registered_path.exists():
print(f"\n{'=' * 60}")
print("OVERLAP CHECK")
reg_data = reg_img.get_fdata(dtype=np.float32)
t1_data = t1_img.get_fdata(dtype=np.float32)
if reg_data.shape[:3] == t1_data.shape[:3]:
_check_overlap(reg_data, t1_data, "Registered tensor vs T1 (same grid)")
else:
print(
f" Shape mismatch: tensor={reg_data.shape[:3]}, T1={t1_data.shape[:3]}"
)
# 4. Eigenvector check
if registered_path.exists():
print(f"\n{'=' * 60}")
print("EIGENVECTOR CHECK (stored tensor, pre-FSL-compensation)")
_check_eigenvectors(reg_data, reg_img.affine, "Registered tensor")
# Also check what SimNIBS would see after correct_FSL
print(f"\n{'=' * 60}")
print("EIGENVECTOR CHECK (after simulated correct_FSL = world space)")
aff = reg_img.affine
M = aff[:3, :3] / np.linalg.norm(aff[:3, :3], axis=0)[:, None]
R = np.eye(3)
if np.linalg.det(M) > 0:
R[0, 0] = -1
M_fsl = M.dot(R)
# Rotate the stored tensor to world space
mask = np.any(reg_data != 0, axis=-1)
vox = reg_data[mask].copy()
T = np.zeros((vox.shape[0], 3, 3), dtype=np.float32)
T[:, 0, 0] = vox[:, 0]
T[:, 0, 1] = T[:, 1, 0] = vox[:, 1]
T[:, 0, 2] = T[:, 2, 0] = vox[:, 2]
T[:, 1, 1] = vox[:, 3]
T[:, 1, 2] = T[:, 2, 1] = vox[:, 4]
T[:, 2, 2] = vox[:, 5]
M32 = M_fsl.astype(np.float32)
T_world = np.einsum("ij,njk,lk->nil", M32, T, M32)
world_data = np.zeros_like(reg_data)
world_vox = np.zeros((vox.shape[0], 6), dtype=np.float32)
world_vox[:, 0] = T_world[:, 0, 0]
world_vox[:, 1] = T_world[:, 0, 1]
world_vox[:, 2] = T_world[:, 0, 2]
world_vox[:, 3] = T_world[:, 1, 1]
world_vox[:, 4] = T_world[:, 1, 2]
world_vox[:, 5] = T_world[:, 2, 2]
world_data[mask] = world_vox
_check_eigenvectors(world_data, aff, "World-space tensor (after correct_FSL)")
# 5. Compare FA: source vs registered (is our processing destroying FA?)
if intermediate_path.exists() and registered_path.exists():
print(f"\n{'=' * 60}")
print("FA COMPARISON: SOURCE vs REGISTERED")
acpc_data = acpc_img.get_fdata(dtype=np.float32)
acpc_mask = np.any(acpc_data != 0, axis=-1)
acpc_vox = acpc_data[acpc_mask]
T_acpc = np.zeros((acpc_vox.shape[0], 3, 3), dtype=np.float32)
T_acpc[:, 0, 0] = acpc_vox[:, 0]
T_acpc[:, 0, 1] = T_acpc[:, 1, 0] = acpc_vox[:, 1]
T_acpc[:, 0, 2] = T_acpc[:, 2, 0] = acpc_vox[:, 2]
T_acpc[:, 1, 1] = acpc_vox[:, 3]
T_acpc[:, 1, 2] = T_acpc[:, 2, 1] = acpc_vox[:, 4]
T_acpc[:, 2, 2] = acpc_vox[:, 5]
evals_acpc = np.linalg.eigvalsh(T_acpc)
fa_acpc = _compute_fa(evals_acpc)
reg_mask2 = np.any(reg_data != 0, axis=-1)
reg_vox = reg_data[reg_mask2]
T_reg = np.zeros((reg_vox.shape[0], 3, 3), dtype=np.float32)
T_reg[:, 0, 0] = reg_vox[:, 0]
T_reg[:, 0, 1] = T_reg[:, 1, 0] = reg_vox[:, 1]
T_reg[:, 0, 2] = T_reg[:, 2, 0] = reg_vox[:, 2]
T_reg[:, 1, 1] = reg_vox[:, 3]
T_reg[:, 1, 2] = T_reg[:, 2, 1] = reg_vox[:, 4]
T_reg[:, 2, 2] = reg_vox[:, 5]
evals_reg = np.linalg.eigvalsh(T_reg)
fa_reg = _compute_fa(evals_reg)
# Also print raw tensor component ranges
print(f" ACPC tensor (source, 2mm):")
print(
f" FA: mean={fa_acpc.mean():.4f}, median={np.median(fa_acpc):.4f}, "
f"max={fa_acpc.max():.4f}, p95={np.percentile(fa_acpc, 95):.4f}"
)
print(
f" Eigenvalues: min={evals_acpc.min():.6f}, max={evals_acpc.max():.6f}"
)
print(f" Component ranges:")
for i, name in enumerate(["Dxx", "Dxy", "Dxz", "Dyy", "Dyz", "Dzz"]):
vals = acpc_vox[:, i]
print(
f" {name}: [{vals.min():.6f}, {vals.max():.6f}], mean={vals.mean():.6f}"
)
print(f" Registered tensor (0.5mm, after resampling+rotation):")
print(
f" FA: mean={fa_reg.mean():.4f}, median={np.median(fa_reg):.4f}, "
f"max={fa_reg.max():.4f}, p95={np.percentile(fa_reg, 95):.4f}"
)
print(f" Eigenvalues: min={evals_reg.min():.6f}, max={evals_reg.max():.6f}")
print(f"\n{'=' * 60}")
print("DONE")