Skip to content

debug_dti

tit.pre.qsi.debug_dti

DTI registration diagnostic — run inside the SimNIBS container.

Usage

python -m tit.pre.qsi.debug_dti /mnt/

Prints affines, orientations, centers of mass, overlap stats, and eigenvalue/eigenvector sanity checks so we can diagnose alignment and orientation issues remotely.

main

main()

CLI entry point for DTI registration debugging diagnostics.

Source code in tit/pre/qsi/debug_dti.py
def main():
    """CLI entry point for DTI registration debugging diagnostics."""
    if len(sys.argv) < 3:
        print(f"Usage: python -m tit.pre.qsi.debug_dti <project_dir> <subject_id>")
        sys.exit(1)

    project_dir = Path(sys.argv[1])
    subject_id = sys.argv[2]

    m2m_dir = (
        project_dir
        / "derivatives"
        / "SimNIBS"
        / f"sub-{subject_id}"
        / f"m2m_{subject_id}"
    )
    t1_path = m2m_dir / "T1.nii.gz"
    registered_path = m2m_dir / "DTI_coregT1_tensor.nii.gz"
    intermediate_path = m2m_dir / "DTI_ACPC_tensor.nii.gz"

    dsi_dir = (
        project_dir
        / "derivatives"
        / "qsirecon"
        / "derivatives"
        / "qsirecon-DSIStudio"
        / f"sub-{subject_id}"
        / "dwi"
    )
    acpc_t1_path = (
        project_dir
        / "derivatives"
        / "qsiprep"
        / f"sub-{subject_id}"
        / "anat"
        / f"sub-{subject_id}_space-ACPC_desc-preproc_T1w.nii.gz"
    )

    print("=" * 60)
    print("DTI REGISTRATION DIAGNOSTIC")
    print("=" * 60)

    # 1. Print all affines
    if t1_path.exists():
        t1_img = nib.load(str(t1_path))
        _print_affine("SimNIBS T1", t1_img)
    else:
        print(f"\nSimNIBS T1 NOT FOUND: {t1_path}")
        sys.exit(1)

    if intermediate_path.exists():
        acpc_img = nib.load(str(intermediate_path))
        _print_affine("ACPC tensor (intermediate)", acpc_img)
    else:
        print(f"\nACPC tensor NOT FOUND: {intermediate_path}")

    if registered_path.exists():
        reg_img = nib.load(str(registered_path))
        _print_affine("Registered tensor (final)", reg_img)
    else:
        print(f"\nRegistered tensor NOT FOUND: {registered_path}")

    if acpc_t1_path.exists():
        acpc_t1_img = nib.load(str(acpc_t1_path))
        _print_affine("QSIPrep ACPC T1", acpc_t1_img)

    # Load a single source component for reference
    src_files = list(
        dsi_dir.glob(f"sub-{subject_id}_space-ACPC_model-tensor_param-txx_dwimap.nii*")
    )
    if src_files:
        src_img = nib.load(str(src_files[0]))
        _print_affine("Source tensor component (txx)", src_img)

    # 2. Centers of mass
    print(f"\n{'=' * 60}")
    print("CENTERS OF MASS (world coordinates)")
    t1_com = _brain_com(t1_img.get_fdata(dtype=np.float32), t1_img.affine, "SimNIBS T1")

    if intermediate_path.exists():
        acpc_com = _brain_com(
            acpc_img.get_fdata(dtype=np.float32), acpc_img.affine, "ACPC tensor"
        )
        if t1_com is not None and acpc_com is not None:
            diff = t1_com - acpc_com
            print(f"  Shift needed: [{diff[0]:.1f}, {diff[1]:.1f}, {diff[2]:.1f}] mm")
            print(f"  Distance: {np.linalg.norm(diff):.1f} mm")

    # 3. Overlap check
    if registered_path.exists():
        print(f"\n{'=' * 60}")
        print("OVERLAP CHECK")
        reg_data = reg_img.get_fdata(dtype=np.float32)
        t1_data = t1_img.get_fdata(dtype=np.float32)

        if reg_data.shape[:3] == t1_data.shape[:3]:
            _check_overlap(reg_data, t1_data, "Registered tensor vs T1 (same grid)")
        else:
            print(
                f"  Shape mismatch: tensor={reg_data.shape[:3]}, T1={t1_data.shape[:3]}"
            )

    # 4. Eigenvector check
    if registered_path.exists():
        print(f"\n{'=' * 60}")
        print("EIGENVECTOR CHECK (stored tensor, pre-FSL-compensation)")
        _check_eigenvectors(reg_data, reg_img.affine, "Registered tensor")

        # Also check what SimNIBS would see after correct_FSL
        print(f"\n{'=' * 60}")
        print("EIGENVECTOR CHECK (after simulated correct_FSL = world space)")
        aff = reg_img.affine
        M = aff[:3, :3] / np.linalg.norm(aff[:3, :3], axis=0)[:, None]
        R = np.eye(3)
        if np.linalg.det(M) > 0:
            R[0, 0] = -1
        M_fsl = M.dot(R)

        # Rotate the stored tensor to world space
        mask = np.any(reg_data != 0, axis=-1)
        vox = reg_data[mask].copy()
        T = np.zeros((vox.shape[0], 3, 3), dtype=np.float32)
        T[:, 0, 0] = vox[:, 0]
        T[:, 0, 1] = T[:, 1, 0] = vox[:, 1]
        T[:, 0, 2] = T[:, 2, 0] = vox[:, 2]
        T[:, 1, 1] = vox[:, 3]
        T[:, 1, 2] = T[:, 2, 1] = vox[:, 4]
        T[:, 2, 2] = vox[:, 5]

        M32 = M_fsl.astype(np.float32)
        T_world = np.einsum("ij,njk,lk->nil", M32, T, M32)

        world_data = np.zeros_like(reg_data)
        world_vox = np.zeros((vox.shape[0], 6), dtype=np.float32)
        world_vox[:, 0] = T_world[:, 0, 0]
        world_vox[:, 1] = T_world[:, 0, 1]
        world_vox[:, 2] = T_world[:, 0, 2]
        world_vox[:, 3] = T_world[:, 1, 1]
        world_vox[:, 4] = T_world[:, 1, 2]
        world_vox[:, 5] = T_world[:, 2, 2]
        world_data[mask] = world_vox
        _check_eigenvectors(world_data, aff, "World-space tensor (after correct_FSL)")

    # 5. Compare FA: source vs registered (is our processing destroying FA?)
    if intermediate_path.exists() and registered_path.exists():
        print(f"\n{'=' * 60}")
        print("FA COMPARISON: SOURCE vs REGISTERED")

        acpc_data = acpc_img.get_fdata(dtype=np.float32)
        acpc_mask = np.any(acpc_data != 0, axis=-1)
        acpc_vox = acpc_data[acpc_mask]

        T_acpc = np.zeros((acpc_vox.shape[0], 3, 3), dtype=np.float32)
        T_acpc[:, 0, 0] = acpc_vox[:, 0]
        T_acpc[:, 0, 1] = T_acpc[:, 1, 0] = acpc_vox[:, 1]
        T_acpc[:, 0, 2] = T_acpc[:, 2, 0] = acpc_vox[:, 2]
        T_acpc[:, 1, 1] = acpc_vox[:, 3]
        T_acpc[:, 1, 2] = T_acpc[:, 2, 1] = acpc_vox[:, 4]
        T_acpc[:, 2, 2] = acpc_vox[:, 5]

        evals_acpc = np.linalg.eigvalsh(T_acpc)
        fa_acpc = _compute_fa(evals_acpc)

        reg_mask2 = np.any(reg_data != 0, axis=-1)
        reg_vox = reg_data[reg_mask2]

        T_reg = np.zeros((reg_vox.shape[0], 3, 3), dtype=np.float32)
        T_reg[:, 0, 0] = reg_vox[:, 0]
        T_reg[:, 0, 1] = T_reg[:, 1, 0] = reg_vox[:, 1]
        T_reg[:, 0, 2] = T_reg[:, 2, 0] = reg_vox[:, 2]
        T_reg[:, 1, 1] = reg_vox[:, 3]
        T_reg[:, 1, 2] = T_reg[:, 2, 1] = reg_vox[:, 4]
        T_reg[:, 2, 2] = reg_vox[:, 5]

        evals_reg = np.linalg.eigvalsh(T_reg)
        fa_reg = _compute_fa(evals_reg)

        # Also print raw tensor component ranges
        print(f"  ACPC tensor (source, 2mm):")
        print(
            f"    FA: mean={fa_acpc.mean():.4f}, median={np.median(fa_acpc):.4f}, "
            f"max={fa_acpc.max():.4f}, p95={np.percentile(fa_acpc, 95):.4f}"
        )
        print(
            f"    Eigenvalues: min={evals_acpc.min():.6f}, max={evals_acpc.max():.6f}"
        )
        print(f"    Component ranges:")
        for i, name in enumerate(["Dxx", "Dxy", "Dxz", "Dyy", "Dyz", "Dzz"]):
            vals = acpc_vox[:, i]
            print(
                f"      {name}: [{vals.min():.6f}, {vals.max():.6f}], mean={vals.mean():.6f}"
            )

        print(f"  Registered tensor (0.5mm, after resampling+rotation):")
        print(
            f"    FA: mean={fa_reg.mean():.4f}, median={np.median(fa_reg):.4f}, "
            f"max={fa_reg.max():.4f}, p95={np.percentile(fa_reg, 95):.4f}"
        )
        print(f"    Eigenvalues: min={evals_reg.min():.6f}, max={evals_reg.max():.6f}")

    print(f"\n{'=' * 60}")
    print("DONE")