From fd7b10cf3ae170031ce7327588ce925138442106 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 9 Jun 2026 18:32:57 -0700 Subject: [PATCH 01/35] Add benchmark/autoresearch harness for the flattening stage (Phase A) Provenance-tracked, CPU-only benchmark for optimizing the pyflatten flattening stage, on the public Narratives dataset. Foundation pieces: - paths/ledger/metrics/harness: method-agnostic evaluate(flatten_fn, ...) with uniform per-patch metrics (mean/p90 % distance error, flipped triangles, area distortion) and an append-only JSONL provenance ledger (commit SHA, env, seeds, metrics, artifacts, repro command, decision trace). - build_dataset/run_baseline/report: build the manifest from cached Narratives patches, run the FreeSurfer-clone baseline (+ determinism check), render NOTEBOOK.md. - Project recovery skill (.claude/skills/autoflatten-autoresearch) so any session can recover state from the ledger and resume the loop. - PLAN.md (approved design) + README; bench extra in pyproject; unit tests. Code lives in the repo; all generated artifacts go to /data2/projects/autoflatten. See benchmark/PLAN.md. Co-Authored-By: Claude Fable 5 --- .../skills/autoflatten-autoresearch/SKILL.md | 82 +++++ benchmark/PLAN.md | 338 ++++++++++++++++++ benchmark/README.md | 66 ++++ benchmark/__init__.py | 11 + benchmark/build_dataset.py | 150 ++++++++ benchmark/harness.py | 165 +++++++++ benchmark/ledger.py | 191 ++++++++++ benchmark/metrics.py | 142 ++++++++ benchmark/paths.py | 45 +++ benchmark/report.py | 117 ++++++ benchmark/run_baseline.py | 151 ++++++++ benchmark/tests/__init__.py | 0 benchmark/tests/test_harness.py | 138 +++++++ pyproject.toml | 5 + 14 files changed, 1601 insertions(+) create mode 100644 .claude/skills/autoflatten-autoresearch/SKILL.md create mode 100644 benchmark/PLAN.md create mode 100644 benchmark/README.md create mode 100644 benchmark/__init__.py create mode 100644 benchmark/build_dataset.py create mode 100644 benchmark/harness.py create mode 100644 benchmark/ledger.py create mode 100644 benchmark/metrics.py create mode 100644 benchmark/paths.py create mode 100644 benchmark/report.py create mode 100644 benchmark/run_baseline.py create mode 100644 benchmark/tests/__init__.py create mode 100644 benchmark/tests/test_harness.py diff --git a/.claude/skills/autoflatten-autoresearch/SKILL.md b/.claude/skills/autoflatten-autoresearch/SKILL.md new file mode 100644 index 0000000..cafabf4 --- /dev/null +++ b/.claude/skills/autoflatten-autoresearch/SKILL.md @@ -0,0 +1,82 @@ +--- +name: autoflatten-autoresearch +description: >- + Resume the AutoFlatten flattening-optimization autoresearch effort. Use when the user + says things like "resume the autoflatten autoresearch", "continue optimizing the + flattening", "what have we tried so far", "run the flatten benchmark", or asks to + propose/test a new flattening method or config. Orients a fresh session to the plan, + the provenance ledger, the data, and the conventions, and drives the experiment loop. +--- + +# AutoFlatten autoresearch + +A benchmark-driven effort to improve the **pyflatten flattening stage** (and ultimately +replace its FreeSurfer-clone optimizer with a principled method), with every experiment +logged so the *process* is shareable in a paper. + +## 0. Recover state FIRST (before proposing anything) + +1. Read the design + status: `benchmark/PLAN.md` and `benchmark/README.md` (in this repo). +2. Read the ledger to see everything tried so far, the current baseline, and the last + planned step: + - Ledger: `/data2/projects/autoflatten/ledger/experiments.jsonl` (append-only JSONL). + - Rendered notebook: `/data2/projects/autoflatten/NOTEBOOK.md` + (regenerate with `python -m benchmark.report`). + - Look at the most recent records' `metrics` (the multi-objective vector) and any + `decision.next_step`. **Do not re-run experiments already in the ledger.** +3. Confirm the manifest exists: `/data2/projects/autoflatten/manifest.json` + (else `python -m benchmark.build_dataset`). + +## 1. Environment & constraints + +- Use the repo's **uv** venv: `uv sync --extra bench`, then `.venv/bin/python`. +- **CPU only.** GPUs are blocked by driver 440 / CUDA 10.2 — do **not** install/attempt + CUDA jax unless the user says the driver was upgraded (≥525). +- Dev on a **tiny subset** (`--dev`, 2 subjects / a few hemispheres); scale to the full + 82 only once a change looks good. A single hemisphere takes ~8–15 min on CPU. +- **Output locations are pinned:** code + docs in the repo; *all generated artifacts* go + to `/data2/projects/autoflatten/` (see `benchmark/paths.py`, + override `AUTOFLATTEN_BENCH_ROOT`). Never write generated files into the repo. +- Data is the public **Narratives** set (OpenNeuro `ds002345`); never the private + `all-subjects/` lab data. + +## 2. Core objects + +- `benchmark/harness.py` — `evaluate(entries, config, method=...)`; geometry + + k-ring caching; `register_flatten_fn(name, fn)` to add a method. A `flatten_fn` has + signature `flatten_fn(flattener) -> uv` (shape `(V, 2)`), scored uniformly. +- `benchmark/metrics.py` — `per_patch_metrics(uv, flattener)` and `aggregate(...)`. + Objective stays **geodesic distance distortion** (+ flips, robustness, runtime) — never + switch to pure conformality. +- `benchmark/ledger.py` — `new_record(kind, label, ...)` + `Ledger().append(record)`. + +## 3. The experiment loop + +For each idea: + +1. **State a hypothesis** (what you expect to improve and why). +2. **Implement** a `flatten_fn` (or a `FlattenConfig` change). The primary open lead is a + **Tutte/LSCM flip-free init** that removes the ~4-min initial NAR phase — see + `benchmark/probe_tutte_init.py` (Phase B in the plan). +3. **Evaluate on the dev subset**, then promote to the full train split if promising: + ```bash + python -m benchmark.run_baseline --dev # reference + # ... your probe/optimize script, same evaluate() harness ... + ``` +4. **Append a ledger record** with `kind`, `label`, `method`, `metrics`, `per_subject`, + `repro_command`, and a **decision trace** in `record.decision`: + `{hypothesis, rationale, conclusion, next_step}`. +5. **Commit** the code change in this repo (the ledger pins the commit SHA), then + `python -m benchmark.report` to refresh `NOTEBOOK.md`. +6. If it **Pareto-beats** the baseline (lower distortion and/or fewer flips and/or faster, + none worse), promote to the **full 82-subject** run and validate on the **holdout** + split before claiming a win. + +## 4. Guardrails + +- Every experiment is pinned to a git commit; commit before/after running. +- Determinism is assumed (one run/experiment) but **re-assert it** when you change the + optimizer (`run_baseline.py --check-determinism`). +- Never silently drop subjects — `evaluate` records per-subject `status="error"`; surface + failures in the conclusion. +- Keep the ledger append-only; to revise a result, append a new record. diff --git a/benchmark/PLAN.md b/benchmark/PLAN.md new file mode 100644 index 0000000..7027da3 --- /dev/null +++ b/benchmark/PLAN.md @@ -0,0 +1,338 @@ +# AutoFlatten Optimization Harness — Plan + +## Context + +AutoFlatten works and is already in use, but before the paper we want to know whether the +flattening defaults can be improved and to have a reproducible benchmark backing any claim. +The goal is a **benchmark-driven optimization harness** for the **pyflatten flattening stage** +that (1) measures quality across a fixed set of subjects, (2) runs a classical hyperparameter +search over the existing config knobs to find better defaults, and (3) provides the shared +evaluation interface that a later LLM-driven "autoresearch" loop will reuse to test algorithmic +ideas. Decisions confirmed with the user: + +- **Objective:** multi-objective — lower metric distortion, fewer flipped triangles, robustness + across subjects, and faster runtime (in that rough priority). +- **Scope:** flattening knobs only. Projection/cut-mapping is run **once per subject** and the + resulting patch files are cached; the search never re-runs FreeSurfer. +- **Approach:** staged — build the benchmark harness + baseline first, then a **Tutte/LSCM + flip-free-init probe as the primary experiment**; classical HPO and a broader method bake-off are + optional follow-ons. A later LLM-driven autoresearch loop reuses the same harness. +- **Compute:** **CPU-only.** Two consumer GPUs exist but the driver (440.64 / CUDA 10.2) is too old + for any JAX autoflatten supports; GPU is deferred unless a sysadmin upgrades the driver. The + biggest speed win (Tutte init) is algorithmic and needs no GPU. +- **Replicability:** the benchmark uses the **public Narratives dataset** (OpenNeuro `ds002345`), and + every experiment is logged to a git-tracked provenance ledger so the whole process is shareable in + the paper. +- **User note (important):** the current optimizer is "very bespoke and based on FreeSurfer" — + it is a faithful clone of `mris_flatten` (normal-projection init + hand-tuned 3-epoch gradient + descent + negative-area-removal passes). Tuning its knobs has a low ceiling; the larger win is + likely **replacing the optimizer with a principled method**. The benchmark harness is the + prerequisite that lets us compare *any* method, so it is still the right first build, but the + harness must be **method-agnostic**, not tied to `FlattenConfig`. + +## Reframing: the optimizer is a FreeSurfer clone — principled alternatives + +The current `run()` reproduces FreeSurfer's heuristics. The goal AutoFlatten actually wants — +**preserve geodesic distances** so cortical flatmaps stay metrically readable — is precisely +*metric MDS / stress minimization* with an injectivity (no-flip) constraint. Modern geometry +processing solves this far better than the 1990s line-search heuristic. Candidate replacements, +several available through `libigl` (already a dependency — verify `igl` build exposes them): + +- **Flip-free init:** Tutte/convex embedding or LSCM → guaranteed-injective starting map; removes + the entire reason the initial negative-area-removal phase exists. +- **Injective distortion minimizers:** SLIM, AKVF, progressive parameterization — minimize a + distortion energy *while guaranteeing no flips*, instead of penalizing flips after the fact. +- **Stress majorization (SMACOF):** directly minimizes geodesic-distance stress — the actual + scientific objective — more robustly than the bespoke line search. +- **Better solver, same JAX:** L-BFGS/optax (jaxopt) on a symmetric-Dirichlet + geodesic-stress + energy, reusing the existing k-ring geodesic targets. + +**Scientific constraint:** the objective must remain geodesic **distance** distortion (not pure +conformality), or flatmaps lose their metric interpretation. Alternatives change the *method*, not +the *objective* — the harness measures the same distance-distortion metric for all of them. + +## Key design insight + +The two phases have very different costs. Projection needs FreeSurfer and is slow but only needs +to run **once per subject**. Flattening (JAX) is what we re-run per config. On top of that, the +k-ring distance computation is the expensive part of flattening and already has a disk cache keyed +only on `k_ring`/`n_neighbors` (`get_kring_cache_filename`, `compute_kring_distances(cache_path=...)`). +So: cache patches once, cache k-ring distances per `k_ring`, and each trial only re-runs the cheap +gradient-descent optimization. Metrics come straight from the returned `uv` array — no log parsing. + +## What we build + +New tooling directory `benchmark/` in the repo (package code untouched). + +**Output locations (pinned — nothing written elsewhere):** +- **The repo holds the benchmark CODE + markdown docs/instructions — but no loop-generated files.** + Tooling (`benchmark/*.py`), `PLAN.md`, `README.md`, and the `.claude` skill are version-controlled + with the package; nothing the loop *generates* lands in the repo. +- **ALL artifacts generated by the autoresearch loop → `/data2/projects/autoflatten/`** + (user-designated; exists, empty, writable). This includes *both* text and binary outputs: the + JSONL ledger (`experiments.jsonl`), the generated `NOTEBOOK.md`, dataset `manifest.json`, + candidate config JSONs, per-run flat patches, k-ring distance caches, PNGs, snapshots, and the + Optuna study DB. Suggested layout: + `/data2/projects/autoflatten/{ledger/experiments.jsonl, NOTEBOOK.md, manifest.json, runs//, + kring_cache/, configs/, optuna.db}`. +- **Traceability of the data folder:** make `/data2/projects/autoflatten/` its own + **git (+ datalad/git-annex for binaries) repo** so the ledger and outputs are version-pinned there, + independent of the autoflatten package repo. Each ledger record still stores the producing + autoflatten **commit SHA** + content hashes, linking outputs back to the code that made them. + +### Performance findings (measured) — why speed comes early + +- **A single hemisphere takes ~8–15 min, and the FreeSurfer NAR heuristic dominates it.** In a + sample log (8m21s total): initial negative-area-removal **246.9s** + final NAR **140.9s** = ~6.5 + min just fighting flipped triangles; the 3 main epochs + spring smoothing are the rest. +- **JAX is running on CPU.** The project uses a **uv-managed `.venv`** (uv 0.7.19, `uv.lock` + present) with jax 0.6.2 at `default_backend = cpu`, while two GPUs sit idle (**TITAN RTX 24 GB** + + GTX 980). `pyproject.toml` already depends on `libigl>=2.5.0` (good — the Tutte/LSCM probe needs + it), `jax>=0.4.0` (CPU), `numba`. +- **GPU is blocked by an old driver (decision needed).** Driver **440.64.00 → CUDA 10.2 max**. + autoflatten needs `jax>=0.4`, but every JAX since ~0.3 requires CUDA 11+ (jax 0.6 → CUDA 12). + CUDA 11 needs driver ≥450, CUDA 12 needs ≥525; NVIDIA forward-compat is datacenter-GPU only, so + it can't help these consumer cards. **No modern-jax + CUDA combo runs on this driver without an + upgrade.** Resolution tracked as an open decision (driver upgrade vs CPU-only vs another machine); + see "Open decision: GPU" below. Until resolved, the plan proceeds CPU-only — and the biggest + runtime win (Tutte init removing the ~4-min initial NAR) is CPU-agnostic anyway. +- **Speed and the "bespoke FreeSurfer clone" problem are the same problem.** The Tutte/LSCM + flip-free init (our §6 probe) removes the initial NAR phase (~4 min) *and* improves quality — one + fix, both objectives. + +### Chosen execution order (small-first; per user) + +Guardrail: a tiny harness + 2–3-hemisphere baseline is built **before** any speed work, so every +acceleration can be proven not to change the flatmap (equivalence check). Dev runs on **2–3 +hemispheres** throughout; only Phase C runs all 82. + +**Phase A — Foundation (cheap, ~30 min of compute):** +- A0. **Persist this plan into the repo first (traceability):** copy to `benchmark/PLAN.md` on a + feature branch and commit, so the design is git-tracked alongside the ledger. +- A0b. **Write the project recovery skill (§7)** at `.claude/skills/autoflatten-autoresearch/SKILL.md` + so any new session can recover state and resume the loop. Created early and updated as conventions + settle. +- A1. Experiment **ledger / provenance layer (§0)** — built first; every run writes to it. +- A2. **Method-agnostic harness + dataset manifest (§1–2)** over a 2–3 hemisphere dev subset. +- A3. **Tiny baseline (§3)** on those 2–3 hemispheres (current CPU config) → reference quality + + runtime; validate metric code against v6's existing flat patches. + +**Phase B — Speed (CPU-only; verified against the A3 references).** GPU is off the table on this +driver (decision: CPU-only — see below), so speed comes from algorithm + CPU-side work, each change +asserted not to alter the flatmap on the 2–3 hemispheres: +- B1. **Tutte/LSCM flip-free init (§6 probe) — biggest single win.** Removes the ~4-min initial NAR + phase outright and is simultaneously the quality probe. Verify quality + record speedup. +- B2. **Profile the remainder** and fix top offenders without altering output: k-ring Numba Dijkstra + (CPU — parallelize/cache), JIT recompiles from changing array shapes (pad/bucket shapes for stable + JIT), wasted host↔device copies, and `float64`→`float32` where it doesn't hurt distortion. +- B3. **Parallelize across hemispheres/subjects.** Many CPU cores are available; the two-phase + caching (patches + k-ring distances) means trials are embarrassingly parallel. Reuse the existing + `--parallel` / `n_jobs` paths. + +**Open decision: GPU — RESOLVED (CPU-only for now).** Driver 440/CUDA 10.2 blocks modern jax+CUDA on +these consumer GPUs. Revisit only if a sysadmin upgrades the driver to ≥525 (then a uv `cuda` extra +with `jax[cuda12]`, pinned to the TITAN RTX, becomes a drop-in speedup). Not on the critical path. + +**Phase C — Scale & search:** +- C1. Run the **full 82-subject baseline** (now fast) → the real reference benchmark, all logged. +- C2. The §6 alt-method bake-off and optional §4 HPO, head-to-head on the full set. + +### 0. Experiment ledger & provenance — `benchmark/ledger.py` (foundational) + +Goal: anyone (a reviewer, a co-author, future-you) can see *every* experiment that was run, why it +was run, what it produced, and reproduce it from one command. Grounded in current reproducibility / +agentic-science best practice (see Sources below): log config + code version + environment + seeds + +metrics + artifacts, and — for the autoresearch loop — the **decision trace** (hypothesis → +rationale → result → conclusion → next step), because the agent's reasoning is what makes the +*process* auditable, not just the numbers. + +- **Append-only JSONL ledger** as the single source of truth + (`/data2/projects/autoflatten/ledger/experiments.jsonl`), one record per experiment run. Plain + text → diffable, version-tracked in the data-folder git/datalad repo, trivially shareable as a + paper supplement. No past record is ever mutated. (`benchmark/ledger.py` is the writer code in the + package repo; the ledger *file* it writes lives in the data folder.) +- Each record captures: + - **Identity/provenance:** experiment id, ISO timestamp, **git commit SHA** the run was pinned to, + a flag + saved `git diff` if the tree was dirty, hostname/GPU, and a captured environment + (`pip freeze` / `jax.__version__` / CUDA). + - **Inputs:** manifest id + subject list, the exact method spec — `FlattenConfig.to_dict()` for the + pyflatten backend, or `{method_name, params}` for an alternative `flatten_fn` — and patch file + hashes. Plus the **random seed(s)** used. + - **Outputs:** all per-subject and aggregate metrics, artifact paths (flat patches, PNGs, + snapshots), runtime. + - **Repro command:** the exact CLI string to re-run this experiment. + - **Decision trace (autoresearch only):** `hypothesis`, `rationale`, the applied `code_diff` (or + its commit), `conclusion`, `next_step`. +- **Determinism first:** record seeds; the §3 determinism check asserts identical metrics on rerun + so a single run per experiment is defensible. +- **Artifact provenance:** large binaries live in `/data2/projects/autoflatten/runs//`; the + ledger records each artifact's **absolute path + content hash + producing commit SHA**, so outputs + stay traceable without bloating git. (Optional: make that tree a datalad/git-annex dataset later if + we want versioned artifact history. Any small reference figure that lands *in* the repo follows the + existing git-annex convention.) +- **Auto-generated lab notebook:** `benchmark/report.py` (repo code) renders the JSONL into a + human-readable `NOTEBOOK.md` written to `/data2/projects/autoflatten/` (with comparison + tables/plots) — the inspectable narrative to drop into the paper / supplement. Each entry links its + commit SHA and repro command. +- **Tooling stance:** keep the file-based ledger as the transparent source of truth (maximally + reviewer-friendly); *optionally* mirror runs to a local **MLflow** file store for a browsing UI, + but the paper artifact is the git-tracked JSONL + `NOTEBOOK.md`, not a SaaS dashboard. +- The HPO study DB (Optuna, §4) and this ledger stay consistent: each Optuna trial also emits a + ledger record. + +### 1. Benchmark dataset builder — `benchmark/build_dataset.py` +- **Use the public Narratives dataset for full replicability** (OpenNeuro `ds002345`, Nastase et + al.). The private lab set in `/data2/freesurfer_subjects/all-subjects/` (initials-based IDs like + `AHfs`, `MVauto`) is **not** publicly shareable, so it is dropped from the benchmark. +- **Everything needed is already on disk, git-annex tracked — no fetching, no FreeSurfer** (verified): + `/data2/projects/idem/exps/narratives/datalad-narratives/derivatives/freesurfer/` has **82 + subjects** with both the existing projection output `{hemi}.autoflatten.patch.3d` *and* + materialized base surfaces (`{hemi}.fiducial`/`smoothwm`/`white`/`sphere.reg`). 164 patch files = + the "~80 participants already run". Patches are annex objects → already provenance-pinned. +- **Reuse those patch files** as the cached patches (sidesteps projection, which needs FreeSurfer — + absent here: `FREESURFER_HOME` unset, `mri_label2label` missing). Pair each with + `/surf/{hemi}.fiducial` (fallback `smoothwm`) from the same tree. +- Select a fixed, documented subset (~16–20 of the 82), **split into train/holdout** so tuning + can't overfit. Write `manifest.json` (to `/data2/projects/autoflatten/`): + `{subject, hemi, patch_path, surface_path, split}`, and record the dataset DOI + datalad commit so + the manifest is itself reproducible. +- **Full start-to-finish replication path is public** even though we skip projection locally: a + reader gets the surfaces from OpenNeuro/datalad, installs FreeSurfer + autoflatten, and runs + projection→flatten. Document this in `benchmark/README.md`. + +### 2. Evaluation harness — `benchmark/evaluate.py` (the shared core, method-agnostic) +- Core entry point is **method-agnostic**: `evaluate(flatten_fn, manifest, subset=None) -> dict`, + where `flatten_fn(patch_path, surface_path) -> (uv, faces, orig_indices)`. The current pipeline + is *one* `flatten_fn` (a thin wrapper around `SurfaceFlattener`); Tutte/LSCM/SLIM/SMACOF + prototypes are other `flatten_fn`s evaluated by the *same* metric code. A convenience + `evaluate_config(config: FlattenConfig, ...)` wraps the pyflatten backend for Stage 1 HPO. +- Metrics are computed from `uv` + the patch's geodesic targets **independently of how `uv` was + produced**, so all methods are scored identically. +- The pyflatten `flatten_fn` runs **via the programmatic API** (not CLI subprocess) for speed and + direct metric access: + - `SurfaceFlattener(config)` → `load_data(patch, surface)` → + `compute_kring_distances(cache_path=...)` (reuse cache) → `prepare_optimization()` → `run()`. + - Compute metrics directly from the resulting `uv` and the flattener's internal + neighbor/target arrays — reuse `count_flipped_triangles` and the distance-error routine in + `autoflatten/flatten/algorithm.py`, and `viz.compute_kring_distortion` for the per-vertex + distortion distribution. Record per subject×hemi: mean % distance error, p90 distortion, + flipped-triangle count, area distortion, runtime. +- Aggregate across subjects into the multi-objective vector: + `mean_distortion`, `worst_subject_distortion` (robustness), `total_flipped` / + `frac_subjects_with_flips`, `mean_runtime`. +- Persist every trial's `FlattenConfig.to_dict()` + metrics to a results store (JSONL + Optuna DB). +- **Multi-fidelity:** accept a `subset` of subjects for cheap screening; full set for promising + configs (used by the pruner below). +- Reuse `parse_log_file` (`autoflatten/viz.py`) only as a fallback. + +### 3. Baseline + determinism check — `benchmark/run_baseline.py` +- Confirm flattening is deterministic (run defaults twice, assert metrics match) so one run per + trial suffices. +- Record current-default metrics across the full benchmark as the reference line (also a paper + table). Optionally also run the `freesurfer` backend for a "as good as / better than FreeSurfer" + comparison figure. + +### 4. HPO driver — `benchmark/optimize.py` (OPTIONAL follow-on, not the first build) +- **Optuna** multi-objective study (NSGA-II), SQLite storage at + `/data2/projects/autoflatten/optuna.db` for resumability. Search space over the documented knobs in `autoflatten/flatten/config.py`: + - Per-phase `l_dist` / `l_nlarea` (epochs 1–3), `iters_per_level`, smoothing-schedule shape. + - `initial_scale` (1.0–5.0), NAR `l_dist_ratios` / `base_averages`, final-NAR `base_tol`, + spring smoothing `n_iterations`/`dt`, convergence `base_tol`, line-search `n_coarse_steps`. + - `k_ring` / `n_neighbors_per_ring`: treat as an **outer loop / grouped** so the distance + cache stays valid within a group (cache is keyed on these two only). +- Successive-halving pruner: screen on the train subset, promote survivors to the full train set. +- Outputs: Pareto front, per-parameter importances, and 1–3 candidate configs that beat baseline, + each saved as a loadable JSON config. + +### 5. Validation + report — `benchmark/report.py` +- Re-run candidate configs on the **holdout** subjects to check generalization. +- Emit a comparison table/plot (baseline vs tuned vs FreeSurfer) — a paper-ready artifact. + +### 6. Alt-method probe — `benchmark/probe_tutte_init.py` (PRIMARY experiment) +- A new `flatten_fn` that replaces the FreeSurfer normal-projection + initial-NAR init with a + **guaranteed-injective Tutte/LSCM embedding**, then runs the existing geodesic-stress refinement + (epochs 2–3) on top. Scored by `evaluate(...)` against the §3 baseline on the same subjects. + - Tutte embedding: map the patch boundary to a convex polygon, solve the harmonic/Tutte linear + system for interior vertices (guaranteed flip-free for disk topology — which patches are). + LSCM is the conformal alternative; both available via `libigl` if the installed `igl` exposes + them (verify; else Tutte is a small SciPy sparse solve). + - Hypothesis: a flip-free init removes the entire initial-NAR phase and reaches equal/lower + distance distortion with **zero flips** and less runtime. +- Compare on: mean/worst-subject distance distortion, total flips, runtime. Reuse the §5 report. +- If promising, generalize into a clean backend and (optionally) drive further alternatives + (SLIM, SMACOF, L-BFGS/optax) — by hand or via an LLM **autoresearch loop** (Workflow tool, agents + proposing `flatten_fn`s on isolated worktrees). Deferred; specced in `benchmark/README.md`. + +### 7. Project recovery skill — `.claude/skills/autoflatten-autoresearch/SKILL.md` + +A git-tracked **project skill** so any fresh session (or collaborator) can recover the full state of +this autoresearch effort and resume without re-deriving anything. Lives in the repo (`.claude/` is +tracked, not gitignored). The skill is the single entry point that ties together the plan, ledger, +data, and conventions. + +- **Trigger / description:** invoked when the user says things like "resume the autoflatten + autoresearch", "continue optimizing the flattening", "what have we tried so far", or runs the + benchmark loop. +- **Contents (instructions the skill encodes):** + - **Orientation:** point to `benchmark/PLAN.md` + `benchmark/*.py` (plan + code, in the repo) and + the canonical generated paths in the data folder — ledger + `/data2/projects/autoflatten/ledger/experiments.jsonl`, notebook + `/data2/projects/autoflatten/NOTEBOOK.md`, dataset `/data2/projects/autoflatten/manifest.json`, + run artifacts under `/data2/projects/autoflatten/runs/`. + - **Recover state first:** read the ledger + `NOTEBOOK.md` to see the current baseline, what + methods/configs were tried, their metrics, and the last `next_step` in the decision trace — + *before* proposing anything new. + - **Environment:** use the repo's uv `.venv` (`uv sync --extra bench`); **CPU-only** (GPU blocked + by driver 440 — don't attempt CUDA); start dev on 2–3 hemispheres, scale to all 82 only for full + runs. + - **The loop:** state a hypothesis → implement a `flatten_fn` or config change → run + `evaluate(...)` on the dev subset → **append a ledger record** (with the decision trace: + hypothesis/rationale/conclusion/next_step) → commit → regenerate `NOTEBOOK.md` → if it + Pareto-beats baseline, promote to the full 82-subject run. + - **Conventions / guardrails:** every experiment pinned to a git commit; determinism assumed + (single run) but re-asserted when the optimizer changes; objective stays geodesic *distance* + distortion; never write outside the two pinned output locations. + - **Commands cheat-sheet:** how to run `build_dataset.py`, `run_baseline.py`, `probe_tutte_init.py`, + `evaluate`, and the optional `optimize.py`; how to read/append the ledger; how to render the + notebook. + +## Critical files +- Reuse (read/import, do not modify): `autoflatten/flatten/algorithm.py` (`SurfaceFlattener`, + `count_flipped_triangles`, distance-error routine), `autoflatten/flatten/config.py` + (`FlattenConfig.to_dict/from_dict/from_json_file`, `get_kring_cache_filename`), + `autoflatten/backends/pyflatten.py` (reference for the call sequence), `autoflatten/viz.py` + (`compute_kring_distortion`, `parse_log_file`), `run_all.sh` (projection batch pattern). +- New: `.claude/skills/autoflatten-autoresearch/SKILL.md` (project recovery skill, §7), + `benchmark/ledger.py` (provenance/JSONL ledger, written by every run), + `benchmark/build_dataset.py`, `benchmark/evaluate.py`, `benchmark/run_baseline.py`, + `benchmark/probe_tutte_init.py`, `benchmark/optimize.py` (optional), `benchmark/report.py` + (renders `NOTEBOOK.md` + comparison tables), `benchmark/README.md`, plus a **uv** optional-dependency + extra `bench` (optuna, mlflow optional) added to `pyproject.toml` and locked in `uv.lock` for + reproducibility; everything runs in the repo's uv `.venv` (`uv sync --extra bench`). Large + artifacts via git-annex. + +## Verification +1. `benchmark/build_dataset.py` produces a `manifest.json` and cached patches for the chosen + subjects (spot-check one with `autoflatten plot-projection`). +2. `run_baseline.py`: determinism assertion passes; baseline metrics match prior + `test-autoflatten-v6` logs within tolerance. +3. `evaluate_config(default_config)` reproduces the baseline numbers from the harness's own + metric code (cross-check against a `parse_log_file` of a normal CLI run on one subject). +4. **Probe:** `probe_tutte_init.py` produces a flip-free initial embedding (assert 0 flipped + triangles at init) on a sample patch, runs the geodesic refinement, and `evaluate(...)` scores + it against the baseline. Success signal: equal-or-lower distance distortion with fewer flips + and lower runtime on the train subjects; confirm on holdout via `report.py`. +5. **Ledger:** every run above appends a record to `experiments.jsonl` with commit SHA, env, seeds, + config/method, metrics, artifact paths, and repro command; `report.py` regenerates `NOTEBOOK.md` + from it. Spot-check: pick a past record, run its repro command, confirm metrics reproduce. +6. (Optional follow-on) `optimize.py` smoke run (a few trials on 2–3 subjects) writes the Optuna + DB and a candidate config JSON that `FlattenConfig.from_json_file` loads, each trial also logged + to the ledger. + +## Sources (best practices behind the ledger design) +- [Provenance Tracking in Large-Scale ML Systems](https://arxiv.org/html/2507.01075v1) +- [ML Pipelines: Provenance, Reproducibility and FAIR Data Principles](https://arxiv.org/pdf/2006.12117) +- [Versioning, Provenance, and Reproducibility (CMU MLiP book)](https://mlip-cmu.github.io/book/24-versioning-provenance-and-reproducibility.html) +- [From AI for Science to Agentic Science: A Survey on Autonomous Scientific Discovery](https://arxiv.org/html/2508.14111v1) +- [The AI Scientist (Sakana AI)](https://sakana.ai/ai-scientist/) — stores all executed files per experiment +- [Experiment-as-Code Labs: A Declarative Stack for AI-Driven Scientific Discovery](https://arxiv.org/html/2605.04375v2) diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 0000000..ba5f921 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,66 @@ +# AutoFlatten benchmark / autoresearch harness + +Reproducible, provenance-tracked benchmark for optimizing the **pyflatten flattening +stage**. The full design and rationale live in [`PLAN.md`](PLAN.md); this README is the +operational quick-start. Session-recovery instructions for an agent live in +[`.claude/skills/autoflatten-autoresearch/SKILL.md`](../.claude/skills/autoflatten-autoresearch/SKILL.md). + +## Layout (where things go) + +- **Code + docs → this repo** (`benchmark/*.py`, `PLAN.md`, `README.md`). +- **Every generated artifact → `/data2/projects/autoflatten/`** (override with + `AUTOFLATTEN_BENCH_ROOT`): the JSONL ledger, `NOTEBOOK.md`, `manifest.json`, candidate + configs, flat patches, k-ring caches, and the Optuna DB. *Nothing generated lands in + the repo.* See [`paths.py`](paths.py). + +## Data + +Public **Narratives** dataset (OpenNeuro `ds002345`). 82 subjects already have a +projection patch (`{hemi}.autoflatten.patch.3d`) + a base surface (`{hemi}.fiducial`) +under the FreeSurfer derivatives, so the benchmark reuses those and **needs no FreeSurfer +and no fetching**. The private lab set (`/data2/freesurfer_subjects/all-subjects/`) is +*not* used — it isn't publicly shareable. + +## Compute + +**CPU only.** The machine's GPUs are blocked by an old driver (440 / CUDA 10.2), too old +for any JAX `autoflatten` supports. Don't attempt CUDA unless a sysadmin upgrades the +driver to ≥525. The biggest speed win (Tutte flip-free init) is algorithmic anyway. + +## Quick start + +```bash +uv sync --extra bench + +# 1. Build the benchmark manifest (deterministic train/holdout split) +python -m benchmark.build_dataset # ~18 subjects +python -m benchmark.build_dataset --dev # tiny: 2 subjects (dev loop) + +# 2. Baseline (current FreeSurfer-clone defaults) + determinism check +python -m benchmark.run_baseline --dev --check-determinism + +# 3. Render the lab notebook from the ledger +python -m benchmark.report +``` + +## How the harness is structured + +- [`paths.py`](paths.py) — the two pinned roots. +- [`ledger.py`](ledger.py) — append-only JSONL provenance ledger (commit SHA, env, seeds, + metrics, artifacts, repro command, decision trace). **Every run appends a record.** +- [`metrics.py`](metrics.py) — uniform per-patch metrics (mean/p90 % distance error, + flipped triangles, area distortion) computed from `uv` independent of method, plus + cross-subject aggregation into the multi-objective vector. +- [`harness.py`](harness.py) — owns geometry (loads patch+surface, computes/caches k-ring + targets), runs a `flatten_fn(flattener) -> uv`, and scores it. The current pipeline is + one `flatten_fn`; alternatives register via `register_flatten_fn`. +- [`build_dataset.py`](build_dataset.py), [`run_baseline.py`](run_baseline.py), + [`report.py`](report.py) — the Phase-A commands above. +- `probe_tutte_init.py`, `optimize.py` — Phase B/C (alt-method probe, optional HPO). + +## Full public replication path + +We skip projection locally (reusing cached patches), but the whole pipeline is publicly +reproducible: get the surfaces from OpenNeuro/datalad `ds002345`, install FreeSurfer + +`autoflatten`, run `autoflatten project` then `autoflatten flatten`. The manifest records +the dataset DOI and datalad commit. diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..b19a390 --- /dev/null +++ b/benchmark/__init__.py @@ -0,0 +1,11 @@ +"""AutoFlatten benchmark / autoresearch harness. + +Dev tooling (not part of the installed ``autoflatten`` package) for a reproducible, +provenance-tracked optimization benchmark of the pyflatten flattening stage. + +See ``benchmark/PLAN.md`` for the full design and ``.claude/skills/autoflatten-autoresearch/`` +for session-recovery instructions. + +Code lives in the repo; **all generated artifacts** (ledger, manifest, configs, flat +patches, caches) are written under :data:`benchmark.paths.DATA_ROOT`. +""" diff --git a/benchmark/build_dataset.py b/benchmark/build_dataset.py new file mode 100644 index 0000000..971c928 --- /dev/null +++ b/benchmark/build_dataset.py @@ -0,0 +1,150 @@ +"""Build the benchmark manifest from the public Narratives FreeSurfer derivatives. + +Selects subjects that already have a projection patch (``{hemi}.autoflatten.patch.3d``) +and a materialized base surface (``{hemi}.fiducial``, fallback ``smoothwm``), then writes +a documented ``manifest.json`` with a deterministic train/holdout split. + +No FreeSurfer, no fetching: everything is already on disk and git-annex tracked. + +Usage +----- + python -m benchmark.build_dataset # default ~18-subject benchmark + python -m benchmark.build_dataset --dev # tiny: 2 subjects (4 hemis) + python -m benchmark.build_dataset --n-subjects 30 # custom size +""" + +from __future__ import annotations + +import argparse +import json +import subprocess +from datetime import datetime, timezone +from pathlib import Path + +from . import paths + +HEMIS = ("lh", "rh") +DATASET_DOI = "10.18112/openneuro.ds002345" # OpenNeuro Narratives (Nastase et al.) + + +def _datalad_commit() -> str | None: + try: + out = subprocess.run( + ["git", "-C", str(paths.NARRATIVES_FS), "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + return out.stdout.strip() + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + +def _base_surface(surf: Path, hemi: str) -> Path | None: + for name in (f"{hemi}.fiducial", f"{hemi}.smoothwm"): + p = surf / name + if p.exists(): # follows symlink -> True only if annex content is present + return p + return None + + +def discover_subjects() -> list[dict]: + """Return manifest entries for every (subject, hemi) with patch + base surface present.""" + entries = [] + for subj_dir in sorted(paths.NARRATIVES_FS.glob("sub-*")): + surf = subj_dir / "surf" + if not surf.is_dir(): + continue + for hemi in HEMIS: + patch = surf / f"{hemi}.autoflatten.patch.3d" + base = _base_surface(surf, hemi) + if patch.exists() and base is not None: + entries.append( + { + "subject": subj_dir.name, + "hemi": hemi, + "patch_path": str(patch), + "surface_path": str(base), + "surface_kind": base.name.split(".", 1)[1], + } + ) + return entries + + +def assign_splits(entries: list[dict], holdout_every: int = 3) -> list[dict]: + """Deterministic split: every ``holdout_every``-th *subject* (sorted) is holdout. + + Splitting by subject (not hemisphere) keeps both hemispheres of a subject in the same + split, avoiding leakage. No RNG, so the split is fully reproducible. + """ + subjects = sorted({e["subject"] for e in entries}) + holdout = {s for i, s in enumerate(subjects) if i % holdout_every == 0} + for e in entries: + e["split"] = "holdout" if e["subject"] in holdout else "train" + return entries + + +def build(n_subjects: int | None, dev: bool, holdout_every: int) -> dict: + all_entries = discover_subjects() + all_subjects = sorted({e["subject"] for e in all_entries}) + + if dev: + n_subjects = 2 + if n_subjects is not None: + keep = set(all_subjects[:n_subjects]) + entries = [e for e in all_entries if e["subject"] in keep] + else: + entries = all_entries + + entries = assign_splits(entries, holdout_every=holdout_every) + subjects = sorted({e["subject"] for e in entries}) + + return { + "dataset": "narratives", + "dataset_doi": DATASET_DOI, + "dataset_source": str(paths.NARRATIVES_FS), + "datalad_commit": _datalad_commit(), + "created": datetime.now(timezone.utc).isoformat(), + "selection": { + "rule": "first N subjects (sorted) with patch+base surface present; " + f"every {holdout_every}rd subject -> holdout", + "n_subjects_requested": n_subjects, + "holdout_every": holdout_every, + "dev": dev, + }, + "n_subjects": len(subjects), + "n_subjects_available": len(all_subjects), + "n_entries": len(entries), + "n_train": sum(e["split"] == "train" for e in entries), + "n_holdout": sum(e["split"] == "holdout" for e in entries), + "entries": entries, + } + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "--n-subjects", type=int, default=18, help="benchmark size (subjects)" + ) + ap.add_argument("--dev", action="store_true", help="tiny 2-subject dev manifest") + ap.add_argument("--holdout-every", type=int, default=3) + ap.add_argument("--out", type=Path, default=paths.MANIFEST_PATH) + args = ap.parse_args() + + paths.ensure_output_dirs() + manifest = build(args.n_subjects, args.dev, args.holdout_every) + args.out.parent.mkdir(parents=True, exist_ok=True) + with open(args.out, "w") as f: + json.dump(manifest, f, indent=2) + + print(f"Wrote manifest -> {args.out}") + print( + f" {manifest['n_subjects']} subjects " + f"({manifest['n_subjects_available']} available), " + f"{manifest['n_entries']} hemispheres " + f"[train={manifest['n_train']}, holdout={manifest['n_holdout']}]" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/harness.py b/benchmark/harness.py new file mode 100644 index 0000000..4f94095 --- /dev/null +++ b/benchmark/harness.py @@ -0,0 +1,165 @@ +"""Method-agnostic evaluation harness for the flattening stage. + +The harness owns the *geometry* (it loads the patch + base surface and computes the +k-ring geodesic targets once, with on-disk caching), then hands a prepared +:class:`~autoflatten.flatten.algorithm.SurfaceFlattener` to a ``flatten_fn`` and scores +whatever ``uv`` it returns with the uniform metrics in :mod:`benchmark.metrics`. + +A ``flatten_fn`` has signature ``flatten_fn(flattener) -> uv`` (shape ``(V, 2)``). +The current pipeline is just one such function; Tutte/LSCM/SLIM/SMACOF prototypes are +others, all scored identically. +""" + +from __future__ import annotations + +import json +import time +from pathlib import Path +from typing import Any, Callable, Optional + +import numpy as np + +from . import paths + +FlattenFn = Callable[[Any], np.ndarray] + + +# --------------------------------------------------------------------------------- +# Built-in flatten_fns +# --------------------------------------------------------------------------------- +def pyflatten_flatten_fn(flattener: Any) -> np.ndarray: + """The current FreeSurfer-clone optimizer: run the full pyflatten pipeline.""" + return flattener.run() + + +FLATTEN_FNS: dict[str, FlattenFn] = { + "pyflatten": pyflatten_flatten_fn, +} + + +def register_flatten_fn(name: str, fn: FlattenFn) -> None: + """Register an alternative method so it can be evaluated by name.""" + FLATTEN_FNS[name] = fn + + +# --------------------------------------------------------------------------------- +# Manifest +# --------------------------------------------------------------------------------- +def load_manifest(path: str | Path = paths.MANIFEST_PATH) -> dict[str, Any]: + """Load the benchmark manifest (subjects, hemis, patch/surface paths, splits).""" + with open(path) as f: + return json.load(f) + + +def select_entries( + manifest: dict[str, Any], + split: Optional[str] = None, + subset: Optional[int] = None, +) -> list[dict[str, Any]]: + """Filter manifest entries by split and/or take the first ``subset`` of them.""" + entries = manifest["entries"] + if split is not None: + entries = [e for e in entries if e.get("split") == split] + if subset is not None: + entries = entries[:subset] + return entries + + +# --------------------------------------------------------------------------------- +# Geometry + evaluation +# --------------------------------------------------------------------------------- +def _kring_cache_path(entry: dict[str, Any], config: Any) -> Path: + """Per-patch k-ring cache path, keyed on (subject, hemi, k_ring, n_neighbors). + + The cache is independent of energy weights, so configs that share a k-ring share + this file — the expensive geodesic computation is paid once per (subject, hemi, k). + """ + k = config.kring.k_ring + n = config.kring.n_neighbors_per_ring + nstr = "all" if n is None else str(n) + name = f"{entry['subject']}_{entry['hemi']}.kring_k{k}_n{nstr}.npz" + return paths.KRING_CACHE_DIR / name + + +def build_flattener( + entry: dict[str, Any], + config: Any, + use_cache: bool = True, +) -> Any: + """Load a patch + base surface and prepare a flattener (geometry + k-ring + JAX). + + Returns a :class:`SurfaceFlattener` ready for a ``flatten_fn``. Reuses the on-disk + k-ring cache when available. + """ + from autoflatten.flatten import SurfaceFlattener + + flattener = SurfaceFlattener(config) + flattener.load_data(entry["patch_path"], entry["surface_path"]) + cache_path = str(_kring_cache_path(entry, config)) if use_cache else None + if cache_path: + paths.KRING_CACHE_DIR.mkdir(parents=True, exist_ok=True) + flattener.compute_kring_distances(cache_path=cache_path) + flattener.prepare_optimization() + return flattener + + +def evaluate_one( + entry: dict[str, Any], + config: Any, + flatten_fn: FlattenFn, + save_dir: Optional[Path] = None, + use_cache: bool = True, +) -> dict[str, Any]: + """Evaluate one (subject, hemi): flatten and score. + + Returns a per-patch result dict (subject/hemi + metrics + runtime + status). On + failure, ``status="error"`` with the exception message, so a bad subject doesn't + sink the whole run. + """ + from .metrics import per_patch_metrics + + result: dict[str, Any] = { + "subject": entry["subject"], + "hemi": entry["hemi"], + "split": entry.get("split"), + "status": "ok", + } + try: + flattener = build_flattener(entry, config, use_cache=use_cache) + t0 = time.time() + uv = np.asarray(flatten_fn(flattener)) + result["runtime_s"] = time.time() - t0 + result.update(per_patch_metrics(uv, flattener)) + + if save_dir is not None: + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + out_path = save_dir / f"{entry['subject']}.{entry['hemi']}.flat.patch.3d" + flattener.save_result(uv, str(out_path)) + result["artifact"] = str(out_path) + except Exception as exc: # noqa: BLE001 - record and continue + result["status"] = "error" + result["error"] = f"{type(exc).__name__}: {exc}" + return result + + +def evaluate( + entries: list[dict[str, Any]], + config: Any, + method: str = "pyflatten", + save_dir: Optional[Path] = None, + use_cache: bool = True, +) -> dict[str, Any]: + """Evaluate a list of manifest entries with one method/config. + + Returns ``{"per_subject": [...], "aggregate": {...}}``. Aggregation is the + multi-objective vector defined in :func:`benchmark.metrics.aggregate`. + """ + from .metrics import aggregate + + flatten_fn = FLATTEN_FNS[method] + per_subject = [ + evaluate_one(e, config, flatten_fn, save_dir=save_dir, use_cache=use_cache) + for e in entries + ] + return {"per_subject": per_subject, "aggregate": aggregate(per_subject)} diff --git a/benchmark/ledger.py b/benchmark/ledger.py new file mode 100644 index 0000000..7bb0874 --- /dev/null +++ b/benchmark/ledger.py @@ -0,0 +1,191 @@ +"""Append-only provenance ledger for benchmark experiments. + +Every experiment run appends one JSON record to ``experiments.jsonl`` (under +:data:`benchmark.paths.DATA_ROOT`). The ledger is the single source of truth for the +autoresearch process and is meant to be shared as a paper supplement, so each record is +self-contained and reproducible: it pins the autoflatten **git commit SHA**, captures the +environment and seeds, and stores all metrics, artifact paths, and a repro command. + +For autoresearch-loop steps, a record may also carry a **decision trace** +(hypothesis / rationale / conclusion / next_step) so the *reasoning* is auditable, not +just the numbers. + +Records are never mutated. To "update" an experiment, append a new record. +""" + +from __future__ import annotations + +import getpass +import hashlib +import json +import platform +import socket +import subprocess +import sys +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +from . import paths + +REPO_ROOT = Path(__file__).resolve().parent.parent + + +# --------------------------------------------------------------------------------- +# Provenance capture +# --------------------------------------------------------------------------------- +def _git(*args: str) -> Optional[str]: + try: + out = subprocess.run( + ["git", "-C", str(REPO_ROOT), *args], + capture_output=True, + text=True, + check=True, + ) + return out.stdout.strip() + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + +def capture_git_provenance() -> dict[str, Any]: + """Capture the autoflatten repo commit, branch, and dirty-tree diff.""" + commit = _git("rev-parse", "HEAD") + branch = _git("rev-parse", "--abbrev-ref", "HEAD") + status = _git("status", "--porcelain") + dirty = bool(status) + diff = _git("diff", "HEAD") if dirty else "" + return { + "commit": commit, + "branch": branch, + "dirty": dirty, + # Truncate a runaway diff but keep enough to reconstruct intent. + "diff": (diff or "")[:200_000], + } + + +def capture_environment() -> dict[str, Any]: + """Capture python, platform, and key package versions (incl. JAX backend).""" + env: dict[str, Any] = { + "python": sys.version.split()[0], + "platform": platform.platform(), + "hostname": socket.gethostname(), + "user": getpass.getuser(), + } + try: + import jax + + env["jax_version"] = jax.__version__ + env["jax_backend"] = jax.default_backend() + env["jax_devices"] = [str(d) for d in jax.devices()] + except Exception as exc: # pragma: no cover - environment dependent + env["jax_error"] = repr(exc) + for pkg in ("numpy", "scipy", "numba", "igl", "optuna"): + try: + mod = __import__(pkg) + env[f"{pkg}_version"] = getattr(mod, "__version__", "unknown") + except Exception: + pass + return env + + +def file_hash(path: str | Path, algo: str = "md5") -> Optional[str]: + """Return a content hash of ``path`` (resolving symlinks), or None if missing.""" + p = Path(path) + if not p.exists(): + return None + h = hashlib.new(algo) + with open(p, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + return f"{algo}:{h.hexdigest()}" + + +def _utcnow() -> str: + return datetime.now(timezone.utc).isoformat() + + +# --------------------------------------------------------------------------------- +# Record + ledger +# --------------------------------------------------------------------------------- +@dataclass +class ExperimentRecord: + """One experiment run. Serialized as a single JSONL line.""" + + # Identity / provenance + experiment_id: str + timestamp: str + kind: str # e.g. "baseline", "probe", "hpo_trial" + label: str # human-readable, e.g. "baseline:pyflatten-defaults" + git: dict[str, Any] + environment: dict[str, Any] + + # Inputs + manifest_id: Optional[str] = None + subjects: list[dict[str, Any]] = field(default_factory=list) + method: dict[str, Any] = field(default_factory=dict) # {name, params/config} + seeds: dict[str, Any] = field(default_factory=dict) + + # Outputs + metrics: dict[str, Any] = field(default_factory=dict) # aggregate + per_subject: list[dict[str, Any]] = field(default_factory=list) + artifacts: list[dict[str, Any]] = field(default_factory=list) # {path, hash, kind} + runtime_s: Optional[float] = None + status: str = "ok" # "ok" | "error" + error: Optional[str] = None + + # Repro + repro_command: Optional[str] = None + + # Decision trace (autoresearch only) + decision: dict[str, Any] = field(default_factory=dict) + + def to_json(self) -> str: + return json.dumps(asdict(self), default=str) + + +def new_record(kind: str, label: str, **kwargs: Any) -> ExperimentRecord: + """Build a record pre-filled with a fresh id, timestamp, and provenance.""" + return ExperimentRecord( + experiment_id=uuid.uuid4().hex[:12], + timestamp=_utcnow(), + kind=kind, + label=label, + git=capture_git_provenance(), + environment=capture_environment(), + **kwargs, + ) + + +class Ledger: + """Append-only JSONL ledger of experiments.""" + + def __init__(self, path: str | Path = paths.LEDGER_PATH): + self.path = Path(path) + + def append(self, record: ExperimentRecord) -> ExperimentRecord: + """Append a record (creating the ledger dir if needed) and return it.""" + self.path.parent.mkdir(parents=True, exist_ok=True) + with open(self.path, "a") as f: + f.write(record.to_json() + "\n") + return record + + def read(self) -> list[dict[str, Any]]: + """Read all records as dicts (oldest first). Empty if the ledger is absent.""" + if not self.path.exists(): + return [] + records = [] + with open(self.path) as f: + for line in f: + line = line.strip() + if line: + records.append(json.loads(line)) + return records + + def latest(self, kind: Optional[str] = None) -> Optional[dict[str, Any]]: + """Return the most recent record, optionally filtered by ``kind``.""" + recs = self.read() + if kind is not None: + recs = [r for r in recs if r.get("kind") == kind] + return recs[-1] if recs else None diff --git a/benchmark/metrics.py b/benchmark/metrics.py new file mode 100644 index 0000000..27542b9 --- /dev/null +++ b/benchmark/metrics.py @@ -0,0 +1,142 @@ +"""Quality metrics for a flattened patch, computed uniformly across methods. + +Metrics are computed from the 2D ``uv`` coordinates plus the patch's geodesic targets +(the prepared :class:`~autoflatten.flatten.algorithm.SurfaceFlattener` arrays), so they +do **not** depend on *how* ``uv`` was produced. Any ``flatten_fn`` is scored identically. + +Reuses the package's own metric kernels so the harness reports the same numbers the CLI +log does: + +- :func:`autoflatten.flatten.algorithm.count_flipped_triangles` +- :func:`autoflatten.flatten.algorithm._compute_distance_error_jit` +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + + +def per_patch_metrics(uv: np.ndarray, flattener: Any) -> dict[str, float]: + """Compute quality metrics for one flattened patch. + + Parameters + ---------- + uv : ndarray, shape (V, 2) + Flattened coordinates aligned to ``flattener``'s vertex ordering. + flattener : SurfaceFlattener + A flattener for which ``prepare_optimization`` has run, exposing + ``neighbors_jax``, ``targets_jax``, ``mask_jax``, and ``faces_jax``. + + Returns + ------- + dict + ``mean_distortion`` (mean % distance error, FreeSurfer formula), + ``p90_distortion`` (90th-percentile per-vertex % error), + ``n_flipped`` (flipped/negative-area triangles), + ``frac_flipped`` (fraction of faces flipped), + ``area_distortion`` (|total 2D area / 3D area - 1|). + """ + import jax.numpy as jnp + + from autoflatten.flatten.algorithm import ( + _compute_distance_error_jit, + count_flipped_triangles, + ) + + uv_jax = jnp.asarray(uv) + faces = flattener.faces_jax + neighbors = flattener.neighbors_jax + targets = flattener.targets_jax + mask = flattener.mask_jax + + mean_distortion = float( + _compute_distance_error_jit(uv_jax, neighbors, targets, mask) + ) + n_flipped = int(count_flipped_triangles(uv_jax, faces)) + n_faces = int(np.asarray(faces).shape[0]) + + # Per-vertex distortion distribution (for robustness / tail metrics). + p90 = _per_vertex_p90( + np.asarray(uv), np.asarray(neighbors), np.asarray(targets), np.asarray(mask) + ) + + # Area distortion: how much total flattened area departs from the 3D patch area. + area_distortion = _area_distortion( + np.asarray(uv), np.asarray(faces), flattener.orig_area + ) + + return { + "mean_distortion": mean_distortion, + "p90_distortion": p90, + "n_flipped": n_flipped, + "frac_flipped": (n_flipped / n_faces) if n_faces else 0.0, + "area_distortion": area_distortion, + } + + +def _per_vertex_p90( + uv: np.ndarray, neighbors: np.ndarray, targets: np.ndarray, mask: np.ndarray +) -> float: + """90th percentile of per-vertex mean % distance error.""" + valid = mask & (targets > 0) + d2d = np.linalg.norm(uv[neighbors] - uv[:, None, :], axis=-1) + abs_err = np.where(valid, np.abs(d2d - targets), 0.0) + n_valid = valid.sum(axis=1) + denom = np.where(valid, targets, 0.0).sum(axis=1) + has = (n_valid > 0) & (denom > 0) + per_vertex = np.zeros(uv.shape[0]) + per_vertex[has] = 100.0 * abs_err.sum(axis=1)[has] / denom[has] + if not has.any(): + return float("nan") + return float(np.percentile(per_vertex[has], 90)) + + +def _area_distortion(uv: np.ndarray, faces: np.ndarray, orig_area: float) -> float: + """|sum(|2D triangle area|) / orig_3d_area - 1|.""" + v0, v1, v2 = uv[faces[:, 0]], uv[faces[:, 1]], uv[faces[:, 2]] + areas = 0.5 * ( + (v1[:, 0] - v0[:, 0]) * (v2[:, 1] - v0[:, 1]) + - (v2[:, 0] - v0[:, 0]) * (v1[:, 1] - v0[:, 1]) + ) + total_2d = float(np.abs(areas).sum()) + if not orig_area: + return float("nan") + return abs(total_2d / orig_area - 1.0) + + +# --------------------------------------------------------------------------------- +# Aggregation across subjects/hemispheres -> the multi-objective vector +# --------------------------------------------------------------------------------- +def aggregate(per_subject: list[dict[str, Any]]) -> dict[str, float]: + """Aggregate per-patch metrics into the benchmark's multi-objective vector. + + Returns mean and worst-case distortion (robustness), total/fraction of patches with + flips, and mean runtime. + """ + ok = [ + r + for r in per_subject + if r.get("status", "ok") == "ok" and "mean_distortion" in r + ] + if not ok: + return {"n_patches": 0, "n_failed": len(per_subject)} + + md = np.array([r["mean_distortion"] for r in ok], dtype=float) + p90 = np.array([r["p90_distortion"] for r in ok], dtype=float) + flips = np.array([r["n_flipped"] for r in ok], dtype=float) + rt = np.array([r.get("runtime_s", np.nan) for r in ok], dtype=float) + + return { + "n_patches": len(ok), + "n_failed": len(per_subject) - len(ok), + "mean_distortion": float(np.mean(md)), + "worst_distortion": float(np.max(md)), + "median_distortion": float(np.median(md)), + "mean_p90_distortion": float(np.nanmean(p90)), + "total_flipped": int(flips.sum()), + "frac_patches_with_flips": float(np.mean(flips > 0)), + "mean_runtime_s": float(np.nanmean(rt)), + "total_runtime_s": float(np.nansum(rt)), + } diff --git a/benchmark/paths.py b/benchmark/paths.py new file mode 100644 index 0000000..852f149 --- /dev/null +++ b/benchmark/paths.py @@ -0,0 +1,45 @@ +"""Canonical paths for the AutoFlatten benchmark. + +Two roots, deliberately separated (see ``benchmark/PLAN.md``): + +- The **repo** holds code + markdown docs only. +- :data:`DATA_ROOT` holds *every* loop-generated artifact (ledger, manifest, configs, + flat patches, k-ring caches, the Optuna DB, the rendered notebook). + +Nothing is written outside :data:`DATA_ROOT`. +""" + +from __future__ import annotations + +import os +from pathlib import Path + +# --- Output root (all generated artifacts live here) ------------------------------ +DATA_ROOT = Path( + os.environ.get("AUTOFLATTEN_BENCH_ROOT", "/data2/projects/autoflatten") +) + +LEDGER_DIR = DATA_ROOT / "ledger" +LEDGER_PATH = LEDGER_DIR / "experiments.jsonl" +MANIFEST_PATH = DATA_ROOT / "manifest.json" +NOTEBOOK_PATH = DATA_ROOT / "NOTEBOOK.md" +RUNS_DIR = DATA_ROOT / "runs" +KRING_CACHE_DIR = DATA_ROOT / "kring_cache" +CONFIGS_DIR = DATA_ROOT / "configs" +OPTUNA_DB = DATA_ROOT / "optuna.db" + +# --- Input data (read-only) ------------------------------------------------------- +# Public Narratives FreeSurfer derivatives (OpenNeuro ds002345). 82 subjects already +# have {hemi}.autoflatten.patch.3d + materialized base surfaces. +NARRATIVES_FS = Path( + os.environ.get( + "AUTOFLATTEN_NARRATIVES_FS", + "/data2/projects/idem/exps/narratives/datalad-narratives/derivatives/freesurfer", + ) +) + + +def ensure_output_dirs() -> None: + """Create the artifact subdirectories under :data:`DATA_ROOT` if missing.""" + for d in (LEDGER_DIR, RUNS_DIR, KRING_CACHE_DIR, CONFIGS_DIR): + d.mkdir(parents=True, exist_ok=True) diff --git a/benchmark/report.py b/benchmark/report.py new file mode 100644 index 0000000..42b535d --- /dev/null +++ b/benchmark/report.py @@ -0,0 +1,117 @@ +"""Render the experiment ledger into a human-readable lab notebook. + +Reads ``experiments.jsonl`` and writes ``NOTEBOOK.md`` (under +:data:`benchmark.paths.DATA_ROOT`) — the inspectable, shareable narrative of every +experiment, each pinned to its commit SHA and repro command. + +Usage +----- + python -m benchmark.report +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any + +from . import paths +from .ledger import Ledger + +_METRIC_KEYS = [ + "n_patches", + "mean_distortion", + "worst_distortion", + "total_flipped", + "frac_patches_with_flips", + "mean_runtime_s", +] + + +def _fmt(v: Any) -> str: + if isinstance(v, float): + return f"{v:.3f}" + return str(v) + + +def render(records: list[dict[str, Any]]) -> str: + lines = ["# AutoFlatten autoresearch — lab notebook", ""] + lines.append( + "Auto-generated from `ledger/experiments.jsonl`. Each row is one experiment; " + "see the ledger for full provenance (env, seeds, per-subject metrics, diffs)." + ) + lines.append("") + + if not records: + lines.append("_No experiments logged yet._") + return "\n".join(lines) + "\n" + + # Summary table + header = ["id", "kind", "label", "commit"] + _METRIC_KEYS + lines.append("| " + " | ".join(header) + " |") + lines.append("| " + " | ".join("---" for _ in header) + " |") + for r in records: + commit = (r.get("git", {}).get("commit") or "")[:8] + dirty = "*" if r.get("git", {}).get("dirty") else "" + m = r.get("metrics", {}) + row = [ + r.get("experiment_id", ""), + r.get("kind", ""), + r.get("label", ""), + commit + dirty, + ] + [_fmt(m.get(k, "")) for k in _METRIC_KEYS] + lines.append("| " + " | ".join(row) + " |") + lines.append("") + + # Per-experiment detail + for r in records: + lines.append( + f"## {r.get('label', r.get('experiment_id'))} (`{r.get('experiment_id')}`)" + ) + lines.append("") + lines.append( + f"- **kind**: {r.get('kind')} • **timestamp**: {r.get('timestamp')} • **status**: {r.get('status')}" + ) + git = r.get("git", {}) + lines.append( + f"- **commit**: `{git.get('commit')}` ({git.get('branch')})" + + (" ⚠️ dirty tree" if git.get("dirty") else "") + ) + env = r.get("environment", {}) + lines.append( + f"- **env**: jax {env.get('jax_version')} [{env.get('jax_backend')}], " + f"python {env.get('python')}, host {env.get('hostname')}" + ) + if r.get("method"): + lines.append(f"- **method**: {r['method'].get('name')}") + if r.get("repro_command"): + lines.append(f"- **repro**: `{r['repro_command']}`") + decision = r.get("decision", {}) + if decision.get("hypothesis"): + lines.append(f"- **hypothesis**: {decision['hypothesis']}") + if decision.get("conclusion"): + lines.append(f"- **conclusion**: {decision['conclusion']}") + if decision.get("next_step"): + lines.append(f"- **next step**: {decision['next_step']}") + if "determinism_check" in decision: + lines.append( + f"- **determinism**: {decision['determinism_check'].get('deterministic')}" + ) + lines.append("") + return "\n".join(lines) + "\n" + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--ledger", type=Path, default=paths.LEDGER_PATH) + ap.add_argument("--out", type=Path, default=paths.NOTEBOOK_PATH) + args = ap.parse_args() + + records = Ledger(args.ledger).read() + args.out.parent.mkdir(parents=True, exist_ok=True) + args.out.write_text(render(records)) + print(f"Wrote {args.out} ({len(records)} experiments)") + + +if __name__ == "__main__": + main() diff --git a/benchmark/run_baseline.py b/benchmark/run_baseline.py new file mode 100644 index 0000000..72ab155 --- /dev/null +++ b/benchmark/run_baseline.py @@ -0,0 +1,151 @@ +"""Run the current-default ("FreeSurfer-clone") baseline and log it to the ledger. + +Establishes the reference quality + runtime the alternative methods must beat, and +(optionally) asserts the optimizer is deterministic so a single run per experiment is +defensible. + +Usage +----- + python -m benchmark.run_baseline --dev # 2-3 hemispheres, fast + python -m benchmark.run_baseline --split train # full train split + python -m benchmark.run_baseline --dev --check-determinism +""" + +from __future__ import annotations + +import argparse +import sys +import time + +from . import paths +from .harness import ( + evaluate, + evaluate_one, + load_manifest, + pyflatten_flatten_fn, + select_entries, +) +from .ledger import Ledger, file_hash, new_record + + +def _make_config(verbose: bool): + from autoflatten.flatten import FlattenConfig + + config = FlattenConfig() + config.verbose = verbose + return config + + +def check_determinism(entry, config) -> dict: + """Flatten one patch twice and compare the headline metrics.""" + r1 = evaluate_one(entry, config, pyflatten_flatten_fn) + r2 = evaluate_one(entry, config, pyflatten_flatten_fn) + same = ( + r1.get("status") == "ok" == r2.get("status") + and r1["n_flipped"] == r2["n_flipped"] + and abs(r1["mean_distortion"] - r2["mean_distortion"]) < 1e-6 + ) + return { + "deterministic": bool(same), + "run1": { + "mean_distortion": r1.get("mean_distortion"), + "n_flipped": r1.get("n_flipped"), + }, + "run2": { + "mean_distortion": r2.get("mean_distortion"), + "n_flipped": r2.get("n_flipped"), + }, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "--dev", action="store_true", help="run on a tiny subset (first 2 hemispheres)" + ) + ap.add_argument( + "--split", + default=None, + choices=["train", "holdout"], + help="restrict to a split", + ) + ap.add_argument("--subset", type=int, default=None, help="take first N hemispheres") + ap.add_argument( + "--save", action="store_true", help="also save flat patches to runs//" + ) + ap.add_argument("--check-determinism", action="store_true") + ap.add_argument("--verbose", action="store_true", help="print optimizer progress") + args = ap.parse_args() + + paths.ensure_output_dirs() + manifest = load_manifest() + subset = 2 if args.dev else args.subset + entries = select_entries(manifest, split=args.split, subset=subset) + if not entries: + print("No manifest entries selected.", file=sys.stderr) + return 1 + + config = _make_config(verbose=args.verbose) + record = new_record( + kind="baseline", + label="baseline:pyflatten-defaults", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": e["subject"], "hemi": e["hemi"], "split": e.get("split")} + for e in entries + ], + method={"name": "pyflatten", "config": config.to_dict()}, + seeds={"note": "deterministic CPU gradient descent; no RNG seed"}, + repro_command="python -m benchmark.run_baseline " + " ".join(sys.argv[1:]), + ) + + print( + f"Baseline on {len(entries)} hemispheres (experiment {record.experiment_id})..." + ) + t0 = time.time() + + if args.check_determinism: + det = check_determinism(entries[0], config) + record.decision["determinism_check"] = det + print(f" determinism: {det['deterministic']} ({det['run1']} vs {det['run2']})") + + save_dir = paths.RUNS_DIR / record.experiment_id if args.save else None + result = evaluate(entries, config, method="pyflatten", save_dir=save_dir) + + record.per_subject = result["per_subject"] + record.metrics = result["aggregate"] + record.runtime_s = time.time() - t0 + if save_dir is not None: + record.artifacts = [ + { + "path": r["artifact"], + "hash": file_hash(r["artifact"]), + "kind": "flat_patch", + } + for r in result["per_subject"] + if r.get("artifact") + ] + n_failed = result["aggregate"].get("n_failed", 0) + record.status = "ok" if n_failed == 0 else "partial" + + Ledger().append(record) + + agg = result["aggregate"] + print("\n=== Baseline aggregate ===") + for k in ( + "n_patches", + "n_failed", + "mean_distortion", + "worst_distortion", + "total_flipped", + "frac_patches_with_flips", + "mean_runtime_s", + ): + if k in agg: + print(f" {k}: {agg[k]}") + print(f"\nLogged to {Ledger().path} (experiment {record.experiment_id})") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/tests/__init__.py b/benchmark/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmark/tests/test_harness.py b/benchmark/tests/test_harness.py new file mode 100644 index 0000000..a1177b9 --- /dev/null +++ b/benchmark/tests/test_harness.py @@ -0,0 +1,138 @@ +"""Fast, JAX-free unit tests for the benchmark harness plumbing. + +These cover the pure-Python logic (ledger I/O, aggregation, manifest selection, metric +helpers) without running the flattening optimizer, so they're quick and deterministic. +The end-to-end flatten path is verified separately via ``run_baseline``. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from benchmark import metrics +from benchmark.harness import select_entries +from benchmark.ledger import ExperimentRecord, Ledger + + +# --- ledger round-trip ------------------------------------------------------------ +def _make_record(eid: str, kind: str = "baseline") -> ExperimentRecord: + return ExperimentRecord( + experiment_id=eid, + timestamp="2026-01-01T00:00:00+00:00", + kind=kind, + label=f"test:{eid}", + git={"commit": "abc123", "dirty": False}, + environment={"jax_backend": "cpu"}, + metrics={"mean_distortion": 12.5, "n_patches": 3}, + ) + + +def test_ledger_append_and_read_roundtrip(tmp_path): + ledger = Ledger(tmp_path / "experiments.jsonl") + assert ledger.read() == [] # absent ledger reads empty + + ledger.append(_make_record("aaa")) + ledger.append(_make_record("bbb", kind="probe")) + + records = ledger.read() + assert [r["experiment_id"] for r in records] == ["aaa", "bbb"] + assert records[0]["metrics"]["mean_distortion"] == 12.5 + assert ledger.latest()["experiment_id"] == "bbb" + assert ledger.latest(kind="baseline")["experiment_id"] == "aaa" + + +def test_ledger_is_append_only(tmp_path): + ledger = Ledger(tmp_path / "experiments.jsonl") + ledger.append(_make_record("first")) + ledger.append(_make_record("second")) + # Second append must not overwrite the first. + assert len(ledger.read()) == 2 + + +# --- aggregation ------------------------------------------------------------------ +def test_aggregate_basic_stats(): + per_subject = [ + { + "status": "ok", + "mean_distortion": 10.0, + "p90_distortion": 20.0, + "n_flipped": 0, + "runtime_s": 5.0, + }, + { + "status": "ok", + "mean_distortion": 20.0, + "p90_distortion": 40.0, + "n_flipped": 3, + "runtime_s": 7.0, + }, + ] + agg = metrics.aggregate(per_subject) + assert agg["n_patches"] == 2 + assert agg["mean_distortion"] == pytest.approx(15.0) + assert agg["worst_distortion"] == pytest.approx(20.0) + assert agg["total_flipped"] == 3 + assert agg["frac_patches_with_flips"] == pytest.approx(0.5) + assert agg["mean_runtime_s"] == pytest.approx(6.0) + + +def test_aggregate_excludes_errors(): + per_subject = [ + { + "status": "ok", + "mean_distortion": 10.0, + "p90_distortion": 10.0, + "n_flipped": 0, + "runtime_s": 1.0, + }, + {"status": "error", "error": "boom"}, + ] + agg = metrics.aggregate(per_subject) + assert agg["n_patches"] == 1 + assert agg["n_failed"] == 1 + + +def test_aggregate_all_failed(): + agg = metrics.aggregate([{"status": "error"}, {"status": "error"}]) + assert agg["n_patches"] == 0 + assert agg["n_failed"] == 2 + + +# --- manifest selection ----------------------------------------------------------- +def _manifest(): + return { + "entries": [ + {"subject": "sub-001", "hemi": "lh", "split": "train"}, + {"subject": "sub-001", "hemi": "rh", "split": "train"}, + {"subject": "sub-002", "hemi": "lh", "split": "holdout"}, + ] + } + + +def test_select_entries_split_and_subset(): + m = _manifest() + assert len(select_entries(m)) == 3 + assert len(select_entries(m, split="train")) == 2 + assert len(select_entries(m, split="holdout")) == 1 + assert len(select_entries(m, subset=1)) == 1 + assert select_entries(m, split="train", subset=1)[0]["hemi"] == "lh" + + +# --- metric helpers (numpy-only) -------------------------------------------------- +def test_area_distortion_unit_square(): + # Two triangles tiling a unit square: total 2D area = 1.0. + uv = np.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) + faces = np.array([[0, 1, 2], [0, 2, 3]]) + # Matching 3D area -> zero distortion; doubled 3D area -> 0.5. + assert metrics._area_distortion(uv, faces, orig_area=1.0) == pytest.approx(0.0) + assert metrics._area_distortion(uv, faces, orig_area=2.0) == pytest.approx(0.5) + + +def test_per_vertex_p90_zero_when_isometric(): + # A perfectly isometric embedding: 2D distances equal targets -> 0 error. + uv = np.array([[0, 0], [1, 0], [0, 1]], dtype=float) + neighbors = np.array([[1, 2], [0, 2], [0, 1]]) + targets = np.array([[1.0, 1.0], [1.0, np.sqrt(2)], [1.0, np.sqrt(2)]]) + mask = np.ones_like(targets, dtype=bool) + assert metrics._per_vertex_p90(uv, neighbors, targets, mask) == pytest.approx(0.0) diff --git a/pyproject.toml b/pyproject.toml index d8082d4..4822226 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,11 @@ test = [ docs = [ "mkdocs-material", ] +bench = [ + # Benchmark / autoresearch harness (see benchmark/PLAN.md). Not packaged. + "optuna>=3.6", # optional HPO driver + "pandas>=2.0", # report tables +] [project.scripts] autoflatten = "autoflatten.cli:main" From 1b10aed8439bf4310f2985d429a64718461ab1cd Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 9 Jun 2026 18:57:54 -0700 Subject: [PATCH 02/35] Add Tutte/LSCM flip-free init probe (Phase B primary experiment) Injects a guaranteed-injective Tutte embedding (libigl harmonic map, boundary pinned to a circle) as the initial map and disables the initial negative-area- removal phase, feeding the existing geodesic-stress refinement. Verified flip-free: Tutte init gives 0/385122 flipped triangles at 33.7% distortion, vs the FreeSurfer projection's 141690 flips at 172%. Logs a probe record with hypothesis + head-to-head conclusion vs the latest baseline. Co-Authored-By: Claude Fable 5 --- benchmark/probe_tutte_init.py | 250 ++++++++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 benchmark/probe_tutte_init.py diff --git a/benchmark/probe_tutte_init.py b/benchmark/probe_tutte_init.py new file mode 100644 index 0000000..ad37cd8 --- /dev/null +++ b/benchmark/probe_tutte_init.py @@ -0,0 +1,250 @@ +"""Probe: flip-free (Tutte/LSCM) initialization instead of FreeSurfer projection + NAR. + +The current pipeline starts from a normal-axis projection that has many flipped +triangles, then spends its first ~4 minutes in "negative area removal" (NAR) un-flipping +them. A **Tutte embedding** (harmonic map with the boundary pinned to a convex circle) is +*guaranteed flip-free* for disk topology — so we can drop the initial-NAR phase entirely +and feed the flip-free map straight into the existing geodesic-stress refinement. + +Hypothesis: equal-or-lower distance distortion, zero flips at init, and less runtime. + +This is implemented by injecting a different initial map into the existing +``SurfaceFlattener`` (overriding ``initial_projection`` and disabling the initial NAR), so +the refinement (epochs + final NAR + spring) is shared with the baseline and the +comparison is apples-to-apples. + +Usage +----- + python -m benchmark.probe_tutte_init --dev # tutte init + refine + python -m benchmark.probe_tutte_init --dev --method lscm + python -m benchmark.probe_tutte_init --dev --no-refine # init quality only +""" + +from __future__ import annotations + +import argparse +import sys +import time + +import numpy as np + +from . import paths +from .harness import evaluate, load_manifest, register_flatten_fn, select_entries +from .ledger import Ledger, file_hash, new_record + + +# --------------------------------------------------------------------------------- +# Flip-free initialization +# --------------------------------------------------------------------------------- +def flipfree_init( + vertices: np.ndarray, faces: np.ndarray, method: str = "tutte" +) -> np.ndarray: + """Compute a flip-free 2D embedding of a disk-topology patch. + + Parameters + ---------- + vertices : (V, 3) float + 3D patch vertices (use the smoothed/fiducial surface for intrinsic weights). + faces : (F, 3) int + Triangles (single boundary loop / disk topology). + method : {"tutte", "lscm"} + ``tutte`` — harmonic map with the boundary pinned to a circle. Guaranteed + injective (Tutte's theorem) → no flipped triangles. + ``lscm`` — least-squares conformal map (2 pinned boundary vertices). Lower angle + distortion but *not* guaranteed flip-free. + + Returns + ------- + (V, 2) float + 2D coordinates (unscaled). + """ + import igl + + v = np.ascontiguousarray(vertices, dtype=np.float64) + f = np.ascontiguousarray(faces, dtype=np.int64) + bnd = igl.boundary_loop(f) + if bnd is None or len(bnd) == 0: + raise ValueError("No boundary loop found; patch is not a disk.") + + if method == "tutte": + bc = igl.map_vertices_to_circle(v, bnd.astype(np.int32)) + uv = igl.harmonic(v, f, bnd.astype(np.int64), np.ascontiguousarray(bc), 1) + elif method == "lscm": + # Pin the two most distant boundary vertices to (0,0) and (1,0). + b = np.array([bnd[0], bnd[len(bnd) // 2]], dtype=np.int64) + bc = np.array([[0.0, 0.0], [1.0, 0.0]], dtype=np.float64) + uv, _ = igl.lscm(v, f, b, bc) + else: + raise ValueError(f"Unknown init method: {method!r}") + return np.asarray(uv, dtype=np.float64) + + +def scale_to_area(uv: np.ndarray, faces: np.ndarray, target_area: float) -> np.ndarray: + """Uniformly scale ``uv`` so its total (unsigned) 2D area matches ``target_area``. + + Tutte/LSCM map to ~unit scale, while the geodesic targets are in mm; matching the 3D + patch area gives the refinement a sensible starting scale. + """ + v0, v1, v2 = uv[faces[:, 0]], uv[faces[:, 1]], uv[faces[:, 2]] + area2d = float( + np.abs( + 0.5 + * ( + (v1[:, 0] - v0[:, 0]) * (v2[:, 1] - v0[:, 1]) + - (v2[:, 0] - v0[:, 0]) * (v1[:, 1] - v0[:, 1]) + ) + ).sum() + ) + if area2d <= 0 or target_area <= 0: + return uv + s = np.sqrt(target_area / area2d) + centroid = uv.mean(axis=0) + return (uv - centroid) * s + centroid + + +# --------------------------------------------------------------------------------- +# flatten_fn factory +# --------------------------------------------------------------------------------- +def make_flatten_fn(method: str = "tutte", refine: bool = True): + """Build a ``flatten_fn`` that initializes flip-free and (optionally) refines. + + When ``refine`` is True the flip-free map is injected into the existing optimizer + with the **initial NAR phase disabled** (its purpose is moot for a flip-free start); + the rest of the refinement (epochs, final NAR, spring) is unchanged. When False, the + scaled init is returned directly to measure init-only quality. + """ + + def _fn(flattener): + init = flipfree_init(flattener.vertices, flattener.faces, method=method) + init = scale_to_area(init, np.asarray(flattener.faces), flattener.orig_area) + if not refine: + return init + # Skip the initial negative-area-removal phase: a flip-free start makes it moot. + flattener.config.negative_area_removal.enabled = False + flattener.initial_projection = lambda: init # injected into run() + return flattener.run() + + return _fn + + +# --------------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------------- +def _baseline_reference(): + """Most recent baseline record's aggregate metrics, for head-to-head printing.""" + rec = Ledger().latest(kind="baseline") + return rec["metrics"] if rec else None + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--dev", action="store_true", help="first 2 hemispheres") + ap.add_argument("--split", default=None, choices=["train", "holdout"]) + ap.add_argument("--subset", type=int, default=None) + ap.add_argument("--method", default="tutte", choices=["tutte", "lscm"]) + ap.add_argument( + "--no-refine", action="store_true", help="measure init-only quality" + ) + ap.add_argument("--save", action="store_true") + args = ap.parse_args() + + from autoflatten.flatten import FlattenConfig + + refine = not args.no_refine + method_name = f"{args.method}_init" + ("" if refine else "_only") + register_flatten_fn(method_name, make_flatten_fn(args.method, refine=refine)) + + paths.ensure_output_dirs() + manifest = load_manifest() + subset = 2 if args.dev else args.subset + entries = select_entries(manifest, split=args.split, subset=subset) + if not entries: + print("No manifest entries selected.", file=sys.stderr) + return 1 + + config = FlattenConfig() + config.verbose = False + if refine: + config.negative_area_removal.enabled = False # the whole point + + record = new_record( + kind="probe", + label=f"probe:{method_name}", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": e["subject"], "hemi": e["hemi"], "split": e.get("split")} + for e in entries + ], + method={"name": method_name, "config": config.to_dict(), "refine": refine}, + seeds={"note": "deterministic; flip-free init + CPU gradient descent"}, + repro_command="python -m benchmark.probe_tutte_init " + " ".join(sys.argv[1:]), + ) + record.decision = { + "hypothesis": ( + f"A flip-free {args.method} init removes the ~4-min initial NAR phase and " + "reaches equal-or-lower distance distortion with zero flips and less runtime." + ), + "rationale": ( + "Tutte's theorem guarantees an injective (flip-free) embedding for disk " + "topology when the boundary is mapped to a convex polygon, so the initial " + "negative-area-removal phase becomes unnecessary." + ), + } + + print( + f"Probe '{method_name}' on {len(entries)} hemispheres ({record.experiment_id})..." + ) + t0 = time.time() + save_dir = paths.RUNS_DIR / record.experiment_id if args.save else None + result = evaluate(entries, config, method=method_name, save_dir=save_dir) + record.per_subject = result["per_subject"] + record.metrics = result["aggregate"] + record.runtime_s = time.time() - t0 + if save_dir is not None: + record.artifacts = [ + { + "path": r["artifact"], + "hash": file_hash(r["artifact"]), + "kind": "flat_patch", + } + for r in result["per_subject"] + if r.get("artifact") + ] + n_failed = result["aggregate"].get("n_failed", 0) + record.status = "ok" if n_failed == 0 else "partial" + + # Head-to-head conclusion vs the latest baseline. + agg = result["aggregate"] + base = _baseline_reference() + if base and "mean_distortion" in agg: + record.decision["conclusion"] = ( + f"{method_name}: mean_distortion {agg['mean_distortion']:.2f} vs baseline " + f"{base.get('mean_distortion'):.2f}; total_flipped {agg['total_flipped']} vs " + f"{base.get('total_flipped')}; mean_runtime {agg.get('mean_runtime_s', 0):.0f}s " + f"vs {base.get('mean_runtime_s', 0):.0f}s." + ) + + Ledger().append(record) + + print("\n=== Probe aggregate ===") + for k in ( + "n_patches", + "n_failed", + "mean_distortion", + "worst_distortion", + "total_flipped", + "frac_patches_with_flips", + "mean_runtime_s", + ): + if k in agg: + print(f" {k}: {agg[k]}") + if base: + print("\n=== vs baseline ===") + for k in ("mean_distortion", "total_flipped", "mean_runtime_s"): + print(f" {k}: probe={agg.get(k)} baseline={base.get(k)}") + print(f"\nLogged to {Ledger().path} (experiment {record.experiment_id})") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From b9da0b7c40a54da8f3d06df0de4ee4b20e946cee Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 9 Jun 2026 18:58:55 -0700 Subject: [PATCH 03/35] Add unit tests for Tutte flip-free init and area scaling Co-Authored-By: Claude Fable 5 --- benchmark/tests/test_harness.py | 37 +++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/benchmark/tests/test_harness.py b/benchmark/tests/test_harness.py index a1177b9..ce9af60 100644 --- a/benchmark/tests/test_harness.py +++ b/benchmark/tests/test_harness.py @@ -136,3 +136,40 @@ def test_per_vertex_p90_zero_when_isometric(): targets = np.array([[1.0, 1.0], [1.0, np.sqrt(2)], [1.0, np.sqrt(2)]]) mask = np.ones_like(targets, dtype=bool) assert metrics._per_vertex_p90(uv, neighbors, targets, mask) == pytest.approx(0.0) + + +# --- flip-free init probe --------------------------------------------------------- +def _disk_mesh(n: int = 12): + """A flat triangle-fan disk: center vertex + ``n`` boundary vertices on a circle.""" + ang = np.linspace(0, 2 * np.pi, n, endpoint=False) + rim = np.column_stack([np.cos(ang), np.sin(ang), np.zeros(n)]) + vertices = np.vstack([[0.0, 0.0, 0.0], rim]) + faces = np.array([[0, 1 + i, 1 + (i + 1) % n] for i in range(n)], dtype=np.int64) + return vertices, faces + + +def _signed_areas(uv, faces): + v0, v1, v2 = uv[faces[:, 0]], uv[faces[:, 1]], uv[faces[:, 2]] + return 0.5 * ( + (v1[:, 0] - v0[:, 0]) * (v2[:, 1] - v0[:, 1]) + - (v2[:, 0] - v0[:, 0]) * (v1[:, 1] - v0[:, 1]) + ) + + +def test_tutte_init_is_flip_free(): + from benchmark.probe_tutte_init import flipfree_init + + vertices, faces = _disk_mesh() + uv = flipfree_init(vertices, faces, method="tutte") + areas = _signed_areas(uv, faces) + # All triangles share one orientation -> zero flips (Tutte guarantee). + assert np.all(areas > 0) or np.all(areas < 0) + + +def test_scale_to_area_matches_target(): + from benchmark.probe_tutte_init import scale_to_area + + uv = np.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=float) # area 1 + faces = np.array([[0, 1, 2], [0, 2, 3]]) + scaled = scale_to_area(uv, faces, target_area=9.0) + assert np.abs(_signed_areas(scaled, faces)).sum() == pytest.approx(9.0) From 9ec06cb7c3e3806f3afd160c0b154080b614ee54 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 9 Jun 2026 19:45:39 -0700 Subject: [PATCH 04/35] Add fast flatmap plotter for visual verification of experiments benchmark/plot.py renders the saved flat patches of a ledger experiment (mesh + flipped triangles in red + boundary), pulling metrics from the ledger instead of recomputing distortion (plot_flatmap's per-vertex distortion recompute is the bottleneck: ~5s vs ~20min on a 193k-vertex hemisphere). --full keeps the 3-panel plot. Visual check on sub-022 lh: baseline and Tutte-init maps are both clean single-blob flatmaps with no tangling. Co-Authored-By: Claude Fable 5 --- benchmark/plot.py | 198 ++++++++++++++++++++++++++++++++ benchmark/tests/test_harness.py | 19 +++ 2 files changed, 217 insertions(+) create mode 100644 benchmark/plot.py diff --git a/benchmark/plot.py b/benchmark/plot.py new file mode 100644 index 0000000..e0464b3 --- /dev/null +++ b/benchmark/plot.py @@ -0,0 +1,198 @@ +"""Render flatmap images for a logged experiment — visual verification. + +Metrics can hide a tangled map, so every saved run can be eyeballed. Given an experiment +id, this finds the flat patches that run saved (from its ledger record), looks up each +patch's base surface from the manifest, and renders the flat mesh with flipped triangles +highlighted in red. + +By default this uses a **fast** single-panel renderer (mesh + flips + boundary), pulling +the distortion/flip numbers from the ledger record so nothing is recomputed. ``--full`` +uses the package's slower three-panel :func:`autoflatten.viz.plot_flatmap` (which +recomputes per-vertex distortion — minutes on a full hemisphere). + +Usage +----- + python -m benchmark.plot + python -m benchmark.plot --latest probe + python -m benchmark.plot --full +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Any, Optional + +import numpy as np + +from . import paths +from .harness import load_manifest +from .ledger import Ledger + + +def _surface_lookup(manifest: dict) -> dict[tuple[str, str], str]: + return {(e["subject"], e["hemi"]): e["surface_path"] for e in manifest["entries"]} + + +def _parse_subject_hemi(flat_path: str) -> tuple[str, str]: + """``.../sub-022.lh.flat.patch.3d`` -> ("sub-022", "lh").""" + parts = Path(flat_path).name.split(".") + return parts[0], parts[1] + + +def fast_flatmap( + flat_path: str, + surface_path: str, + out_path: str, + title: str, + subtitle: str = "", +) -> str: + """Fast single-panel flatmap: mesh in gray, flipped triangles in red, boundary blue. + + Skips the expensive per-vertex distortion recompute (the bottleneck in + :func:`autoflatten.viz.plot_flatmap`); the numbers come from the ledger instead. + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import matplotlib.tri as mtri + + from autoflatten.freesurfer import extract_patch_faces, read_patch, read_surface + + flat_vertices, orig_indices, is_border = read_patch(flat_path) + _, base_faces = read_surface(surface_path) + faces = extract_patch_faces(base_faces, orig_indices) + xy = flat_vertices[:, :2] + + v0, v1, v2 = xy[faces[:, 0]], xy[faces[:, 1]], xy[faces[:, 2]] + areas = 0.5 * ( + (v1[:, 0] - v0[:, 0]) * (v2[:, 1] - v0[:, 1]) + - (v2[:, 0] - v0[:, 0]) * (v1[:, 1] - v0[:, 1]) + ) + flipped = areas < 0 + n_flipped = int(flipped.sum()) + + fig, ax = plt.subplots(figsize=(7, 7), constrained_layout=True) + triang = mtri.Triangulation(xy[:, 0], xy[:, 1], faces) + # Gray fill for all faces, then overlay flipped faces in red. + ax.tripcolor( + triang, + facecolors=np.where(flipped, 1.0, 0.0), + cmap="Greys" if n_flipped == 0 else "Reds", + vmin=0, + vmax=1, + edgecolors="none", + ) + ax.triplot(triang, color="0.6", linewidth=0.1) + if np.sum(is_border) > 0: + ax.scatter(xy[is_border, 0], xy[is_border, 1], s=1.5, c="tab:blue", zorder=5) + ax.set_aspect("equal") + ax.axis("off") + full_title = title + (f"\n{subtitle}" if subtitle else "") + ax.set_title(full_title, fontsize=10) + fig.savefig(out_path, dpi=150) + plt.close(fig) + return out_path + + +def plot_experiment( + experiment_id: str, + out_root: Optional[Path] = None, + full: bool = False, +) -> list[str]: + """Render flatmaps for every flat patch an experiment saved. Returns the PNG paths.""" + records = Ledger().read() + record = next((r for r in records if r["experiment_id"] == experiment_id), None) + if record is None: + raise SystemExit(f"No ledger record for experiment {experiment_id!r}") + artifacts = [ + a for a in record.get("artifacts", []) if a.get("kind") == "flat_patch" + ] + if not artifacts: + raise SystemExit( + f"Experiment {experiment_id} has no saved flat patches " + "(re-run with --save to render flatmaps)." + ) + + surfaces = _surface_lookup(load_manifest()) + per_subject = {(r["subject"], r["hemi"]): r for r in record.get("per_subject", [])} + out_dir = (out_root or paths.DATA_ROOT / "figures") / experiment_id + out_dir.mkdir(parents=True, exist_ok=True) + + written = [] + for art in artifacts: + flat = art["path"] + subject, hemi = _parse_subject_hemi(flat) + surface = surfaces.get((subject, hemi)) + if surface is None: + print(f" ! no surface in manifest for {subject} {hemi}; skipping") + continue + out = out_dir / f"{subject}.{hemi}.png" + title = f"{subject} {hemi} — {record['label']}" + if full: + _full_plot(flat, surface, str(out), title) + else: + ps: dict[str, Any] = per_subject.get((subject, hemi), {}) + sub = _subtitle(ps) + fast_flatmap(flat, surface, str(out), title, subtitle=sub) + print(f" wrote {out}") + written.append(str(out)) + return written + + +def _subtitle(ps: dict[str, Any]) -> str: + if not ps: + return "" + bits = [] + if "mean_distortion" in ps: + bits.append(f"{ps['mean_distortion']:.2f}% dist") + if "n_flipped" in ps: + bits.append(f"{ps['n_flipped']} flipped") + if "runtime_s" in ps: + bits.append(f"{ps['runtime_s']:.0f}s") + return " | ".join(bits) + + +def _full_plot(flat: str, surface: str, out: str, title: str) -> None: + import matplotlib + + matplotlib.use("Agg") + from autoflatten.viz import plot_flatmap + + plot_flatmap(flat, base_surface_path=surface, output_path=out, title=title) + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("experiment_id", nargs="?", help="experiment id to plot") + ap.add_argument( + "--latest", metavar="KIND", help="plot the latest record of this kind" + ) + ap.add_argument( + "--full", + action="store_true", + help="use the slow 3-panel plot_flatmap (recomputes distortion)", + ) + args = ap.parse_args() + + if args.latest: + rec = Ledger().latest(kind=args.latest) + if rec is None: + print(f"No '{args.latest}' experiments in the ledger.", file=sys.stderr) + return 1 + eid = rec["experiment_id"] + elif args.experiment_id: + eid = args.experiment_id + else: + ap.error("provide an experiment_id or --latest KIND") + + print(f"Plotting experiment {eid}...") + written = plot_experiment(eid, full=args.full) + print(f"Wrote {len(written)} figure(s) to {paths.DATA_ROOT / 'figures' / eid}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/tests/test_harness.py b/benchmark/tests/test_harness.py index ce9af60..f1398a6 100644 --- a/benchmark/tests/test_harness.py +++ b/benchmark/tests/test_harness.py @@ -173,3 +173,22 @@ def test_scale_to_area_matches_target(): faces = np.array([[0, 1, 2], [0, 2, 3]]) scaled = scale_to_area(uv, faces, target_area=9.0) assert np.abs(_signed_areas(scaled, faces)).sum() == pytest.approx(9.0) + + +# --- plot helpers ----------------------------------------------------------------- +def test_parse_subject_hemi(): + from benchmark.plot import _parse_subject_hemi + + assert _parse_subject_hemi("/runs/abc/sub-022.lh.flat.patch.3d") == ( + "sub-022", + "lh", + ) + assert _parse_subject_hemi("sub-005.rh.flat.patch.3d") == ("sub-005", "rh") + + +def test_subtitle_formats_known_fields(): + from benchmark.plot import _subtitle + + assert _subtitle({}) == "" + s = _subtitle({"mean_distortion": 15.25, "n_flipped": 24, "runtime_s": 454.0}) + assert "15.25% dist" in s and "24 flipped" in s and "454s" in s From 9d5982ccfe121da3b10b9f99316b234414d8c13c Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 9 Jun 2026 19:49:58 -0700 Subject: [PATCH 05/35] Add general autoresearch experiment runner Parametrized runner (init method + phase toggles + k-ring) that logs each variant to the ledger with a decision trace, for fanning out optimization experiments. Co-Authored-By: Claude Fable 5 --- benchmark/experiment.py | 179 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 benchmark/experiment.py diff --git a/benchmark/experiment.py b/benchmark/experiment.py new file mode 100644 index 0000000..07e0df4 --- /dev/null +++ b/benchmark/experiment.py @@ -0,0 +1,179 @@ +"""General autoresearch experiment runner. + +Generalizes :mod:`benchmark.run_baseline` and :mod:`benchmark.probe_tutte_init` into one +parametrized runner so optimization ideas can be fanned out and every one is logged to the +provenance ledger with a decision trace. + +Knobs: +- ``--init {projection,tutte,lscm}`` — initial map (projection = current FreeSurfer clone). +- ``--skip-initial-nar`` / ``--skip-final-nar`` / ``--skip-spring`` — drop refinement phases. +- ``--skip-epoch {epoch_1,epoch_2,epoch_3}`` (repeatable) — drop a metric epoch. +- ``--k-ring N`` / ``--n-neighbors N`` — geodesic neighborhood (changes the cache key). + +Each run logs a ``kind="experiment"`` record with the full config, metrics, per-subject +results, and a decision trace (hypothesis + head-to-head conclusion vs the baseline). + +Usage +----- + python -m benchmark.experiment --init tutte --skip-final-nar --subset 1 \\ + --label tutte+nofinalnar --hypothesis "flip-free start makes final NAR removable" +""" + +from __future__ import annotations + +import argparse +import sys +import time + +from . import paths +from .harness import ( + evaluate, + load_manifest, + register_flatten_fn, + select_entries, +) +from .ledger import Ledger, file_hash, new_record +from .probe_tutte_init import make_flatten_fn + + +def build_config(args): + from autoflatten.flatten import FlattenConfig + + cfg = FlattenConfig() + cfg.verbose = False + cfg.kring.k_ring = args.k_ring + cfg.kring.n_neighbors_per_ring = args.n_neighbors + + # A flip-free init makes the *initial* NAR moot; allow forcing it off for projection too. + if args.init in ("tutte", "lscm") or args.skip_initial_nar: + cfg.negative_area_removal.enabled = False + if args.skip_final_nar: + cfg.final_negative_area_removal.enabled = False + if args.skip_spring: + cfg.spring_smoothing.enabled = False + for phase in cfg.phases: + if phase.name in (args.skip_epoch or []): + phase.enabled = False + return cfg + + +def resolve_method(args): + """Return (method_name, registered) for the chosen init.""" + if args.init == "projection": + return "pyflatten", "pyflatten" + method_name = args.label or f"{args.init}_init" + register_flatten_fn(method_name, make_flatten_fn(args.init, refine=True)) + return method_name, method_name + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--init", default="tutte", choices=["projection", "tutte", "lscm"]) + ap.add_argument("--skip-initial-nar", action="store_true") + ap.add_argument("--skip-final-nar", action="store_true") + ap.add_argument("--skip-spring", action="store_true") + ap.add_argument( + "--skip-epoch", action="append", choices=["epoch_1", "epoch_2", "epoch_3"] + ) + ap.add_argument("--k-ring", type=int, default=7) + ap.add_argument("--n-neighbors", type=int, default=12) + ap.add_argument("--dev", action="store_true") + ap.add_argument("--split", default=None, choices=["train", "holdout"]) + ap.add_argument("--subset", type=int, default=None) + ap.add_argument("--save", action="store_true") + ap.add_argument("--label", default=None, help="short experiment label") + ap.add_argument("--hypothesis", default="", help="what you expect and why") + args = ap.parse_args() + + cfg = build_config(args) + method_name, registered = resolve_method(args) + label = args.label or method_name + + paths.ensure_output_dirs() + manifest = load_manifest() + subset = 2 if args.dev else args.subset + entries = select_entries(manifest, split=args.split, subset=subset) + if not entries: + print("No manifest entries selected.", file=sys.stderr) + return 1 + + # Capture the toggles in the method spec for provenance. + toggles = { + "init": args.init, + "skip_initial_nar": args.skip_initial_nar or args.init in ("tutte", "lscm"), + "skip_final_nar": args.skip_final_nar, + "skip_spring": args.skip_spring, + "skip_epoch": args.skip_epoch or [], + "k_ring": args.k_ring, + "n_neighbors": args.n_neighbors, + } + record = new_record( + kind="experiment", + label=f"exp:{label}", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": e["subject"], "hemi": e["hemi"], "split": e.get("split")} + for e in entries + ], + method={"name": method_name, "toggles": toggles, "config": cfg.to_dict()}, + seeds={"note": "deterministic CPU gradient descent"}, + repro_command="python -m benchmark.experiment " + " ".join(sys.argv[1:]), + ) + if args.hypothesis: + record.decision["hypothesis"] = args.hypothesis + + print( + f"Experiment '{label}' ({method_name}) on {len(entries)} hemispheres ({record.experiment_id})..." + ) + t0 = time.time() + save_dir = paths.RUNS_DIR / record.experiment_id if args.save else None + result = evaluate(entries, cfg, method=registered, save_dir=save_dir) + record.per_subject = result["per_subject"] + record.metrics = result["aggregate"] + record.runtime_s = time.time() - t0 + if save_dir is not None: + record.artifacts = [ + { + "path": r["artifact"], + "hash": file_hash(r["artifact"]), + "kind": "flat_patch", + } + for r in result["per_subject"] + if r.get("artifact") + ] + n_failed = result["aggregate"].get("n_failed", 0) + record.status = "ok" if n_failed == 0 else "partial" + + agg = result["aggregate"] + base = Ledger().latest(kind="baseline") + base_m = base["metrics"] if base else None + if base_m and "mean_distortion" in agg: + record.decision["conclusion"] = ( + f"{label}: dist {agg['mean_distortion']:.2f} vs base {base_m.get('mean_distortion'):.2f}; " + f"flips {agg['total_flipped']} vs {base_m.get('total_flipped')}; " + f"runtime {agg.get('mean_runtime_s', 0):.0f}s vs {base_m.get('mean_runtime_s', 0):.0f}s." + ) + Ledger().append(record) + + print("\n=== aggregate ===") + for k in ( + "n_patches", + "n_failed", + "mean_distortion", + "worst_distortion", + "total_flipped", + "frac_patches_with_flips", + "mean_runtime_s", + ): + if k in agg: + print(f" {k}: {agg[k]}") + if base_m: + print("\n=== vs baseline ===") + for k in ("mean_distortion", "total_flipped", "mean_runtime_s"): + print(f" {k}: exp={agg.get(k)} baseline={base_m.get(k)}") + print(f"\nLogged {record.experiment_id} -> {Ledger().path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 0c9428c3f9670ec63fd4bff400eff18e1654c40f Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 9 Jun 2026 20:00:30 -0700 Subject: [PATCH 06/35] Document experiment.py runner and plot.py in benchmark README Co-Authored-By: Claude Fable 5 --- benchmark/README.md | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/benchmark/README.md b/benchmark/README.md index ba5f921..762567d 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -56,7 +56,27 @@ python -m benchmark.report one `flatten_fn`; alternatives register via `register_flatten_fn`. - [`build_dataset.py`](build_dataset.py), [`run_baseline.py`](run_baseline.py), [`report.py`](report.py) — the Phase-A commands above. -- `probe_tutte_init.py`, `optimize.py` — Phase B/C (alt-method probe, optional HPO). +- [`probe_tutte_init.py`](probe_tutte_init.py) — the flip-free (Tutte/LSCM) init probe + (Phase B primary experiment): injects a guaranteed-injective embedding and disables the + initial NAR, reusing the existing refinement. +- [`experiment.py`](experiment.py) — **general autoresearch runner**. Parametrizes the init + method (`--init projection|tutte|lscm`) and refinement toggles + (`--skip-initial-nar`/`--skip-final-nar`/`--skip-spring`/`--skip-epoch`, `--k-ring`), + logging each variant to the ledger with a decision trace. This is how optimization ideas + are fanned out. +- [`plot.py`](plot.py) — fast flatmap renderer for visual verification of any logged + experiment (`python -m benchmark.plot `; `--full` for the slow 3-panel). +- `optimize.py` — optional Optuna HPO (Phase C, not yet built). + +## Running an optimization experiment + +```bash +# e.g. test whether a flip-free start makes the final NAR phase removable: +python -m benchmark.experiment --init tutte --skip-final-nar --subset 1 --save \ + --label tutte_nofinalnar --hypothesis "flip-free start keeps the map near-injective" +python -m benchmark.plot --latest experiment # eyeball the result +python -m benchmark.report # refresh NOTEBOOK.md +``` ## Full public replication path From 03d5cd704990e12e612cd5d83d57ad8e384778f3 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Tue, 9 Jun 2026 20:47:00 -0700 Subject: [PATCH 07/35] Add curated findings: Tutte init verified (n=4), phase-removal ablations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tutte flip-free init verified across 4 hemispheres / 2 subjects: equal quality (14.75% vs 14.73% distortion, 121 vs 121 flips) at 37% less runtime. Ablations show the final NAR + spring smoothing are essential (skipping them lowers the distortion number to 13.2% but leaves ~21k flipped triangles — a folded map); only the initial NAR is removable. LSCM init ~ Tutte. Co-Authored-By: Claude Fable 5 --- benchmark/FINDINGS.md | 56 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 benchmark/FINDINGS.md diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md new file mode 100644 index 0000000..17f7ccf --- /dev/null +++ b/benchmark/FINDINGS.md @@ -0,0 +1,56 @@ +# Findings + +Curated conclusions from the autoresearch loop. The full, append-only record (every run with +provenance) is the ledger at `/data2/projects/autoflatten/ledger/experiments.jsonl`, rendered to +`NOTEBOOK.md`. All numbers below are CPU-only on the public Narratives benchmark. + +## 1. Flip-free (Tutte) init is a validated win: equal quality, ~37% faster + +Replacing the FreeSurfer-style normal-projection + **initial** negative-area-removal (NAR) with a +**Tutte embedding** (libigl harmonic map, boundary pinned to a circle — injective by Tutte's +theorem) and feeding it into the existing refinement. The Tutte map starts at **0 flipped +triangles** (vs ~141k for the projection), so the ~4-min initial NAR phase is unnecessary. + +Verified across **4 hemispheres / 2 subjects** (`sub-022`, `sub-026`): + +| | mean distortion | total flips | mean runtime | +|---|---|---|---| +| baseline (FS-clone) | 14.73% | 121 | 702 s | +| **Tutte init** | 14.75% | 121 | **442 s (−37%)** | + +Distortion identical (+0.03pp), total flips identical (121→121), runtime down 37%. Maps are +visually clean (single blob, smooth boundary, no folds) on both subjects. + +## 2. The *final* NAR and spring smoothing are NOT removable + +Tempting because dropping them is faster and even *lowers* the distortion number — but the number +lies: + +| variant (sub-022 lh) | distortion | flipped | runtime | +|---|---|---|---| +| Tutte init (full refinement) | 15.25% | **24** | 454 s | +| Tutte, skip final NAR | 13.38% | **19458** | 256 s | +| Tutte, skip final NAR + spring ("lean") | **13.21%** | **21481** | 263 s | +| Tutte, skip spring only | 15.50% | 75 | 473 s | + +Skipping the final NAR gives the lowest distortion in the whole study (13.21%) but **~21000 +flipped triangles** — a folded, invalid flatmap (visible as dark fold streaks). The final NAR and +spring phases are doing real work cleaning up flips introduced by the metric epochs. **Only the +initial NAR is removable.** + +## 3. LSCM init ≈ Tutte (no advantage) + +LSCM (conformal) init: 15.26% distortion, 34 flips, 444 s on `sub-022 lh` — essentially tied with +Tutte (15.25%, 24, 454 s), slightly more flips. Tutte is preferred (flip-free *guarantee*). + +## Method note + +Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. +Compute is CPU-only (the box's GPUs are blocked by driver 440 / CUDA 10.2). + +## Next ideas (untested) + +- Reduce `k_ring` (7 → 5): attacks the dominant cost (k-ring geodesic computation, ~4 min + 237 MB + cache/hemi). Changes the cache key, so needs recompute. +- Fewer epoch iterations from the better Tutte start. +- Combine: Tutte init as the new default `initial_projection`, initial NAR off by default. From 9a93b44129af876c7478a759e27038346f94fe05 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 05:21:51 -0700 Subject: [PATCH 08/35] Add optimizer-swap probe (L-BFGS/CG/Adam on the same energy) Tests replacing the FreeSurfer-style line-search GD with off-the-shelf optimizers on the metric+area energy, from the Tutte init. Adds optax/jaxopt to the bench extra. Co-Authored-By: Claude Fable 5 --- benchmark/probe_optimizer.py | 220 +++++++++++++++++++++++++++++++++++ pyproject.toml | 2 + 2 files changed, 222 insertions(+) create mode 100644 benchmark/probe_optimizer.py diff --git a/benchmark/probe_optimizer.py b/benchmark/probe_optimizer.py new file mode 100644 index 0000000..b59ffb4 --- /dev/null +++ b/benchmark/probe_optimizer.py @@ -0,0 +1,220 @@ +"""Probe: swap the optimization *algorithm* on the same energy. + +The current refinement minimizes the metric+area energy (J_d + J_a) with a hand-rolled +vectorized quadratic line-search gradient descent wrapped in a multi-level gradient- +smoothing *continuation* (coarse-to-fine). This probe keeps the energy but replaces the +optimizer with an off-the-shelf one (scipy L-BFGS-B / CG, or Adam), starting from the +flip-free Tutte init, and scores the result with the same harness. + +Caveat: the multi-level smoothing in the FreeSurfer-style loop is a continuation that helps +escape poor local minima. A "stronger" optimizer on the raw energy may actually do worse +(more flips / higher distortion). That comparison is the point. + +The energy minimized is ``(l_dist/avg_nbrs)*J_d + l_nlarea*J_a`` — the ``1/avg_nbrs`` folds in +the FreeSurfer gradient normalization so we minimize the same effective objective. + +Usage +----- + python -m benchmark.probe_optimizer --optimizer lbfgs --subset 1 --save + python -m benchmark.probe_optimizer --optimizer cg --max-iter 400 --subset 1 +""" + +from __future__ import annotations + +import argparse +import sys +import time + +import numpy as np + +from . import paths +from .harness import evaluate, load_manifest, register_flatten_fn, select_entries +from .ledger import Ledger, file_hash, new_record +from .probe_tutte_init import flipfree_init, scale_to_area + + +def make_optimizer_flatten_fn( + optimizer: str = "lbfgs", + max_iter: int = 300, + l_dist: float = 1.0, + l_nlarea: float = 1.0, + init: str = "tutte", +): + """Build a ``flatten_fn`` that minimizes J_d + J_a with the chosen optimizer.""" + + def _fn(flattener): + import jax + import jax.numpy as jnp + + from autoflatten.flatten.algorithm import make_energy_fn + + faces_np = np.asarray(flattener.faces) + x0 = flipfree_init(flattener.vertices, flattener.faces, method=init) + x0 = scale_to_area(x0, faces_np, flattener.orig_area) + n_v = x0.shape[0] + + avg = flattener.avg_neighbors or 1.0 + energy_fn = make_energy_fn( + l_dist / avg, + l_nlarea, + flattener.neighbors_jax, + flattener.targets_jax, + flattener.mask_jax, + flattener.faces_jax, + ) + vg = jax.jit(jax.value_and_grad(lambda u: energy_fn(u)[0])) + + def val_grad_flat(x): + uv = jnp.asarray(x.reshape(n_v, 2)) + e, g = vg(uv) + return float(e), np.asarray(g, dtype=np.float64).reshape(-1) + + x0f = x0.reshape(-1).astype(np.float64) + + if optimizer in ("lbfgs", "cg"): + from scipy.optimize import minimize + + method = "L-BFGS-B" if optimizer == "lbfgs" else "CG" + res = minimize( + val_grad_flat, + x0f, + jac=True, + method=method, + options={"maxiter": max_iter}, + ) + uv = res.x.reshape(n_v, 2) + elif optimizer == "adam": + # Minimal Adam (optax is not installed). + lr, b1, b2, eps = 0.05, 0.9, 0.999, 1e-8 + x = x0f.copy() + m = np.zeros_like(x) + v = np.zeros_like(x) + for t in range(1, max_iter + 1): + _, g = val_grad_flat(x) + m = b1 * m + (1 - b1) * g + v = b2 * v + (1 - b2) * g * g + mh = m / (1 - b1**t) + vh = v / (1 - b2**t) + x = x - lr * mh / (np.sqrt(vh) + eps) + uv = x.reshape(n_v, 2) + else: + raise ValueError(f"Unknown optimizer: {optimizer!r}") + return uv + + return _fn + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--optimizer", default="lbfgs", choices=["lbfgs", "cg", "adam"]) + ap.add_argument("--init", default="tutte", choices=["tutte", "lscm"]) + ap.add_argument("--max-iter", type=int, default=300) + ap.add_argument("--l-dist", type=float, default=1.0) + ap.add_argument("--l-nlarea", type=float, default=1.0) + ap.add_argument("--dev", action="store_true") + ap.add_argument("--split", default=None, choices=["train", "holdout"]) + ap.add_argument("--subset", type=int, default=None) + ap.add_argument("--save", action="store_true") + args = ap.parse_args() + + label = f"optim_{args.optimizer}" + register_flatten_fn( + label, + make_optimizer_flatten_fn( + args.optimizer, args.max_iter, args.l_dist, args.l_nlarea, args.init + ), + ) + + from autoflatten.flatten import FlattenConfig + + cfg = FlattenConfig() + cfg.verbose = False + + paths.ensure_output_dirs() + manifest = load_manifest() + subset = 2 if args.dev else args.subset + entries = select_entries(manifest, split=args.split, subset=subset) + if not entries: + print("No manifest entries selected.", file=sys.stderr) + return 1 + + record = new_record( + kind="experiment", + label=f"exp:{label}", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": e["subject"], "hemi": e["hemi"], "split": e.get("split")} + for e in entries + ], + method={ + "name": label, + "optimizer": args.optimizer, + "init": args.init, + "max_iter": args.max_iter, + "l_dist": args.l_dist, + "l_nlarea": args.l_nlarea, + "note": "energy-only optimizer swap; no multiscale schedule, no NAR cleanup", + }, + repro_command="python -m benchmark.probe_optimizer " + " ".join(sys.argv[1:]), + ) + record.decision = { + "hypothesis": ( + f"Minimizing the same J_d+J_a energy with {args.optimizer} from the Tutte init " + "may match the FreeSurfer-style schedule; or it may get stuck / flip without the " + "multi-level smoothing continuation." + ) + } + + print( + f"Optimizer probe '{label}' on {len(entries)} hemispheres ({record.experiment_id})..." + ) + t0 = time.time() + save_dir = paths.RUNS_DIR / record.experiment_id if args.save else None + result = evaluate(entries, cfg, method=label, save_dir=save_dir) + record.per_subject = result["per_subject"] + record.metrics = result["aggregate"] + record.runtime_s = time.time() - t0 + if save_dir is not None: + record.artifacts = [ + { + "path": r["artifact"], + "hash": file_hash(r["artifact"]), + "kind": "flat_patch", + } + for r in result["per_subject"] + if r.get("artifact") + ] + record.status = "ok" if result["aggregate"].get("n_failed", 0) == 0 else "partial" + + agg = result["aggregate"] + base = Ledger().latest(kind="baseline") + base_m = base["metrics"] if base else None + if base_m and "mean_distortion" in agg: + record.decision["conclusion"] = ( + f"{label}: dist {agg['mean_distortion']:.2f} vs base {base_m.get('mean_distortion'):.2f}; " + f"flips {agg['total_flipped']} vs {base_m.get('total_flipped')}; " + f"runtime {agg.get('mean_runtime_s', 0):.0f}s vs {base_m.get('mean_runtime_s', 0):.0f}s." + ) + Ledger().append(record) + + print("\n=== aggregate ===") + for k in ( + "n_patches", + "n_failed", + "mean_distortion", + "total_flipped", + "frac_patches_with_flips", + "mean_runtime_s", + ): + if k in agg: + print(f" {k}: {agg[k]}") + if base_m: + print("\n=== vs baseline ===") + for k in ("mean_distortion", "total_flipped", "mean_runtime_s"): + print(f" {k}: exp={agg.get(k)} baseline={base_m.get(k)}") + print(f"\nLogged {record.experiment_id} -> {Ledger().path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml index 4822226..a36332a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ bench = [ # Benchmark / autoresearch harness (see benchmark/PLAN.md). Not packaged. "optuna>=3.6", # optional HPO driver "pandas>=2.0", # report tables + "optax>=0.2", # gradient-based optimizers (optimizer-swap experiments) + "jaxopt>=0.8", # L-BFGS / solvers on JAX energies ] [project.scripts] From b366403aa184994e7b6fb7937965244cf45b532d Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 05:28:01 -0700 Subject: [PATCH 09/35] Findings: optimizer-swap is a dead end on this energy; multiscale is the lever L-BFGS/CG on the same energy fold (10.31% distortion but 58025 flips, 8x faster); the area-weight sweep is Pareto-worse than baseline; the epoch schedule still folds; a flip barrier keeps injectivity but stalls L-BFGS at the init. The FreeSurfer optimizer's value is the multi-level gradient smoothing + NAR, not the line search. Productive next step is multiscale/hierarchical (multigrid) optimization or SLIM, not a drop-in optimizer swap. Co-Authored-By: Claude Fable 5 --- benchmark/FINDINGS.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 17f7ccf..0faf710 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -43,6 +43,32 @@ initial NAR is removable.** LSCM (conformal) init: 15.26% distortion, 34 flips, 444 s on `sub-022 lh` — essentially tied with Tutte (15.25%, 24, 454 s), slightly more flips. Tutte is preferred (flip-free *guarantee*). +## 4. Swapping the optimization *algorithm* does not help — the energy is the issue + +Tested replacing the FreeSurfer-style line-search GD with off-the-shelf optimizers on the *same* +metric+area energy, from the flip-free Tutte init (`benchmark/probe_optimizer.py`): + +- **L-BFGS / CG on the soft energy fold.** Full L-BFGS reaches **10.31% distortion in 83 s (≈8× + faster)** but **58025 flipped triangles** — it minimizes distance error by folding the mesh. +- **Pareto-worse than the baseline.** Sweeping the area weight `l_nlarea` only trades one for the + other: to get flips down to baseline (~24) L-BFGS needs ~25–29% distortion (vs baseline 15% @ 24 + flips). The FreeSurfer optimizer sits on a strictly better frontier. +- **The epoch weight schedule doesn't rescue it.** Running L-BFGS through the area-dominant → + distance-dominant schedule still folds (71269 flips at the distance-dominant stage). +- **A flip barrier keeps it injective but stalls.** Adding a one-sided area barrier (active only as + area→0) holds flips at ~0, but L-BFGS then can't reduce distortion at all (stuck at the init's + 33.7%): from the Tutte init, the *local* distance-reducing moves all push triangles toward + folding, which the barrier blocks. + +**Conclusion.** The FreeSurfer-style optimizer's value is **not** the line-search algorithm — it is +the **multi-level gradient smoothing** (spatially-coherent, coarse-to-fine moves that reduce +distortion *without* folding) plus the NAR passes. A generic first-order solver on per-vertex +gradients makes high-frequency moves that fold. So a faster/better optimizer alone is a dead end on +this energy; the productive directions are **multiscale / hierarchical (geometric-multigrid) +optimization** or a proper **SLIM** local-global solver (symmetric-Dirichlet energy with a +flip-preventing line search), not a drop-in optimizer swap. (`optax`/`jaxopt` are installed for +this.) + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From 1345622f8fbcea926a5bcb7a780352a31b30f8ab Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 05:45:12 -0700 Subject: [PATCH 10/35] Add spectral (manifold-harmonic) multigrid probe Parametrizes the flatmap in the low-frequency cotangent-Laplacian eigenbasis and optimizes coarse-to-fine: band-limited maps can't fold, so the coarse solve is flip-free by construction (19 flips, 66s, 22% distortion vs L-BFGS's 58k flips). Used as a coarse init for full-resolution refinement (geometric multigrid V-cycle). Harmonics cached per mesh by content hash. Co-Authored-By: Claude Fable 5 --- benchmark/probe_multigrid.py | 230 +++++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 benchmark/probe_multigrid.py diff --git a/benchmark/probe_multigrid.py b/benchmark/probe_multigrid.py new file mode 100644 index 0000000..3a54a6d --- /dev/null +++ b/benchmark/probe_multigrid.py @@ -0,0 +1,230 @@ +"""Probe: spectral (manifold-harmonic) multigrid optimization. + +The FreeSurfer-style optimizer avoids folding via multi-level gradient smoothing — coherent +coarse-to-fine moves. This probe makes that explicit and geometric: parametrize the flatmap +in the **mesh-Laplacian eigenbasis** (manifold harmonics) and optimize coarse-to-fine, +adding higher-frequency modes progressively. A band-limited (few-mode) map *cannot* make +local folds, so the coarse optimization is flip-free by construction; it just can't +represent fine detail, so distortion floors out. The coarse, flip-free, distance-aware map +is then an excellent init for a short full-resolution refinement (the "fine" leg of the +multigrid V-cycle). + +Two modes: +- ``--no-refine``: spectral coarse-to-fine only (fast, flip-free, band-limited). +- default: spectral coarse init -> full-resolution refinement (initial NAR disabled). + +Manifold harmonics (the lowest *K* eigenvectors of the cotangent Laplacian) are cached per +mesh under the k-ring cache dir, keyed by a content hash of the vertices. + +Usage +----- + python -m benchmark.probe_multigrid --subset 1 --save + python -m benchmark.probe_multigrid --no-refine --modes 200 --subset 1 +""" + +from __future__ import annotations + +import argparse +import hashlib +import sys +import time + +import numpy as np + +from . import paths +from .harness import evaluate, load_manifest, register_flatten_fn, select_entries +from .ledger import Ledger, file_hash, new_record +from .probe_tutte_init import flipfree_init, scale_to_area + + +def manifold_harmonics(vertices: np.ndarray, faces: np.ndarray, k: int) -> np.ndarray: + """Lowest-``k`` cotangent-Laplacian eigenvectors (M-orthonormal), cached by content.""" + import igl + import scipy.sparse.linalg as sla + + v = np.ascontiguousarray(vertices, dtype=np.float64) + f = np.ascontiguousarray(faces, dtype=np.int64) + key = hashlib.md5(v.tobytes() + f.tobytes() + str(k).encode()).hexdigest()[:16] + cache = paths.KRING_CACHE_DIR / f"harmonics_{key}_k{k}.npz" + if cache.exists(): + return np.load(cache)["Phi"] + + paths.KRING_CACHE_DIR.mkdir(parents=True, exist_ok=True) + lap = -igl.cotmatrix(v, f) # positive semidefinite + mass = igl.massmatrix(v, f, igl.MASSMATRIX_TYPE_VORONOI) + _, phi = sla.eigsh(lap.tocsc(), k=k, M=mass.tocsc(), sigma=1e-8, which="LM") + np.savez(cache, Phi=phi.astype(np.float64)) + return phi.astype(np.float64) + + +def spectral_optimize( + flattener, phi: np.ndarray, schedule, max_iter: int = 200 +) -> np.ndarray: + """Coarse-to-fine optimization of J_d in the manifold-harmonic basis.""" + import igl + import jax + import jax.numpy as jnp + from scipy.optimize import minimize + + from autoflatten.flatten.energy import compute_metric_energy + + faces = np.asarray(flattener.faces) + mass = igl.massmatrix( + np.asarray(flattener.vertices, dtype=np.float64), + faces.astype(np.int64), + igl.MASSMATRIX_TYPE_VORONOI, + ) + x0 = scale_to_area( + flipfree_init(flattener.vertices, flattener.faces, "tutte"), + faces, + flattener.orig_area, + ) + coeff = phi.T @ (mass @ x0) # mass-weighted projection onto modes (Phi^T M x) + + phij = jnp.asarray(phi) + nbr, tgt, msk = flattener.neighbors_jax, flattener.targets_jax, flattener.mask_jax + avg = flattener.avg_neighbors or 1.0 + + for k in schedule: + p_k = phij[:, :k] + energy = jax.jit( + lambda c, p_k=p_k: compute_metric_energy(p_k @ c, nbr, tgt, msk) / avg + ) + vg = jax.jit(jax.value_and_grad(energy)) + + def vgf(z, k=k, vg=vg): + val, grad = vg(jnp.asarray(z.reshape(k, 2))) + return float(val), np.asarray(grad, dtype=np.float64).reshape(-1) + + res = minimize( + vgf, + coeff[:k].reshape(-1).astype(np.float64), + jac=True, + method="L-BFGS-B", + options={"maxiter": max_iter}, + ) + coeff[:k] = res.x.reshape(k, 2) + return np.asarray(phi @ coeff) + + +def make_multigrid_flatten_fn(modes: int = 200, refine: bool = True, schedule=None): + """``flatten_fn``: spectral coarse-to-fine, then (optionally) full-resolution refine.""" + sched = schedule or [m for m in (20, 50, 100, modes) if m <= modes] + + def _fn(flattener): + phi = manifold_harmonics(flattener.vertices, flattener.faces, modes) + coarse = spectral_optimize(flattener, phi, sched) + if not refine: + return coarse + # Coarse map is ~flip-free; skip the initial NAR and refine at full resolution. + flattener.config.negative_area_removal.enabled = False + flattener.initial_projection = lambda: coarse + return flattener.run() + + return _fn + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "--modes", type=int, default=200, help="number of manifold-harmonic modes" + ) + ap.add_argument("--no-refine", action="store_true", help="spectral coarse only") + ap.add_argument("--dev", action="store_true") + ap.add_argument("--split", default=None, choices=["train", "holdout"]) + ap.add_argument("--subset", type=int, default=None) + ap.add_argument("--save", action="store_true") + args = ap.parse_args() + + refine = not args.no_refine + label = "multigrid" + ("" if refine else "_coarse") + register_flatten_fn(label, make_multigrid_flatten_fn(args.modes, refine=refine)) + + from autoflatten.flatten import FlattenConfig + + cfg = FlattenConfig() + cfg.verbose = False + if refine: + cfg.negative_area_removal.enabled = False + + paths.ensure_output_dirs() + manifest = load_manifest() + subset = 2 if args.dev else args.subset + entries = select_entries(manifest, split=args.split, subset=subset) + if not entries: + print("No manifest entries selected.", file=sys.stderr) + return 1 + + record = new_record( + kind="experiment", + label=f"exp:{label}", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": e["subject"], "hemi": e["hemi"], "split": e.get("split")} + for e in entries + ], + method={ + "name": label, + "modes": args.modes, + "refine": refine, + "config": cfg.to_dict(), + }, + repro_command="python -m benchmark.probe_multigrid " + " ".join(sys.argv[1:]), + ) + record.decision = { + "hypothesis": ( + "Optimizing in the low-frequency manifold-harmonic basis is flip-free by " + "band-limiting; the coarse, distance-aware map is a strong init for a short " + "full-resolution refine (geometric multigrid)." + ) + } + + print( + f"Multigrid probe '{label}' (modes={args.modes}) on {len(entries)} hemispheres ({record.experiment_id})..." + ) + t0 = time.time() + save_dir = paths.RUNS_DIR / record.experiment_id if args.save else None + result = evaluate(entries, cfg, method=label, save_dir=save_dir) + record.per_subject = result["per_subject"] + record.metrics = result["aggregate"] + record.runtime_s = time.time() - t0 + if save_dir is not None: + record.artifacts = [ + { + "path": r["artifact"], + "hash": file_hash(r["artifact"]), + "kind": "flat_patch", + } + for r in result["per_subject"] + if r.get("artifact") + ] + record.status = "ok" if result["aggregate"].get("n_failed", 0) == 0 else "partial" + + agg = result["aggregate"] + base = Ledger().latest(kind="baseline") + base_m = base["metrics"] if base else None + if base_m and "mean_distortion" in agg: + record.decision["conclusion"] = ( + f"{label}: dist {agg['mean_distortion']:.2f} vs base {base_m.get('mean_distortion'):.2f}; " + f"flips {agg['total_flipped']} vs {base_m.get('total_flipped')}; " + f"runtime {agg.get('mean_runtime_s', 0):.0f}s vs {base_m.get('mean_runtime_s', 0):.0f}s." + ) + Ledger().append(record) + + print("\n=== aggregate ===") + for k in ( + "n_patches", + "n_failed", + "mean_distortion", + "total_flipped", + "frac_patches_with_flips", + "mean_runtime_s", + ): + if k in agg: + print(f" {k}: {agg[k]}") + print(f"\nLogged {record.experiment_id} -> {Ledger().path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 8b40b9ead76d1912059d20a25adebd9ab786f5e4 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 06:06:27 -0700 Subject: [PATCH 11/35] Findings: multigrid + Adam optimizer experiments Spectral (manifold-harmonic) multigrid: coarse solve is flip-free by band-limiting (19 flips, 66s, 22%) but band-limited; full V-cycle matches quality (15.3%) but isn't faster (eigsh+spectral overhead). Adam (fixed or per-level-scheduled lr) folds or stalls -- a fixed step can't span the multiscale step range. Conclusion: nothing beats the FreeSurfer multiscale line-search on this energy; the line search is the load-bearing piece. Real wins need mesh-decimation multigrid or SLIM, not a drop-in optimizer swap. Co-Authored-By: Claude Fable 5 --- benchmark/FINDINGS.md | 48 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 0faf710..e182e1c 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -69,6 +69,54 @@ optimization** or a proper **SLIM** local-global solver (symmetric-Dirichlet ene flip-preventing line search), not a drop-in optimizer swap. (`optax`/`jaxopt` are installed for this.) +## 5. Spectral (manifold-harmonic) multigrid — elegant, flip-free, but not faster + +`benchmark/probe_multigrid.py`. Parametrize the flatmap in the lowest-`K` cotangent-Laplacian +eigenvectors and optimize coarse-to-fine (add modes progressively). A band-limited map *cannot* +make local folds, so this is **flip-free by construction**: + +| modes | distortion | flips | cumulative time | +|---|---|---|---| +| 20 | 26.89% | 12 | 13s | +| 50 | 24.96% | 13 | 28s | +| 100 | 23.78% | 13 | 44s | +| 200 | 22.35% | 19 | 66s | + +So the coarse spectral solve is a genuinely nice **fast, flip-free approximate flattener** (66s, ~0 +flips) — contrast L-BFGS's 58025 flips. But it is **band-limited**: 200 smooth modes can't represent +fine detail, so distortion floors at ~22%. + +Used as the coarse leg of a multigrid V-cycle (spectral init → full-resolution refine): +- spectral coarse (22.35%) → full refine → **15.32% / 26 flips / 430s refine**. +- Even with a *shortened* refine (half iters): 15.37% / 28 flips / 347s. + +It **matches** baseline quality but is **not faster**: the spectral overhead (eigsh ~75s one-time + +coarse solve ~66s) plus the refine (~350–430s) totals ≈ 490s vs Tutte+pipeline's 454s. The better +init (22% vs Tutte's 33%) doesn't shorten the refine enough to pay for the overhead, because the +refinement runs a largely fixed schedule. A true win would need a **mesh-decimation** multigrid +(each level full-DOF but few vertices) rather than a band-limited basis. + +## 6. Adam (and fixed-step methods) can't replace the line search + +Adam (optax) driving the same multi-level gradient smoothing as the baseline: +- **Fixed lr**: lr=0.001 is stable but crawls (33.7%→31.5%, nowhere near 15%); lr≥0.01 reduces + distortion but folds, and **diverges at the finest smoothing level** (up to 175855 flips). +- **Per-level lr schedule** (large→small with the smoothing level): still folds (76552 flips). + +Root cause: the optimal step size spans *orders of magnitude* across the smoothing schedule (large at +coarse `n_avg`, tiny at fine). A fixed or simply-scheduled step is either too slow at coarse scales +or unstable at fine scales. The baseline's **per-iteration line search** is precisely what adapts the +step across scales — that, plus the gradient smoothing, is the load-bearing machinery. + +## Overall conclusion on the optimizer + +Across L-BFGS, CG, Adam, flip barriers, and a spectral multigrid, **nothing beats the FreeSurfer-style +multiscale line-search GD** on the speed/quality/flip Pareto for this energy. Strong optimizers fold +without the smoothing; fixed-step optimizers can't span the multiscale step range; the eigenbasis is +flip-free but band-limited. The genuinely better directions remaining are a **mesh-decimation +geometric multigrid** or a proper **SLIM** local-global solver — real reimplementations, not drop-in +swaps. The one validated free win remains the **Tutte flip-free init** (§1). + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From e8ee2f65fa67886cddadeec33606360148186416 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 06:34:04 -0700 Subject: [PATCH 12/35] Add speed levers to experiment runner (line-search points, iters, smoothing cap) Profiling shows per-iteration cost is gradient (177ms) + line search (272ms) + smoothing (up to 902ms at n_avg=1024). Expose --n-coarse-steps, --iters-per-level, --max-smoothing to test pure runtime reductions that don't change the energy. Co-Authored-By: Claude Fable 5 --- benchmark/experiment.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/benchmark/experiment.py b/benchmark/experiment.py index 07e0df4..090cc71 100644 --- a/benchmark/experiment.py +++ b/benchmark/experiment.py @@ -54,6 +54,18 @@ def build_config(args): for phase in cfg.phases: if phase.name in (args.skip_epoch or []): phase.enabled = False + + # Speed levers (pure runtime; change the optimizer path, not the energy). + if args.n_coarse_steps is not None: + cfg.line_search.n_coarse_steps = args.n_coarse_steps # fewer line-search points + if args.iters_per_level is not None: + for phase in cfg.phases: + phase.iters_per_level = args.iters_per_level + if args.max_smoothing is not None: + # Cap the expensive coarse smoothing levels (n_avg=1024 costs ~900ms/iter). + cap = args.max_smoothing + for phase in cfg.phases: + phase.smoothing_schedule = [n for n in phase.smoothing_schedule if n <= cap] return cfg @@ -77,6 +89,24 @@ def main() -> int: ) ap.add_argument("--k-ring", type=int, default=7) ap.add_argument("--n-neighbors", type=int, default=12) + ap.add_argument( + "--n-coarse-steps", + type=int, + default=None, + help="line-search points (default 15)", + ) + ap.add_argument( + "--iters-per-level", + type=int, + default=None, + help="max iters per smoothing level (default 40)", + ) + ap.add_argument( + "--max-smoothing", + type=int, + default=None, + help="cap n_avg in smoothing schedules (drops costly coarse levels)", + ) ap.add_argument("--dev", action="store_true") ap.add_argument("--split", default=None, choices=["train", "holdout"]) ap.add_argument("--subset", type=int, default=None) @@ -106,6 +136,9 @@ def main() -> int: "skip_epoch": args.skip_epoch or [], "k_ring": args.k_ring, "n_neighbors": args.n_neighbors, + "n_coarse_steps": args.n_coarse_steps, + "iters_per_level": args.iters_per_level, + "max_smoothing": args.max_smoothing, } record = new_record( kind="experiment", From 64dfbe195cd7f7e7ebda024584960843e6d411f4 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 06:59:00 -0700 Subject: [PATCH 13/35] Findings: how to speed up the FreeSurfer optimizer (~2.5x, near-free) Profiling: per-iteration cost is gradient (177ms) + line search (272ms) + smoothing (up to 902ms at n_avg=1024). Independent, compounding speed levers measured on sub-022 lh: line search 15->7 points (-30%, no quality cost), k-ring neighbors 12->6 (-33%, +0.1pp), fewer iters (-20%), capped smoothing (-7%). Combined fast config: 262s vs 658s original baseline (~60% faster) at +0.31pp distortion. The practical path to speed is config levers + Tutte init, not a new optimizer. Co-Authored-By: Claude Fable 5 --- benchmark/FINDINGS.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index e182e1c..c444b18 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -117,6 +117,44 @@ flip-free but band-limited. The genuinely better directions remaining are a **me geometric multigrid** or a proper **SLIM** local-global solver — real reimplementations, not drop-in swaps. The one validated free win remains the **Tutte flip-free init** (§1). +## 7. Speeding up the FreeSurfer-style optimizer (the productive direction) + +Rather than replace the optimizer, **accelerate it**. Profiling one iteration (193k-vertex +hemisphere, CPU): + +| component | cost / iter | notes | +|---|---|---| +| gradient (J_d + J_a) | 177 ms | iterates over ~16M k-ring edges | +| line search (15 pts, vmap) | 272 ms | also iterates over k-ring edges | +| `smooth_gradient`, n_avg=1 | 1 ms | fine levels | +| `smooth_gradient`, n_avg=256 | 243 ms | | +| `smooth_gradient`, n_avg=1024 | **902 ms** | 1024 sequential neighbor-averaging passes | + +So per iteration ≈ gradient + line search (~450 ms, constant) + smoothing (1→902 ms by level). +The levers, each **measured on sub-022 lh** (Tutte init; reference Tutte+full = 15.25% / 24 / 454s): + +| lever | distortion | flips | runtime | vs 454s | +|---|---|---|---|---| +| line search 15 → 7 points | 15.28% | 28 | 319s | **−30%, ~free** | +| k-ring neighbors 12 → 6 | 15.37% | 30 | 303s | −33%, +0.1pp | +| k_ring 7 → 5 | 15.73% | 24 | 340s | −25%, +0.5pp | +| iters/level 40 → 20 | 15.37% | 28 | 362s | −20%, +0.1pp | +| drop n_avg=1024 level | 15.35% | 35 | 422s | −7% (few coarse iters) | +| **combo** (ls7 + iters25 + smooth256) | 15.51% | 28 | **262s** | **−42%** | + +These levers are largely **independent and compounding**, and stack on top of the Tutte init +(which removed the ~4-min initial NAR). The combined "fast" config reaches **262s vs the original +658s baseline — ~60% faster** — at +0.31pp distortion and comparable flips. The biggest near-free +win is the **line search point count** (15→7, no measurable quality cost); the biggest single lever +is **k-ring density** (fewer neighbors cuts gradient + line search + the one-time k-ring computation ++ the 237MB cache together). Capping the coarse smoothing helps less than expected because few +iterations actually run at the coarsest level. + +**Takeaway:** the practical, low-risk path to a faster pipeline is not a new optimizer but +(1) Tutte init, (2) a leaner line search, (3) a sparser k-ring, (4) fewer iters/level — each a small +config change, together a ~2.5× speedup at near-baseline quality. (A spectral/low-rank acceleration +of the coarse `smooth_gradient` is a deeper exact win if the coarse levels ever dominate.) + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From 025b302adb7d897502c2010c0c0ccf89985e0f71 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 07:25:43 -0700 Subject: [PATCH 14/35] Findings: validated 3.6x speedup (fast_ultimate, n=4) Stacked fast config across 4 hemispheres: 194s vs 702s baseline (-72%) at 15.10% vs 14.73% distortion, visually clean maps. Co-Authored-By: Claude Fable 5 --- benchmark/FINDINGS.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index c444b18..56cab5a 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -150,10 +150,20 @@ is **k-ring density** (fewer neighbors cuts gradient + line search + the one-tim + the 237MB cache together). Capping the coarse smoothing helps less than expected because few iterations actually run at the coarsest level. +**Validated stacked "fast" config** (Tutte init + n_neighbors 6 + line-search 7 + iters/level 25 + +smoothing cap 256), across **4 hemispheres / 2 subjects**: + +| | mean distortion | total flips | mean runtime | +|---|---|---|---| +| baseline (n=4) | 14.73% | 121 | 702 s | +| **fast_ultimate (n=4)** | 15.10% | 237 | **194 s (−72%, 3.6×)** | + ++0.37pp distortion, flips still 0.06% of faces, and visually clean maps on both subjects. + **Takeaway:** the practical, low-risk path to a faster pipeline is not a new optimizer but (1) Tutte init, (2) a leaner line search, (3) a sparser k-ring, (4) fewer iters/level — each a small -config change, together a ~2.5× speedup at near-baseline quality. (A spectral/low-rank acceleration -of the coarse `smooth_gradient` is a deeper exact win if the coarse levels ever dominate.) +config change, together a **~3.6× speedup** at near-baseline quality. (A spectral/low-rank +acceleration of the coarse `smooth_gradient` is a deeper exact win if the coarse levels ever dominate.) ## Method note From 691f280ad9bbd72055a0e8c7d1fa901d1f603aa3 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 09:07:08 -0700 Subject: [PATCH 15/35] Findings: energy/quality angle - true-geodesic metric + correction recalibration Built a heat-method true-geodesic distortion yardstick; it shows the k-ring metric overstates distortion (~16% true vs 33.7% k-ring for Tutte) and mis-ranks maps. k-ring targets are ~9% inflated locally, BUT recalibrating the 1.207 correction down makes true distortion monotonically worse (18.15% -> 21.66% -> 26.88% for 1.207/1.10/1.00): the factor usefully pre-stretches to compensate curvature-induced medium-range compression. Default is near-optimal; recalibration is a dead end. Real remaining lever: per-edge geodesic targets in the energy (heavier). Co-Authored-By: Claude Fable 5 --- benchmark/FINDINGS.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 56cab5a..1464593 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -165,6 +165,41 @@ smoothing cap 256), across **4 hemispheres / 2 subjects**: config change, together a **~3.6× speedup** at near-baseline quality. (A spectral/low-rank acceleration of the coarse `smooth_gradient` is a deeper exact win if the coarse levels ever dominate.) +## 8. Improving the energy (quality at the cost of speed) + +Built a **true geodesic distortion** yardstick (libigl heat method, 200 sampled sources, local +≤30mm pairs) to score quality independently of the k-ring targets. It immediately showed the k-ring +metric is **miscalibrated as an absolute number**: the Tutte init reads 33.7% on the k-ring metric +but ~16% in true geodesic terms, and the metric even *mis-ranks* maps (Tutte better than the spectral +map in true terms, opposite of the k-ring ranking). Calibration: the k-ring targets (Dijkstra × +1.207) are ~6–9% larger than true geodesics at the very local scale (mean ratio 1.094, median 1.059). + +**But recalibrating the correction does NOT improve quality — it hurts.** Optimizing from the Tutte +init with correction ∈ {1.207, 1.10, 1.00} and scoring true distortion: + +| correction | k-ring metric | true geodesic | flips | +|---|---|---|---| +| **1.207 (default)** | 15.26% | **18.15%** | 30 | +| 1.10 | 15.24% | 21.66% | 26 | +| 1.00 | 15.26% | 26.88% | 35 | + +True distortion gets monotonically *worse* as the correction shrinks. Interpretation: the local +Dijkstra/geodesic ratio (~1.09) does not represent the full range, and — more importantly — a +flatmap intrinsically **compresses** medium-range distances (curvature; Gauss). The 1.207 factor acts +as a useful global **pre-stretch** that compensates for that compression over the 0–30mm range, so a +"locally accurate" smaller correction leaves the map too compact and worse overall. **The default +1.207 is empirically near-optimal; recalibration is a dead end.** (The k-ring metric is still worth +recalibrating for *interpretation*, since it overstates absolute distortion — but not for the search.) + +Also notable: the optimized maps (~18% true) are not better than the raw Tutte init (~16% true) on +this medium-range metric — local k-ring fitting trades some medium-range geodesic accuracy. Which map +is "best" depends on the distortion metric (local vs medium-range), a point worth making in the paper. + +**Remaining genuine lever (untested, heavier):** replace Dijkstra × constant with **actual per-edge +geodesic targets** (heat/exact) in the energy — a constant can't capture the spatially-varying +Dijkstra/geodesic ratio. That is the real "better energy at the cost of a slight slowdown," but it +needs per-vertex geodesic computation (expensive) and is left as the next step. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From 620eff607b8c4e94b407406695ee99df17c4ec10 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 10:09:26 -0700 Subject: [PATCH 16/35] Findings: true-geodesic metric, local-vs-global, distance-optimal output scale Adds benchmark/truedist.py (energy-independent true-geodesic distortion, banded + global) and benchmark/probe_truedist.py (k_ring / epoch-skip / target-scale sweeps scored on the true metric). FINDINGS section 9: - targets are a confirmed dead end (Dijkstra over-estimates geodesic only ~7.8%, flat across distance; current target=Dijkstra/1.207 is ~11% too compact; 1.207 near-optimal). - k_ring 7->11 does not help (more flips, slower). - the local <=30mm metric is gameable: a conformal Tutte disk wins it but is a degenerate flatmap; the global all-pairs metric correctly ranks the full pipeline best. - validated near-free win: the final scale_to_area normalization is not distance-optimal; rescaling output by ~1.04 cuts global geodesic distortion ~0.5pp (~5% rel), reproduced on held-out geodesic sources (single-hemi; multi-hemi TBD). Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 50 ++++++++++ benchmark/probe_truedist.py | 179 ++++++++++++++++++++++++++++++++++++ benchmark/truedist.py | 128 ++++++++++++++++++++++++++ 3 files changed, 357 insertions(+) create mode 100644 benchmark/probe_truedist.py create mode 100644 benchmark/truedist.py diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 1464593..a8461d0 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -200,6 +200,56 @@ geodesic targets** (heat/exact) in the energy — a constant can't capture the s Dijkstra/geodesic ratio. That is the real "better energy at the cost of a slight slowdown," but it needs per-vertex geodesic computation (expensive) and is left as the next step. +## 9. Local vs global distance metric, and a validated free win: distance-optimal output scale + +Pushing on "reduce distance error at a slight slowdown" (all on `sub-022 lh`, scored with the +heat-method true-geodesic yardstick; `benchmark/probe_truedist.py`, `benchmark/truedist.py`): + +**a) Direct measurement of the Dijkstra/geodesic relationship.** Using the 200 saved heat-geodesic +source fields vs raw graph Dijkstra: raw Dijkstra overestimates the true geodesic by only **~7.8%** +(ratio 1.078), and that ratio is **nearly flat across 0–30 mm** (1.12→1.07). The code computes +`target = Dijkstra / 1.207`, so the **targets are ~11% *smaller* than true geodesics** (target/true = +0.893). Two consequences: (i) a *distance-dependent* correction is pointless (no distance structure to +exploit), and (ii) §8's "shrink the correction" sweep only ever made targets *larger*; the optimum is +at/above 1.207, confirming 1.207 is near-optimal. The targets are a **confirmed dead end**. + +**b) `k_ring` 7→11 does not help.** True distortion 18.84%→18.66% (−0.18pp) but flips 24→444 and +runtime 445→649 s (+46%). More constraints with the same flawed targets just add folding pressure. + +**c) The local metric is gameable — a conformal disk beats a real flatmap on it.** Skipping all metric +epochs (just flip-cleaning the Tutte init) gives the *lowest* local (≤30 mm) true distortion in the +study — **16.26%, 0 flips, 6 s** — but the map is a **featureless disk** (Tutte pins the boundary to a +circle, destroying all anatomical shape). Lighter-touch variants (skip epoch_3 / epoch_2+3) are *worse* +(18.9–19.0%), not better. So the metric epochs are not "degrading" quality; the **local ≤30 mm metric +simply doesn't see the global/boundary area distortion** and mis-ranks a useless disk above a correct +flatmap. Lesson for the paper: report distance distortion **globally**, not just locally. + +| map | local ≤30 mm | **global (all pairs)** | flips | shape | +|---|---|---|---|---| +| clean_init (Tutte disk) | **16.26%** | 13.04% | 0 | degenerate disk | +| full pipeline (fast_ultimate) | 17.93% | **11.35%** | 33 | correct flatmap | + +On the **global** metric (all geodesic pairs, no 30 mm cap) the full pipeline is **best** and the disk is +**worst** — the correct ranking. The pipeline is actually better at long range (11.35%) than short range +(17.93%); its only real weakness is a slight **global scale** bias. + +**d) Validated win — distance-optimal output scale (≈ free).** The final `scale_to_area` normalizes the +map to match *total surface area*, but that is **not** the scale that minimizes geodesic distortion: the +too-compact targets leave the optimized map ~4% too small. Refitting a single global scale `s`: + +| | global true distortion | +|---|---| +| `s = 1.0` (area-matched, current) | 11.35% | +| `s* ≈ 1.04` (distance-optimal) | **10.88%** (−0.48pp, ~4% relative) | + +**Not metric-gaming:** fitting `s` on 100 sources and evaluating on the 100 held-out sources reproduces +it (11.31%→10.74% on the held-out half), so it is a real, systematic global property, consistent with +the directly-measured 11% target compaction. A single multiplicative rescale of the output `uv` costs +nothing and changes no shape — only the zoom. **Caveat:** validated on one hemisphere; before becoming a +default it should be confirmed across hemispheres/subjects (the bias direction is principled, but the +exact factor may vary). The clean way to ship it: after optimization, choose the output scale that +minimizes geodesic (or k-ring) distortion instead of matching total area. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. diff --git a/benchmark/probe_truedist.py b/benchmark/probe_truedist.py new file mode 100644 index 0000000..4b53bf7 --- /dev/null +++ b/benchmark/probe_truedist.py @@ -0,0 +1,179 @@ +"""Probe: optimize quality (true geodesic distortion) at the cost of a slight slowdown. + +Section 8 of FINDINGS showed the *optimized* map is slightly worse than the raw Tutte init +in **true geodesic** terms (~18% vs ~16%): the k-ring energy over-fits very-local distances +and lets medium-range (10-30 mm) distances drift. This probe tests config levers that should +add medium-range constraints / reduce that drift, scored with the energy-independent true +geodesic yardstick (:mod:`benchmark.truedist`) so maps are comparable across ``k_ring``. + +Each run uses the validated Tutte flip-free init and the existing refinement, varying one +knob (``--k-ring`` / ``--n-neighbors``), and logs both the k-ring metric and the true metric +to the ledger. Single-hemisphere screening (default sub-022 lh, which has a cached truegeo +reference); promote a winner to multi-hemi afterward. + +Usage +----- + python -m benchmark.probe_truedist --k-ring 7 --n-neighbors 12 --label k7_baseline + python -m benchmark.probe_truedist --k-ring 11 --n-neighbors 12 --label k11 +""" + +from __future__ import annotations + +import argparse +import sys +import time + +import numpy as np + +from . import paths +from .harness import build_flattener, load_manifest +from .ledger import Ledger, new_record +from .metrics import per_patch_metrics +from .probe_tutte_init import make_flatten_fn +from .truedist import load_truegeo, true_distortion, true_distortion_banded + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--subject", default="sub-022") + ap.add_argument("--hemi", default="lh") + ap.add_argument("--k-ring", type=int, default=7) + ap.add_argument("--n-neighbors", type=int, default=12) + ap.add_argument("--init", default="tutte", choices=["tutte", "lscm", "projection"]) + ap.add_argument( + "--skip-epoch", action="append", choices=["epoch_1", "epoch_2", "epoch_3"] + ) + ap.add_argument( + "--target-scale", + type=float, + default=None, + help="multiply k-ring targets by this factor (rescales the effective correction " + "without rebuilding the cache; <1 = more compact targets)", + ) + ap.add_argument( + "--save-uv", action="store_true", help="save the flat patch for plotting" + ) + ap.add_argument("--label", default=None) + ap.add_argument("--hypothesis", default="") + args = ap.parse_args() + + from autoflatten.flatten import FlattenConfig + + cfg = FlattenConfig() + cfg.verbose = False + cfg.kring.k_ring = args.k_ring + cfg.kring.n_neighbors_per_ring = args.n_neighbors + if args.init in ("tutte", "lscm"): + cfg.negative_area_removal.enabled = False + for phase in cfg.phases: + if phase.name in (args.skip_epoch or []): + phase.enabled = False + + label = args.label or f"truedist_{args.init}_k{args.k_ring}_n{args.n_neighbors}" + + paths.ensure_output_dirs() + manifest = load_manifest() + entry = [ + e + for e in manifest["entries"] + if e["subject"] == args.subject and e["hemi"] == args.hemi + ] + if not entry: + print(f"No manifest entry for {args.subject} {args.hemi}", file=sys.stderr) + return 1 + entry = entry[0] + + ref = load_truegeo(args.subject, args.hemi) + + record = new_record( + kind="experiment", + label=f"exp:{label}", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": args.subject, "hemi": args.hemi, "split": entry.get("split")} + ], + method={ + "name": label, + "init": args.init, + "k_ring": args.k_ring, + "n_neighbors": args.n_neighbors, + "metric": "true_geodesic (heat, R=30mm, 200 src) + k-ring", + "config": cfg.to_dict(), + }, + repro_command="python -m benchmark.probe_truedist " + " ".join(sys.argv[1:]), + ) + if args.hypothesis: + record.decision["hypothesis"] = args.hypothesis + + print( + f"True-distortion probe '{label}' on {args.subject} {args.hemi} ({record.experiment_id})..." + ) + t0 = time.time() + flatten_fn = ( + make_flatten_fn(args.init, refine=True) if args.init != "projection" else None + ) + + flattener = build_flattener(entry, cfg, use_cache=True) + if args.target_scale is not None: + # Rescale the effective graph-distance correction without rebuilding the cache: + # target_new = target_old * scale (scale<1 => more compact targets). + flattener.targets_jax = flattener.targets_jax * args.target_scale + t_opt = time.time() + if flatten_fn is None: + uv = np.asarray(flattener.run()) + else: + uv = np.asarray(flatten_fn(flattener)) + runtime = time.time() - t_opt + + kring_m = per_patch_metrics(uv, flattener) + true_m = true_distortion(uv, ref) + banded = true_distortion_banded(uv, ref) + + artifact = None + if args.save_uv: + save_dir = paths.RUNS_DIR / record.experiment_id + save_dir.mkdir(parents=True, exist_ok=True) + artifact = str(save_dir / f"{args.subject}.{args.hemi}.flat.patch.3d") + flattener.save_result(uv, artifact) + record.artifacts = [{"path": artifact, "kind": "flat_patch"}] + + record.runtime_s = time.time() - t0 + record.metrics = { + **{f"kring_{k}": v for k, v in kring_m.items()}, + **true_m, + **banded, + "opt_runtime_s": runtime, + } + record.per_subject = [ + {"subject": args.subject, "hemi": args.hemi, **record.metrics} + ] + record.status = "ok" + record.decision["conclusion"] = ( + f"{label}: TRUE mean {true_m['true_mean_distortion']:.2f}% " + f"(p90 {true_m['true_p90_distortion']:.2f}%); " + f"k-ring {kring_m['mean_distortion']:.2f}%; " + f"flips {kring_m['n_flipped']}; opt {runtime:.0f}s." + ) + Ledger().append(record) + + print("\n=== metrics ===") + print(f" TRUE mean distortion: {true_m['true_mean_distortion']:.3f}%") + print(f" TRUE p90 distortion: {true_m['true_p90_distortion']:.3f}%") + print(f" TRUE median distortion: {true_m['true_median_distortion']:.3f}%") + print(f" k-ring mean distortion: {kring_m['mean_distortion']:.3f}%") + print(f" flipped triangles: {kring_m['n_flipped']}") + print(" TRUE by band: ", end="") + for lo, hi in ((0, 5), (5, 15), (15, 30)): + key = f"band_{lo}_{hi}_mean" + if key in banded: + print(f"{lo}-{hi}mm={banded[key]:.1f}% ", end="") + print() + print(f" opt runtime: {runtime:.0f}s") + if artifact: + print(f" saved uv: {artifact}") + print(f"\nLogged {record.experiment_id} -> {Ledger().path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/truedist.py b/benchmark/truedist.py new file mode 100644 index 0000000..5473789 --- /dev/null +++ b/benchmark/truedist.py @@ -0,0 +1,128 @@ +"""True geodesic distortion yardstick (method-agnostic quality metric). + +The k-ring energy metric is miscalibrated as an *absolute* number and is not comparable +across different ``k_ring`` values (its targets change with k). To compare maps produced +under different energies/neighborhoods we score against a fixed, energy-independent +**true geodesic** reference: libigl heat-method geodesics from a set of sampled sources, +restricted to local pairs (<= R mm), comparing the 2D flat distance to the true 3D +geodesic distance. + +The reference (sources + per-source geodesic fields) is precomputed once per (subject, +hemi) and cached under the k-ring cache dir as ``{subject}_{hemi}.truegeo.npz`` with keys +``srcs`` (M source indices into the *patch* vertex array), ``geo`` (M x V geodesics), and +``R`` (the local radius in mm). + +Usage +----- + from benchmark.truedist import load_truegeo, true_distortion + ref = load_truegeo("sub-022", "lh") + stats = true_distortion(uv, ref) # mean/p90 true distortion over local pairs +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np + +from . import paths + + +def truegeo_path(subject: str, hemi: str) -> Path: + return paths.KRING_CACHE_DIR / f"{subject}_{hemi}.truegeo.npz" + + +def load_truegeo(subject: str, hemi: str) -> dict[str, Any]: + """Load the cached true-geodesic reference for a (subject, hemi).""" + d = np.load(truegeo_path(subject, hemi)) + return {"srcs": d["srcs"], "geo": d["geo"], "R": float(d["R"])} + + +def compute_truegeo( + flattener: Any, + n_sources: int = 200, + radius: float = 30.0, + seed: int = 0, +) -> dict[str, Any]: + """Compute and cache heat-method geodesic fields from ``n_sources`` sampled sources. + + Sources are drawn deterministically (fixed seed) from the patch vertices. The geodesic + field is restricted (for scoring) to pairs within ``radius`` mm, but stored densely. + """ + import igl + + v = np.ascontiguousarray(flattener.vertices, dtype=np.float64) + f = np.ascontiguousarray(flattener.faces, dtype=np.int64) + n_v = v.shape[0] + + rng = np.random.default_rng(seed) + srcs = np.sort(rng.choice(n_v, size=min(n_sources, n_v), replace=False)) + + data = igl.HeatGeodesicsData() + igl.heat_geodesics_precompute(v, f, data) + geo = np.empty((srcs.shape[0], n_v), dtype=np.float64) + for i, s in enumerate(srcs): + geo[i] = igl.heat_geodesics_solve(data, np.array([s], dtype=np.int64)) + return {"srcs": srcs, "geo": geo, "R": float(radius)} + + +def true_distortion(uv: np.ndarray, ref: dict[str, Any]) -> dict[str, float]: + """Mean/p90 absolute relative distortion of 2D distances vs true geodesics. + + For each source ``s`` and target ``j`` with ``geo[s,j] <= R`` (and > 0), the per-pair + distortion is ``|d2d - d_geo| / d_geo`` where ``d2d`` is the Euclidean distance between + ``uv[s]`` and ``uv[j]``. Returns percentages. + """ + uv = np.asarray(uv, dtype=np.float64) + srcs = ref["srcs"] + geo = ref["geo"] + R = ref["R"] + + errs = [] + for i, s in enumerate(srcs): + d_geo = geo[i] + mask = (d_geo > 1e-6) & (d_geo <= R) + if not np.any(mask): + continue + d2d = np.linalg.norm(uv[mask] - uv[s], axis=1) + rel = np.abs(d2d - d_geo[mask]) / d_geo[mask] + errs.append(rel) + all_err = np.concatenate(errs) + return { + "true_mean_distortion": float(np.mean(all_err) * 100.0), + "true_p90_distortion": float(np.percentile(all_err, 90) * 100.0), + "true_median_distortion": float(np.median(all_err) * 100.0), + "n_pairs": int(all_err.size), + } + + +def true_distortion_banded( + uv: np.ndarray, ref: dict[str, Any], bands=((0, 5), (5, 15), (15, 30)) +) -> dict[str, float]: + """Mean true distortion broken down by geodesic-distance band (mm). + + Reveals *where* in the 0-R range the map distorts: a metric energy that over-fits the + very-local scale tends to show low error in the (0,5] band but rising error at medium + range. Returns ``band_{lo}_{hi}_mean`` percentages and pair counts. + """ + uv = np.asarray(uv, dtype=np.float64) + srcs = ref["srcs"] + geo = ref["geo"] + + band_errs = {b: [] for b in bands} + for i, s in enumerate(srcs): + d_geo = geo[i] + d2d_all = np.linalg.norm(uv - uv[s], axis=1) + for lo, hi in bands: + mask = (d_geo > max(lo, 1e-6)) & (d_geo <= hi) + if np.any(mask): + rel = np.abs(d2d_all[mask] - d_geo[mask]) / d_geo[mask] + band_errs[(lo, hi)].append(rel) + out: dict[str, float] = {} + for (lo, hi), chunks in band_errs.items(): + if chunks: + e = np.concatenate(chunks) + out[f"band_{lo}_{hi}_mean"] = float(np.mean(e) * 100.0) + out[f"band_{lo}_{hi}_n"] = int(e.size) + return out From c52aaa8599bb6ca42e68f1f4d515ef976767ad10 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 10:16:01 -0700 Subject: [PATCH 17/35] Findings: ground objective in Fischl 1999 (metric distortion, no area term) Read the original Fischl/Sereno/Dale 1999 energy: J = lambda_d J_d + lambda_a J_a, where J_a is gated to negative-area (folded) triangles only -> fold removal, NOT area preservation. Objective is purely metric (distance) distortion; area is a dependent byproduct. scale_to_area is a display convention, not part of the objective. Corrects section 9(d): the ~1.04 scale gain is real on true geodesics but is a symptom of the target-compaction bias, and the k-ring surrogate prefers the opposite scale; not a clean free win. Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 44 +++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index a8461d0..82af336 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -233,22 +233,34 @@ On the **global** metric (all geodesic pairs, no 30 mm cap) the full pipeline is **worst** — the correct ranking. The pipeline is actually better at long range (11.35%) than short range (17.93%); its only real weakness is a slight **global scale** bias. -**d) Validated win — distance-optimal output scale (≈ free).** The final `scale_to_area` normalizes the -map to match *total surface area*, but that is **not** the scale that minimizes geodesic distortion: the -too-compact targets leave the optimized map ~4% too small. Refitting a single global scale `s`: - -| | global true distortion | -|---|---| -| `s = 1.0` (area-matched, current) | 11.35% | -| `s* ≈ 1.04` (distance-optimal) | **10.88%** (−0.48pp, ~4% relative) | - -**Not metric-gaming:** fitting `s` on 100 sources and evaluating on the 100 held-out sources reproduces -it (11.31%→10.74% on the held-out half), so it is a real, systematic global property, consistent with -the directly-measured 11% target compaction. A single multiplicative rescale of the output `uv` costs -nothing and changes no shape — only the zoom. **Caveat:** validated on one hemisphere; before becoming a -default it should be confirmed across hemispheres/subjects (the bias direction is principled, but the -exact factor may vary). The clean way to ship it: after optimization, choose the output scale that -minimizes geodesic (or k-ring) distortion instead of matching total area. +**d) Output scale is not metric-optimal — but the picture is subtle (grounded in Fischl 1999).** +Fischl, Sereno & Dale (1999), §2.1–2.3: the flattening energy is `J = λ_d·J_d + λ_a·J_a` where +`J_d = (1/2V) Σ_i Σ_{n∈N(i)} (d_in^t − d_in^0)^2` is **metric (distance) distortion** — *the* objective — +and `J_a = (1/2T) Σ_i P(A_i)(A_i − A_i^0)^2` with `P(A_i)=1 iff A_i ≤ 0`. The area term is **gated to +folded (negative-area) triangles only**; valid triangles' area magnitude is unconstrained. So FreeSurfer +has **no area-preservation objective** — `J_a` only removes folds, and area is preserved only as a +*byproduct* of distance preservation (an isometry preserves both; distance is the stronger property). +The final `scale_to_area` (`s = √(orig_area/total_area)`) is therefore a **display convention**, not part +of the objective. + +Refitting a single global scale `s` on the optimized map exposes that the area-matched scale is not +distortion-optimal — and the two distance metrics disagree on direction: + +| objective | optimal scale | note | +|---|---|---| +| **true geodesic** (heat, the paper's *intent*) | `s ≈ 1.04` (expand) | area-matched map is ~4% too small vs truth; rescale cuts true distortion 11.35%→10.88% (−0.48pp), reproduced on held-out sources (11.31%→10.74%) | +| **k-ring surrogate `J_d`** (the energy actually minimized, `Dijkstra/1.207` targets) | `s ≈ 0.98` (shrink) | the ~11% too-compact targets pull the surrogate optimum the *wrong* way | + +So this is **not** a clean free win: the true-geodesic gain is real (it improves the paper's actual +objective) but it is a **symptom of the target-compaction bias (a)**, and the implemented surrogate energy +prefers the opposite scale. Shippable form: after optimization, choose the output scale that minimizes +distortion against a **true-geodesic sample**, rather than matching total area (and ideally fix the target +calibration upstream). Validated on one hemisphere; confirm multi-hemi before changing a default. + +**Area as an objective?** Per Fischl 1999 it deliberately is *not* one. Adding an area-preservation term +would be redundant in the isometric limit and would *fight* `J_d` in the real (non-isometric) regime, +trading distance fidelity for area fidelity. Only worth it if the scientific use is reading cortical +*area* off the map (an equiareal/authalic objective) — a different goal from Fischl's metric-distance one. ## Method note From c7488d80c8b45efe46f678807b868f1be94b5d45 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 10:49:43 -0700 Subject: [PATCH 18/35] Bench: global true metric + target-scale sweep; correction ~1.10 wins on global Adds true_distortion_full (local<=R, global all-pairs, distance-optimal scale) and a build_truegeo runner. Target-scale sweep on sub-022 lh scored on the GLOBAL metric: target_scale=1.10 (effective correction ~1.10, near the measured Dijkstra/geodesic ratio) cuts global true distortion 11.80->11.15% (-0.65pp, ~5.5% rel), flips unchanged. Reverses section 8's local-metric "dead end": on the faithful global metric, calibrating the correction toward true geodesics helps. Cross-hemisphere validation pending. Co-Authored-By: Claude Opus 4.8 --- benchmark/build_truegeo.py | 71 +++++++++++++++++++++++++++++++++++++ benchmark/probe_truedist.py | 22 ++++++++---- benchmark/truedist.py | 38 ++++++++++++++++++++ 3 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 benchmark/build_truegeo.py diff --git a/benchmark/build_truegeo.py b/benchmark/build_truegeo.py new file mode 100644 index 0000000..6454fee --- /dev/null +++ b/benchmark/build_truegeo.py @@ -0,0 +1,71 @@ +"""Compute and cache the true-geodesic reference for a (subject, hemi). + +Heat-method geodesic fields from a fixed set of sampled sources, used by +:mod:`benchmark.truedist` to score maps independently of the k-ring energy. Run once per +hemisphere you want to evaluate on; the result is cached under the k-ring cache dir as +``{subject}_{hemi}.truegeo.npz``. + +Usage +----- + python -m benchmark.build_truegeo --subject sub-026 --hemi lh + python -m benchmark.build_truegeo --subject sub-022 --hemi rh --n-sources 200 --radius 30 +""" + +from __future__ import annotations + +import argparse +import sys +import time + +import numpy as np + +from . import paths +from .harness import build_flattener, load_manifest +from .truedist import compute_truegeo, truegeo_path + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--subject", required=True) + ap.add_argument("--hemi", required=True, choices=["lh", "rh"]) + ap.add_argument("--n-sources", type=int, default=200) + ap.add_argument("--radius", type=float, default=30.0) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--force", action="store_true") + args = ap.parse_args() + + out = truegeo_path(args.subject, args.hemi) + if out.exists() and not args.force: + print(f"Already cached: {out}") + return 0 + + from autoflatten.flatten import FlattenConfig + + cfg = FlattenConfig() + cfg.verbose = False + manifest = load_manifest() + entry = [ + e + for e in manifest["entries"] + if e["subject"] == args.subject and e["hemi"] == args.hemi + ] + if not entry: + print(f"No manifest entry for {args.subject} {args.hemi}", file=sys.stderr) + return 1 + + paths.ensure_output_dirs() + t0 = time.time() + flattener = build_flattener(entry[0], cfg, use_cache=True) + ref = compute_truegeo( + flattener, n_sources=args.n_sources, radius=args.radius, seed=args.seed + ) + paths.KRING_CACHE_DIR.mkdir(parents=True, exist_ok=True) + np.savez(out, srcs=ref["srcs"], geo=ref["geo"], R=ref["R"]) + print( + f"Wrote {out} ({ref['srcs'].size} sources, R={ref['R']}mm) in {time.time() - t0:.0f}s" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/probe_truedist.py b/benchmark/probe_truedist.py index 4b53bf7..18cd913 100644 --- a/benchmark/probe_truedist.py +++ b/benchmark/probe_truedist.py @@ -30,7 +30,12 @@ from .ledger import Ledger, new_record from .metrics import per_patch_metrics from .probe_tutte_init import make_flatten_fn -from .truedist import load_truegeo, true_distortion, true_distortion_banded +from .truedist import ( + load_truegeo, + true_distortion, + true_distortion_banded, + true_distortion_full, +) def main() -> int: @@ -128,6 +133,7 @@ def main() -> int: kring_m = per_patch_metrics(uv, flattener) true_m = true_distortion(uv, ref) banded = true_distortion_banded(uv, ref) + full = true_distortion_full(uv, ref) artifact = None if args.save_uv: @@ -142,6 +148,7 @@ def main() -> int: **{f"kring_{k}": v for k, v in kring_m.items()}, **true_m, **banded, + **full, "opt_runtime_s": runtime, } record.per_subject = [ @@ -149,17 +156,20 @@ def main() -> int: ] record.status = "ok" record.decision["conclusion"] = ( - f"{label}: TRUE mean {true_m['true_mean_distortion']:.2f}% " - f"(p90 {true_m['true_p90_distortion']:.2f}%); " + f"{label}: GLOBAL {full['true_global_mean']:.2f}% " + f"(s*={full['opt_scale']:.3f}->{full['true_global_at_optscale']:.2f}%); " + f"local<=30 {full['true_local_mean']:.2f}%; " f"k-ring {kring_m['mean_distortion']:.2f}%; " f"flips {kring_m['n_flipped']}; opt {runtime:.0f}s." ) Ledger().append(record) print("\n=== metrics ===") - print(f" TRUE mean distortion: {true_m['true_mean_distortion']:.3f}%") - print(f" TRUE p90 distortion: {true_m['true_p90_distortion']:.3f}%") - print(f" TRUE median distortion: {true_m['true_median_distortion']:.3f}%") + print(f" GLOBAL true mean: {full['true_global_mean']:.3f}%") + print( + f" GLOBAL @opt scale: {full['true_global_at_optscale']:.3f}% (s*={full['opt_scale']:.3f})" + ) + print(f" local <=30 mean: {full['true_local_mean']:.3f}%") print(f" k-ring mean distortion: {kring_m['mean_distortion']:.3f}%") print(f" flipped triangles: {kring_m['n_flipped']}") print(" TRUE by band: ", end="") diff --git a/benchmark/truedist.py b/benchmark/truedist.py index 5473789..4a66376 100644 --- a/benchmark/truedist.py +++ b/benchmark/truedist.py @@ -97,6 +97,44 @@ def true_distortion(uv: np.ndarray, ref: dict[str, Any]) -> dict[str, float]: } +def true_distortion_full(uv: np.ndarray, ref: dict[str, Any]) -> dict[str, float]: + """Local (<=R), global (all pairs), and global-at-distance-optimal-scale distortion. + + The local <=R metric is gameable (a conformal disk scores well locally while globally + catastrophic), so the *global* all-pairs metric is the faithful objective. Also reports + the single global scale ``s*`` that minimizes global distortion (the area-matched output + is generally not metric-optimal) and the distortion at ``s*``. + """ + uv = np.asarray(uv, dtype=np.float64) + srcs = ref["srcs"] + geo = ref["geo"] + R = ref["R"] + + d2_all, dg_all = [], [] + for i, s in enumerate(srcs): + dg = geo[i] + m = dg > 1e-6 + d2_all.append(np.linalg.norm(uv[m] - uv[s], axis=1)) + dg_all.append(dg[m]) + d2 = np.concatenate(d2_all) + dg = np.concatenate(dg_all) + + rel = np.abs(d2 - dg) / dg + loc = dg <= R + # optimal global scale s* minimizing mean |s*d2 - dg|/dg over a fine grid + scales = np.linspace(0.90, 1.15, 51) + errs = np.array([np.mean(np.abs(sc * d2 - dg) / dg) for sc in scales]) + j = int(np.argmin(errs)) + return { + "true_local_mean": float(np.mean(rel[loc]) * 100.0), + "true_global_mean": float(np.mean(rel) * 100.0), + "true_global_p90": float(np.percentile(rel, 90) * 100.0), + "opt_scale": float(scales[j]), + "true_global_at_optscale": float(errs[j] * 100.0), + "n_pairs_global": int(rel.size), + } + + def true_distortion_banded( uv: np.ndarray, ref: dict[str, Any], bands=((0, 5), (5, 15), (15, 30)) ) -> dict[str, float]: From 9f15f0f2ba98fca42fefe44f3ac090144915d3f7 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 11:25:51 -0700 Subject: [PATCH 19/35] Findings: target correction re-judged on global metric (cross-hemi sweep) Section 10: re-running the correction sweep scored on the GLOBAL true metric (not the gameable local one of section 8) reverses section 8's "dead end" -- larger targets (eff. correction ~1.10, toward the measured Dijkstra/geodesic ratio ~1.08) reduce global true distortion -- but the win is subject-dependent: clear on sub-022 (both hemis, -0.1 to -0.5pp), neutral/slightly-worse on sub-026 lh. Flips stay comparable. Robust component is the output-scale fix (s*~1.02 everywhere). A larger robust gain needs long-range geodesic anchors in the energy (next step, pending scope decision). Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 82af336..5d388f6 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -262,6 +262,41 @@ would be redundant in the isometric limit and would *fight* `J_d` in the real (n trading distance fidelity for area fidelity. Only worth it if the scientific use is reading cortical *area* off the map (an equiareal/authalic objective) — a different goal from Fischl's metric-distance one. +## 10. Target correction, re-judged on the GLOBAL metric (reverses §8, but only partly) + +§8 swept the graph-distance correction and concluded "1.207 is near-optimal, recalibration is a dead +end" — but that was scored on the **local ≤30 mm** metric, which §9(c) showed is gameable. Re-running +the sweep scored on the faithful **global** metric (`benchmark/probe_truedist.py --target-scale`, which +rescales the cached targets, equivalent to changing the correction without a cache rebuild; eff. +correction = 1.207 / target_scale) flips the conclusion — *larger* targets (smaller correction, toward +the directly-measured Dijkstra/geodesic ratio ~1.08) **reduce** global true distortion, but the size of +the win is **subject-dependent**. + +Global true distortion at each map's distance-optimal scale (removes the global-zoom confound): + +| target_scale (eff. correction) | sub-022 lh | sub-022 rh | sub-026 lh | mean | +|---|---|---|---|---| +| 1.00 (1.207, default) | 11.65% | 11.30% | 11.78% | 11.58% | +| 1.05 (1.15) | 11.34% | 11.20% | 11.84% | 11.46% | +| 1.10 (1.10) | **11.13%** | 11.20% | 11.82% | 11.38% | + +- On **sub-022** (both hemispheres) calibrating toward ~1.10 helps clearly (−0.1 to −0.5pp; lh is + monotonic out to 1.10). At ts=1.10 the map beats even the *best post-hoc rescale of the baseline* + (11.13 vs 11.65 on lh), so calibrated targets reshape the map, not just rezoom it. +- On **sub-026 lh** it is neutral-to-slightly-worse (+0.04–0.06pp). The optimal correction is + subject-dependent, and ts=1.10 *overshoots* (its post-hoc s* drops below 1.0). +- Flips stay comparable throughout (24–40), so calibration is not trading validity for distance. + +**Takeaways.** (i) §8's "dead end" was a metric artifact — on the right (global) objective the correction +*does* matter and the FreeSurfer 1.207 is **slightly too large** (too-compact targets). (ii) But the gain +is small (~0.2pp mean) and not universal, so the shippable change is modest: nudge the effective +correction toward ~1.10–1.15 (the measured ratio), expect a small global-distortion reduction on most +subjects, and pair it with a distance-optimal output scale (§9d) rather than `scale_to_area`. (iii) The +consistent, always-safe component is the output-scale fix (`s*≈1.02` on all three hemispheres). A larger, +robust reduction would need **long-range geodesic anchors** in the energy (the paper's own point that +long-range distances are required to unfold) — a real energy change with weight-tuning, not a config +knob; left as the next step pending a decision on scope. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From 122db8dbefdb67b3543530b43c84feb98545e271 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 12:22:00 -0700 Subject: [PATCH 20/35] Findings: broadened correction validation (8 hemis/7 subjects) - not a robust default Section 11: extended the correction sweep to 8 hemispheres across 7 subjects. Two robust negatives: (1) the local Dijkstra/geodesic ratio is nearly subject-invariant (~1.08) and does NOT predict the optimal correction, so per-subject auto-calibration won't work; (2) no correction change is universally safe -- 5/8 hemis want ~1.10, 3 prefer the original 1.207, and a milder 1.05 doesn't rescue the non-helpers (sub-075 is worse). Mean gain only -0.19pp with per-subject downside up to +0.35pp. Config-lever avenue exhausted; the only consistently safe op is the distance-optimal output scale (small); a larger robust gain needs long-range geodesic anchors in the energy. Adds 5 validation subjects (lh) to the manifest. Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 5d388f6..e1a8765 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -297,6 +297,43 @@ robust reduction would need **long-range geodesic anchors** in the energy (the p long-range distances are required to unfold) — a real energy change with weight-tuning, not a config knob; left as the next step pending a decision on scope. +## 11. Broadened validation: the correction optimum is subject-specific and unpredictable + +Per the §10 caveat, the correction sweep was broadened to **8 hemispheres across 7 subjects** (added +sub-041/052/059/066/075 lh to the benchmark, each with its own heat-geodesic reference; scored on global +true distortion at each map's own optimal scale). + +| subject | Dijkstra/geo ratio | ts=1.0 (1.207) | ts=1.05 (1.15) | ts=1.10 (1.10) | Δ(1.0→1.10) | +|---|---|---|---|---|---| +| sub-041 lh | 1.087 | 12.33% | — | 11.78% | **−0.55** | +| sub-022 lh | 1.078 | 11.65% | 11.34% | 11.13% | **−0.52** | +| sub-052 lh | 1.089 | 12.02% | — | 11.69% | **−0.33** | +| sub-059 lh | 1.093 | 12.50% | — | 12.21% | **−0.29** | +| sub-022 rh | 1.082 | 11.30% | — | 11.20% | −0.10 | +| sub-026 lh | 1.084 | 11.78% | — | 11.82% | +0.04 | +| sub-066 lh | 1.085 | 11.59% | 11.64% | 11.63% | +0.05 | +| sub-075 lh | 1.076 | 12.28% | 12.64% | 12.50% | +0.22 | + +**Mean Δ = −0.19pp; 5/8 helped, 2 neutral, 1 (sub-075) hurt.** Two robust negatives emerged: + +1. **The local Dijkstra/geodesic ratio is nearly subject-invariant (1.076–1.093)** — so all subjects' + targets are ~8% too compact in the *same* way. It therefore **cannot predict** the optimum: sub-075 + has the *lowest* ratio yet is *hurt* by a smaller correction; sub-059 has the *highest* and is + *helped*. Per-subject **auto-calibration to the measured ratio would not work** (it would apply ~the + same correction to everyone). +2. **No correction change is universally safe.** A milder ts=1.05 does not rescue the non-helpers — on + sub-075 it is *worse* than ts=1.10 (12.64 vs 12.50 vs 12.28 baseline); sub-066 also degrades at any + increase. 5 subjects want ~1.10, 3 genuinely prefer the original 1.207, and nothing local separates + the groups (the difference is global geometry / optimization dynamics). + +**Conclusion.** Changing the default correction is **not a robust win**: ~−0.2pp on average but with +real per-subject downside (up to +0.35pp). FreeSurfer's 1.207 is defensible as a robust compromise even +though it is slightly too large for the average subject on the global metric. The config-lever avenue is +now **exhausted**: the only consistently non-harmful operation is the distance-optimal **output scale** +(a 1-parameter minimization, ≥0 by construction, but small — `s*` ranges 0.995–1.020 across subjects), and +the only remaining path to a *larger, robust* reduction is **long-range geodesic anchors in the energy** +(a real energy change, not a config knob). + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From d7176538b66d199450c361f17725545c6ca96ccc Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 12:24:28 -0700 Subject: [PATCH 21/35] Findings: executive summary + refreshed next-ideas (autoresearch wrap-up) Adds a paper-facing executive summary and updates the next-ideas list to reflect the exhausted config-lever search: long-range geodesic anchors remain the one open lever for a larger robust distance-error reduction; distance-optimal output scale is the small safe ship; speed defaults (Tutte init, lean levers) are validated. Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index e1a8765..ed2ec7a 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -4,6 +4,28 @@ Curated conclusions from the autoresearch loop. The full, append-only record (ev provenance) is the ledger at `/data2/projects/autoflatten/ledger/experiments.jsonl`, rendered to `NOTEBOOK.md`. All numbers below are CPU-only on the public Narratives benchmark. +## Executive summary + +- **Speed (validated, shippable):** a Tutte flip-free init removes the ~4-min initial NAR (§1), and + stacked config levers (lean line search, sparser k-ring, fewer iters, capped smoothing) give a + combined **~3.6× speedup at near-baseline quality** (§7). This is the main practical win. +- **Quality / distance error:** the FreeSurfer-style multiscale line-search optimizer is **hard to + beat** — optimizer swaps fold (§4), Adam can't span the multiscale step range (§6), spectral + multigrid is elegant but not faster (§5). +- **Metric caveat (important for the paper):** the k-ring energy metric is miscalibrated and a purely + **local** distance metric is *gameable* — a conformal Tutte disk wins it while being a degenerate + flatmap (§9c). Score distance distortion **globally** (all-pairs geodesics), not locally. +- **Objective (grounded in Fischl 1999, but the code diverges from the paper):** the goal is **metric + (distance) distortion**; area is a *dependent byproduct*, not a separate objective (§9, §10). The + implemented area term is a pure fold barrier. +- **Where the implementation drifts from that objective:** the `Dijkstra/1.207` target correction is + slightly too compact and `scale_to_area` is not distance-optimal (§9d) — but recalibrating the + correction is **not a robust default** (subject-specific optimum, unpredictable from local geometry; + §10–11). The only consistently-safe tweak is a distance-optimal output scale (small). +- **Net:** the pipeline is well-tuned; config-lever gains on distance error are small and non-robust. + The one remaining path to a *larger robust* reduction is **long-range geodesic anchors in the energy** + (the paper's own argument) — a real energy change, left as future work. + ## 1. Flip-free (Tutte) init is a validated win: equal quality, ~37% faster Replacing the FreeSurfer-style normal-projection + **initial** negative-area-removal (NAR) with a @@ -339,9 +361,16 @@ the only remaining path to a *larger, robust* reduction is **long-range geodesic Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. Compute is CPU-only (the box's GPUs are blocked by driver 440 / CUDA 10.2). -## Next ideas (untested) +## Next ideas + +Remaining, in priority order (config levers are exhausted — see §10–11): -- Reduce `k_ring` (7 → 5): attacks the dominant cost (k-ring geodesic computation, ~4 min + 237 MB - cache/hemi). Changes the cache key, so needs recompute. -- Fewer epoch iterations from the better Tutte start. -- Combine: Tutte init as the new default `initial_projection`, initial NAR off by default. +- **Long-range geodesic anchors in the energy** (the genuine open lever). Add a sparse set of + true-geodesic long-range distance constraints, up-weighted vs the local k-ring, to directly constrain + the global metric (Fischl 1999's own point that long-range distances are needed to unfold). Requires + modifying the energy/optimizer (not a config knob) and weight-tuning; validate with train/test-split + geodesic sources across subjects. Uncertain but the only path to a larger robust reduction. +- **Distance-optimal output scale** (small, safe): replace `scale_to_area` with the scale that minimizes + distance distortion (a 1-parameter minimization, ≥0 by construction; `s*`≈0.995–1.02). Cheap to ship. +- **Ship the validated speed defaults**: Tutte init as default `initial_projection` with initial NAR off + (§1), plus the §7 lean line-search / sparse k-ring levers. From 58457de0edf94875183bcb3833d4db7bd2d6f010 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 15:34:01 -0700 Subject: [PATCH 22/35] Validate fast config on 9 hemispheres: not overfit, ~3.4x at equal-or-better quality Add benchmark/validate_speed.py and FINDINGS section 12. The speed levers were tuned on a single hemisphere (sub-022 lh); this re-validates the stacked fast_ultimate config across all 9 manifest hemispheres, grouped tuned/seen/ held-out, scoring quality with the energy-independent global true-geodesic metric (not the incomparable raw k-ring number). Result: 3.36x mean speedup [3.12-3.60], held-out (3.34x) matches tuned (3.42x) => not overfit. Corrects section 7's "+0.37pp worse" (a k-ring metric artifact): on the faithful metric the fast config is -0.73pp BETTER on every scored hemisphere. 10 ledger records logged. Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 50 ++++++- benchmark/validate_speed.py | 280 ++++++++++++++++++++++++++++++++++++ 2 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 benchmark/validate_speed.py diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index ed2ec7a..66a4925 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -8,7 +8,10 @@ provenance) is the ledger at `/data2/projects/autoflatten/ledger/experiments.jso - **Speed (validated, shippable):** a Tutte flip-free init removes the ~4-min initial NAR (§1), and stacked config levers (lean line search, sparser k-ring, fewer iters, capped smoothing) give a - combined **~3.6× speedup at near-baseline quality** (§7). This is the main practical win. + combined **~3.4× speedup** that is **not overfit** — confirmed across **9 hemispheres / 7 subjects**, + with the 5 held-out subjects matching the tuned one (§7, §12). On the *faithful* (global true-geodesic) + metric the fast config is in fact **slightly better** than baseline (−0.73pp), not worse; the earlier + "+0.37pp" was an artifact of the miscalibrated k-ring metric (§12). This is the main practical win. - **Quality / distance error:** the FreeSurfer-style multiscale line-search optimizer is **hard to beat** — optimizer swaps fold (§4), Adam can't span the multiscale step range (§6), spectral multigrid is elegant but not faster (§5). @@ -181,6 +184,8 @@ smoothing cap 256), across **4 hemispheres / 2 subjects**: | **fast_ultimate (n=4)** | 15.10% | 237 | **194 s (−72%, 3.6×)** | +0.37pp distortion, flips still 0.06% of faces, and visually clean maps on both subjects. +(**Update — see §12:** that +0.37pp is on the *miscalibrated* k-ring metric; on the faithful global +true-geodesic metric the fast config is actually −0.73pp *better*, validated on 9 hemispheres.) **Takeaway:** the practical, low-risk path to a faster pipeline is not a new optimizer but (1) Tutte init, (2) a leaner line search, (3) a sparser k-ring, (4) fewer iters/level — each a small @@ -356,6 +361,49 @@ now **exhausted**: the only consistently non-harmful operation is the distance-o the only remaining path to a *larger, robust* reduction is **long-range geodesic anchors in the energy** (a real energy change, not a config knob). +## 12. Full-benchmark validation of the fast config — not overfit (corrects §7's quality claim) + +The §7 speed levers were tuned on a **single hemisphere** (`sub-022 lh`) and the stacked +`fast_ultimate` config had only been confirmed on 4 hemispheres / 2 subjects — a real overfit risk. +Re-validated across **all 9 manifest hemispheres** (`benchmark/validate_speed.py`), running baseline +and fast back-to-back per hemisphere (so the speedup ratio is internally valid), and grouping by +whether the hemisphere was used to tune the levers: **tuned** (`sub-022 lh`), **seen** (in the n=4 +stack), **held-out** (`sub-041/052/059/066/075 lh`, never used to tune anything). + +Quality is scored with the **energy-independent global true-geodesic** metric at each map's +distance-optimal scale — *not* raw k-ring distortion, which is incomparable across n12 vs n6 (the §8 +trap). `dq` = fast − baseline (negative = fast better). + +| subject | group | base s | fast s | speedup | base q% | fast q% | dq (pp) | flips b/f | +|---|---|---|---|---|---|---|---|---| +| sub-022 lh | tuned | 641 | 187 | 3.42 | 11.76 | 10.88 | −0.89 | 25/33 | +| sub-022 rh | seen | 667 | 185 | 3.60 | 11.46 | 10.94 | −0.52 | 37/129 | +| sub-026 lh | seen | 626 | 195 | 3.21 | 11.96 | 11.20 | −0.76 | 36/41 | +| sub-026 rh | seen | 696 | 211 | 3.31 | — | — | — | 23/34 | +| sub-041 lh | held-out | 758 | 221 | 3.43 | 12.25 | 11.37 | −0.88 | 81/120 | +| sub-052 lh | held-out | 561 | 180 | 3.12 | 12.22 | 11.20 | −1.02 | 39/44 | +| sub-059 lh | held-out | 759 | 222 | 3.42 | 12.60 | 11.77 | −0.83 | 353/88 | +| sub-066 lh | held-out | 699 | 208 | 3.37 | 11.71 | 11.27 | −0.44 | 54/55 | +| sub-075 lh | held-out | 615 | 183 | 3.35 | 12.45 | 11.97 | −0.48 | 60/65 | + +**By group:** tuned 3.42× (dq −0.89); seen 3.37× [3.21–3.60] (dq −0.64); **held-out 3.34× [3.12–3.43] +(dq −0.73)**. **Overall 9 hemis: 3.36× [3.12–3.60], dq −0.73pp.** + +**Takeaways.** +1. **Not overfit.** The held-out group matches the tuned hemisphere on *both* axes — speedup 3.34× + vs 3.42×, quality −0.73 vs −0.89pp. The win generalizes across subjects. +2. **§7's quality claim was a metric artifact, now corrected.** §7 reported fast as **+0.37pp worse**, + but that was the raw k-ring metric comparing n6 targets to n12 targets (incomparable; §8). On the + faithful global true-geodesic metric the fast config is **−0.73pp better** on every scored + hemisphere (8/8). The sparser k-ring/leaner schedule does not cost quality — if anything the Tutte + init + distance-optimal scoring helps. The honest headline is **~3.4× faster at equal-or-better + distance quality**. +3. **Flips stay negligible.** Fast has slightly more flipped triangles on most hemispheres (max 129 on + `sub-022 rh`) but fewer on `sub-059` (353→88); all are <0.1% of faces either way — visually clean. + +This is the strongest single result for the paper: a 3.4× speedup with no quality cost, validated on +held-out subjects. Logged as 10 ledger records (`exp:validate_speed:*`). + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. diff --git a/benchmark/validate_speed.py b/benchmark/validate_speed.py new file mode 100644 index 0000000..2336e02 --- /dev/null +++ b/benchmark/validate_speed.py @@ -0,0 +1,280 @@ +"""Validate the stacked "fast" config across the full benchmark (overfit check). + +The §7 speed levers were tuned on a *single* hemisphere (``sub-022 lh``) and the stacked +``fast_ultimate`` config was only confirmed on 4 hemispheres / 2 subjects. This runner +re-validates it across **all manifest hemispheres**, including the 5 held-out validation +subjects (sub-041/052/059/066/075) that were never used to tune any speed lever, so we can +see whether the ~3.6x speedup and near-baseline quality generalize or were overfit. + +Key methodological point: the raw k-ring ``mean_distortion`` is **not comparable** across +``n_neighbors`` (baseline n12 vs fast n6) -- that is the §8 miscalibration trap. So quality +is scored with the energy-independent **true-geodesic global** metric (at each map's own +distance-optimal scale) wherever a ``{subject}_{hemi}.truegeo.npz`` reference exists. Runtime +and flips are directly comparable. + +Each hemisphere logs its own ledger record (crash-survivable) with both configs' metrics and +the per-hemisphere speedup; a final summary record carries the aggregate, split by +tuned / seen / held-out. + +Usage +----- + python -m benchmark.validate_speed + python -m benchmark.validate_speed --only sub-041 sub-052 +""" + +from __future__ import annotations + +import argparse +import sys +import time + +import numpy as np + +from . import paths +from .harness import build_flattener, load_manifest +from .ledger import Ledger, new_record +from .metrics import per_patch_metrics +from .probe_tutte_init import make_flatten_fn +from .truedist import load_truegeo, true_distortion_full, truegeo_path + +# Hemispheres used (directly or in the n=4 stack) to TUNE the speed levers. +TUNED = {("sub-022", "lh")} +SEEN = {("sub-022", "rh"), ("sub-026", "lh"), ("sub-026", "rh")} +# everything else in the manifest = held-out for the speed claim + + +def fast_config(): + """The validated stacked 'fast_ultimate' config (Tutte init + 4 stacked levers).""" + from autoflatten.flatten import FlattenConfig + + cfg = FlattenConfig() + cfg.verbose = False + cfg.kring.k_ring = 7 + cfg.kring.n_neighbors_per_ring = 6 # sparser k-ring + cfg.negative_area_removal.enabled = False # Tutte init makes initial NAR moot + cfg.line_search.n_coarse_steps = 7 # leaner line search (15 -> 7) + for phase in cfg.phases: + phase.iters_per_level = 25 # fewer iters/level (40 -> 25) + phase.smoothing_schedule = [n for n in phase.smoothing_schedule if n <= 256] + return cfg + + +def baseline_config(): + """Current shipped defaults (FreeSurfer-clone: projection init, full refinement).""" + from autoflatten.flatten import FlattenConfig + + cfg = FlattenConfig() + cfg.verbose = False + return cfg + + +def _score(uv, flattener, subject, hemi): + """k-ring metrics + true-geodesic global distortion (if a reference exists).""" + m = per_patch_metrics(uv, flattener) + out = { + "kring_mean_distortion": m["mean_distortion"], + "n_flipped": m["n_flipped"], + "frac_flipped": m["frac_flipped"], + } + if truegeo_path(subject, hemi).exists(): + ref = load_truegeo(subject, hemi) + full = true_distortion_full(uv, ref) + out["true_global_mean"] = full["true_global_mean"] + out["true_global_at_optscale"] = full["true_global_at_optscale"] + out["opt_scale"] = full["opt_scale"] + return out + + +def run_hemi(entry): + subject, hemi = entry["subject"], entry["hemi"] + group = ( + "tuned" + if (subject, hemi) in TUNED + else ("seen" if (subject, hemi) in SEEN else "held-out") + ) + + # --- baseline (projection init, default config) --- + fl = build_flattener(entry, baseline_config(), use_cache=True) + t0 = time.time() + uv_base = np.asarray(fl.run()) + base_rt = time.time() - t0 + base = _score(uv_base, fl, subject, hemi) + base["runtime_s"] = base_rt + + # --- fast (Tutte init + stacked levers) --- + fl = build_flattener(entry, fast_config(), use_cache=True) + fast_fn = make_flatten_fn("tutte", refine=True) + t0 = time.time() + uv_fast = np.asarray(fast_fn(fl)) + fast_rt = time.time() - t0 + fast = _score(uv_fast, fl, subject, hemi) + fast["runtime_s"] = fast_rt + + speedup = base_rt / fast_rt if fast_rt > 0 else float("nan") + dq = None + if "true_global_at_optscale" in base and "true_global_at_optscale" in fast: + dq = fast["true_global_at_optscale"] - base["true_global_at_optscale"] + + return { + "subject": subject, + "hemi": hemi, + "group": group, + "speedup": speedup, + "true_global_delta": dq, + "baseline": base, + "fast": fast, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "--only", nargs="*", default=None, help="restrict to these subjects" + ) + args = ap.parse_args() + + paths.ensure_output_dirs() + manifest = load_manifest() + entries = manifest["entries"] + if args.only: + entries = [e for e in entries if e["subject"] in set(args.only)] + if not entries: + print("No manifest entries selected.", file=sys.stderr) + return 1 + + ledger = Ledger() + results = [] + print(f"Validating fast config on {len(entries)} hemispheres...\n") + print( + f"{'subject':10} {'hemi':4} {'group':8} {'base_s':>8} {'fast_s':>8} " + f"{'x':>5} {'base_q%':>8} {'fast_q%':>8} {'dq':>6} {'flips b/f':>12}" + ) + for entry in entries: + r = run_hemi(entry) + results.append(r) + b, f = r["baseline"], r["fast"] + bq = b.get("true_global_at_optscale") + fq = f.get("true_global_at_optscale") + print( + f"{r['subject']:10} {r['hemi']:4} {r['group']:8} " + f"{b['runtime_s']:8.0f} {f['runtime_s']:8.0f} {r['speedup']:5.2f} " + f"{(bq if bq is not None else float('nan')):8.2f} " + f"{(fq if fq is not None else float('nan')):8.2f} " + f"{(r['true_global_delta'] if r['true_global_delta'] is not None else float('nan')):6.2f} " + f"{b['n_flipped']:5d}/{f['n_flipped']:<5d}" + ) + + rec = new_record( + kind="experiment", + label=f"exp:validate_speed:{r['subject']}.{r['hemi']}", + manifest_id=manifest.get("created"), + subjects=[ + { + "subject": r["subject"], + "hemi": r["hemi"], + "split": entry.get("split"), + } + ], + method={ + "name": "fast_ultimate_vs_baseline", + "group": r["group"], + "fast_config": fast_config().to_dict(), + }, + seeds={"note": "deterministic CPU gradient descent"}, + repro_command="python -m benchmark.validate_speed --only " + r["subject"], + ) + rec.metrics = { + "speedup": r["speedup"], + "true_global_delta": r["true_global_delta"], + **{f"baseline_{k}": v for k, v in b.items()}, + **{f"fast_{k}": v for k, v in f.items()}, + } + rec.per_subject = [{"subject": r["subject"], "hemi": r["hemi"], **rec.metrics}] + rec.status = "ok" + rec.decision["hypothesis"] = ( + "fast_ultimate's ~3.6x speedup + near-baseline quality generalize beyond the " + "tuned sub-022 lh to held-out subjects (overfit check)." + ) + rec.decision["conclusion"] = ( + f"{r['group']}: {r['speedup']:.2f}x faster " + f"({b['runtime_s']:.0f}s->{f['runtime_s']:.0f}s); " + + ( + f"true-global {bq:.2f}%->{fq:.2f}% ({r['true_global_delta']:+.2f}pp); " + if dq_ok(r) + else "no truegeo ref; " + ) + + f"flips {b['n_flipped']}->{f['n_flipped']}." + ) + ledger.append(rec) + + # --- aggregate summary, split by group --- + print("\n=== summary by group ===") + summary = {} + for grp in ("tuned", "seen", "held-out"): + g = [r for r in results if r["group"] == grp] + if not g: + continue + sp = np.array([r["speedup"] for r in g]) + dq = np.array( + [r["true_global_delta"] for r in g if r["true_global_delta"] is not None] + ) + summary[grp] = { + "n": len(g), + "speedup_mean": float(np.mean(sp)), + "speedup_min": float(np.min(sp)), + "speedup_max": float(np.max(sp)), + "quality_delta_mean_pp": float(np.mean(dq)) if dq.size else None, + "quality_delta_max_pp": float(np.max(dq)) if dq.size else None, + } + qd = summary[grp] + print( + f" {grp:9} n={qd['n']} speedup {qd['speedup_mean']:.2f}x " + f"[{qd['speedup_min']:.2f}-{qd['speedup_max']:.2f}] " + f"quality dq {'%+.2f' % qd['quality_delta_mean_pp'] if qd['quality_delta_mean_pp'] is not None else 'n/a'}" + f"{' (max %+.2f)' % qd['quality_delta_max_pp'] if qd['quality_delta_max_pp'] is not None else ''} pp" + ) + + allsp = np.array([r["speedup"] for r in results]) + alldq = np.array( + [r["true_global_delta"] for r in results if r["true_global_delta"] is not None] + ) + print( + f"\n OVERALL n={len(results)} speedup {np.mean(allsp):.2f}x " + f"[{np.min(allsp):.2f}-{np.max(allsp):.2f}] " + f"quality dq {np.mean(alldq):+.2f} pp (max {np.max(alldq):+.2f})" + ) + + rec = new_record( + kind="experiment", + label="exp:validate_speed:summary", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": r["subject"], "hemi": r["hemi"], "split": None} for r in results + ], + method={"name": "fast_ultimate_vs_baseline_summary"}, + repro_command="python -m benchmark.validate_speed", + ) + rec.metrics = { + "n_hemispheres": len(results), + "overall_speedup_mean": float(np.mean(allsp)), + "overall_quality_delta_mean_pp": float(np.mean(alldq)) if alldq.size else None, + "by_group": summary, + } + rec.status = "ok" + rec.decision["conclusion"] = ( + f"fast_ultimate over {len(results)} hemis: {np.mean(allsp):.2f}x mean speedup, " + f"true-global quality {np.mean(alldq):+.2f}pp; held-out matches tuned => not overfit." + if alldq.size + else f"fast_ultimate over {len(results)} hemis: {np.mean(allsp):.2f}x mean speedup." + ) + ledger.append(rec) + print(f"\nLogged {len(results) + 1} records -> {ledger.path}") + return 0 + + +def dq_ok(r): + return r["true_global_delta"] is not None + + +if __name__ == "__main__": + raise SystemExit(main()) From b0320177c8de97808abe4a099125308f6347474c Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 21:54:03 -0700 Subject: [PATCH 23/35] Parallelize angular k-ring computation with a Numba prange kernel (~26x) The per-vertex angular-sampling loop in compute_kring_geodesic_distances_angular was serial Python over all vertices (~4 min/hemi, the one-time cache-build cost), while 31 of 32 cores sat idle. Fold tangent-plane projection + per-ring angular sampling + limited Dijkstra into a single @njit(parallel=True) prange kernel (_angular_kring_kernel) plus an njit port of the angular sampler (_select_angular_samples_njit). Output is bit-identical to the previous numba production path: verified against the existing sub-022 lh cache (0/193174 neighbor-set mismatches, 0 distance mismatches), and pinned by a new regression test (TestAngularKernelParity). Steady-state runtime for sub-022 lh drops from ~230s to 8.7s (~26x); ~17s including cold JIT. The serial loop is kept as the use_numba=False fallback. Co-Authored-By: Claude Opus 4.8 --- autoflatten/flatten/distance.py | 255 ++++++++++++++++++++++++------ autoflatten/tests/test_flatten.py | 92 +++++++++++ 2 files changed, 295 insertions(+), 52 deletions(-) diff --git a/autoflatten/flatten/distance.py b/autoflatten/flatten/distance.py index c4150e3..7135780 100644 --- a/autoflatten/flatten/distance.py +++ b/autoflatten/flatten/distance.py @@ -799,6 +799,141 @@ def get_num_threads(): return numba.get_num_threads() +@njit(cache=True) +def _select_angular_samples_njit(angles, n_samples): + """Numba port of :func:`select_angular_samples` (bit-identical selection). + + Returns indices into ``angles`` (length ``<= n_samples``) chosen one per angular + sector, closest-to-center, deduplicated, with the same ``< sector_width`` gate and the + same first-min (``argmin``) tie-breaking as the NumPy version. + """ + m = angles.shape[0] + if m == 0: + return np.empty(0, dtype=np.int64) + if m <= n_samples: + out = np.empty(m, dtype=np.int64) + for i in range(m): + out[i] = i + return out + + two_pi = 2.0 * np.pi + sector_width = two_pi / n_samples + amod = np.empty(m, dtype=np.float64) + for i in range(m): + amod[i] = angles[i] % two_pi + + sel = np.empty(n_samples, dtype=np.int64) + nsel = 0 + for c in range(n_samples): + center = c * sector_width + best_idx = 0 + best_val = np.inf + for i in range(m): + d = abs(amod[i] - center) + d2 = two_pi - d + if d2 < d: + d = d2 + if d < best_val: # strict '<' => first-min, matches np.argmin + best_val = d + best_idx = i + if best_val < sector_width: + found = False + for j in range(nsel): + if sel[j] == best_idx: + found = True + break + if not found: + sel[nsel] = best_idx + nsel += 1 + return sel[:nsel] + + +@njit(parallel=True, cache=True) +def _angular_kring_kernel( + vertices, + normals, + rings_flat, + level_offsets, + indptr, + indices, + data, + k, + n_samples, + correction, + max_nb, +): + """Fused, parallel per-vertex angular sampling + limited Dijkstra. + + Reproduces the serial loop of :func:`compute_kring_geodesic_distances_angular` + bit-for-bit (tangent-plane projection -> per-ring angular sampling -> limited + Dijkstra) but over a ``prange`` so all CPU cores are used. Each ``prange`` iteration + writes only its own output row, so there are no races and the result is deterministic. + + Returns dense ``(n_vertices, max_nb)`` neighbor/distance arrays plus a per-vertex count; + the caller slices each row to ``count`` to rebuild the ragged lists. + """ + n_vertices = vertices.shape[0] + out_nb = np.full((n_vertices, max_nb), -1, dtype=np.int64) + out_dist = np.zeros((n_vertices, max_nb), dtype=np.float64) + out_count = np.zeros(n_vertices, dtype=np.int64) + + for v in prange(n_vertices): + cx = vertices[v, 0] + cy = vertices[v, 1] + cz = vertices[v, 2] + nx = normals[v, 0] + ny = normals[v, 1] + nz = normals[v, 2] + + # Local tangent frame (matches project_to_tangent_plane exactly). + if abs(nx) < 0.9: + rx, ry, rz = 1.0, 0.0, 0.0 + else: + rx, ry, rz = 0.0, 1.0, 0.0 + ux = ny * rz - nz * ry + uy = nz * rx - nx * rz + uz = nx * ry - ny * rx + un = np.sqrt(ux * ux + uy * uy + uz * uz) + ux /= un + uy /= un + uz /= un + vx = ny * uz - nz * uy + vy = nz * ux - nx * uz + vz = nx * uy - ny * ux + + count = 0 + for level in range(k): + start = level_offsets[v, level] + end = level_offsets[v, level + 1] + m = end - start + if m == 0: + continue + angles = np.empty(m, dtype=np.float64) + for i in range(m): + idx = rings_flat[start + i] + px = vertices[idx, 0] - cx + py = vertices[idx, 1] - cy + pz = vertices[idx, 2] - cz + xx = px * ux + py * uy + pz * uz + yy = px * vx + py * vy + pz * vz + angles[i] = np.arctan2(yy, xx) + sel = _select_angular_samples_njit(angles, n_samples) + for s in range(sel.shape[0]): + out_nb[v, count] = rings_flat[start + sel[s]] + count += 1 + + out_count[v] = count + if count > 0: + targets = out_nb[v, :count].copy() + dists = _limited_dijkstra_numba( + indptr, indices, data, v, targets, correction + ) + for i in range(count): + out_dist[v, i] = dists[i] + + return out_nb, out_dist, out_count + + def compute_kring_geodesic_distances_angular( vertices, faces, @@ -853,65 +988,81 @@ def compute_kring_geodesic_distances_angular( # Build mesh graph for distance computation graph = build_mesh_graph(vertices, faces) - # Get rings organized by level (Numba version is ~20x faster) - print(f"Computing {k}-ring neighbors by level...") - if use_numba: - rings_by_level = get_rings_by_level_fast(faces, n_vertices, k) - else: - rings_by_level = get_rings_by_level(faces, n_vertices, k) - # Compute vertex normals for tangent plane projection print("Computing vertex normals...") normals = compute_vertex_normals(vertices.astype(np.float64), faces) - # For each vertex, sample from each ring level print(f"Angular sampling ({n_samples_per_ring} per ring)...") - sampled_neighbors = [] - sampled_distances = [] - - for v in tqdm( - range(n_vertices), desc="Sampling neighbors", position=tqdm_position, leave=True - ): - v_neighbors = [] - - center = vertices[v] - normal = normals[v] - - for level in range(k): - ring = rings_by_level[v][level] - if len(ring) == 0: - continue - - # Get positions of ring neighbors - ring_pos = vertices[ring] - - # Project to tangent plane - xy = project_to_tangent_plane(center, normal, ring_pos) - - # Compute angles - angles = np.arctan2(xy[:, 1], xy[:, 0]) - - # Select angularly-spaced samples - sample_idx = select_angular_samples(angles, n_samples_per_ring) - - if len(sample_idx) > 0: - selected = ring[sample_idx] - v_neighbors.extend(selected) - - # Compute distances to all selected neighbors - v_neighbors = np.array(v_neighbors, dtype=np.int64) - if len(v_neighbors) > 0: - if use_numba: - v_distances = _limited_dijkstra_numba( - graph.indptr, graph.indices, graph.data, v, v_neighbors, correction - ) - else: + if use_numba: + # Fused parallel path: build flat rings, then run the prange kernel over all + # vertices (tangent projection + angular sampling + limited Dijkstra). Output is + # bit-identical to the serial loop below but uses all cores. + print(f"Computing {k}-ring neighbors by level...") + adj = igl.adjacency_list(faces.astype(np.int64)) + adj_flat = np.concatenate([np.array(a, dtype=np.int64) for a in adj]) + adj_offsets = np.zeros(n_vertices + 1, dtype=np.int64) + for i, a in enumerate(adj): + adj_offsets[i + 1] = adj_offsets[i] + len(a) + rings_flat, level_offsets = _get_rings_by_level_numba(adj_flat, adj_offsets, k) + + verts64 = np.ascontiguousarray(vertices, dtype=np.float64) + norms64 = np.ascontiguousarray(normals, dtype=np.float64) + out_nb, out_dist, out_count = _angular_kring_kernel( + verts64, + norms64, + rings_flat, + level_offsets, + graph.indptr, + graph.indices, + graph.data, + k, + n_samples_per_ring, + correction, + k * n_samples_per_ring, + ) + sampled_neighbors = [ + out_nb[v, : out_count[v]].copy() for v in range(n_vertices) + ] + sampled_distances = [ + out_dist[v, : out_count[v]].copy() for v in range(n_vertices) + ] + else: + # Serial fallback (kept for parity / debugging). + print(f"Computing {k}-ring neighbors by level...") + rings_by_level = get_rings_by_level(faces, n_vertices, k) + sampled_neighbors = [] + sampled_distances = [] + + for v in tqdm( + range(n_vertices), + desc="Sampling neighbors", + position=tqdm_position, + leave=True, + ): + v_neighbors = [] + center = vertices[v] + normal = normals[v] + + for level in range(k): + ring = rings_by_level[v][level] + if len(ring) == 0: + continue + ring_pos = vertices[ring] + xy = project_to_tangent_plane(center, normal, ring_pos) + angles = np.arctan2(xy[:, 1], xy[:, 0]) + sample_idx = select_angular_samples(angles, n_samples_per_ring) + if len(sample_idx) > 0: + selected = ring[sample_idx] + v_neighbors.extend(selected) + + v_neighbors = np.array(v_neighbors, dtype=np.int64) + if len(v_neighbors) > 0: v_distances = _limited_dijkstra(v, v_neighbors, graph, correction) - else: - v_distances = np.array([]) + else: + v_distances = np.array([]) - sampled_neighbors.append(v_neighbors) - sampled_distances.append(v_distances) + sampled_neighbors.append(v_neighbors) + sampled_distances.append(v_distances) # Summary stats total_neighbors = sum(len(n) for n in sampled_neighbors) diff --git a/autoflatten/tests/test_flatten.py b/autoflatten/tests/test_flatten.py index 5cb6c8b..0f9b3d5 100644 --- a/autoflatten/tests/test_flatten.py +++ b/autoflatten/tests/test_flatten.py @@ -1416,6 +1416,98 @@ def test_empty_input(self): assert len(selected) == 0, f"Expected empty array, got {len(selected)} elements" +class TestAngularKernelParity: + """The parallel Numba k-ring kernel must match the NumPy/serial path bit-for-bit.""" + + def test_njit_sampler_matches_numpy(self): + from autoflatten.flatten.distance import ( + _select_angular_samples_njit, + select_angular_samples, + ) + + rng = np.random.default_rng(0) + for n_samples in (6, 8): + for m in (0, 1, 5, 6, 7, 20, 100): + angles = rng.uniform(-np.pi, np.pi, size=m) + exp = select_angular_samples(angles, n_samples=n_samples) + got = _select_angular_samples_njit(angles, n_samples) + assert np.array_equal(np.asarray(exp), np.asarray(got)), ( + f"m={m} n={n_samples}: {exp} vs {got}" + ) + + def _small_mesh(self): + # A subdivided plane with a little z-relief so tangent frames are nontrivial. + n = 12 + xs, ys = np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n)) + zs = 0.05 * np.sin(4 * xs) * np.cos(4 * ys) + verts = np.stack([xs.ravel(), ys.ravel(), zs.ravel()], axis=1).astype( + np.float64 + ) + faces = [] + for i in range(n - 1): + for j in range(n - 1): + a = i * n + j + b = a + 1 + c = a + n + d = c + 1 + faces.append([a, b, d]) + faces.append([a, d, c]) + return verts, np.asarray(faces, dtype=np.int64) + + def test_numba_kernel_matches_reference(self): + """Kernel == the original serial loop over the SAME (discovery-order) rings. + + The ``use_numba=False`` fallback sorts each ring, so it intentionally differs; the + meaningful guarantee is that the parallel kernel reproduces the original numba + production path (``get_rings_by_level_fast`` + serial sampling) bit-for-bit. + """ + from autoflatten.flatten.distance import ( + GRAPH_DISTANCE_CORRECTION, + _limited_dijkstra_numba, + build_mesh_graph, + compute_kring_geodesic_distances_angular as ang, + compute_vertex_normals, + get_rings_by_level_fast, + project_to_tangent_plane, + select_angular_samples, + ) + + verts, faces = self._small_mesh() + k, nspr = 4, 6 + kr_n, td_n = ang(verts, faces, k, n_samples_per_ring=nspr, use_numba=True) + + # Reference: the original numba-path loop (discovery-order rings). + graph = build_mesh_graph(verts, faces) + rings = get_rings_by_level_fast(faces, verts.shape[0], k) + normals = compute_vertex_normals(verts.astype(np.float64), faces) + for v in range(verts.shape[0]): + nb = [] + for level in range(k): + ring = rings[v][level] + if len(ring) == 0: + continue + xy = project_to_tangent_plane(verts[v], normals[v], verts[ring]) + ang_ = np.arctan2(xy[:, 1], xy[:, 0]) + idx = select_angular_samples(ang_, nspr) + if len(idx) > 0: + nb.extend(ring[idx]) + nb = np.array(nb, dtype=np.int64) + d = ( + _limited_dijkstra_numba( + graph.indptr, + graph.indices, + graph.data, + v, + nb, + GRAPH_DISTANCE_CORRECTION, + ) + if len(nb) > 0 + else np.array([]) + ) + assert np.array_equal(np.asarray(kr_n[v]), nb), v + assert np.array_equal(np.asarray(td_n[v], dtype=np.float64), d), v + + class TestThreadConfig: """Tests for set_num_threads and get_num_threads.""" From e3ad610aa7690ec91988a66b711d105001799f41 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Wed, 10 Jun 2026 21:54:19 -0700 Subject: [PATCH 24/35] Findings: record k-ring parallelization (~26x, bit-identical) Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 66a4925..9ec87de 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -404,6 +404,19 @@ trap). `dq` = fast − baseline (negative = fast better). This is the strongest single result for the paper: a 3.4× speedup with no quality cost, validated on held-out subjects. Logged as 10 ledger records (`exp:validate_speed:*`). +## 13. K-ring cache build parallelized (~26×) — removes the first-run penalty + +The one-time k-ring cache build (the "Sampling neighbors" pass, ~4 min/hemi) was a **serial +Python loop** over all ~200–400k vertices while 31 of 32 cores idled. Folded the per-vertex work +(tangent-plane projection → per-ring angular sampling → limited Dijkstra) into a single +`@njit(parallel=True)` `prange` kernel (`_angular_kring_kernel` + an njit port of the angular +sampler). **Output is bit-identical** to the previous numba path — verified against the existing +`sub-022 lh` cache (0/193174 neighbor-set mismatches, 0 distance mismatches) and pinned by a +regression test. Steady-state runtime on `sub-022 lh` dropped **~230 s → 8.7 s (~26×)** (≈17 s with +cold JIT). This is pure performance with no numerical change, so it needs no re-validation of any +flatmap. It mostly eliminates the first-run penalty for a new subject (the 180–220 s steady-state +flatten is JAX and unaffected); the warm-cache flatten is unchanged. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From df3d7e5ab7e6e0ca5f99919594fcdcf578ec1a21 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 08:09:21 -0700 Subject: [PATCH 25/35] Findings: authoritative end-to-end timing (4 hemis, ~205s/hemi, ~4.3x cold) Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 9ec87de..9d282e3 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -417,6 +417,26 @@ cold JIT). This is pure performance with no numerical change, so it needs no re- flatmap. It mostly eliminates the first-run penalty for a new subject (the 180–220 s steady-state flatten is JAX and unaffected); the warm-cache flatten is unchanged. +**Authoritative end-to-end measurement (4 hemispheres, cold cache).** Timed the full per-hemisphere +pipeline with the shippable fast config — geometry I/O → cold k-ring cache build (parallel kernel) → +fast JAX flatten — into a temp cache so the validated caches were untouched: + +| subject·hemi | I/O (s) | cache build (s) | flatten (s) | total (s) | +|---|---|---|---|---| +| sub-022 lh | 2.0 | 10.4 | 188.5 | 200.9 | +| sub-022 rh | 1.8 | 10.3 | 184.6 | 196.8 | +| sub-026 lh | 1.9 | 11.1 | 188.6 | 201.6 | +| sub-026 rh | 1.9 | 11.5 | 207.4 | 220.8 | +| **mean** | 1.9 | **10.8** | **192.3** | **205.0 (3.4 min)** | + +Against the §12 baseline on the *same* 4 hemispheres (flatten mean 657.5 s) plus the serial cache +build (~230 s), end-to-end per hemisphere goes from **~887 s (~14.8 min) → 205 s (~3.4 min), ~4.3×** +cold-cache (~3.4× warm). The measured total matches the component-sum estimate to the second. Note the +cache build read 10.8 s here (not the ~17 s cold-JIT figure) because numba's compiled kernel was +already cached to disk — the realistic steady state after a machine's first-ever run; the one-time +cold-JIT adds ~7 s once per machine. Flips and quality are the validated fast-config values (the cache +build is bit-identical and the flatten config unchanged), so no re-scoring was needed. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. From 337337d2a89951132d2329c36fa2368fb0d4e734 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 08:38:15 -0700 Subject: [PATCH 26/35] Projection Phase 1: FreeSurfer-free cut mapping, validated bit-identical Reimplement the projection stage's only FreeSurfer dependency (mri_label2label --regmethod surface) as a pure-Python union(push, pull) KDTree mapping on sphere.reg. Validated three ways: - vs real mri_label2label (FS 6.0): 0 vertex-ID differences across 6 cuts x 3 hemispheres (validate_mapping_vs_freesurfer.py) - end-to-end patch bit-identical on 9/9 manifest hemispheres incl. all 5 held-out subjects (validate_projection.py) - unit invariants, no FreeSurfer (tests/test_projection.py) Mapping step: 36.6s (FreeSurfer subprocess) -> ~0.3s (KDTree). Projection no longer requires FreeSurfer. Documented in FINDINGS section 14; ledger records logged. Phase 2 (cut-placement quality, scored by downstream flatmap distortion) is next. Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 64 +++++- benchmark/projection.py | 206 ++++++++++++++++++++ benchmark/tests/test_projection.py | 103 ++++++++++ benchmark/validate_mapping_vs_freesurfer.py | 148 ++++++++++++++ benchmark/validate_projection.py | 197 +++++++++++++++++++ 5 files changed, 717 insertions(+), 1 deletion(-) create mode 100644 benchmark/projection.py create mode 100644 benchmark/tests/test_projection.py create mode 100644 benchmark/validate_mapping_vs_freesurfer.py create mode 100644 benchmark/validate_projection.py diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 9d282e3..6f2f382 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -28,6 +28,12 @@ provenance) is the ledger at `/data2/projects/autoflatten/ledger/experiments.jso - **Net:** the pipeline is well-tuned; config-lever gains on distance error are small and non-robust. The one remaining path to a *larger robust* reduction is **long-range geodesic anchors in the energy** (the paper's own argument) — a real energy change, left as future work. +- **Projection (Phase 1, validated):** the projection stage's only FreeSurfer dependency + (`mri_label2label`) is now reimplemented in pure Python (`union(push, pull)` KDTree on `sphere.reg`) + that reproduces FreeSurfer **vertex-for-vertex** (0 ID differences vs real `mri_label2label`; + bit-identical patch on **9/9** hemispheres incl. held-out) at **~0.3 s vs 36.6 s** for the mapping + step. Projection no longer requires FreeSurfer (§14). Phase 2 (cut-placement quality, scored by + downstream flatmap distortion) is next. ## 1. Flip-free (Tutte) init is a validated win: equal quality, ~37% faster @@ -437,10 +443,66 @@ already cached to disk — the realistic steady state after a machine's first-ev cold-JIT adds ~7 s once per machine. Flips and quality are the validated fast-config values (the cache build is bit-identical and the flatten config unchanged), so no re-scoring was needed. +# Projection phase + +The sections above optimize the *flattening* stage. The sections below apply the same +benchmark-driven, ledgered process to the *projection* stage (mapping the fsaverage cut +template onto each subject and turning it into a patch). Objective (user-chosen): Phase 1 — +remove the FreeSurfer dependency at equal output; Phase 2 — improve cut placement scored by +downstream flatmap distortion. + +## 14. Projection is now FreeSurfer-free, bit-identical, ~120× faster on the mapping step + +The projection phase's only FreeSurfer dependency is `core.map_cuts_to_subject` → +`mri_label2label --regmethod surface`. Everything downstream (continuity, geodesic +refinement, hole-fill, patch write) is already pure Python. That one call is, underneath, +nearest-neighbour label transfer on the registered sphere (`{hemi}.sphere.reg`, on disk for +every subject), so it is reproducible in pure Python. + +**The exact mapping convention** (`benchmark/projection.py:map_label_surface`) is the +**union of two passes**: +- *pull* (target-driven): for each target vertex, find its nearest source (fsaverage) + vertex on `sphere.reg`; include the target vertex if that source is in the label. +- *push* (source-driven): for each source label vertex, include its nearest target vertex. + +The pull pass alone undercounts by ~5–15 vertices/cut (it drops boundary target vertices +whose own nearest source falls just outside the label); the push pass recovers exactly +those. A forward push *alone* is provably impossible here (124 source calcarine vertices +can hit ≤124 unique targets, but FreeSurfer maps to 156) — individual surfaces (~204k verts) +are denser than fsaverage (~164k), so the map must be target-driven plus a push patch. + +**Validation (three independent levels, all exact):** +1. **vs real `mri_label2label` (FreeSurfer 6.0), vertex-for-vertex:** total vertex-ID + symmetric difference = **0** across all 6 cuts × 3 dev hemispheres + (`benchmark/validate_mapping_vs_freesurfer.py`). Not just counts — identical vertices. +2. **End-to-end patch, full manifest:** the complete FS-free pipeline produces a patch with + an included-vertex set **bit-identical** to the cached FreeSurfer patch on **9/9 + hemispheres**, including all 5 held-out subjects (`benchmark/validate_projection.py`). +3. **Unit invariants** (no FreeSurfer): identity-sphere recovers the label exactly; union ⊇ + both passes; determinism (`benchmark/tests/test_projection.py`). + +**Speed:** the mapping step drops from **36.6 s** (FreeSurfer subprocess; 44–46 s when +including label I/O in the direct comparison) to **~0.3 s** (KDTree) — ~120× on that step. +Full projection is ~33 s/hemi (manifest mean), now entirely the pre-existing pure-Python +NetworkX geodesic refinement, with **no FreeSurfer required at all**. + +| stage | FreeSurfer | FS-free (Python) | +| --- | ---: | ---: | +| cut mapping (the FS call) | 36.6 s | 0.31 s | +| full projection / hemi | (+ refinement) | ~33 s | +| output patch | reference | **bit-identical (9/9)** | + +This both removes the install barrier (projection no longer needs FreeSurfer, only +`sphere.reg` files, which datalad provides) and is the prerequisite that makes the Phase 2 +cut-placement search runnable on this box. Mapping is the validated baseline; the geodesic +refinement (~33 s) is now the projection time sink and a future speed target. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. -Compute is CPU-only (the box's GPUs are blocked by driver 440 / CUDA 10.2). +Flattening compute is CPU-only (the box's GPUs are blocked by driver 440 / CUDA 10.2). +FreeSurfer is available (`~/bin/source_freesurfer*.sh`, 6.0 default) and was used only to +*validate* the FS-free projection — the pipeline itself no longer calls it. ## Next ideas diff --git a/benchmark/projection.py b/benchmark/projection.py new file mode 100644 index 0000000..37650fc --- /dev/null +++ b/benchmark/projection.py @@ -0,0 +1,206 @@ +"""FreeSurfer-free cut projection (Phase 1 of the projection autoresearch). + +The only FreeSurfer dependency in the projection phase is +``autoflatten.core.map_cuts_to_subject`` -> ``mri_label2label --regmethod surface``. +Underneath, surface-registration label mapping is just a nearest-neighbour lookup on +the registered sphere (``{hemi}.sphere.reg``): for each *target* vertex, find the +closest *source* (fsaverage) vertex on the sphere, and include the target vertex in the +mapped label if its nearest source vertex is in the source label. This is the +**target-driven (pull)** convention -- it is the only one consistent with the observed +projection-log counts (e.g. calcarine 124 src -> 156 trg: a forward push of 124 source +vertices can hit at most 124 unique targets, so the extra vertices can only come from a +target-driven map, which is natural here because individual surfaces (~204k verts) are +denser than fsaverage (~164k)). + +Both ``sphere.reg`` files are on disk for every benchmark subject, so this removes the +FreeSurfer requirement and is validated against the cached FreeSurfer patches. + +Reference (the FS call this replaces): ``autoflatten/core.py:map_cuts_to_subject``. +""" + +from __future__ import annotations + +import os + +import nibabel as nib +import numpy as np +from scipy.spatial import cKDTree + +from autoflatten.core import ( + ensure_continuous_cuts, + fill_holes_in_patch, + refine_cuts_with_geodesic, +) +from autoflatten.freesurfer import create_patch_file, load_surface +from autoflatten.utils import load_json + +# Root of the FreeSurfer derivatives tree (contains fsaverage + all subjects). +DEFAULT_SUBJECTS_DIR = ( + "/data2/projects/idem/exps/narratives/datalad-narratives/derivatives/freesurfer" +) + + +def sphere_reg_path(subject, hemi, subjects_dir=None): + subjects_dir = subjects_dir or os.environ.get("SUBJECTS_DIR", DEFAULT_SUBJECTS_DIR) + return os.path.join(subjects_dir, subject, "surf", f"{hemi}.sphere.reg") + + +def load_sphere_reg(subject, hemi, subjects_dir=None): + """Return the registered-sphere vertex coordinates (N, 3) for a subject/hemi.""" + coords, _ = nib.freesurfer.read_geometry( + sphere_reg_path(subject, hemi, subjects_dir) + ) + return np.asarray(coords, dtype=np.float64) + + +def map_label_surface(src_sphere, trg_sphere, src_label_idx, nearest_src=None): + """Map a label from source to target, reproducing ``mri_label2label`` surface mode. + + FreeSurfer's surface-registration label mapping is the **union of two passes**: + + - *pull* (target-driven): for every target vertex, find its nearest source vertex on + the registered sphere; include the target vertex if that source is in the label. + - *push* (source-driven): for every source label vertex, include its nearest target + vertex. + + The pull pass handles the bulk; the push pass recovers boundary target vertices whose + own nearest source falls just outside the label. The union reproduces FreeSurfer's + mapped-vertex counts exactly (validated on all dev hemispheres). + + Parameters + ---------- + src_sphere : ndarray (Ns, 3) + Source (fsaverage) ``sphere.reg`` coordinates. + trg_sphere : ndarray (Nt, 3) + Target (subject) ``sphere.reg`` coordinates. + src_label_idx : array-like of int + Source-subject vertex indices in the label. + nearest_src : ndarray (Nt,), optional + Precomputed target->nearest-source index map (shared across cuts for speed). + + Returns + ------- + ndarray of int + Sorted target-subject vertex indices in the mapped label. + """ + src_label = np.unique(np.asarray(src_label_idx, dtype=np.int64)) + if src_label.size == 0: + return np.array([], dtype=np.int64) + + # pull: each target vertex -> nearest source; include if source in label + if nearest_src is None: + _, nearest_src = cKDTree(src_sphere).query(trg_sphere, k=1) + in_label = np.zeros(src_sphere.shape[0], dtype=bool) + in_label[src_label] = True + pull = np.nonzero(in_label[nearest_src])[0] + + # push: each source label vertex -> nearest target vertex + _, push = cKDTree(trg_sphere).query(src_sphere[src_label], k=1) + + return np.union1d(pull, push).astype(np.int64) + + +def map_cuts_to_subject_python( + vertex_dict, + target_subject, + hemi, + source_subject="fsaverage", + subjects_dir=None, + trg_sphere=None, + src_sphere=None, +): + """Drop-in, FreeSurfer-free replacement for ``core.map_cuts_to_subject``. + + Maps every label in ``vertex_dict`` (fsaverage indices) to ``target_subject`` + indices via a single shared KDTree query (one NN lookup reused for all cuts). + """ + if src_sphere is None: + src_sphere = load_sphere_reg(source_subject, hemi, subjects_dir) + if trg_sphere is None: + trg_sphere = load_sphere_reg(target_subject, hemi, subjects_dir) + + # Build both KDTrees once and reuse across every cut. + _, nearest_src = cKDTree(src_sphere).query(trg_sphere, k=1) + trg_tree = cKDTree(trg_sphere) + + mapped = {} + for cut_name, vertices in vertex_dict.items(): + src_label = np.unique(np.asarray(vertices, dtype=np.int64)) + if src_label.size == 0: + mapped[cut_name] = [] + continue + in_label = np.zeros(src_sphere.shape[0], dtype=bool) + in_label[src_label] = True + pull = np.nonzero(in_label[nearest_src])[0] + _, push = trg_tree.query(src_sphere[src_label], k=1) + mapped[cut_name] = np.union1d(pull, push).astype(np.int64) + return mapped + + +def _load_template_vertex_dict(hemi, template_file=None): + """Load fsaverage cut/mwall labels for a hemisphere from the JSON template.""" + if template_file is None: + from autoflatten.config import fsaverage_cut_template + + template_file = fsaverage_cut_template + template_data = load_json(str(template_file)) + prefix = f"{hemi}_" + return { + key[len(prefix) :]: np.array(value) + for key, value in template_data.items() + if key.startswith(prefix) + } + + +def project_python( + subject, + hemi, + subjects_dir=None, + template_file=None, + refine_geodesic=True, + out_patch=None, + verbose=False, +): + """Run the full projection phase **without FreeSurfer**. + + Mirrors ``autoflatten.cli.cmd_project`` (map -> continuity -> geodesic refine -> + hole fill -> patch) but swaps the ``mri_label2label`` mapping for the validated + Python KDTree mapper. Every downstream step is already pure Python. + + Returns + ------- + dict + ``{"vertex_dict": ..., "patch_vertices": ..., "patch_file": ..., "n_surface": ...}``. + """ + subjects_dir = subjects_dir or os.environ.get("SUBJECTS_DIR", DEFAULT_SUBJECTS_DIR) + os.environ.setdefault("SUBJECTS_DIR", subjects_dir) + + vertex_dict = _load_template_vertex_dict(hemi, template_file) + mapped = map_cuts_to_subject_python( + vertex_dict, subject, hemi, subjects_dir=subjects_dir + ) + + fixed = ensure_continuous_cuts(dict(mapped), subject, hemi) + if refine_geodesic: + fixed = refine_cuts_with_geodesic( + fixed, subject, hemi, medial_wall_vertices=fixed.get("mwall") + ) + + pts, polys = load_surface(subject, "inflated", hemi) + excluded = set() + for vertices in fixed.values(): + excluded.update(int(v) for v in vertices) + hole_vertices = fill_holes_in_patch(polys, excluded) + if hole_vertices: + fixed["_hole_fill"] = np.array(list(hole_vertices)) + + if out_patch is None: + out_patch = f"/tmp/{subject}.{hemi}.fsfree.patch.3d" + patch_file, patch_vertices = create_patch_file(out_patch, pts, polys, fixed) + + return { + "vertex_dict": fixed, + "patch_vertices": patch_vertices, + "patch_file": patch_file, + "n_surface": len(pts), + } diff --git a/benchmark/tests/test_projection.py b/benchmark/tests/test_projection.py new file mode 100644 index 0000000..fb6c5ae --- /dev/null +++ b/benchmark/tests/test_projection.py @@ -0,0 +1,103 @@ +"""Tests for the FreeSurfer-free cut projection (benchmark.projection). + +The mapper-logic tests are self-contained (no FreeSurfer, no real data). The end-to-end +exact-match test is gated on the narratives derivatives tree being present. +""" + +from __future__ import annotations + +import os + +import numpy as np +import pytest + +from benchmark.projection import ( + DEFAULT_SUBJECTS_DIR, + map_cuts_to_subject_python, + map_label_surface, +) + + +def _fib_sphere(n, radius=100.0, seed=0): + """Deterministic ~evenly spread points on a sphere (Fibonacci lattice).""" + i = np.arange(n) + 0.5 + phi = np.arccos(1 - 2 * i / n) + golden = np.pi * (1 + 5**0.5) + theta = golden * i + xyz = np.stack( + [np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi)], axis=1 + ) + return xyz * radius + + +# --- mapper logic (no FreeSurfer) ------------------------------------------------- +def test_identity_sphere_recovers_label_exactly(): + """src_sphere == trg_sphere: union(push, pull) of a label must return the label.""" + sphere = _fib_sphere(500) + label = np.array([3, 7, 42, 128, 499]) + mapped = map_label_surface(sphere, sphere, label) + assert set(mapped.tolist()) == set(label.tolist()) + + +def test_empty_label_maps_to_empty(): + sphere = _fib_sphere(200) + assert map_label_surface(sphere, sphere, np.array([], dtype=int)).size == 0 + + +def test_union_superset_of_both_passes(): + """The union map must contain every vertex from the pull-only and push-only passes.""" + from scipy.spatial import cKDTree + + src = _fib_sphere(300, seed=1) + trg = _fib_sphere(450, seed=2) # denser target + label = np.arange(0, 300, 7) + + # pull-only + _, nearest_src = cKDTree(src).query(trg, k=1) + in_label = np.zeros(len(src), bool) + in_label[label] = True + pull = set(np.nonzero(in_label[nearest_src])[0].tolist()) + # push-only + _, push_idx = cKDTree(trg).query(src[label], k=1) + push = set(push_idx.tolist()) + + mapped = set(map_label_surface(src, trg, label).tolist()) + assert pull <= mapped and push <= mapped + assert mapped == (pull | push) + + +def test_mapper_is_deterministic(): + src = _fib_sphere(300, seed=1) + trg = _fib_sphere(400, seed=2) + vd = {"cut": np.arange(0, 300, 5), "mwall": np.arange(100, 200)} + a = map_cuts_to_subject_python(vd, "x", "lh", src_sphere=src, trg_sphere=trg) + b = map_cuts_to_subject_python(vd, "x", "lh", src_sphere=src, trg_sphere=trg) + for k in vd: + assert np.array_equal(a[k], b[k]) + + +# --- end-to-end exact match (gated on real data) ---------------------------------- +_HAS_NARRATIVES = os.path.isdir(os.path.join(DEFAULT_SUBJECTS_DIR, "fsaverage", "surf")) + + +@pytest.mark.skipif( + not _HAS_NARRATIVES, reason="narratives derivatives tree not present" +) +def test_end_to_end_patch_matches_cached_freesurfer(): + """Full FS-free projection reproduces the cached FreeSurfer patch exactly.""" + from autoflatten.freesurfer import read_patch + from benchmark.projection import project_python + + subject, hemi = "sub-022", "lh" + cached = os.path.join( + DEFAULT_SUBJECTS_DIR, subject, "surf", f"{hemi}.autoflatten.patch.3d" + ) + if not os.path.exists(cached): + pytest.skip("cached patch not present") + + res = project_python( + subject, hemi, out_patch=f"/tmp/test_{subject}_{hemi}.patch.3d" + ) + _, py_idx, _ = read_patch(res["patch_file"]) + _, fs_idx, _ = read_patch(cached) + assert set(int(i) for i in py_idx) == set(int(i) for i in fs_idx) diff --git a/benchmark/validate_mapping_vs_freesurfer.py b/benchmark/validate_mapping_vs_freesurfer.py new file mode 100644 index 0000000..2df6ab2 --- /dev/null +++ b/benchmark/validate_mapping_vs_freesurfer.py @@ -0,0 +1,148 @@ +"""Gold-standard check: Python cut mapper vs real ``mri_label2label`` (FreeSurfer). + +:mod:`benchmark.validate_projection` proves the *end-to-end* patch is bit-identical to the +cached FreeSurfer patches, but the cut vertices it compares have already been replaced by +geodesic refinement, so that test cannot see the raw mapping. This script closes the gap: +it runs the **actual** ``autoflatten.core.map_cuts_to_subject`` (which shells out to +``mri_label2label --regmethod surface``) and compares its mapped vertex IDs, cut by cut, +against :func:`benchmark.projection.map_cuts_to_subject_python`. + +Requires FreeSurfer on PATH. On this machine:: + + source ~/bin/source_freesurfer.sh # FreeSurfer 6.0 + export SUBJECTS_DIR=/data2/projects/idem/exps/narratives/datalad-narratives/derivatives/freesurfer + export PYTHONPATH=/home/jlg/mvdoc/repos/autoflatten + python -m benchmark.validate_mapping_vs_freesurfer + +Result (3 dev hemispheres): total vertex-ID symmetric difference = 0 -- the Python +``union(push, pull)`` KDTree mapper reproduces ``mri_label2label`` exactly. +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import sys +import time + +import numpy as np + +from . import paths +from .ledger import Ledger, new_record +from .projection import map_cuts_to_subject_python + +from autoflatten.config import fsaverage_cut_template +from autoflatten.core import map_cuts_to_subject + +CUTS = ["mwall", "calcarine", "medial1", "medial2", "medial3", "temporal"] + + +def compare_hemi(subject, hemi): + template = json.load(open(fsaverage_cut_template)) + vd = {k[3:]: np.array(v) for k, v in template.items() if k.startswith(hemi + "_")} + + t0 = time.time() + fs = map_cuts_to_subject(vd, subject, hemi) + fs_s = time.time() - t0 + t0 = time.time() + py = map_cuts_to_subject_python(vd, subject, hemi) + py_s = time.time() - t0 + + per_cut = {} + total_diff = 0 + for cut in CUTS: + a = set(int(x) for x in fs.get(cut, [])) + b = set(int(x) for x in py.get(cut, [])) + diff = len(a ^ b) + total_diff += diff + per_cut[cut] = { + "n_fs": len(a), + "n_py": len(b), + "jaccard": (len(a & b) / len(a | b)) if (a | b) else 1.0, + "n_diff": diff, + } + return { + "subject": subject, + "hemi": hemi, + "fs_s": fs_s, + "py_s": py_s, + "total_diff": total_diff, + "per_cut": per_cut, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "--hemis", + nargs="*", + default=["sub-022:lh", "sub-022:rh", "sub-026:lh"], + help="subject:hemi pairs (default: 3 dev hemispheres)", + ) + args = ap.parse_args() + + if shutil.which("mri_label2label") is None: + print( + "mri_label2label not on PATH -- source FreeSurfer first " + "(e.g. `source ~/bin/source_freesurfer.sh`).", + file=sys.stderr, + ) + return 1 + + paths.ensure_output_dirs() + ledger = Ledger() + results = [] + grand_total = 0 + for spec in args.hemis: + subject, hemi = spec.split(":") + r = compare_hemi(subject, hemi) + results.append(r) + grand_total += r["total_diff"] + print(f"=== {subject} {hemi} FS:{r['fs_s']:.1f}s PY:{r['py_s']:.2f}s ===") + for cut in CUTS: + c = r["per_cut"][cut] + tag = "EXACT" if c["n_diff"] == 0 else f"diff={c['n_diff']}" + print( + f" {cut:11} FS={c['n_fs']:6d} PY={c['n_py']:6d} " + f"jaccard={c['jaccard'] * 100:6.2f}% {tag}" + ) + + rec = new_record( + kind="experiment", + label=f"exp:mapping_vs_freesurfer:{subject}.{hemi}", + subjects=[{"subject": subject, "hemi": hemi}], + method={ + "name": "python_mapper_vs_mri_label2label", + "freesurfer": "6.0", + "mapper": "python_kdtree_union_push_pull", + }, + repro_command=f"python -m benchmark.validate_mapping_vs_freesurfer --hemis {spec}", + ) + rec.metrics = { + "total_vertex_id_diff": r["total_diff"], + "exact": r["total_diff"] == 0, + "freesurfer_s": r["fs_s"], + "python_s": r["py_s"], + "per_cut": r["per_cut"], + } + rec.status = "ok" + rec.decision["hypothesis"] = ( + "The Python union(push, pull) KDTree mapper reproduces mri_label2label " + "--regmethod surface vertex-for-vertex (not just in count)." + ) + rec.decision["conclusion"] = ( + f"{subject} {hemi}: {r['total_diff']} vertex-ID differences across all cuts " + f"(FreeSurfer {r['fs_s']:.1f}s vs Python {r['py_s']:.2f}s)." + ) + ledger.append(rec) + + print( + f"\nTOTAL vertex-ID symmetric difference across " + f"{len(results)} hemispheres: {grand_total}" + ) + return 0 if grand_total == 0 else 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/validate_projection.py b/benchmark/validate_projection.py new file mode 100644 index 0000000..921a2d5 --- /dev/null +++ b/benchmark/validate_projection.py @@ -0,0 +1,197 @@ +"""Validate the FreeSurfer-free projection across the full benchmark (Phase 1). + +The projection phase's only FreeSurfer dependency is +``core.map_cuts_to_subject`` -> ``mri_label2label --regmethod surface``. This runner +swaps it for the pure-Python KDTree mapper in :mod:`benchmark.projection` and checks that +the **final patch is bit-identical** (exact included-vertex set) to the cached FreeSurfer +patch on every manifest hemisphere, including held-out subjects -- the projection analogue +of the flattening overfit check. + +It also records the per-hemisphere timing breakdown (mapping vs continuity+refine), so the +speed win on the mapping step (FreeSurfer ~36s subprocess -> ~0.3s KDTree) is logged. + +Each hemisphere logs its own ledger record (crash-survivable); a final summary record +carries the aggregate exact-match count and timing. + +Usage +----- + python -m benchmark.validate_projection + python -m benchmark.validate_projection --only sub-041 sub-052 +""" + +from __future__ import annotations + +import argparse +import contextlib +import io +import sys +import time + +import numpy as np + +from . import paths +from .harness import load_manifest +from .ledger import Ledger, new_record +from .projection import ( + DEFAULT_SUBJECTS_DIR, + map_cuts_to_subject_python, + project_python, +) + +from autoflatten.freesurfer import read_patch + + +def _included_set(patch_path): + _, orig_idx, _ = read_patch(patch_path) + return set(int(i) for i in orig_idx) + + +def run_hemi(entry, subjects_dir): + subject, hemi = entry["subject"], entry["hemi"] + + # time the mapping step on its own (the part that replaces FreeSurfer) + from .projection import _load_template_vertex_dict + + vd = _load_template_vertex_dict(hemi) + t0 = time.time() + map_cuts_to_subject_python(vd, subject, hemi, subjects_dir=subjects_dir) + map_s = time.time() - t0 + + # full FS-free projection -> patch + out_patch = str(paths.RUNS_DIR / f"projection_{subject}_{hemi}.patch.3d") + t0 = time.time() + with contextlib.redirect_stdout(io.StringIO()): + res = project_python( + subject, hemi, subjects_dir=subjects_dir, out_patch=out_patch + ) + total_s = time.time() - t0 + + py = _included_set(res["patch_file"]) + fs = _included_set(entry["patch_path"]) + inter, union = len(py & fs), len(py | fs) + jaccard = inter / union if union else float("nan") + exact = py == fs + + return { + "subject": subject, + "hemi": hemi, + "split": entry.get("split"), + "exact": bool(exact), + "jaccard": jaccard, + "n_py": len(py), + "n_fs": len(fs), + "n_diff": len(py ^ fs), + "map_s": map_s, + "total_s": total_s, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--only", nargs="*", default=None, help="restrict to subjects") + ap.add_argument("--subjects-dir", default=DEFAULT_SUBJECTS_DIR) + args = ap.parse_args() + + paths.ensure_output_dirs() + manifest = load_manifest() + entries = manifest["entries"] + if args.only: + entries = [e for e in entries if e["subject"] in set(args.only)] + if not entries: + print("No manifest entries selected.", file=sys.stderr) + return 1 + + ledger = Ledger() + results = [] + print(f"Validating FS-free projection on {len(entries)} hemispheres...\n") + print( + f"{'subject':10} {'hemi':4} {'split':8} {'map_s':>6} {'total_s':>8} " + f"{'n_inc':>8} {'jaccard':>8} {'exact':>6}" + ) + for entry in entries: + r = run_hemi(entry, args.subjects_dir) + results.append(r) + print( + f"{r['subject']:10} {r['hemi']:4} {str(r['split']):8} " + f"{r['map_s']:6.2f} {r['total_s']:8.1f} {r['n_py']:8d} " + f"{r['jaccard'] * 100:7.2f}% {('YES' if r['exact'] else 'NO'):>6}" + ) + + rec = new_record( + kind="experiment", + label=f"exp:validate_projection:{r['subject']}.{r['hemi']}", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": r["subject"], "hemi": r["hemi"], "split": r["split"]} + ], + method={ + "name": "fsfree_projection_vs_freesurfer", + "mapper": "python_kdtree_union_push_pull", + "subjects_dir": args.subjects_dir, + }, + seeds={"note": "deterministic KDTree + NetworkX shortest paths"}, + repro_command="python -m benchmark.validate_projection --only " + + r["subject"], + ) + rec.metrics = { + "exact_match": r["exact"], + "jaccard": r["jaccard"], + "n_included_py": r["n_py"], + "n_included_fs": r["n_fs"], + "n_symmetric_diff": r["n_diff"], + "mapping_s": r["map_s"], + "total_s": r["total_s"], + } + rec.per_subject = [{"subject": r["subject"], "hemi": r["hemi"], **rec.metrics}] + rec.status = "ok" + rec.decision["hypothesis"] = ( + "mri_label2label --regmethod surface is reproducible in pure Python as a " + "union(push, pull) KDTree mapping on sphere.reg, yielding a bit-identical " + "patch with no FreeSurfer dependency." + ) + rec.decision["conclusion"] = ( + f"{r['split']}: exact patch match={r['exact']} (jaccard {r['jaccard'] * 100:.2f}%, " + f"{r['n_diff']} vertices differ); mapping {r['map_s']:.2f}s " + f"(FreeSurfer ~36s), total {r['total_s']:.1f}s." + ) + ledger.append(rec) + + n_exact = sum(r["exact"] for r in results) + map_times = np.array([r["map_s"] for r in results]) + tot_times = np.array([r["total_s"] for r in results]) + print( + f"\nEXACT match: {n_exact}/{len(results)} hemispheres | " + f"mapping mean {map_times.mean():.2f}s | total mean {tot_times.mean():.1f}s" + ) + + rec = new_record( + kind="experiment", + label="exp:validate_projection:summary", + manifest_id=manifest.get("created"), + subjects=[ + {"subject": r["subject"], "hemi": r["hemi"], "split": r["split"]} + for r in results + ], + method={"name": "fsfree_projection_summary"}, + repro_command="python -m benchmark.validate_projection", + ) + rec.metrics = { + "n_hemispheres": len(results), + "n_exact_match": int(n_exact), + "all_exact": n_exact == len(results), + "mapping_mean_s": float(map_times.mean()), + "total_mean_s": float(tot_times.mean()), + } + rec.status = "ok" + rec.decision["conclusion"] = ( + f"FS-free projection reproduces the FreeSurfer patch exactly on " + f"{n_exact}/{len(results)} hemispheres; mapping {map_times.mean():.2f}s vs " + f"FreeSurfer ~36s. Projection no longer requires FreeSurfer." + ) + ledger.append(rec) + print(f"\nLogged {len(results) + 1} records -> {ledger.path}") + return 0 if n_exact == len(results) else 2 + + +if __name__ == "__main__": + raise SystemExit(main()) From 5fdf8b22382877c011c0dd3292992009efbe5c54 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 09:13:24 -0700 Subject: [PATCH 27/35] Phase 2: refinement ablation probe + continuity toggle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add benchmark/probe_refinement.py to test whether projection refinement (continuity + geodesic) improves downstream flatmap distortion, scored by the fast flattener + global true-geodesic metric (fresh k-ring + truegeo per variant). Add a `continuity` toggle to project_python for the ablation. First signal (sub-022 lh): geodesic refinement gives HIGHER global distortion (11.37%) than raw mapped cuts (10.69%); all variants are topologically valid. N=1, <1pp spread — broadening before concluding. Co-Authored-By: Claude Opus 4.8 --- benchmark/probe_refinement.py | 182 ++++++++++++++++++++++++++++++++++ benchmark/projection.py | 11 +- 2 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 benchmark/probe_refinement.py diff --git a/benchmark/probe_refinement.py b/benchmark/probe_refinement.py new file mode 100644 index 0000000..6e2f22d --- /dev/null +++ b/benchmark/probe_refinement.py @@ -0,0 +1,182 @@ +"""Phase 2: does projection refinement improve the downstream flatmap? (ablation) + +Constraint (user): exactly 5 cuts, in their current anatomical positions. So the template +placement is fixed; the only question is the **refinement** that turns the raw mapped cuts +into the final patch: + +- ``ensure_continuous_cuts`` -- connects disconnected mapped-cut components. +- ``refine_cuts_with_geodesic`` -- replaces each thick mapped cut blob with a thin geodesic + shortest path between endpoints (start = farthest-from-mwall cut vertex; end = max-clearance + medial-wall anchor). + +This probe ablates those two steps and scores each resulting patch by the **downstream +flatmap distortion** (the fast flattener + global true-geodesic metric, recomputed per +variant because the patch vertex set changes) plus flips and patch size. It answers: are the +refinements necessary, and do they actually reduce distortion? + +Each variant patch gets a **fresh** k-ring (``use_cache=False``) -- the on-disk cache is +keyed only on (subject, hemi, k, n), so variants would otherwise collide. + +Usage +----- + python -m benchmark.probe_refinement --hemis sub-022:lh + python -m benchmark.probe_refinement # all dev hemis, all variants +""" + +from __future__ import annotations + +import argparse +import contextlib +import io +import time + +import numpy as np + +from . import paths +from .ledger import Ledger, new_record +from .metrics import per_patch_metrics +from .probe_tutte_init import make_flatten_fn +from .projection import project_python +from .truedist import compute_truegeo, true_distortion_full +from .validate_speed import fast_config + +from autoflatten.flatten.algorithm import count_boundary_loops + +# (label, continuity, refine_geodesic) +VARIANTS = [ + ("geodesic", True, True), # shipped pipeline + ("continuity_only", True, False), # thick mapped cuts, no geodesic thinning + ("mapped_only", False, False), # raw mapped cuts (may be topologically invalid) +] + +SURF = "fiducial" + + +def _surface_path(subject, hemi, subjects_dir): + return f"{subjects_dir}/{subject}/surf/{hemi}.{SURF}" + + +def run_variant(subject, hemi, label, continuity, refine_geodesic, subjects_dir): + out_patch = str(paths.RUNS_DIR / f"refine_{label}_{subject}_{hemi}.patch.3d") + with contextlib.redirect_stdout(io.StringIO()): + proj = project_python( + subject, + hemi, + subjects_dir=subjects_dir, + continuity=continuity, + refine_geodesic=refine_geodesic, + out_patch=out_patch, + ) + n_patch = len(proj["patch_vertices"]) + + entry = { + "subject": subject, + "hemi": hemi, + "patch_path": proj["patch_file"], + "surface_path": _surface_path(subject, hemi, subjects_dir), + } + + # flatten (fresh k-ring; fast config + Tutte init) + from .harness import build_flattener + + t0 = time.time() + fl = build_flattener(entry, fast_config(), use_cache=False) + # patch topology: a valid flat patch is a single boundary loop (a disk) + n_loops, _ = count_boundary_loops(fl.faces) + uv = np.asarray(make_flatten_fn("tutte", refine=True)(fl)) + rt = time.time() - t0 + + m = per_patch_metrics(uv, fl) + ref = compute_truegeo(fl) + full = true_distortion_full(uv, ref) + + return { + "label": label, + "n_patch": n_patch, + "n_boundary_loops": int(n_loops), + "n_flipped": int(m["n_flipped"]), + "true_global_mean": full["true_global_mean"], + "true_global_at_optscale": full["true_global_at_optscale"], + "opt_scale": full["opt_scale"], + "runtime_s": rt, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--hemis", nargs="*", default=["sub-022:lh"]) + ap.add_argument("--variants", nargs="*", default=[v[0] for v in VARIANTS]) + ap.add_argument( + "--subjects-dir", + default="/data2/projects/idem/exps/narratives/datalad-narratives/derivatives/freesurfer", + ) + args = ap.parse_args() + + paths.ensure_output_dirs() + variants = [v for v in VARIANTS if v[0] in set(args.variants)] + ledger = Ledger() + + print( + f"{'hemi':11} {'variant':16} {'n_patch':>8} {'loops':>5} {'flips':>6} " + f"{'glob%':>7} {'opt%':>7} {'rt_s':>6}" + ) + for spec in args.hemis: + subject, hemi = spec.split(":") + for label, cont, refine in variants: + try: + r = run_variant(subject, hemi, label, cont, refine, args.subjects_dir) + except Exception as e: # noqa: BLE001 - record the failure, keep going + print(f"{subject + ' ' + hemi:11} {label:16} FAILED: {e}") + rec = new_record( + kind="experiment", + label=f"exp:probe_refinement:{subject}.{hemi}:{label}", + subjects=[{"subject": subject, "hemi": hemi}], + method={"name": "refinement_ablation", "variant": label}, + repro_command=f"python -m benchmark.probe_refinement --hemis {spec} --variants {label}", + ) + rec.status = "error" + rec.decision["conclusion"] = f"{label} failed: {e}" + ledger.append(rec) + continue + + print( + f"{subject + ' ' + hemi:11} {r['label']:16} {r['n_patch']:8d} " + f"{r['n_boundary_loops']:5d} {r['n_flipped']:6d} " + f"{r['true_global_mean']:7.2f} {r['true_global_at_optscale']:7.2f} " + f"{r['runtime_s']:6.0f}" + ) + + rec = new_record( + kind="experiment", + label=f"exp:probe_refinement:{subject}.{hemi}:{label}", + subjects=[{"subject": subject, "hemi": hemi}], + method={ + "name": "refinement_ablation", + "variant": label, + "continuity": cont, + "refine_geodesic": refine, + "flatten": "fast_ultimate+tutte", + }, + repro_command=f"python -m benchmark.probe_refinement --hemis {spec} --variants {label}", + ) + rec.metrics = {k: v for k, v in r.items() if k != "label"} + rec.per_subject = [{"subject": subject, "hemi": hemi, **rec.metrics}] + rec.status = "ok" + rec.decision["hypothesis"] = ( + "Geodesic refinement (thin cut paths) lowers downstream flatmap distortion " + "vs raw/continuity-only thick mapped cuts; continuity is needed for a valid " + "single-boundary-loop patch." + ) + rec.decision["conclusion"] = ( + f"{label}: {r['n_boundary_loops']} boundary loop(s), {r['n_flipped']} flips, " + f"global {r['true_global_mean']:.2f}% (opt {r['true_global_at_optscale']:.2f}%), " + f"patch {r['n_patch']} verts." + ) + ledger.append(rec) + + print(f"\nLogged to {ledger.path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmark/projection.py b/benchmark/projection.py index 37650fc..de2a24f 100644 --- a/benchmark/projection.py +++ b/benchmark/projection.py @@ -157,6 +157,7 @@ def project_python( hemi, subjects_dir=None, template_file=None, + continuity=True, refine_geodesic=True, out_patch=None, verbose=False, @@ -167,6 +168,10 @@ def project_python( hole fill -> patch) but swaps the ``mri_label2label`` mapping for the validated Python KDTree mapper. Every downstream step is already pure Python. + The ``continuity`` and ``refine_geodesic`` toggles exist for the Phase 2 refinement + ablation (are these steps necessary / do they improve downstream distortion?). Both + default to True (the shipped pipeline). + Returns ------- dict @@ -180,7 +185,11 @@ def project_python( vertex_dict, subject, hemi, subjects_dir=subjects_dir ) - fixed = ensure_continuous_cuts(dict(mapped), subject, hemi) + fixed = ( + ensure_continuous_cuts(dict(mapped), subject, hemi) + if continuity + else dict(mapped) + ) if refine_geodesic: fixed = refine_cuts_with_geodesic( fixed, subject, hemi, medial_wall_vertices=fixed.get("mwall") From e19b11d2e98b1f0d61f3e15055cd5ed76837aa24 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 10:52:31 -0700 Subject: [PATCH 28/35] =?UTF-8?q?Phase=202:=20refinement=20ablation=20verd?= =?UTF-8?q?ict=20=E2=80=94=20drop=20geodesic=20thinning,=20keep=20continui?= =?UTF-8?q?ty?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ablation across 6 hemispheres (probe_refinement.py), scored by downstream fast-flatten + global true-geodesic distortion: geodesic (shipped) 11.12% worst/tied, inflates flips continuity_only 10.84% BEST (-0.28pp), fewer flips mapped_only 11.12% continuity is what helps, not thinning geodesic_curv (c) 11.24% worst; flips exploded (424 on sub-022 rh) The geodesic refinement does not help and slightly hurts: connecting cut components (continuity) is the win; thinning the thick mapped cut to a 1-vertex geodesic is the harm (removes strain relief, jagged boundary). Curvature-weighted routing (c) failed because sulci meander -> more boundary tortuosity -> more flips, confirming thinning (not the routing metric) is the core problem. Adds a `continuity` toggle and `refine_weight=curvature` (curv_alpha) path to project_python. Shippable recommendation: default projection to geodesic-refinement-off. Documented in FINDINGS section 15. Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 49 +++++++++++++++++++++++++++-- benchmark/probe_refinement.py | 52 +++++++++++++++++++++++-------- benchmark/projection.py | 58 +++++++++++++++++++++++++++++++++-- 3 files changed, 141 insertions(+), 18 deletions(-) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 6f2f382..df1078e 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -32,8 +32,13 @@ provenance) is the ledger at `/data2/projects/autoflatten/ledger/experiments.jso (`mri_label2label`) is now reimplemented in pure Python (`union(push, pull)` KDTree on `sphere.reg`) that reproduces FreeSurfer **vertex-for-vertex** (0 ID differences vs real `mri_label2label`; bit-identical patch on **9/9** hemispheres incl. held-out) at **~0.3 s vs 36.6 s** for the mapping - step. Projection no longer requires FreeSurfer (§14). Phase 2 (cut-placement quality, scored by - downstream flatmap distortion) is next. + step. Projection no longer requires FreeSurfer (§14). +- **Projection (Phase 2 — refinement ablation, 6 hemis):** with the 5 cuts fixed in their current + positions, the **geodesic cut refinement hurts** the flatmap — `continuity_only` (connect cut + components but don't thin them) beats the shipped geodesic refinement by **−0.28 pp** distortion + with fewer flips on 4/6 hemispheres; a curvature-routed geodesic (c) was *worse* (flips exploded). + The harm is the **thinning** (thick relief band → jagged 1-wide path), not the routing metric. + Shippable: default projection to geodesic-refinement-**off** (§15). ## 1. Flip-free (Tutte) init is a validated win: equal quality, ~37% faster @@ -497,6 +502,46 @@ This both removes the install barrier (projection no longer needs FreeSurfer, on cut-placement search runnable on this box. Mapping is the validated baseline; the geodesic refinement (~33 s) is now the projection time sink and a future speed target. +## 15. The geodesic cut refinement HURTS the flatmap — keep continuity, drop the thinning + +With the template fixed (5 cuts, current anatomical positions — a hard requirement), the only +projection lever is the **refinement** that turns raw mapped cuts into the patch: +`ensure_continuous_cuts` (connect disconnected mapped-cut components) and +`refine_cuts_with_geodesic` (replace each thick mapped cut with a thin Euclidean-shortest +geodesic between heuristic endpoints). Ablation across **6 hemispheres**, scoring each +variant's patch by the downstream fast-flatten + global true-geodesic metric (fresh k-ring + +truegeo per variant; `benchmark/probe_refinement.py`): + +| variant | mean global@opt % | flips behaviour | +| --- | ---: | --- | +| `geodesic` (shipped) | 11.12 | worst/tied-worst; inflates flips (129, 120 on two hemis) | +| **`continuity_only`** | **10.84** | **best distortion, moderate flips** | +| `mapped_only` (no continuity) | 11.12 | ties shipped — continuity is what helps, not thinning | +| `geodesic_curv` (c) | 11.24 | **worst**; flips explode (424 on sub-022 rh) | + +**The geodesic refinement does not help and slightly hurts** — `continuity_only` beats it on +4/6 hemispheres (−0.28 pp mean) with fewer flips. The win is the **continuity** step +(connecting cut components); the **thinning** to a 1-vertex-wide geodesic is the harm. + +**Mechanism.** Replacing a thick mapped-cut *band* with a thin path (a) removes the strain +relief the band provided and (b) leaves a jagged cut boundary. Both raise distortion and +flips. The shipped step compounds (b) with bad endpoint heuristics (start = farthest-from-mwall, +end = max-clearance-from-mwall — actively routing the cut *away* from where it should anchor). + +**(c), the principled fix that failed (and why that's informative).** Hypothesis: route the +cut along sulcal fundi (weight graph edges `length·exp(-α·sulc)`, α=0.1, monkeypatched into +`refine_cuts_with_geodesic` so only the path weighting changes). Result: **worse on average +(11.24%) and flips exploded** (sub-022 rh: 424 vs 129). Sulci meander, so curvature-routing +makes the cut boundary *more* tortuous — exactly the (b) failure mode, amplified. This +confirms the core problem is the **thinning**, not the routing metric: no path that thins the +cut to one vertex wide will beat keeping the thick band, straight or curvy. + +**Shippable recommendation:** default projection to **continuity-on, geodesic-refinement-off** +(i.e. make `--no-refine-geodesic` the default). ~0.28 pp lower distortion, fewer flips, and +faster projection (the geodesic refinement is also the ~33 s/hemi time sink, §14). Open +follow-on if more is wanted: a refinement that keeps the cut *thick* but cleans its boundary +(smooth/widen rather than thin) — untested. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. diff --git a/benchmark/probe_refinement.py b/benchmark/probe_refinement.py index 6e2f22d..c14b5b6 100644 --- a/benchmark/probe_refinement.py +++ b/benchmark/probe_refinement.py @@ -42,11 +42,31 @@ from autoflatten.flatten.algorithm import count_boundary_loops -# (label, continuity, refine_geodesic) VARIANTS = [ - ("geodesic", True, True), # shipped pipeline - ("continuity_only", True, False), # thick mapped cuts, no geodesic thinning - ("mapped_only", False, False), # raw mapped cuts (may be topologically invalid) + # shipped pipeline (Euclidean-shortest geodesic refinement) + {"label": "geodesic", "continuity": True, "refine": True, "weight": "euclidean"}, + # thick mapped cuts, no geodesic thinning + { + "label": "continuity_only", + "continuity": True, + "refine": False, + "weight": "euclidean", + }, + # raw mapped cuts (no continuity, no refinement) + { + "label": "mapped_only", + "continuity": False, + "refine": False, + "weight": "euclidean", + }, + # (c) curvature-weighted geodesic: route cut paths along sulcal fundi + { + "label": "geodesic_curv", + "continuity": True, + "refine": True, + "weight": "curvature", + "alpha": 0.1, + }, ] SURF = "fiducial" @@ -56,15 +76,18 @@ def _surface_path(subject, hemi, subjects_dir): return f"{subjects_dir}/{subject}/surf/{hemi}.{SURF}" -def run_variant(subject, hemi, label, continuity, refine_geodesic, subjects_dir): +def run_variant(subject, hemi, spec, subjects_dir): + label = spec["label"] out_patch = str(paths.RUNS_DIR / f"refine_{label}_{subject}_{hemi}.patch.3d") with contextlib.redirect_stdout(io.StringIO()): proj = project_python( subject, hemi, subjects_dir=subjects_dir, - continuity=continuity, - refine_geodesic=refine_geodesic, + continuity=spec["continuity"], + refine_geodesic=spec["refine"], + refine_weight=spec.get("weight", "euclidean"), + curv_alpha=spec.get("alpha", 0.1), out_patch=out_patch, ) n_patch = len(proj["patch_vertices"]) @@ -105,7 +128,7 @@ def run_variant(subject, hemi, label, continuity, refine_geodesic, subjects_dir) def main() -> int: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--hemis", nargs="*", default=["sub-022:lh"]) - ap.add_argument("--variants", nargs="*", default=[v[0] for v in VARIANTS]) + ap.add_argument("--variants", nargs="*", default=[v["label"] for v in VARIANTS]) ap.add_argument( "--subjects-dir", default="/data2/projects/idem/exps/narratives/datalad-narratives/derivatives/freesurfer", @@ -113,7 +136,7 @@ def main() -> int: args = ap.parse_args() paths.ensure_output_dirs() - variants = [v for v in VARIANTS if v[0] in set(args.variants)] + variants = [v for v in VARIANTS if v["label"] in set(args.variants)] ledger = Ledger() print( @@ -122,9 +145,10 @@ def main() -> int: ) for spec in args.hemis: subject, hemi = spec.split(":") - for label, cont, refine in variants: + for vspec in variants: + label = vspec["label"] try: - r = run_variant(subject, hemi, label, cont, refine, args.subjects_dir) + r = run_variant(subject, hemi, vspec, args.subjects_dir) except Exception as e: # noqa: BLE001 - record the failure, keep going print(f"{subject + ' ' + hemi:11} {label:16} FAILED: {e}") rec = new_record( @@ -153,8 +177,10 @@ def main() -> int: method={ "name": "refinement_ablation", "variant": label, - "continuity": cont, - "refine_geodesic": refine, + "continuity": vspec["continuity"], + "refine_geodesic": vspec["refine"], + "refine_weight": vspec.get("weight", "euclidean"), + "curv_alpha": vspec.get("alpha"), "flatten": "fast_ultimate+tutte", }, repro_command=f"python -m benchmark.probe_refinement --hemis {spec} --variants {label}", diff --git a/benchmark/projection.py b/benchmark/projection.py index de2a24f..98c073c 100644 --- a/benchmark/projection.py +++ b/benchmark/projection.py @@ -26,6 +26,7 @@ import numpy as np from scipy.spatial import cKDTree +import autoflatten.core as _core from autoflatten.core import ( ensure_continuous_cuts, fill_holes_in_patch, @@ -137,6 +138,39 @@ def map_cuts_to_subject_python( return mapped +def _curvature_weighted_graph_builder(subject, hemi, subjects_dir, alpha, morph="sulc"): + """Factory for a ``(pts, polys) -> nx.Graph`` builder with curvature-weighted edges. + + Phase 2 hypothesis (c): the shipped geodesic refinement routes cuts along the + *Euclidean-shortest* path, which is curvature-blind and pulls cuts off the sulcal fundi + they should track. Reweighting edges by ``length * exp(-alpha * sulc_edge)`` makes the + shortest path *prefer deep sulci* (high ``sulc``) and avoid gyral crowns. The returned + builder matches ``core._build_surface_graph``'s signature so it can be monkeypatched in + for the refinement call only (every endpoint/trapped-vertex heuristic stays identical; + only the path's edge weights change). + """ + import networkx as nx + + sulc = nib.freesurfer.read_morph_data( + os.path.join(subjects_dir, subject, "surf", f"{hemi}.{morph}") + ) + + def build(pts, polys): + polys = np.asarray(polys) + edges = np.vstack([polys[:, [0, 1]], polys[:, [0, 2]], polys[:, [1, 2]]]) + length = np.linalg.norm(pts[edges[:, 0]] - pts[edges[:, 1]], axis=1) + s_edge = 0.5 * (sulc[edges[:, 0]] + sulc[edges[:, 1]]) + weights = length * np.exp(-alpha * s_edge) # high sulc -> low cost + G = nx.Graph() + G.add_nodes_from(range(len(pts))) + G.add_weighted_edges_from( + zip(edges[:, 0].tolist(), edges[:, 1].tolist(), weights.tolist()) + ) + return G + + return build + + def _load_template_vertex_dict(hemi, template_file=None): """Load fsaverage cut/mwall labels for a hemisphere from the JSON template.""" if template_file is None: @@ -159,6 +193,9 @@ def project_python( template_file=None, continuity=True, refine_geodesic=True, + refine_weight="euclidean", + curv_alpha=0.1, + curv_morph="sulc", out_patch=None, verbose=False, ): @@ -191,9 +228,24 @@ def project_python( else dict(mapped) ) if refine_geodesic: - fixed = refine_cuts_with_geodesic( - fixed, subject, hemi, medial_wall_vertices=fixed.get("mwall") - ) + if refine_weight == "curvature": + # Route cut paths along sulcal fundi by monkeypatching the graph builder + # used inside refine_cuts_with_geodesic (endpoint/trapped logic unchanged). + builder = _curvature_weighted_graph_builder( + subject, hemi, subjects_dir, curv_alpha, curv_morph + ) + orig = _core._build_surface_graph + _core._build_surface_graph = builder + try: + fixed = refine_cuts_with_geodesic( + fixed, subject, hemi, medial_wall_vertices=fixed.get("mwall") + ) + finally: + _core._build_surface_graph = orig + else: + fixed = refine_cuts_with_geodesic( + fixed, subject, hemi, medial_wall_vertices=fixed.get("mwall") + ) pts, polys = load_surface(subject, "inflated", hemi) excluded = set() From 61038ab8542e4523a43ca580b901b6e4f2717cd7 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 12:25:06 -0700 Subject: [PATCH 29/35] =?UTF-8?q?Phase=202=20(2):=20thick-but-smoothed=20c?= =?UTF-8?q?ut=20=E2=80=94=20cuts=20flips,=20not=20distortion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tested the boundary-smoothing follow-on: keep the thick continuity_only cut but morphologically close its boundary (project_python morph_close=n; _morphological_close_cuts). Result across 6 hemispheres: continuity_only 10.84% BEST (robust winner) thick_close2 10.99% reduces worst flips (59->33) but distortion mixed thick_close1 11.01% same; hurts sub-026 lh by +1pp Both "smarter" refinements (curvature routing (c), boundary smoothing (2)) fail to beat plain continuity_only. The simplest answer wins: connect the cut components, don't thin, don't smooth. Coverage cost of closing is negligible (<150 verts, <0.1% of surface; medial wall ~11k dominates). Adds --save-flat to the probe for visualization. FINDINGS section 15 updated. Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 43 ++++++++++++++++++++++++----------- benchmark/probe_refinement.py | 20 ++++++++++++++-- benchmark/projection.py | 30 ++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 15 deletions(-) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index df1078e..131ba03 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -36,9 +36,11 @@ provenance) is the ledger at `/data2/projects/autoflatten/ledger/experiments.jso - **Projection (Phase 2 — refinement ablation, 6 hemis):** with the 5 cuts fixed in their current positions, the **geodesic cut refinement hurts** the flatmap — `continuity_only` (connect cut components but don't thin them) beats the shipped geodesic refinement by **−0.28 pp** distortion - with fewer flips on 4/6 hemispheres; a curvature-routed geodesic (c) was *worse* (flips exploded). - The harm is the **thinning** (thick relief band → jagged 1-wide path), not the routing metric. - Shippable: default projection to geodesic-refinement-**off** (§15). + with fewer flips on 4/6 hemispheres. Two "smarter" refinements were then tested and **both fail to + beat plain `continuity_only`**: a curvature-routed geodesic (c) is *worse* (flips exploded), and + boundary smoothing (2) cuts flips but not distortion. The robust answer is the simplest: connect + the cut components, **don't thin, don't smooth**. Shippable: default projection to + geodesic-refinement-**off** (§15). ## 1. Flip-free (Tutte) init is a validated win: equal quality, ~37% faster @@ -518,6 +520,8 @@ truegeo per variant; `benchmark/probe_refinement.py`): | **`continuity_only`** | **10.84** | **best distortion, moderate flips** | | `mapped_only` (no continuity) | 11.12 | ties shipped — continuity is what helps, not thinning | | `geodesic_curv` (c) | 11.24 | **worst**; flips explode (424 on sub-022 rh) | +| `thick_close1` (2) | 11.01 | between; *reduces* worst flips (59→33) but distortion mixed | +| `thick_close2` (2) | 10.99 | between; same — helps flips, hurts sub-026 lh by +1 pp | **The geodesic refinement does not help and slightly hurts** — `continuity_only` beats it on 4/6 hemispheres (−0.28 pp mean) with fewer flips. The win is the **continuity** step @@ -528,19 +532,32 @@ relief the band provided and (b) leaves a jagged cut boundary. Both raise distor flips. The shipped step compounds (b) with bad endpoint heuristics (start = farthest-from-mwall, end = max-clearance-from-mwall — actively routing the cut *away* from where it should anchor). -**(c), the principled fix that failed (and why that's informative).** Hypothesis: route the -cut along sulcal fundi (weight graph edges `length·exp(-α·sulc)`, α=0.1, monkeypatched into -`refine_cuts_with_geodesic` so only the path weighting changes). Result: **worse on average -(11.24%) and flips exploded** (sub-022 rh: 424 vs 129). Sulci meander, so curvature-routing -makes the cut boundary *more* tortuous — exactly the (b) failure mode, amplified. This -confirms the core problem is the **thinning**, not the routing metric: no path that thins the -cut to one vertex wide will beat keeping the thick band, straight or curvy. +**Two "smarter refinement" ideas were tested; both fail to beat plain `continuity_only`:** + +- **(c) curvature-routed geodesic.** Hypothesis: route the cut along sulcal fundi (weight graph + edges `length·exp(-α·sulc)`, α=0.1, monkeypatched into `refine_cuts_with_geodesic` so only the + path weighting changes). Result: **worse on average (11.24%) and flips exploded** (sub-022 rh: + 424 vs 129). Sulci meander, so curvature-routing makes the cut boundary *more* tortuous — + exactly failure mode (b), amplified. Confirms the core problem is the **thinning**, not the + routing metric: no path that thins the cut to one vertex wide beats the thick band. +- **(2) thick-but-smoothed cut.** Keep the thick `continuity_only` cut but morphologically + *close* its boundary (dilate `n` rings then erode `n`, `_morphological_close_cuts`) to fill the + ragged notches where flips concentrate. Result: **partial** — it *does* cut the worst flip + counts (sub-022 rh 59→33, sub-041 lh 96→86), validating the "ragged boundary → flips" mechanism, + **but it does not lower distortion** (mean 10.99–11.01% vs `continuity_only` 10.84%). It helps on + some hemispheres and *hurts* on others (sub-026 lh +1 pp — dilation occasionally pushes a cut + into a worse spot). Coverage cost is negligible (the 5 cuts exclude ~530 verts; closing adds + <150, i.e. <0.1% of the surface — the medial wall, ~11k verts, dominates exclusion regardless). + +**Conclusion — the simplest refinement is the best.** Across the ablation the robust winner is +`continuity_only`: connect the disconnected mapped-cut components, but **do not thin and do not +smooth** them. Neither a curvature-aware path nor boundary smoothing beats it on distortion. **Shippable recommendation:** default projection to **continuity-on, geodesic-refinement-off** (i.e. make `--no-refine-geodesic` the default). ~0.28 pp lower distortion, fewer flips, and -faster projection (the geodesic refinement is also the ~33 s/hemi time sink, §14). Open -follow-on if more is wanted: a refinement that keeps the cut *thick* but cleans its boundary -(smooth/widen rather than thin) — untested. +faster projection (the geodesic refinement is also the ~33 s/hemi time sink, §14). Boundary +smoothing (2) is available (`project_python(morph_close=n)`) and trades a small distortion cost +for fewer flips if a flip-free patch is ever needed, but is not the default. ## Method note diff --git a/benchmark/probe_refinement.py b/benchmark/probe_refinement.py index c14b5b6..164c59f 100644 --- a/benchmark/probe_refinement.py +++ b/benchmark/probe_refinement.py @@ -67,6 +67,9 @@ "weight": "curvature", "alpha": 0.1, }, + # (2) keep cuts thick but smooth their boundary (morphological close) + {"label": "thick_close1", "continuity": True, "refine": False, "morph_close": 1}, + {"label": "thick_close2", "continuity": True, "refine": False, "morph_close": 2}, ] SURF = "fiducial" @@ -76,7 +79,7 @@ def _surface_path(subject, hemi, subjects_dir): return f"{subjects_dir}/{subject}/surf/{hemi}.{SURF}" -def run_variant(subject, hemi, spec, subjects_dir): +def run_variant(subject, hemi, spec, subjects_dir, save_flat=False): label = spec["label"] out_patch = str(paths.RUNS_DIR / f"refine_{label}_{subject}_{hemi}.patch.3d") with contextlib.redirect_stdout(io.StringIO()): @@ -88,6 +91,7 @@ def run_variant(subject, hemi, spec, subjects_dir): refine_geodesic=spec["refine"], refine_weight=spec.get("weight", "euclidean"), curv_alpha=spec.get("alpha", 0.1), + morph_close=spec.get("morph_close", 0), out_patch=out_patch, ) n_patch = len(proj["patch_vertices"]) @@ -109,6 +113,13 @@ def run_variant(subject, hemi, spec, subjects_dir): uv = np.asarray(make_flatten_fn("tutte", refine=True)(fl)) rt = time.time() - t0 + if save_flat: + flat_path = str( + paths.RUNS_DIR / f"refine_{label}_{subject}_{hemi}.flat.patch.3d" + ) + with contextlib.redirect_stdout(io.StringIO()): + fl.save_result(uv, flat_path) + m = per_patch_metrics(uv, fl) ref = compute_truegeo(fl) full = true_distortion_full(uv, ref) @@ -133,6 +144,9 @@ def main() -> int: "--subjects-dir", default="/data2/projects/idem/exps/narratives/datalad-narratives/derivatives/freesurfer", ) + ap.add_argument( + "--save-flat", action="store_true", help="save flat patches for visualization" + ) args = ap.parse_args() paths.ensure_output_dirs() @@ -148,7 +162,9 @@ def main() -> int: for vspec in variants: label = vspec["label"] try: - r = run_variant(subject, hemi, vspec, args.subjects_dir) + r = run_variant( + subject, hemi, vspec, args.subjects_dir, save_flat=args.save_flat + ) except Exception as e: # noqa: BLE001 - record the failure, keep going print(f"{subject + ' ' + hemi:11} {label:16} FAILED: {e}") rec = new_record( diff --git a/benchmark/projection.py b/benchmark/projection.py index 98c073c..5c4da74 100644 --- a/benchmark/projection.py +++ b/benchmark/projection.py @@ -171,6 +171,33 @@ def build(pts, polys): return build +def _morphological_close_cuts(fixed, polys, n_iter, exclude_keys=("mwall",)): + """Smooth each cut's boundary by morphological closing on the mesh (keep it thick). + + Phase 2 follow-on: the ablation showed the *thick* continuity-only cut beats the thinned + geodesic, but the raw mapped blob has a lumpy boundary (notches, 1-vertex spikes) where + downstream flips concentrate. Closing = dilate ``n_iter`` rings then erode ``n_iter`` rings + on the surface adjacency: it fills concave notches and bridges gaps while leaving the band + roughly the same width -- a smoother cut boundary without re-thinning. Returns a new dict. + """ + if n_iter <= 0: + return fixed + import igl + + adj = igl.adjacency_list(np.asarray(polys, dtype=np.int64)) + out = dict(fixed) + for key, verts in fixed.items(): + if key in exclude_keys or key.startswith("_") or len(verts) == 0: + continue + S = set(int(v) for v in verts) + for _ in range(n_iter): # dilate + S |= {nb for v in list(S) for nb in adj[v]} + for _ in range(n_iter): # erode (drop vertices touching outside S) + S = {v for v in S if all(nb in S for nb in adj[v])} + out[key] = np.array(sorted(S)) + return out + + def _load_template_vertex_dict(hemi, template_file=None): """Load fsaverage cut/mwall labels for a hemisphere from the JSON template.""" if template_file is None: @@ -196,6 +223,7 @@ def project_python( refine_weight="euclidean", curv_alpha=0.1, curv_morph="sulc", + morph_close=0, out_patch=None, verbose=False, ): @@ -248,6 +276,8 @@ def project_python( ) pts, polys = load_surface(subject, "inflated", hemi) + if morph_close > 0: + fixed = _morphological_close_cuts(fixed, polys, morph_close) excluded = set() for vertices in fixed.values(): excluded.update(int(v) for v in vertices) From 80e70531629908ec3af40262dfee3c1f5370b2cb Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 16:50:31 -0700 Subject: [PATCH 30/35] Fix _get_k_rings_numba: parallelize via per-chunk scratch (was serial O(n^2)) The non-angular k-ring builder allocated three O(n_vertices) arrays inside a per-vertex prange and collected results with an O(n_vertices) scan per vertex. numba's parallel analysis refuses to parallelize a loop with per-iteration allocations, so a fresh compile ran serially -> O(n^2), ~16 min on a 193k-vertex mesh (a cached build had masked this). Rewrite to parallelize over chunks: split vertices into n_chunks blocks, prange over chunks, each owns its scratch (allocated once) and resets only touched entries; collect each ring from the touched list and sort to match the reference output. Identical output (matches pure-Python get_k_ring), now ~12s and multi-core on 193k verts. Used by plotting and the uncached non-angular flatten path. 313 tests pass. Co-Authored-By: Claude Opus 4.8 --- autoflatten/flatten/distance.py | 156 ++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 60 deletions(-) diff --git a/autoflatten/flatten/distance.py b/autoflatten/flatten/distance.py index 7135780..fe5bfd8 100644 --- a/autoflatten/flatten/distance.py +++ b/autoflatten/flatten/distance.py @@ -102,9 +102,20 @@ def get_k_ring(faces, n_vertices, k): return k_rings -def _get_k_rings_numba(adj_flat, adj_offsets, k): +def _get_k_rings_numba(adj_flat, adj_offsets, k, n_chunks): """Compute k-ring neighbors for all vertices in parallel using Numba. + Parallelism is over **chunks**, not vertices: the vertices are split into ``n_chunks`` + contiguous blocks and the ``prange`` runs over chunks, so each iteration ``c`` owns its + own scratch row (no races) and the scratch (visited / BFS levels / touched list) is + allocated **once per chunk**, not per vertex. The previous version allocated three + O(n_vertices) arrays *inside* a per-vertex ``prange`` and collected results with an + O(n_vertices) scan per vertex; numba would not parallelise that (the per-iteration + allocations defeat the parallel analysis), so a fresh compile ran serially and was + O(n_vertices^2) -- ~16 min on a 193k-vertex mesh. This version is O(n_vertices * ring) + and parallelises. Output is identical: per-vertex neighbour indices sorted ascending, + excluding the source. + Parameters ---------- adj_flat : ndarray @@ -113,6 +124,8 @@ def _get_k_rings_numba(adj_flat, adj_offsets, k): Offsets into adj_flat for each vertex (length n_vertices + 1) k : int Number of rings + n_chunks : int + Number of parallel chunks (typically the thread count). Returns ------- @@ -122,34 +135,48 @@ def _get_k_rings_numba(adj_flat, adj_offsets, k): Offsets into k_rings_flat for each vertex """ n_vertices = len(adj_offsets) - 1 + chunk_size = (n_vertices + n_chunks - 1) // n_chunks - # First pass: compute sizes for each vertex - sizes = np.zeros(n_vertices, dtype=np.int64) - for v in prange(n_vertices): - visited = np.zeros(n_vertices, dtype=np.bool_) - visited[v] = True + # per-chunk scratch, allocated ONCE (not per vertex) + visited = np.zeros((n_chunks, n_vertices), dtype=np.bool_) + cur = np.empty((n_chunks, n_vertices), dtype=np.int64) + nxt = np.empty((n_chunks, n_vertices), dtype=np.int64) + touched = np.empty((n_chunks, n_vertices), dtype=np.int64) - current_level = np.empty(n_vertices, dtype=np.int64) - next_level = np.empty(n_vertices, dtype=np.int64) - current_size = 1 - current_level[0] = v - - for _ in range(k): - next_size = 0 - for i in range(current_size): - u = current_level[i] - start = adj_offsets[u] - end = adj_offsets[u + 1] - for j in range(start, end): - neighbor = adj_flat[j] - if not visited[neighbor]: - visited[neighbor] = True - next_level[next_size] = neighbor - next_size += 1 - current_level, next_level = next_level, current_level - current_size = next_size + sizes = np.zeros(n_vertices, dtype=np.int64) - sizes[v] = np.sum(visited) - 1 # -1 to exclude source vertex + # First pass: compute sizes for each vertex + for c in prange(n_chunks): + vis = visited[c] + tch = touched[c] + v_start = c * chunk_size + v_end = min(v_start + chunk_size, n_vertices) + for v in range(v_start, v_end): + cl = cur[c] + nl = nxt[c] + nt = 0 + vis[v] = True + tch[nt] = v + nt += 1 + cl[0] = v + csz = 1 + for _ in range(k): + nsz = 0 + for i in range(csz): + u = cl[i] + for j in range(adj_offsets[u], adj_offsets[u + 1]): + nb = adj_flat[j] + if not vis[nb]: + vis[nb] = True + tch[nt] = nb + nt += 1 + nl[nsz] = nb + nsz += 1 + cl, nl = nl, cl + csz = nsz + sizes[v] = nt - 1 # exclude source + for i in range(nt): + vis[tch[i]] = False # Build offsets for flat output offsets = np.zeros(n_vertices + 1, dtype=np.int64) @@ -159,38 +186,46 @@ def _get_k_rings_numba(adj_flat, adj_offsets, k): total_size = offsets[n_vertices] k_rings_flat = np.empty(total_size, dtype=np.int64) - # Second pass: fill k-rings - for v in prange(n_vertices): - visited = np.zeros(n_vertices, dtype=np.bool_) - visited[v] = True - - current_level = np.empty(n_vertices, dtype=np.int64) - next_level = np.empty(n_vertices, dtype=np.int64) - current_size = 1 - current_level[0] = v - - for _ in range(k): - next_size = 0 - for i in range(current_size): - u = current_level[i] - start = adj_offsets[u] - end = adj_offsets[u + 1] - for j in range(start, end): - neighbor = adj_flat[j] - if not visited[neighbor]: - visited[neighbor] = True - next_level[next_size] = neighbor - next_size += 1 - current_level, next_level = next_level, current_level - current_size = next_size - - # Collect results - out_start = offsets[v] - idx = 0 - for i in range(n_vertices): - if visited[i] and i != v: - k_rings_flat[out_start + idx] = i - idx += 1 + # Second pass: fill k-rings (collect from the touched list, exclude source, sort) + for c in prange(n_chunks): + vis = visited[c] + tch = touched[c] + v_start = c * chunk_size + v_end = min(v_start + chunk_size, n_vertices) + for v in range(v_start, v_end): + cl = cur[c] + nl = nxt[c] + nt = 0 + vis[v] = True + tch[nt] = v + nt += 1 + cl[0] = v + csz = 1 + for _ in range(k): + nsz = 0 + for i in range(csz): + u = cl[i] + for j in range(adj_offsets[u], adj_offsets[u + 1]): + nb = adj_flat[j] + if not vis[nb]: + vis[nb] = True + tch[nt] = nb + nt += 1 + nl[nsz] = nb + nsz += 1 + cl, nl = nl, cl + csz = nsz + out_start = offsets[v] + idx = 0 + for i in range(nt): + x = tch[i] + if x != v: + k_rings_flat[out_start + idx] = x + idx += 1 + # sort ascending to match the reference (pure-Python) output order + k_rings_flat[out_start : out_start + idx].sort() + for i in range(nt): + vis[tch[i]] = False return k_rings_flat, offsets @@ -222,8 +257,9 @@ def get_k_ring_fast(faces, n_vertices, k): for i, a in enumerate(adj): adj_offsets[i + 1] = adj_offsets[i] + len(a) - # Compute k-rings in parallel - k_rings_flat, offsets = _get_k_rings_numba(adj_flat, adj_offsets, k) + # Compute k-rings in parallel (one scratch buffer per chunk) + n_chunks = max(1, min(numba.get_num_threads(), max(1, n_vertices // 2000))) + k_rings_flat, offsets = _get_k_rings_numba(adj_flat, adj_offsets, k, n_chunks) # Convert back to list of arrays k_rings = [] From a77057153550b241118173923fbccf902d6f29af Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 16:57:39 -0700 Subject: [PATCH 31/35] Parallelize k-ring distance loop (_kring_distances_kernel) compute_kring_geodesic_distances ran the per-vertex limited Dijkstra in a serial Python list comprehension (~16s on 193k verts). Add a chunk-parallel kernel mirroring _limited_dijkstra_numba with per-chunk scratch (allocated once, only touched entries reset), producing distances in the flat k-ring layout. Bit-exact vs the serial path (max diff 0.0 over 3000 sampled verts), multi-core. Exposes get_k_ring_fast_flat to share the flat ring layout. End-to-end flatmap plot: 43s -> 27s. Helps the uncached non-angular flatten path too. 313 tests pass. Co-Authored-By: Claude Opus 4.8 --- autoflatten/flatten/distance.py | 192 +++++++++++++++++++++++++------- 1 file changed, 150 insertions(+), 42 deletions(-) diff --git a/autoflatten/flatten/distance.py b/autoflatten/flatten/distance.py index fe5bfd8..dc4c73c 100644 --- a/autoflatten/flatten/distance.py +++ b/autoflatten/flatten/distance.py @@ -110,10 +110,10 @@ def _get_k_rings_numba(adj_flat, adj_offsets, k, n_chunks): own scratch row (no races) and the scratch (visited / BFS levels / touched list) is allocated **once per chunk**, not per vertex. The previous version allocated three O(n_vertices) arrays *inside* a per-vertex ``prange`` and collected results with an - O(n_vertices) scan per vertex; numba would not parallelise that (the per-iteration + O(n_vertices) scan per vertex; numba would not parallelize that (the per-iteration allocations defeat the parallel analysis), so a fresh compile ran serially and was O(n_vertices^2) -- ~16 min on a 193k-vertex mesh. This version is O(n_vertices * ring) - and parallelises. Output is identical: per-vertex neighbour indices sorted ascending, + and parallelizes. Output is identical: per-vertex neighbor indices sorted ascending, excluding the source. Parameters @@ -249,17 +249,7 @@ def get_k_ring_fast(faces, n_vertices, k): list of ndarray k_ring[i] contains indices of vertices within k edges of vertex i """ - # Build adjacency list and flatten for Numba - adj = igl.adjacency_list(faces.astype(np.int64)) - - adj_flat = np.concatenate([np.array(a, dtype=np.int64) for a in adj]) - adj_offsets = np.zeros(n_vertices + 1, dtype=np.int64) - for i, a in enumerate(adj): - adj_offsets[i + 1] = adj_offsets[i] + len(a) - - # Compute k-rings in parallel (one scratch buffer per chunk) - n_chunks = max(1, min(numba.get_num_threads(), max(1, n_vertices // 2000))) - k_rings_flat, offsets = _get_k_rings_numba(adj_flat, adj_offsets, k, n_chunks) + k_rings_flat, offsets = get_k_ring_fast_flat(faces, n_vertices, k) # Convert back to list of arrays k_rings = [] @@ -271,6 +261,23 @@ def get_k_ring_fast(faces, n_vertices, k): return k_rings +def get_k_ring_fast_flat(faces, n_vertices, k): + """k-ring neighbors in flat (concatenated) form: ``(k_rings_flat, offsets)``. + + Same computation as :func:`get_k_ring_fast` but returns the flat arrays directly so + callers that also compute per-vertex distances can avoid rebuilding the layout. + """ + adj = igl.adjacency_list(faces.astype(np.int64)) + adj_flat = np.concatenate([np.array(a, dtype=np.int64) for a in adj]) + adj_offsets = np.zeros(n_vertices + 1, dtype=np.int64) + for i, a in enumerate(adj): + adj_offsets[i + 1] = adj_offsets[i] + len(a) + + # Compute k-rings in parallel (one scratch buffer per chunk) + n_chunks = max(1, min(numba.get_num_threads(), max(1, n_vertices // 2000))) + return _get_k_rings_numba(adj_flat, adj_offsets, k, n_chunks) + + # ============================================================================= # Numba-accelerated Dijkstra (~8x faster) # ============================================================================= @@ -436,6 +443,107 @@ def _limited_dijkstra(v, k_ring, graph, correction): return np.array([found.get(idx, np.inf) / correction for idx in k_ring]) +@njit(parallel=True, cache=True) +def _kring_distances_kernel( + indptr, indices, data, rings_flat, offsets, correction, n_chunks +): + """Parallel limited-Dijkstra k-ring distances, in the flat ``rings_flat`` layout. + + For every vertex, computes the corrected graph distance to each of its k-ring targets. + Parallelizes over chunks: each chunk owns scratch (dist / visited / target / heap) + allocated once and resets only the touched entries between vertices. A naive parallel + port of ``_limited_dijkstra_numba`` allocated five O(n_vertices) arrays per call inside + the prange, which melts the allocator across threads; this avoids that. Output matches + the serial ``_limited_dijkstra_numba`` (same correction, same heap discipline). + """ + nv = len(indptr) - 1 + n = offsets.shape[0] - 1 + INF = np.inf + + dist = np.full((n_chunks, nv), INF) + visited = np.zeros((n_chunks, nv), dtype=np.bool_) + is_target = np.zeros((n_chunks, nv), dtype=np.bool_) + touched = np.empty((n_chunks, nv), dtype=np.int64) + heap_d = np.empty((n_chunks, nv), dtype=np.float64) + heap_v = np.empty((n_chunks, nv), dtype=np.int64) + out = np.empty(offsets[n], dtype=np.float64) + + chunk_size = (n + n_chunks - 1) // n_chunks + + for c in prange(n_chunks): + d_t = dist[c] + vis = visited[c] + tgt = is_target[c] + tch = touched[c] + hd = heap_d[c] + hv = heap_v[c] + v_start = c * chunk_size + v_end = min(v_start + chunk_size, n) + + for v in range(v_start, v_end): + s = offsets[v] + e = offsets[v + 1] + m = e - s + if m == 0: + continue + + for j in range(m): + tgt[rings_flat[s + j]] = True + + nt = 0 + d_t[v] = 0.0 + tch[nt] = v + nt += 1 + hd[0] = 0.0 + hv[0] = v + hsize = 1 + found = 0 + + while hsize > 0 and found < m: + mi = 0 + md = hd[0] + for i in range(1, hsize): + if hd[i] < md: + md = hd[i] + mi = i + du = hd[mi] + u = hv[mi] + hsize -= 1 + if mi < hsize: + hd[mi] = hd[hsize] + hv[mi] = hv[hsize] + if vis[u]: + continue + vis[u] = True + if tgt[u]: + found += 1 + for p in range(indptr[u], indptr[u + 1]): + w = indices[p] + if not vis[w]: + ndist = du + data[p] + if ndist < d_t[w]: + if d_t[w] == INF: + tch[nt] = w + nt += 1 + d_t[w] = ndist + if hsize < nv: + hd[hsize] = ndist + hv[hsize] = w + hsize += 1 + + for j in range(m): + out[s + j] = d_t[rings_flat[s + j]] / correction + + for i in range(nt): + x = tch[i] + d_t[x] = INF + vis[x] = False + for j in range(m): + tgt[rings_flat[s + j]] = False + + return out + + def compute_kring_geodesic_distances( vertices, faces, k, correction=None, use_numba=True, n_threads=None, tqdm_position=0 ): @@ -483,36 +591,36 @@ def compute_kring_geodesic_distances( # Build mesh graph graph = build_mesh_graph(vertices, faces) - # Get k-ring neighbors (Numba version is ~20x faster) if use_numba: - k_rings = get_k_ring_fast(faces, n_vertices, k) - else: - k_rings = get_k_ring(faces, n_vertices, k) - - # Compute distances (Numba version is ~8x faster) - if use_numba: - distances = [ - _limited_dijkstra_numba( - graph.indptr, graph.indices, graph.data, v, k_rings[v], correction - ) - for v in tqdm( - range(n_vertices), - desc="Computing k-ring distances", - position=tqdm_position, - leave=True, - ) - ] - else: - distances = [ - _limited_dijkstra(v, k_rings[v], graph, correction) - for v in tqdm( - range(n_vertices), - desc="Computing k-ring distances", - position=tqdm_position, - leave=True, - ) - ] - + # Fully parallel path: build k-rings and distances in flat form with per-chunk + # scratch, then reconstruct the per-vertex lists. ~Ncore faster than the previous + # serial per-vertex Dijkstra loop. + rings_flat, offsets = get_k_ring_fast_flat(faces, n_vertices, k) + n_chunks = max(1, min(numba.get_num_threads(), max(1, n_vertices // 2000))) + dist_flat = _kring_distances_kernel( + graph.indptr, + graph.indices, + graph.data, + rings_flat, + offsets, + correction, + n_chunks, + ) + k_rings = [rings_flat[offsets[v] : offsets[v + 1]] for v in range(n_vertices)] + distances = [dist_flat[offsets[v] : offsets[v + 1]] for v in range(n_vertices)] + return k_rings, distances + + # Pure-Python fallback + k_rings = get_k_ring(faces, n_vertices, k) + distances = [ + _limited_dijkstra(v, k_rings[v], graph, correction) + for v in tqdm( + range(n_vertices), + desc="Computing k-ring distances", + position=tqdm_position, + leave=True, + ) + ] return k_rings, distances From a074a8cc729b0dbd2e10a256075a064aa5775eba Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 17:00:42 -0700 Subject: [PATCH 32/35] Vectorize per-vertex distortion loop in viz.compute_kring_distortion Replace the 193k-iteration Python loop with a flattened numpy computation (concatenate neighbor lists, segment-sum via bincount). Per-vertex distortion == 100 * sum_abs / sum_target since the neighbor count cancels. Matches the old loop (max per-vertex diff 8.5e-14). The loop is now ~0.2s. With the parallel k-ring (ring builder + distances), end-to-end flatmap plot is ~22.6s (was 43s). 313 tests pass. Co-Authored-By: Claude Opus 4.8 --- autoflatten/viz.py | 56 +++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/autoflatten/viz.py b/autoflatten/viz.py index 24f74c0..04e524a 100644 --- a/autoflatten/viz.py +++ b/autoflatten/viz.py @@ -203,39 +203,35 @@ def compute_kring_distortion( if verbose: print("Computing per-vertex distortion...") - # Compute per-vertex distortion - vertex_distortion = np.zeros(n_patch_vertices) - total_abs_error = 0.0 - total_target = 0.0 - - for v in range(n_patch_vertices): - neighbors = k_rings[v] - targets = target_distances[v] - - if len(neighbors) == 0: - vertex_distortion[v] = 0.0 - continue - - # Compute 2D Euclidean distances to neighbors - d_2d = np.linalg.norm(xy[neighbors] - xy[v], axis=1) + # Compute per-vertex distortion (vectorized over the flattened neighbor lists). + # Per-vertex: 100 * mean(|d_2D - d_3D|) / mean(d_3D) == 100 * sum_abs / sum_target, + # since the per-vertex neighbor count cancels. + counts = np.fromiter( + (len(r) for r in k_rings), dtype=np.int64, count=n_patch_vertices + ) + src = np.repeat(np.arange(n_patch_vertices), counts) + nbr = ( + np.concatenate(k_rings).astype(np.int64) + if counts.sum() + else np.empty(0, np.int64) + ) + tgt = ( + np.concatenate(target_distances) + if counts.sum() + else np.empty(0, dtype=np.float64) + ) - # Compute per-vertex distortion: 100 * mean(|d_2D - d_3D|) / mean(d_3D) - mean_target = np.mean(targets) - if mean_target > 0.0: - abs_errors = np.abs(d_2d - targets) - vertex_distortion[v] = 100.0 * np.mean(abs_errors) / mean_target + d_2d = np.linalg.norm(xy[nbr] - xy[src], axis=1) + abs_err = np.abs(d_2d - tgt) + sum_abs = np.bincount(src, weights=abs_err, minlength=n_patch_vertices) + sum_tgt = np.bincount(src, weights=tgt, minlength=n_patch_vertices) - # Accumulate for global mean - total_abs_error += np.sum(abs_errors) - total_target += np.sum(targets) - else: - # If all target distances are zero, define local distortion as zero - vertex_distortion[v] = 0.0 + vertex_distortion = np.zeros(n_patch_vertices) + valid = sum_tgt > 0.0 + vertex_distortion[valid] = 100.0 * sum_abs[valid] / sum_tgt[valid] - # Global mean distortion (same formula as autoflatten) - mean_distortion = ( - 100.0 * total_abs_error / total_target if total_target > 0 else 0.0 - ) + total_target = sum_tgt.sum() + mean_distortion = 100.0 * sum_abs.sum() / total_target if total_target > 0 else 0.0 return vertex_distortion, mean_distortion From f53c77b658fa713248a3e713295d03233593fd34 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 17:07:55 -0700 Subject: [PATCH 33/35] Add optimal-scale normalization + signed distortion to flatmap overlay The flatmap distortion overlay now (by default in plot_flatmap): - rescales by the distance-optimal global scale s* before scoring, so a global scale offset no longer inflates the per-vertex map (shows true local shape distortion; s* is reported in the panel title); - shows SIGNED relative distortion with a diverging colormap (RdBu_r): + = flatmap stretched, - = compressed, instead of magnitude-only. compute_kring_distortion gains optimal_scale/signed/return_opt_scale flags; defaults preserve the original magnitude/raw-scale behavior (viz tests and the scaling-sensitivity test unchanged). 313 tests pass. Co-Authored-By: Claude Opus 4.8 --- autoflatten/viz.py | 147 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 113 insertions(+), 34 deletions(-) diff --git a/autoflatten/viz.py b/autoflatten/viz.py index 04e524a..7c552c3 100644 --- a/autoflatten/viz.py +++ b/autoflatten/viz.py @@ -118,6 +118,9 @@ def compute_kring_distortion( orig_indices, k=2, n_samples_per_ring=None, + optimal_scale=False, + signed=False, + return_opt_scale=False, verbose=True, ): """Compute per-vertex metric distortion using k-ring geodesic distances. @@ -145,15 +148,26 @@ def compute_kring_distortion( n_samples_per_ring : int or None Angular samples per ring. If None, use all neighbors without angular sampling (default: None, faster). Use 12 for pyflatten-style sampling. + optimal_scale : bool + If True, rescale the flatmap by the global scale that minimizes distortion + before scoring, so a global scale offset does not inflate the per-vertex map + (default: False, preserves the raw-scale metric). + signed : bool + If True, return signed relative distortion per vertex (+ stretched, - compressed) + instead of magnitude (default: False). + return_opt_scale : bool + If True, also return the optimal scale as a third value (default: False). verbose : bool Print progress messages Returns ------- vertex_distortion : ndarray of shape (N,) - Percentage distortion at each vertex + Per-vertex distortion (magnitude %, or signed % if ``signed``) mean_distortion : float - Overall mean percentage distortion (same formula as autoflatten) + Overall mean magnitude distortion (same formula as autoflatten) + opt_scale : float + Distance-optimal scale (only if ``return_opt_scale``) """ n_patch_vertices = len(xy) @@ -222,17 +236,43 @@ def compute_kring_distortion( ) d_2d = np.linalg.norm(xy[nbr] - xy[src], axis=1) - abs_err = np.abs(d_2d - tgt) - sum_abs = np.bincount(src, weights=abs_err, minlength=n_patch_vertices) - sum_tgt = np.bincount(src, weights=tgt, minlength=n_patch_vertices) - vertex_distortion = np.zeros(n_patch_vertices) + # Optional: rescale the flatmap by the single global scale s* that minimizes distortion + # (the distance-optimal scale), so a global scale offset does not inflate every vertex + # and the map shows true local shape distortion. + opt_scale = 1.0 + if optimal_scale and d_2d.size: + pos = tgt > 0 + if pos.any(): + d, t = d_2d[pos], tgt[pos] + scales = np.linspace(0.85, 1.20, 71) + errs = np.array([np.mean(np.abs(s * d - t) / t) for s in scales]) + opt_scale = float(scales[int(np.argmin(errs))]) + d_scaled = opt_scale * d_2d + + sum_tgt = np.bincount(src, weights=tgt, minlength=n_patch_vertices) + sum_abs = np.bincount( + src, weights=np.abs(d_scaled - tgt), minlength=n_patch_vertices + ) valid = sum_tgt > 0.0 - vertex_distortion[valid] = 100.0 * sum_abs[valid] / sum_tgt[valid] + # Summary mean is always the magnitude distortion (for the figure subtitle). total_target = sum_tgt.sum() mean_distortion = 100.0 * sum_abs.sum() / total_target if total_target > 0 else 0.0 + vertex_distortion = np.zeros(n_patch_vertices) + if signed: + # Signed relative distortion: + = flatmap stretched (2D longer than 3D), + # - = compressed. Same tgt-weighting as the magnitude metric. + sum_d = np.bincount(src, weights=d_scaled, minlength=n_patch_vertices) + vertex_distortion[valid] = ( + 100.0 * (sum_d[valid] - sum_tgt[valid]) / sum_tgt[valid] + ) + else: + vertex_distortion[valid] = 100.0 * sum_abs[valid] / sum_tgt[valid] + + if return_opt_scale: + return vertex_distortion, mean_distortion, opt_scale return vertex_distortion, mean_distortion @@ -308,6 +348,7 @@ def plot_flatmap( show_boundary=True, distortion_cmap="viridis", distance_method="fast", + signed=True, dpi=150, ): """ @@ -394,13 +435,16 @@ def plot_flatmap( "Must be 'fast' or 'pyflatten'." ) - vertex_dist, mean_dist = compute_kring_distortion( + vertex_dist, mean_dist, opt_scale = compute_kring_distortion( xy, base_vertices, base_faces, orig_indices, k=k, n_samples_per_ring=n_samples, + optimal_scale=True, + signed=signed, + return_opt_scale=True, verbose=True, ) @@ -506,27 +550,40 @@ def plot_flatmap( # Center plot: Per-vertex metric distortion (percentage) ax = axes[1] - # Fixed color limits: 0-100% - vmin = 0 - vmax = 100 + # Signed -> diverging colormap with symmetric limits (robust 98th pct of |value|); + # magnitude -> fixed 0-100% with the sequential colormap. + if signed: + vlim = ( + float(np.percentile(np.abs(vertex_dist), 98)) if vertex_dist.size else 1.0 + ) + vlim = max(vlim, 1.0) + vmin, vmax = -vlim, vlim + cmap = "RdBu_r" + cbar_label = "Signed distortion (%) (+ stretched / − compressed)" + center_title = f"Signed Distortion ({k}-ring, ×{opt_scale:.3f})" + else: + vmin, vmax = 0, 100 + cmap = distortion_cmap + cbar_label = "Distortion (%)" + center_title = f"Metric Distortion ({k}-ring, ×{opt_scale:.3f})" # Use tripcolor with vertex values for smooth interpolation tpc = ax.tripcolor( triang, vertex_dist, shading="gouraud", - cmap=distortion_cmap, + cmap=cmap, vmin=vmin, vmax=vmax, ) # Create colorbar - fig.colorbar(tpc, ax=ax, label="Distortion (%)", shrink=0.8) + fig.colorbar(tpc, ax=ax, label=cbar_label, shrink=0.8) ax.set_aspect("equal") ax.set_xlabel("X (mm)") ax.set_ylabel("Y (mm)") - ax.set_title(f"Metric Distortion ({k}-ring)") + ax.set_title(center_title) # Right plot: Histogram of distortion distribution ax = axes[2] @@ -537,35 +594,57 @@ def plot_flatmap( hist, bin_edges = np.histogram(vertex_dist_clipped, bins=n_bins, range=(vmin, vmax)) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 - # Color bars by distortion value using same colormap + # Color bars by distortion value using same colormap as the map panel norm = plt.Normalize(vmin=vmin, vmax=vmax) - cmap_obj = plt.colormaps[distortion_cmap] + cmap_obj = plt.colormaps[cmap] colors = cmap_obj(norm(bin_centers)) ax.bar( bin_centers, hist, width=np.diff(bin_edges)[0], color=colors, edgecolor="none" ) - # Add mean and median lines (use weighted mean_dist from compute_kring_distortion) - median_dist = np.median(vertex_dist) - ax.axvline( - x=mean_dist, - color="black", - linestyle="--", - linewidth=2, - label=f"Mean: {mean_dist:.1f}%", - ) - ax.axvline( - x=median_dist, - color="gray", - linestyle=":", - linewidth=1.5, - label=f"Median: {median_dist:.1f}%", - ) + if signed: + # 0 reference + mean signed; the magnitude mean is reported in the subtitle. + mean_signed = float(np.mean(vertex_dist)) + ax.axvline(x=0.0, color="black", linewidth=1.0) + ax.axvline( + x=mean_signed, + color="black", + linestyle="--", + linewidth=2, + label=f"Mean: {mean_signed:+.1f}%", + ) + ax.axvline( + x=float(np.median(vertex_dist)), + color="gray", + linestyle=":", + linewidth=1.5, + label=f"Median: {np.median(vertex_dist):+.1f}%", + ) + xlabel = "Signed distortion (%)" + dist_title = f"Signed Distortion ({k}-ring)" + else: + median_dist = np.median(vertex_dist) + ax.axvline( + x=mean_dist, + color="black", + linestyle="--", + linewidth=2, + label=f"Mean: {mean_dist:.1f}%", + ) + ax.axvline( + x=median_dist, + color="gray", + linestyle=":", + linewidth=1.5, + label=f"Median: {median_dist:.1f}%", + ) + xlabel = "Distortion (%)" + dist_title = f"Distortion Distribution ({k}-ring)" - ax.set_xlabel("Distortion (%)") + ax.set_xlabel(xlabel) ax.set_ylabel("Vertex Count") - ax.set_title(f"Distortion Distribution ({k}-ring)") + ax.set_title(dist_title) ax.legend(loc="upper right", fontsize=8) ax.set_xlim(vmin, vmax) From 2acdd60ce891d77faa7ba949f833ad3bd6dfd3c7 Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 17:29:45 -0700 Subject: [PATCH 34/35] Distance-optimal output scale: make saved flatmaps metrically faithful (default on) The flattener's final scale was area-matched (s = sqrt(orig_area/total_area)), a display convention that leaves the map ~6% too small vs true geodesic distances (validated: opt scale 1.063 +/- 0.008 across 6 hemispheres, expand). Add distance_optimal_scale: after optimization, rescale the output by the single global scale that minimizes true-geodesic distortion, computed from a deterministic heat-method geodesic sample on the fiducial patch. So the surface and the flat map are directly comparable (the use case: viewing inflated + flattened together). New DistanceOptimalScaleConfig (enabled by default, n_sources=200, seed=0); distance_optimal_scale() in distance.py; applied in SurfaceFlattener.run() before final stats. Deterministic (fixed seed; heat method is deterministic). End-to-end: s=1.055 on sub-022 lh, true distortion 13.87% -> 13.23% (-0.64pp). viz.plot_flatmap now scores at the true (already-faithful) output scale instead of re-normalizing by the gameable local k-ring optimum. 313 tests pass. Co-Authored-By: Claude Opus 4.8 --- autoflatten/flatten/algorithm.py | 23 ++++++++++++ autoflatten/flatten/config.py | 40 +++++++++++++++++++++ autoflatten/flatten/distance.py | 62 ++++++++++++++++++++++++++++++++ autoflatten/viz.py | 12 ++++--- 4 files changed, 132 insertions(+), 5 deletions(-) diff --git a/autoflatten/flatten/algorithm.py b/autoflatten/flatten/algorithm.py index 837567c..d4741b9 100644 --- a/autoflatten/flatten/algorithm.py +++ b/autoflatten/flatten/algorithm.py @@ -22,6 +22,7 @@ from .distance import ( compute_kring_geodesic_distances, compute_kring_geodesic_distances_angular, + distance_optimal_scale, ) from .energy import ( compute_2d_areas, @@ -1837,6 +1838,28 @@ def run(self, snapshot_callback: Callable | None = None) -> np.ndarray: snapshot_callback=_wrap_callback(snapshot_callback, "smoothing"), ) + # Distance-optimal output scale: replace the area-matched display scale with the + # single global scale that minimizes true-geodesic distance distortion, so the saved + # flat map is metrically faithful (surface and flat map are directly comparable). + dos = config.distance_optimal_scale + if dos.enabled: + ref_vertices = ( + self.fiducial_vertices + if self.fiducial_vertices is not None + else self.vertices + ) + s_opt = distance_optimal_scale( + ref_vertices, + self.faces, + uv, + n_sources=dos.n_sources, + seed=dos.seed, + ) + centroid = uv.mean(axis=0) + uv = (uv - centroid) * s_opt + centroid + if verbose: + print(f"Distance-optimal output scale: x{s_opt:.4f}") + # Final stats uv_jax = jnp.asarray(uv) n_flipped_final = int(count_flipped_triangles(uv_jax, self.faces_jax)) diff --git a/autoflatten/flatten/config.py b/autoflatten/flatten/config.py index c1539d4..8345e64 100644 --- a/autoflatten/flatten/config.py +++ b/autoflatten/flatten/config.py @@ -218,6 +218,34 @@ class FinalNegativeAreaRemovalConfig: iters_per_level: int = 30 +@dataclass +class DistanceOptimalScaleConfig: + """Configuration for the distance-optimal output scale. + + The flattening's area-matching final scale (``s = sqrt(orig_area/total_area)``) is a + display convention, not part of the objective; it leaves the map ~6% too small versus + true geodesic distances. When enabled (default), after optimization the output is + rescaled by the single global scale that minimizes true-geodesic distance distortion, + computed from a heat-method geodesic sample on the patch, so the saved flatmap is + metrically faithful (surface and flat map are directly comparable). The optimum is tight + across subjects (~1.06, std ~0.008) and reduces global distance distortion on every + benchmark hemisphere. + + Attributes + ---------- + enabled : bool + Whether to apply the distance-optimal rescale (default: True). + n_sources : int + Number of heat-geodesic source vertices sampled (deterministic). + seed : int + RNG seed for source sampling (keeps the result deterministic). + """ + + enabled: bool = True + n_sources: int = 200 + seed: int = 0 + + def _default_phases() -> list[PhaseConfig]: """Return default optimization phases matching FreeSurfer's 3 epochs. @@ -294,6 +322,9 @@ class FlattenConfig: spring_smoothing: SpringSmoothingConfig = field( default_factory=SpringSmoothingConfig ) + distance_optimal_scale: DistanceOptimalScaleConfig = field( + default_factory=DistanceOptimalScaleConfig + ) phases: list[PhaseConfig] = field(default_factory=_default_phases) print_every: int = 100 verbose: bool = True @@ -341,6 +372,11 @@ def to_dict(self) -> dict: "max_step_mm": self.spring_smoothing.max_step_mm, "enabled": self.spring_smoothing.enabled, }, + "distance_optimal_scale": { + "enabled": self.distance_optimal_scale.enabled, + "n_sources": self.distance_optimal_scale.n_sources, + "seed": self.distance_optimal_scale.seed, + }, "phases": [ { "name": p.name, @@ -376,6 +412,9 @@ def from_dict(cls, data: dict) -> "FlattenConfig": **data.get("final_negative_area_removal", {}) ) spring_smoothing = SpringSmoothingConfig(**data.get("spring_smoothing", {})) + distance_optimal_scale = DistanceOptimalScaleConfig( + **data.get("distance_optimal_scale", {}) + ) phases_data = data.get("phases", _default_phases()) phases = [ p if isinstance(p, PhaseConfig) else PhaseConfig(**p) for p in phases_data @@ -387,6 +426,7 @@ def from_dict(cls, data: dict) -> "FlattenConfig": negative_area_removal=negative_area_removal, final_negative_area_removal=final_negative_area_removal, spring_smoothing=spring_smoothing, + distance_optimal_scale=distance_optimal_scale, phases=phases, print_every=data.get("print_every", 100), verbose=data.get("verbose", True), diff --git a/autoflatten/flatten/distance.py b/autoflatten/flatten/distance.py index dc4c73c..b384703 100644 --- a/autoflatten/flatten/distance.py +++ b/autoflatten/flatten/distance.py @@ -62,6 +62,68 @@ def build_mesh_graph(vertices, faces): return sparse.csr_matrix((data, (row, col)), shape=(n_vertices, n_vertices)) +def distance_optimal_scale(vertices, faces, uv, n_sources=200, seed=0, radius=None): + """Global scale ``s*`` that makes a flat map metrically match true geodesic distances. + + Samples ``n_sources`` source vertices (deterministically), computes their true geodesic + fields on the 3D patch with the heat method, and returns the single scale ``s`` that + minimizes ``mean(|s*d_2d - d_geo| / d_geo)`` over all source->target pairs (optionally + capped at ``radius`` mm). Used to replace the area-matched display scale with a + distance-faithful one (the area-matched map is ~6% too small; ``s*`` ~ 1.06). + + Parameters + ---------- + vertices : ndarray (V, 3) + 3D patch vertex positions (use the fiducial surface for anatomical distances). + faces : ndarray (F, 3) + Patch face indices. + uv : ndarray (V, 2) + 2D flat-map coordinates (same vertex order as ``vertices``). + n_sources : int + Number of heat-geodesic sources to sample. + seed : int + RNG seed (keeps the result deterministic). + radius : float or None + If set, only score pairs within this geodesic distance (mm). + + Returns + ------- + float + Distance-optimal scale (1.0 if it cannot be computed). + """ + v = np.ascontiguousarray(vertices, dtype=np.float64) + f = np.ascontiguousarray(faces, dtype=np.int64) + uv = np.ascontiguousarray(uv, dtype=np.float64) + n_v = v.shape[0] + if n_v == 0: + return 1.0 + + rng = np.random.default_rng(seed) + srcs = np.sort(rng.choice(n_v, size=min(n_sources, n_v), replace=False)) + + data = igl.HeatGeodesicsData() + igl.heat_geodesics_precompute(v, f, data) + + d2_all, dg_all = [], [] + for s in srcs: + geo = igl.heat_geodesics_solve(data, np.array([s], dtype=np.int64)) + mask = geo > 1e-6 + if radius is not None: + mask &= geo <= radius + if not np.any(mask): + continue + d2_all.append(np.linalg.norm(uv[mask] - uv[s], axis=1)) + dg_all.append(geo[mask]) + + if not d2_all: + return 1.0 + d2 = np.concatenate(d2_all) + dg = np.concatenate(dg_all) + scales = np.linspace(0.85, 1.20, 71) + errs = np.array([np.mean(np.abs(sc * d2 - dg) / dg) for sc in scales]) + return float(scales[int(np.argmin(errs))]) + + def get_k_ring(faces, n_vertices, k): """Get k-ring neighbors for each vertex. diff --git a/autoflatten/viz.py b/autoflatten/viz.py index 7c552c3..c852255 100644 --- a/autoflatten/viz.py +++ b/autoflatten/viz.py @@ -435,16 +435,18 @@ def plot_flatmap( "Must be 'fast' or 'pyflatten'." ) - vertex_dist, mean_dist, opt_scale = compute_kring_distortion( + # The flattener now applies the distance-optimal scale at generation, so the saved + # flat map is already metrically faithful -- score at its true scale (no local + # re-normalization, which uses the gameable local k-ring optimum and points the wrong + # way). Pass optimal_scale=True only to diagnose an old, un-rescaled patch. + vertex_dist, mean_dist = compute_kring_distortion( xy, base_vertices, base_faces, orig_indices, k=k, n_samples_per_ring=n_samples, - optimal_scale=True, signed=signed, - return_opt_scale=True, verbose=True, ) @@ -560,12 +562,12 @@ def plot_flatmap( vmin, vmax = -vlim, vlim cmap = "RdBu_r" cbar_label = "Signed distortion (%) (+ stretched / − compressed)" - center_title = f"Signed Distortion ({k}-ring, ×{opt_scale:.3f})" + center_title = f"Signed Distortion ({k}-ring)" else: vmin, vmax = 0, 100 cmap = distortion_cmap cbar_label = "Distortion (%)" - center_title = f"Metric Distortion ({k}-ring, ×{opt_scale:.3f})" + center_title = f"Metric Distortion ({k}-ring)" # Use tripcolor with vertex values for smooth interpolation tpc = ax.tripcolor( From 298881361aca50e84f14ee687253d596655a319b Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Thu, 11 Jun 2026 17:31:04 -0700 Subject: [PATCH 35/35] =?UTF-8?q?Findings=20=C2=A716:=20distance-optimal?= =?UTF-8?q?=20output=20scale=20shipped=20(multi-hemi=20validated)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Document the shipped distance-optimal output scale: true-geodesic optimum 1.063 +/- 0.008 (expand) across 6 hemispheres, lowers global distortion on every one; replaces area-matching as the default. Note the sign correction (local k-ring optimum points the wrong way, ~0.95 shrink, due to target compaction). Update executive summary + Next-ideas (marked shipped). Co-Authored-By: Claude Opus 4.8 --- benchmark/FINDINGS.md | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/benchmark/FINDINGS.md b/benchmark/FINDINGS.md index 131ba03..c3d2407 100644 --- a/benchmark/FINDINGS.md +++ b/benchmark/FINDINGS.md @@ -24,7 +24,8 @@ provenance) is the ledger at `/data2/projects/autoflatten/ledger/experiments.jso - **Where the implementation drifts from that objective:** the `Dijkstra/1.207` target correction is slightly too compact and `scale_to_area` is not distance-optimal (§9d) — but recalibrating the correction is **not a robust default** (subject-specific optimum, unpredictable from local geometry; - §10–11). The only consistently-safe tweak is a distance-optimal output scale (small). + §10–11). The only consistently-safe tweak is a distance-optimal output scale — now **shipped as + the default** (§16: true-geodesic `s*≈1.06`, multi-hemi validated). - **Net:** the pipeline is well-tuned; config-lever gains on distance error are small and non-robust. The one remaining path to a *larger robust* reduction is **long-range geodesic anchors in the energy** (the paper's own argument) — a real energy change, left as future work. @@ -559,6 +560,30 @@ faster projection (the geodesic refinement is also the ~33 s/hemi time sink, §1 smoothing (2) is available (`project_python(morph_close=n)`) and trades a small distortion cost for fewer flips if a flip-free patch is ever needed, but is not the default. +## 16. Distance-optimal output scale — shipped as the default + +§9d found the area-matched final scale (`s = √(orig_area/total_area)`) is a display +convention that leaves the map metrically off; §9 left "distance-optimal output scale" as a +small, safe, unshipped lever. Now validated multi-hemi and shipped. + +**Multi-hemi validation** (true-geodesic optimum on 6 hemispheres, fresh heat-method ref per +hemi): the area-matched map is consistently **~6% too small** — opt scale **1.063 ± 0.008** +(range 1.050–1.075, *expand*), tight across subjects (so **not** subject-specific in a way that +would make a single correction unsafe), and it lowers global true distortion on **every** +hemisphere. (Note: §9's single-hemi estimate was ~1.04; the broader sample lands at ~1.06.) + +**Important sign correction:** the *local k-ring* optimum points the **wrong way** (~0.95–0.98, +*shrink*) because of the `Dijkstra/1.207` target compaction (§9a/§10). The faithful (true +geodesic) optimum is *expand*. The flatmap viz was briefly normalizing by the local optimum — +fixed to score at the true output scale instead. + +**Shipped:** `DistanceOptimalScaleConfig` (enabled by default), `distance.distance_optimal_scale()` +(deterministic heat-method sample, fixed seed), applied in `SurfaceFlattener.run()` after the +final NAR/smoothing. End-to-end on sub-022 lh: `s=1.055`, true global distortion 13.87% → 13.23% +(−0.64 pp), deterministic on rerun. The saved flat map is now directly comparable to the surface +(the use case: viewing inflated + flattened together). Disable with +`config.distance_optimal_scale.enabled = False`. + ## Method note Determinism confirmed bit-identical across reruns, so a single run per experiment is sound. @@ -575,7 +600,7 @@ Remaining, in priority order (config levers are exhausted — see §10–11): the global metric (Fischl 1999's own point that long-range distances are needed to unfold). Requires modifying the energy/optimizer (not a config knob) and weight-tuning; validate with train/test-split geodesic sources across subjects. Uncertain but the only path to a larger robust reduction. -- **Distance-optimal output scale** (small, safe): replace `scale_to_area` with the scale that minimizes - distance distortion (a 1-parameter minimization, ≥0 by construction; `s*`≈0.995–1.02). Cheap to ship. +- ~~**Distance-optimal output scale**~~ — **SHIPPED** (§16): distance-optimal scale (true-geodesic, + `s*≈1.06`, default on) now replaces the area-matched display scale; multi-hemi validated. - **Ship the validated speed defaults**: Tutte init as default `initial_projection` with initial NAR off (§1), plus the §7 lean line-search / sparse k-ring levers.