Skip to content

FIX: Parallelize DKI Fit/Predict for Usable Multi-Shell Motion Estimation#443

Open
oesteban wants to merge 3 commits into
nipreps:mainfrom
oesteban:fix/dki-speed
Open

FIX: Parallelize DKI Fit/Predict for Usable Multi-Shell Motion Estimation#443
oesteban wants to merge 3 commits into
nipreps:mainfrom
oesteban:fix/dki-speed

Conversation

@oesteban

@oesteban oesteban commented Jun 9, 2026

Copy link
Copy Markdown
Member

Summary

Used as the head-motion/eddy estimator for dMRI, the DKI model was effectively unusable on real multi-shell data (~8.6 vol/h ⇒ ~32 h/series, appearing hung), while DTI on the same data runs in minutes. This PR makes DKI parallelize correctly, recovering a ~5.8× speedup (more with more workers) with numerically identical results.

Resolves: #442

Root cause

DKI was hard-routed onto the serial full-brain fit path (if n_jobs == 1 or is_dki:), so per-volume cost grew linearly with voxel count while DTI's np.array_split + joblib path stayed flat (measured 6×→40×→119× slower as voxels grew 2k→20k→80k). The serial special-case existed for a real reason: DKI's MultiVoxelFit cannot be pickled across loky's process boundary, so naively reusing DTI's fit/predict-split path crashes (BrokenProcessPool / RecursionError). BLAS thread-capping was ruled out as a latency factor (no effect — the fit is GIL-bound per-voxel Python work). Full diagnosis with timings in #442.

Fix

Add an exact, in-worker parallel path for models whose fitted object cannot cross a process boundary:

  • New _exec_fit_predict(...) worker: builds the model, fits its voxel chunk, predicts the held-out gradient, and returns only the predicted ndarray — no model instance is ever serialized.
  • New BaseDWIModel._fit_predict_chunked(index, n_jobs): splits voxels (and the aligned S0) across workers, runs the worker per chunk, and re-assembles the prediction. Because voxel-wise fitting is independent, the result is bit-identical to the serial path.
  • fit_predict routes to this path when not self._picklable_fit and n_jobs > 1. DKIModel sets _picklable_fit = False; all other models keep the existing behavior.
  • Removed the unreachable elif is_dki: (dead model.multi_fit branch) and factored the shared data prep into _lovo_data.

The DTI fit/predict-split path is untouched. DKI at n_jobs == 1, direct _fit calls, and single_fit are unchanged.

Benchmark (synthetic multi-shell, b0 + 3 shells, n_jobs=8)

voxels DKI serial (before) DKI parallel (after) speedup
20,000 19.3 s/vol 3.3 s/vol 5.8×

Per-volume output matches the serial path exactly (max|diff| = 0).

Tests

  • New test_dki_parallel_matches_serial (index ∈ {4,9} × use_mask ∈ {False,True}) asserts fit_predict(index, n_jobs=4) (chunked) equals fit_predict(index, n_jobs=1) (serial).
  • Full test_model_dmri.py DKI/DTI suite green (132 passed).

Notes

  • single_fit (fit-once, predict-each) could cut runtime further but is approximate and currently buggy for n_jobs > 1; left for a follow-up.
  • CHANGES.rst intentionally untouched (release-generated).

DKI was forced onto the serial full-brain fit path, making volume-to-volume
motion estimation grow linearly with voxel count (~120x slower than DTI on
multi-shell data; ~32 h/series, effectively hung). Its fitted MultiVoxelFit
object cannot be pickled across loky's process boundary, so the fit/predict-
split parallelization DTI uses crashes for DKI.

Add an exact in-worker fit+predict path (_fit_predict_chunked) for models
flagged `_picklable_fit = False`: each worker fits its voxel chunk and predicts
the held-out gradient, returning only the (picklable) predicted array. Voxel-
wise fitting is independent, so results are numerically identical to the serial
path (max|diff| = 0) at ~5.8x speedup (n_jobs=8). Also remove the unreachable
`elif is_dki` branch and factor out _lovo_data.

Resolves: nipreps#442

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@codecov

codecov Bot commented Jun 9, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 81.08108% with 7 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@85e3a45). Learn more about missing BASE report.
⚠️ Report is 94 commits behind head on main.

Files with missing lines Patch % Lines
src/nifreeze/model/dmri.py 81.08% 6 Missing and 1 partial ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #443   +/-   ##
=======================================
  Coverage        ?   84.83%           
=======================================
  Files           ?       37           
  Lines           ?     2183           
  Branches        ?      245           
=======================================
  Hits            ?     1852           
  Misses          ?      295           
  Partials        ?       36           

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Record (niprepsgh-442) that the in-worker chunking for DKI can be replaced by
DIPY's own multi_voxel engine path once a release forwards `engine` through
`DKIModel.fit` and stops leaking orchestration kwargs into the per-voxel
kernel (both broken in DIPY <= 1.10.0; arokem's suggested route).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@oesteban

oesteban commented Jun 9, 2026

Copy link
Copy Markdown
Member Author

Re: @arokem's suggestion to use DIPY's multi_voxel decorator (Ray/engine) — I tested it against the pinned DIPY (1.10.0) and it isn't usable on our current floor: DKIModel.fit() doesn't accept engine, and the reachable multi_fit(engine=...) path hits a decorator kwargs-leak bug (engine forwarded into ls_fit_dki()); ray also isn't a dependency. Details in #442 (comment). This PR's in-worker chunking is therefore the interim that works on dipy>=1.5; a code comment marks the DIPY-native path as the follow-up once a release supports it.

The fit happens inside the worker and only the predicted array is returned,
so no model instance is ever serialized. See gh-442.
"""
module_name, class_name = model_class.rsplit(".", 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

Suggested change
module_name, class_name = model_class.rsplit(".", 1)
module_name, class_name = model_class.rsplit(".", -1)

in case of dipy.reconst.dti and similar?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's right split, should be 1, right?

Comment thread src/nifreeze/model/dmri.py Outdated
Co-authored-by: Ariel Rokem <arokem@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DKI head-motion estimation is impractically slow (~8.6 vol/h, ~32 h/series)

2 participants