Skip to content

ex

tit.opt.ex

TI Exhaustive Search Module.

ExConfig dataclass

ExConfig(subject_id: str, project_dir: str, leadfield_hdf: str, roi_name: str, electrodes: BucketElectrodes | PoolElectrodes, total_current: float = 2.0, current_step: float = 0.5, channel_limit: float | None = None, roi_radius: float = 3.0, run_name: str | None = None)

Full configuration for exhaustive search optimization.

BucketElectrodes dataclass

BucketElectrodes(e1_plus: list[str], e1_minus: list[str], e2_plus: list[str], e2_minus: list[str])

Separate electrode lists for each bipolar channel position.

PoolElectrodes dataclass

PoolElectrodes(electrodes: list[str])

Single electrode pool — all positions draw from the same set.

ExResult dataclass

ExResult(success: bool, output_dir: str, n_combinations: int, results_csv: str | None = None, config_json: str | None = None)

Result from an exhaustive search run.

ExSearchEngine

ExSearchEngine(leadfield_hdf: str, roi_file: str, roi_name: str, logger: Logger)

Exhaustive TI electrode search engine.

Owns the full pipeline: leadfield loading, ROI resolution, simulation loop, and ROI CRUD.

Source code in tit/opt/ex/engine.py
def __init__(
    self, leadfield_hdf: str, roi_file: str, roi_name: str, logger: logging.Logger
):
    self.leadfield_hdf = leadfield_hdf
    self.roi_file = roi_file
    self.roi_name = roi_name
    self.logger = logger

    self.leadfield = None
    self.mesh = None
    self.idx_lf = None
    self.roi_coords = None
    self.roi_indices = None
    self.roi_volumes = None
    self.gm_indices = None
    self.gm_volumes = None

initialize

initialize(roi_radius: float = 3.0) -> None

Load leadfield, parse ROI CSV, find ROI + GM elements.

Source code in tit/opt/ex/engine.py
def initialize(self, roi_radius: float = 3.0) -> None:
    """Load leadfield, parse ROI CSV, find ROI + GM elements."""
    self._load_leadfield()
    self._load_roi_coordinates()
    self._find_roi_elements(roi_radius)
    self._find_gm_elements()

compute_ti_field

compute_ti_field(e1_plus: str, e1_minus: str, current_ch1_mA: float, e2_plus: str, e2_minus: str, current_ch2_mA: float) -> dict[str, float]

Compute TI field for one montage and return ROI metrics.

Source code in tit/opt/ex/engine.py
def compute_ti_field(
    self,
    e1_plus: str,
    e1_minus: str,
    current_ch1_mA: float,
    e2_plus: str,
    e2_minus: str,
    current_ch2_mA: float,
) -> dict[str, float]:
    """Compute TI field for one montage and return ROI metrics."""
    lf = self.leadfield
    idx = self.idx_lf

    ef1 = TI.get_field([e1_plus, e1_minus, current_ch1_mA / 1000], lf, idx)
    ef2 = TI.get_field([e2_plus, e2_minus, current_ch2_mA / 1000], lf, idx)
    ti_max_full = TI.get_maxTI(ef1, ef2)

    field_roi = ti_max_full[self.roi_indices]
    field_gm = ti_max_full[self.gm_indices]

    n_elements = int(len(field_roi))
    if n_elements == 0:
        roi_max = 0.0
        roi_mean = 0.0
        gm_mean = 0.0
        focality = 0.0
    else:
        roi_max = float(np.max(field_roi))
        roi_mean = float(np.average(field_roi, weights=self.roi_volumes))
        if len(field_gm) > 0:
            gm_mean = float(np.average(field_gm, weights=self.gm_volumes))
            focality = roi_mean / gm_mean if gm_mean > 0 else 0.0
        else:
            gm_mean = 0.0
            focality = 0.0

    return {
        f"{self.roi_name}_TImax_ROI": roi_max,
        f"{self.roi_name}_TImean_ROI": roi_mean,
        f"{self.roi_name}_TImean_GM": gm_mean,
        f"{self.roi_name}_Focality": focality,
        f"{self.roi_name}_n_elements": n_elements,
        "current_ch1_mA": current_ch1_mA,
        "current_ch2_mA": current_ch2_mA,
    }

run

run(e1_plus: list[str], e1_minus: list[str], e2_plus: list[str], e2_minus: list[str], current_ratios: list[tuple[float, float]], all_combinations: bool, output_dir: str) -> dict[str, dict[str, float]]

Run the full simulation loop. Returns {mesh_key: metrics}.

Source code in tit/opt/ex/engine.py
def run(
    self,
    e1_plus: list[str],
    e1_minus: list[str],
    e2_plus: list[str],
    e2_minus: list[str],
    current_ratios: list[tuple[float, float]],
    all_combinations: bool,
    output_dir: str,
) -> dict[str, dict[str, float]]:
    """Run the full simulation loop. Returns {mesh_key: metrics}."""
    stop = False

    def _on_signal(sig, frame):
        nonlocal stop
        stop = True

    signal.signal(signal.SIGINT, _on_signal)
    signal.signal(signal.SIGTERM, _on_signal)

    total = count_combinations(
        e1_plus, e1_minus, e2_plus, e2_minus, current_ratios, all_combinations
    )
    self._log_config_summary(
        e1_plus,
        e1_minus,
        e2_plus,
        e2_minus,
        current_ratios,
        all_combinations,
        total,
    )

    results: dict[str, dict[str, float]] = {}
    start_time = time.time()

    for i, (ep1, em1, ep2, em2, (ch1, ch2)) in enumerate(
        generate_montage_combinations(
            e1_plus, e1_minus, e2_plus, e2_minus, current_ratios, all_combinations
        ),
        1,
    ):
        if stop:
            self.logger.warning("Interrupted")
            break

        name = f"{ep1}_{em1}_and_{ep2}_{em2}_I1-{ch1:.1f}mA_I2-{ch2:.1f}mA"
        key = f"TI_field_{name}.msh"

        elapsed = time.time() - start_time
        rate = i / elapsed if elapsed > 0 else 0
        eta = (total - i) / rate if rate > 0 else 0

        self.logger.info(f"[{i}/{total}] {name}")
        self.logger.info(
            f"  {100 * i / total:.1f}% | {rate:.2f}/s | ETA {eta / 60:.1f}min"
        )

        sim_start = time.time()
        data = self.compute_ti_field(ep1, em1, ch1, ep2, em2, ch2)
        results[key] = data

        self.logger.info(
            f"  {time.time() - sim_start:.2f}s | "
            f"TImax={data[f'{self.roi_name}_TImax_ROI']:.4f} "
            f"TImean={data[f'{self.roi_name}_TImean_ROI']:.4f} "
            f"Foc={data[f'{self.roi_name}_Focality']:.4f}"
        )

    if results:
        t = time.time() - start_time
        self.logger.info(f"\n{'=' * 60}")
        self.logger.info(
            f"Done: {len(results)}/{total} in {t / 60:.1f}min "
            f"({t / len(results):.2f}s each)"
        )
        self.logger.info(f"Output: {output_dir}")

    return results

get_available_rois staticmethod

get_available_rois(subject_id: str) -> list[str]

List ROI CSV files for a subject.

Source code in tit/opt/ex/engine.py
@staticmethod
def get_available_rois(subject_id: str) -> list[str]:
    """List ROI CSV files for a subject."""
    from tit.paths import get_path_manager

    roi_dir = Path(get_path_manager().rois(subject_id))
    return sorted(p.name for p in roi_dir.glob("*.csv"))

create_roi staticmethod

create_roi(subject_id: str, roi_name: str, x: float, y: float, z: float) -> tuple[bool, str]

Create an ROI CSV from coordinates.

Source code in tit/opt/ex/engine.py
@staticmethod
def create_roi(
    subject_id: str,
    roi_name: str,
    x: float,
    y: float,
    z: float,
) -> tuple[bool, str]:
    """Create an ROI CSV from coordinates."""
    from tit.paths import get_path_manager

    roi_dir = Path(get_path_manager().rois(subject_id))
    roi_dir.mkdir(parents=True, exist_ok=True)

    if not roi_name.endswith(".csv"):
        roi_name += ".csv"

    roi_file = roi_dir / roi_name
    with open(roi_file, "w", newline="") as f:
        csv.writer(f).writerow([x, y, z])

    roi_list = roi_dir / "roi_list.txt"
    existing = []
    if roi_list.exists():
        existing = [
            ln.strip() for ln in roi_list.read_text().splitlines() if ln.strip()
        ]
    if roi_name not in existing:
        with open(roi_list, "a") as f:
            f.write(f"{roi_name}\n")

    return True, f"ROI '{roi_name}' created at ({x:.2f}, {y:.2f}, {z:.2f})"

delete_roi staticmethod

delete_roi(subject_id: str, roi_name: str) -> tuple[bool, str]

Delete an ROI file and remove from roi_list.txt.

Source code in tit/opt/ex/engine.py
@staticmethod
def delete_roi(subject_id: str, roi_name: str) -> tuple[bool, str]:
    """Delete an ROI file and remove from roi_list.txt."""
    from tit.paths import get_path_manager

    roi_dir = Path(get_path_manager().rois(subject_id))

    if not roi_name.endswith(".csv"):
        roi_name += ".csv"

    roi_file = roi_dir / roi_name
    if roi_file.exists():
        roi_file.unlink()

    roi_list = roi_dir / "roi_list.txt"
    if roi_list.exists():
        lines = [
            ln.strip() for ln in roi_list.read_text().splitlines() if ln.strip()
        ]
        if roi_name in lines:
            lines.remove(roi_name)
            roi_list.write_text(("\n".join(lines) + "\n") if lines else "")

    return True, f"ROI '{roi_name}' deleted"

get_roi_coordinates staticmethod

get_roi_coordinates(subject_id: str, roi_name: str) -> tuple[float, float, float] | None

Read ROI center coordinates from CSV.

Source code in tit/opt/ex/engine.py
@staticmethod
def get_roi_coordinates(
    subject_id: str,
    roi_name: str,
) -> tuple[float, float, float] | None:
    """Read ROI center coordinates from CSV."""
    from tit.paths import get_path_manager

    roi_dir = Path(get_path_manager().rois(subject_id))

    if not roi_name.endswith(".csv"):
        roi_name += ".csv"

    roi_file = roi_dir / roi_name
    if not roi_file.exists():
        return None

    with open(roi_file) as f:
        for row in csv.reader(f):
            if not row:
                continue
            coords = [float(v.strip()) for v in row if v.strip()]
            if len(coords) >= 3:
                return (coords[0], coords[1], coords[2])
    return None
run_ex_search(config: ExConfig) -> ExResult

Run exhaustive search from a typed config object.

Source code in tit/opt/ex/ex.py
def run_ex_search(config: ExConfig) -> ExResult:
    """Run exhaustive search from a typed config object."""

    pm = get_path_manager(config.project_dir)

    logs_dir = pm.logs(config.subject_id)
    os.makedirs(logs_dir, exist_ok=True)
    log_file = os.path.join(logs_dir, f'ex_search_{time.strftime("%Y%m%d_%H%M%S")}.log')
    logger_name = f"tit.opt.ex_search.{config.subject_id}"
    add_file_handler(log_file, logger_name=logger_name)
    logger = logging.getLogger(logger_name)

    logger.info(f"{'=' * 60}\nTI Exhaustive Search\n{'=' * 60}")
    logger.info(f"Project: {config.project_dir}")
    logger.info(f"Subject: {config.subject_id}")

    run_name = config.run_name or time.strftime("%Y%m%d_%H%M%S")
    output_dir = pm.ex_search_run(config.subject_id, run_name)
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Output: {output_dir}")

    roi_file = os.path.join(pm.rois(config.subject_id), config.roi_name)

    if isinstance(config.electrodes, ExConfig.PoolElectrodes):
        pool = config.electrodes.electrodes
        e1_plus = e1_minus = e2_plus = e2_minus = pool
        all_combinations = True
    else:
        e1_plus = config.electrodes.e1_plus
        e1_minus = config.electrodes.e1_minus
        e2_plus = config.electrodes.e2_plus
        e2_minus = config.electrodes.e2_minus
        all_combinations = False

    leadfield_path = os.path.join(pm.leadfields(config.subject_id), config.leadfield_hdf)

    engine = ExSearchEngine(leadfield_path, roi_file, config.roi_name, logger)
    engine.initialize(roi_radius=config.roi_radius)

    ratios = generate_current_ratios(
        config.total_current,
        config.current_step,
        config.channel_limit or config.total_current - config.current_step,
    )

    logger.info(f"Generated {len(ratios)} current ratio combinations")

    results = engine.run(
        e1_plus, e1_minus, e2_plus, e2_minus, ratios, all_combinations, output_dir
    )

    output_info = process_and_save(results, config, output_dir, logger)
    logger.info(f"Config: {output_info['config_json_path']}")
    logger.info(f"CSV: {output_info['csv_path']}")

    return ExResult(
        success=True,
        output_dir=output_dir,
        n_combinations=len(results),
        results_csv=output_info.get("csv_path"),
        config_json=output_info.get("config_json_path"),
    )