Vectorize all per-instance batch transforms#1476
Conversation
Add an assert_vectorized fixture (tests/conftest.py) that checks a per-instance batch transform produces, for each element, the same result as applying that element's own parameters alone -- proving the whole-batch computation is faithful to per-element semantics. Add tests/test_vectorization.py exercising the deterministic-given-params transforms (Ghosting, Spike, Blur, BiasField, Flip, Motion, Swap, Anisotropy), including gated batches. Passes against the current loop-based implementations and must stay green as each is vectorized.
Replace the per-element torch.flip loop with three whole-batch flips (one per spatial axis), selecting which elements are flipped via a broadcast boolean mask. Flips along distinct axes commute, so composing them matches flipping an element's axes at once. Equivalence gate green.
Replace the per-element loop (one FFT per element) with a single batched fftn over (B,C,I,J,K), a per-element (B,1,I,J,K) k-space mask, and a single ifftn. Gated-out elements are restored exactly; input dtype is preserved. Equivalence gate green.
Replace the per-element FFT loop with one batched fftn/ifftn and scatter each element's spikes into the batched spectrum. Gated-out elements are restored exactly; input dtype preserved. Equivalence gate green.
Stack per-element seeded coarse noise (cheap, preserves per-element reproducibility), then run interpolate, exp and the multiply/divide once over the whole batch. Zero-std rows restored exactly; dtype preserved. Equivalence gate green.
Replace the per-element separable-conv loop with one grouped conv3d per spatial axis over a (1, B*C, I, J, K) view (groups=B*C), each element using its own kernel padded to the batch-max size. Zero-sigma rows restored exactly; dtype preserved. Equivalence gate green.
Replace the per-element FFT loop with batched fftn/ifftn over the whole batch and apply each segment's per-element rigid transform via a batched resample, mixing k-space segments across the batch at once. Gated-out elements restored exactly; dtype preserved. Equivalence gate green.
Replace the per-element loop with whole-batch generation: broadcast per-element per-label mean/std as (B,1,1,1,1) tensors and draw one batch-shaped random tensor per label. Non-per-instance path unchanged.
Replace the per-element downsample/upsample loop with a fixed loop over the 3 spatial axes, processing each axis's subset of the batch together with per-element factors via batched index/weight maps. Gated factor<=1 elements unchanged; dtype preserved. Equivalence gate green.
Replace the per-element _apply_swaps loop with batched indexed swaps over the whole batch (loop only over the fixed number of swap pairs), each element using its own patch coordinates. Gated-out (empty locations) elements unchanged. Equivalence gate green.
Address GPT 5.5 review: the equivalence gate now asserts gated-out elements are bit-for-bit no-ops (rtol=0, atol=0), not just close. Add a LabelsToImage test verifying each batch element uses its own per-label mean (no cross-element contamination), since that transform samples inside apply_transform and is outside the gate.
📖 Docs PreviewPreview of the documentation for this PR: 🔗 https://smokeshow.helpmanual.io/52046l0f6y2k4k5q6m0y/ Built from 34bad8e |
There was a problem hiding this comment.
Pull request overview
This PR enforces the project’s “batch transforms should be vectorized” principle by replacing per-element batch loops in per-instance transforms with whole-batch tensor implementations, and adds tests that gate correctness by comparing batched output vs. per-element replay using recorded params.
Changes:
- Added a vectorization equivalence test gate (
assert_vectorized) and a dedicated per-element correctness test forLabelsToImage. - Vectorized per-instance batch paths across multiple transforms (e.g.,
Flip,Ghosting,Spike,BiasField,Blur,Motion,Swap,Anisotropy,LabelsToImage). - Hardened “gated-out rows are exact no-ops” behavior checks in tests.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_vectorization.py | Adds parametrized equivalence tests ensuring vectorized batched output matches per-element replay. |
| tests/test_labels_to_image.py | Adds a per-element correctness test to catch cross-element contamination in LabelsToImage. |
| tests/conftest.py | Introduces the assert_vectorized fixture used as the correctness gate. |
| src/torchio/transforms/spatial/flip.py | Replaces per-element flip loop with axis-wise whole-batch flips selected via masks. |
| src/torchio/transforms/spatial/anisotropy.py | Vectorizes per-instance anisotropy by batching per-axis processing and indexed resampling. |
| src/torchio/transforms/intensity/swap.py | Adds per-instance batched indexed swapping and type aliases for clarity. |
| src/torchio/transforms/intensity/spike.py | Vectorizes the per-instance path using batched FFTs with per-element spike injection. |
| src/torchio/transforms/intensity/motion.py | Refactors motion simulation to support per-instance batched rigid transforms and segment mixing. |
| src/torchio/transforms/intensity/labels_to_image.py | Vectorizes per-element label-based synthesis via batched masks and broadcast stats. |
| src/torchio/transforms/intensity/ghosting.py | Vectorizes per-instance ghosting using batched FFTs and a per-element k-space mask. |
| src/torchio/transforms/intensity/blur.py | Implements per-instance blur via grouped conv3d with per-element separable kernels. |
| src/torchio/transforms/intensity/bias_field.py | Vectorizes per-instance bias field by stacking coarse fields then doing a batched upsample/exp/apply. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| >>> transform = tio.Spike(num_spikes=3, intensity=2.0) | ||
| """ | ||
|
|
||
| __name__ = "Spike" |
There was a problem hiding this comment.
[Generated by a coding agent]
Fixed in 32d6042. Removed the redundant __name__ = "Spike" class attribute (a leftover test-selection hack); type already provides Spike.__name__.
| Args: | ||
| tensor: `(B, C, I, J, K)` tensor. | ||
| degrees: Euler angles in degrees, with shape `(B, 3)`. | ||
| translation: Translation in voxels, with shape `(B, 3)`. |
There was a problem hiding this comment.
[Generated by a coding agent]
Fixed in 32d6042. Aligned the docs with the long-standing behavior: both the previous and the vectorized implementation treat translation as voxels, normalized to affine_grid coordinates via translation / (shape / 2) (a voxel-space approximation), not millimeters. Updated the public Motion.translation docstring accordingly rather than changing the numerical behavior.
…s doc - Remove the redundant __name__ class attribute from Spike (it was a leftover test-selection hack; type already provides __name__). - Correct the Motion translation docstring: the implementation treats translation as voxels (normalized to grid coordinates), not mm, so the public docs now match the long-standing behavior.
| result = data.float() | ||
| num_segments = len(segment_parameters) + 1 | ||
| spatial_shape = result.shape[-3:] | ||
| segment_size = spatial_shape[0] // num_segments | ||
|
|
There was a problem hiding this comment.
[Generated by a coding agent]
Fixed in 5fd8d68. Added a clear ValueError when num_transforms + 1 exceeds the first spatial axis size (segment_size would be 0 and the k-space segmentation would silently degenerate to replacing the whole spectrum). Added a regression test.
| def _gaussian_smooth( | ||
| data: torch.Tensor, | ||
| sigmas: list[float] | np.ndarray, | ||
| ) -> torch.Tensor: | ||
| """Apply separable Gaussian smoothing to a 5D tensor. | ||
|
|
||
| Args: | ||
| data: `(B, C, I, J, K)` tensor. | ||
| sigmas: Per-axis sigma in voxels. | ||
| sigmas: Per-axis sigma in voxels, or per-element per-axis sigmas | ||
| with shape `(B, 3)`. Zero means skip that axis. | ||
|
|
||
| Returns: | ||
| Smoothed tensor. | ||
| """ | ||
| if all(s <= 0 for s in sigmas): | ||
| sigmas_array = np.asarray(sigmas, dtype=np.float64) | ||
| if sigmas_array.ndim == 1: | ||
| sigmas_array = np.broadcast_to(sigmas_array, (data.shape[0], 3)).copy() | ||
| if np.all(sigmas_array <= 0): | ||
| return data | ||
| return _gaussian_smooth_per_element(data, sigmas_array) | ||
|
|
There was a problem hiding this comment.
[Generated by a coding agent]
Fixed in 5fd8d68. _gaussian_smooth now routes batch-shared and uniform-sigma inputs through a single shared kernel set (_gaussian_smooth_shared) instead of a grouped conv with B identical kernels, restoring the cheaper path for per_instance=False. Verified the shared path is numerically identical to the grouped-conv path (max diff 0.0). Genuinely per-element sigmas still use the grouped conv3d.
- Blur: route batch-shared / uniform-sigma cases through a single shared kernel set instead of a grouped conv with B identical kernels (the result is numerically identical; restores the cheaper path for per_instance=False). Genuinely per-element sigmas still use grouped conv. - Motion: raise a clear ValueError when num_transforms + 1 exceeds the first spatial axis size (segment_size would be 0 and the segmentation would silently degenerate). Added a regression test.
| def _downsample_sizes(length: int, factors: torch.Tensor) -> torch.Tensor: | ||
| """Return PyTorch-compatible nearest-downsampled sizes. | ||
|
|
||
| Args: | ||
| length: Original length along the degraded axis. | ||
| factors: Per-element downsampling factors. | ||
|
|
||
| Returns: | ||
| Downsampled sizes matching `round(length / factor)`. | ||
| """ | ||
| sizes = torch.round(length / factors).clamp_min(1) | ||
| return sizes.to(torch.long) |
There was a problem hiding this comment.
[Generated by a coding agent]
This one is actually a non-issue: in current PyTorch torch.round uses round-half-to-even (banker's rounding), the same rule as Python's round. Empirically both map 0.5->0, 1.5->2, 2.5->2, 3.5->4, and _downsample_sizes divides in float64, so length/factor matches the scalar original_shape[axis]/factor. The per-instance and scalar paths therefore choose the same downsample size even on .5 ties (e.g. 9/2 = 4.5 -> 4 in both). I added an equivalence test in 34bad8e with an odd length (9) and factor 2.0 that exercises exactly this tie and confirms the batched path matches the scalar path, so no rounding change is needed.
| if not bool(active.any()): | ||
| return data.to(data.dtype) | ||
|
|
||
| output = data.clone() |
There was a problem hiding this comment.
[Generated by a coding agent]
Fixed in 34bad8e. _simulate_anisotropy_per_instance now raises a clear ValueError when any ACTIVE (factor > 1) per-element axis is outside {0, 1, 2}, matching the scalar path instead of silently no-opping. Gated-out / inactive elements (factor <= 1) with an unused axis are ignored. Added a regression test.
- Anisotropy: raise a clear ValueError when an active per-element axis is
outside {0,1,2}, matching the scalar path (which would IndexError)
rather than silently no-opping. Inactive (factor<=1) elements are
ignored. Added a regression test.
- Add an equivalence test with an odd length and factor 2.0 (a .5 tie),
demonstrating the per-instance torch.round path and the scalar Python
round path agree (both round-half-to-even), so no rounding fix is
needed.
[Generated by a coding agent]
Stacked on #1473 (base:
per-instance-augmentation).Vectorize all per-instance batch transforms
Per the principle that batch transforms should always be vectorized, this PR replaces the per-element
for index in range(B): … torch.catloops in the per-instance transforms with whole-batch tensor operations using broadcast per-element parameters.Transforms vectorized
torch.where.fftn/ifftnwith a per-element k-space mask (Ghosting) / per-element peak scatter (Spike).interpolate→exp→ multiply.conv3dwith per-element separable kernels (groups = B*C).Correctness gate
tests/conftest.py::assert_vectorized(used bytests/test_vectorization.py) asserts that, for transforms whoseapply_transformis deterministic given the recorded params, the batched result for each element equals applying that element's own parameters alone — proving the whole-batch computation is faithful per-element with no cross-element contamination. It also asserts gated-out elements are bit-for-bit no-ops. LabelsToImage (which samples insideapply) has a dedicated per-element correctness test.Every transform preserves input dtype and leaves per-element gated-out elements (
_keep == False) exactly unchanged. Motion was verified to match the previous implementation to float32 precision.Full test suite green (1379 passed),
mise run quality(ruff + ty) andprek(incl. Xenon complexity) green. A GPT 5.5 review was addressed (gate hardened to bit-exact gated rows; LabelsToImage per-element test added).