FIX: Parallelize DKI Fit/Predict for Usable Multi-Shell Motion Estimation#443
FIX: Parallelize DKI Fit/Predict for Usable Multi-Shell Motion Estimation#443oesteban wants to merge 3 commits into
Conversation
DKI was forced onto the serial full-brain fit path, making volume-to-volume motion estimation grow linearly with voxel count (~120x slower than DTI on multi-shell data; ~32 h/series, effectively hung). Its fitted MultiVoxelFit object cannot be pickled across loky's process boundary, so the fit/predict- split parallelization DTI uses crashes for DKI. Add an exact in-worker fit+predict path (_fit_predict_chunked) for models flagged `_picklable_fit = False`: each worker fits its voxel chunk and predicts the held-out gradient, returning only the (picklable) predicted array. Voxel- wise fitting is independent, so results are numerically identical to the serial path (max|diff| = 0) at ~5.8x speedup (n_jobs=8). Also remove the unreachable `elif is_dki` branch and factor out _lovo_data. Resolves: nipreps#442 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #443 +/- ##
=======================================
Coverage ? 84.83%
=======================================
Files ? 37
Lines ? 2183
Branches ? 245
=======================================
Hits ? 1852
Misses ? 295
Partials ? 36 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Record (niprepsgh-442) that the in-worker chunking for DKI can be replaced by DIPY's own multi_voxel engine path once a release forwards `engine` through `DKIModel.fit` and stops leaking orchestration kwargs into the per-voxel kernel (both broken in DIPY <= 1.10.0; arokem's suggested route). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Re: @arokem's suggestion to use DIPY's |
| The fit happens inside the worker and only the predicted array is returned, | ||
| so no model instance is ever serialized. See gh-442. | ||
| """ | ||
| module_name, class_name = model_class.rsplit(".", 1) |
There was a problem hiding this comment.
Maybe
| module_name, class_name = model_class.rsplit(".", 1) | |
| module_name, class_name = model_class.rsplit(".", -1) |
in case of dipy.reconst.dti and similar?
There was a problem hiding this comment.
Since it's right split, should be 1, right?
Co-authored-by: Ariel Rokem <arokem@gmail.com>
Summary
Used as the head-motion/eddy estimator for dMRI, the DKI model was effectively unusable on real multi-shell data (~8.6 vol/h ⇒ ~32 h/series, appearing hung), while DTI on the same data runs in minutes. This PR makes DKI parallelize correctly, recovering a ~5.8× speedup (more with more workers) with numerically identical results.
Resolves: #442
Root cause
DKI was hard-routed onto the serial full-brain fit path (
if n_jobs == 1 or is_dki:), so per-volume cost grew linearly with voxel count while DTI'snp.array_split+joblibpath stayed flat (measured 6×→40×→119× slower as voxels grew 2k→20k→80k). The serial special-case existed for a real reason: DKI'sMultiVoxelFitcannot be pickled across loky's process boundary, so naively reusing DTI's fit/predict-split path crashes (BrokenProcessPool/RecursionError). BLAS thread-capping was ruled out as a latency factor (no effect — the fit is GIL-bound per-voxel Python work). Full diagnosis with timings in #442.Fix
Add an exact, in-worker parallel path for models whose fitted object cannot cross a process boundary:
_exec_fit_predict(...)worker: builds the model, fits its voxel chunk, predicts the held-out gradient, and returns only the predicted ndarray — no model instance is ever serialized.BaseDWIModel._fit_predict_chunked(index, n_jobs): splits voxels (and the alignedS0) across workers, runs the worker per chunk, and re-assembles the prediction. Because voxel-wise fitting is independent, the result is bit-identical to the serial path.fit_predictroutes to this path whennot self._picklable_fit and n_jobs > 1.DKIModelsets_picklable_fit = False; all other models keep the existing behavior.elif is_dki:(deadmodel.multi_fitbranch) and factored the shared data prep into_lovo_data.The DTI fit/predict-split path is untouched. DKI at
n_jobs == 1, direct_fitcalls, andsingle_fitare unchanged.Benchmark (synthetic multi-shell, b0 + 3 shells, n_jobs=8)
Per-volume output matches the serial path exactly (
max|diff| = 0).Tests
test_dki_parallel_matches_serial(index ∈ {4,9} × use_mask ∈ {False,True}) assertsfit_predict(index, n_jobs=4)(chunked) equalsfit_predict(index, n_jobs=1)(serial).test_model_dmri.pyDKI/DTI suite green (132 passed).Notes
single_fit(fit-once, predict-each) could cut runtime further but is approximate and currently buggy forn_jobs > 1; left for a follow-up.CHANGES.rstintentionally untouched (release-generated).