Skip to content

opt

tit.opt

TI-Toolbox optimization package.

Public API

  • run_flex_search(config) -> FlexResult -- differential-evolution optimization
  • run_ex_search(config) -> ExResult -- exhaustive / grid search

FlexConfig dataclass

FlexConfig(subject_id: str, project_dir: str, goal: OptGoal, postproc: FieldPostproc, current_mA: float, electrode: ElectrodeConfig, roi: SphericalROI | AtlasROI | SubcorticalROI, anisotropy_type: str = 'scalar', aniso_maxratio: float = 10.0, aniso_maxcond: float = 2.0, non_roi_method: NonROIMethod | None = None, non_roi: SphericalROI | AtlasROI | SubcorticalROI | None = None, thresholds: str | None = None, eeg_net: str | None = None, enable_mapping: bool = False, disable_mapping_simulation: bool = False, output_folder: str | None = None, run_final_electrode_simulation: bool = False, n_multistart: int = 1, max_iterations: int | None = None, population_size: int | None = None, tolerance: float | None = None, mutation: str | None = None, recombination: float | None = None, cpus: int | None = None, detailed_results: bool = False, visualize_valid_skin_region: bool = False, skin_visualization_net: str | None = None)

Full configuration for flex-search optimization.

OptGoal

Bases: StrEnum

Optimization goal.

FieldPostproc

Bases: StrEnum

Field post-processing method.

NonROIMethod

Bases: StrEnum

Non-ROI specification method for focality optimization.

SphericalROI dataclass

SphericalROI(x: float, y: float, z: float, radius: float = 10.0, use_mni: bool = False)

Spherical region of interest defined by center + radius.

AtlasROI dataclass

AtlasROI(atlas_path: str, label: int, hemisphere: str = 'lh')

Cortical surface ROI from a FreeSurfer annotation atlas.

SubcorticalROI dataclass

SubcorticalROI(atlas_path: str, label: int, tissues: str = 'GM')

Subcortical volume ROI from a volumetric atlas.

ElectrodeConfig dataclass

ElectrodeConfig(shape: str = 'ellipse', dimensions: list[float] = (lambda: [8.0, 8.0])(), gel_thickness: float = 4.0)

Electrode geometry for flex-search.

Only gel_thickness is needed here — the optimization leadfield uses point electrodes; gel_thickness is recorded in the manifest for downstream simulation.

FlexResult dataclass

FlexResult(success: bool, output_folder: str, function_values: list[float], best_value: float, best_run_index: int)

Result from a flex-search optimization run.

ExConfig dataclass

ExConfig(subject_id: str, project_dir: str, leadfield_hdf: str, roi_name: str, electrodes: BucketElectrodes | PoolElectrodes, total_current: float = 2.0, current_step: float = 0.5, channel_limit: float | None = None, roi_radius: float = 3.0, run_name: str | None = None)

Full configuration for exhaustive search optimization.

BucketElectrodes dataclass

BucketElectrodes(e1_plus: list[str], e1_minus: list[str], e2_plus: list[str], e2_minus: list[str])

Separate electrode lists for each bipolar channel position.

PoolElectrodes dataclass

PoolElectrodes(electrodes: list[str])

Single electrode pool — all positions draw from the same set.

ExResult dataclass

ExResult(success: bool, output_dir: str, n_combinations: int, results_csv: str | None = None, config_json: str | None = None)

Result from an exhaustive search run.

run_ex_search(config: ExConfig) -> ExResult

Run exhaustive search from a typed config object.

Source code in tit/opt/ex/ex.py
def run_ex_search(config: ExConfig) -> ExResult:
    """Run exhaustive search from a typed config object."""

    pm = get_path_manager(config.project_dir)

    logs_dir = pm.logs(config.subject_id)
    os.makedirs(logs_dir, exist_ok=True)
    log_file = os.path.join(logs_dir, f'ex_search_{time.strftime("%Y%m%d_%H%M%S")}.log')
    logger_name = f"tit.opt.ex_search.{config.subject_id}"
    add_file_handler(log_file, logger_name=logger_name)
    logger = logging.getLogger(logger_name)

    logger.info(f"{'=' * 60}\nTI Exhaustive Search\n{'=' * 60}")
    logger.info(f"Project: {config.project_dir}")
    logger.info(f"Subject: {config.subject_id}")

    run_name = config.run_name or time.strftime("%Y%m%d_%H%M%S")
    output_dir = pm.ex_search_run(config.subject_id, run_name)
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Output: {output_dir}")

    roi_file = os.path.join(pm.rois(config.subject_id), config.roi_name)

    if isinstance(config.electrodes, ExConfig.PoolElectrodes):
        pool = config.electrodes.electrodes
        e1_plus = e1_minus = e2_plus = e2_minus = pool
        all_combinations = True
    else:
        e1_plus = config.electrodes.e1_plus
        e1_minus = config.electrodes.e1_minus
        e2_plus = config.electrodes.e2_plus
        e2_minus = config.electrodes.e2_minus
        all_combinations = False

    leadfield_path = os.path.join(pm.leadfields(config.subject_id), config.leadfield_hdf)

    engine = ExSearchEngine(leadfield_path, roi_file, config.roi_name, logger)
    engine.initialize(roi_radius=config.roi_radius)

    ratios = generate_current_ratios(
        config.total_current,
        config.current_step,
        config.channel_limit or config.total_current - config.current_step,
    )

    logger.info(f"Generated {len(ratios)} current ratio combinations")

    results = engine.run(
        e1_plus, e1_minus, e2_plus, e2_minus, ratios, all_combinations, output_dir
    )

    output_info = process_and_save(results, config, output_dir, logger)
    logger.info(f"Config: {output_info['config_json_path']}")
    logger.info(f"CSV: {output_info['csv_path']}")

    return ExResult(
        success=True,
        output_dir=output_dir,
        n_combinations=len(results),
        results_csv=output_info.get("csv_path"),
        config_json=output_info.get("config_json_path"),
    )
run_flex_search(config: FlexConfig) -> FlexResult

Run flex-search optimization from a typed FlexConfig.

Source code in tit/opt/flex/flex.py
def run_flex_search(config: FlexConfig) -> FlexResult:
    """Run flex-search optimization from a typed FlexConfig."""

    from .manifest import write_manifest
    from .utils import generate_label, generate_run_dirname

    pm = get_path_manager(config.project_dir)

    logger = logging.getLogger(__name__)

    n = config.n_multistart

    # Resolve base output folder
    if config.output_folder:
        base_folder = config.output_folder
    else:
        flex_root = pm.flex_search(config.subject_id)
        os.makedirs(flex_root, exist_ok=True)
        dirname = generate_run_dirname(flex_root)
        base_folder = os.path.join(flex_root, dirname)

    os.makedirs(base_folder, exist_ok=True)
    fvals = np.full(n, float("inf"))

    logger.info(
        f"Flex-search ({config.subject_id}): "
        f"goal={config.goal}, postproc={config.postproc}, runs={n}"
    )

    folders = [os.path.join(base_folder, f"{i:02d}") for i in range(n)]

    # -- Run optimizations --
    for i in range(n):
        opt = builder.build_optimization(config)
        opt.output_folder = folders[i]
        os.makedirs(opt.output_folder, exist_ok=True)
        builder.configure_optimizer_options(opt, config, logger)

        step = f"Run {i + 1}/{n}" if n > 1 else "Optimization"
        logger.info(f"├─ {step}: started")

        opt.run(cpus=config.cpus)
        fvals[i] = opt.optim_funvalue
        logger.info(f"├─ {step}: value={fvals[i]:.6f}")

    # -- Select best --
    valid_mask = fvals < float("inf")
    if not valid_mask.any():
        logger.error("All optimization runs failed")
        result = FlexResult(
            success=False,
            output_folder=base_folder,
            function_values=fvals.tolist(),
            best_value=float("inf"),
            best_run_index=-1,
        )
        label = generate_label(config)
        write_manifest(base_folder, config, result, label)
        return result

    best_idx = int(np.argmin(fvals))
    logger.info(f"Best run: #{best_idx + 1} (value={fvals[best_idx]:.6f})")

    # -- Promote best to base folder --
    best_folder = folders[best_idx]
    for item in os.listdir(best_folder):
        src = os.path.join(best_folder, item)
        dst = os.path.join(base_folder, item)
        if os.path.isdir(src):
            if os.path.exists(dst):
                shutil.rmtree(dst)
            shutil.copytree(src, dst)
        else:
            shutil.copy2(src, dst)

    # -- Cleanup temp subdirs --
    for folder in folders:
        if os.path.isdir(folder):
            shutil.rmtree(folder)

    # -- Report --
    builder.generate_report(config, n, fvals, best_idx, base_folder, logger)

    result = FlexResult(
        success=True,
        output_folder=base_folder,
        function_values=fvals.tolist(),
        best_value=float(fvals[best_idx]),
        best_run_index=best_idx,
    )

    # -- Write manifest --
    label = generate_label(config)
    write_manifest(base_folder, config, result, label)

    return result