Skip to content

Vectorize all per-instance batch transforms#1476

Open
fepegar wants to merge 14 commits into
per-instance-augmentationfrom
vectorize-batch-transforms
Open

Vectorize all per-instance batch transforms#1476
fepegar wants to merge 14 commits into
per-instance-augmentationfrom
vectorize-batch-transforms

Conversation

@fepegar

@fepegar fepegar commented Jun 16, 2026

Copy link
Copy Markdown
Member

[Generated by a coding agent]


Stacked on #1473 (base: per-instance-augmentation).

Vectorize all per-instance batch transforms

Per the principle that batch transforms should always be vectorized, this PR replaces the per-element for index in range(B): … torch.cat loops in the per-instance transforms with whole-batch tensor operations using broadcast per-element parameters.

Transforms vectorized

  • Flip: per-axis whole-batch flips selected with torch.where.
  • Ghosting / Spike: a single batched fftn/ifftn with a per-element k-space mask (Ghosting) / per-element peak scatter (Spike).
  • BiasField: stack per-seed coarse noise, then one batched interpolateexp → multiply.
  • Blur: grouped conv3d with per-element separable kernels (groups = B*C).
  • Motion: batched FFTs + grid-resampled rigid transforms, mixing k-space segments across the batch.
  • LabelsToImage: per-label batched generation with broadcast per-element stats.
  • Anisotropy: group by spatial axis, batch the resample per group.
  • Swap: batched indexed patch swaps.

Correctness gate

tests/conftest.py::assert_vectorized (used by tests/test_vectorization.py) asserts that, for transforms whose apply_transform is deterministic given the recorded params, the batched result for each element equals applying that element's own parameters alone — proving the whole-batch computation is faithful per-element with no cross-element contamination. It also asserts gated-out elements are bit-for-bit no-ops. LabelsToImage (which samples inside apply) has a dedicated per-element correctness test.

Every transform preserves input dtype and leaves per-element gated-out elements (_keep == False) exactly unchanged. Motion was verified to match the previous implementation to float32 precision.

Full test suite green (1379 passed), mise run quality (ruff + ty) and prek (incl. Xenon complexity) green. A GPT 5.5 review was addressed (gate hardened to bit-exact gated rows; LabelsToImage per-element test added).

fepegar added 11 commits June 16, 2026 21:02
Add an assert_vectorized fixture (tests/conftest.py) that checks a
per-instance batch transform produces, for each element, the same result
as applying that element's own parameters alone -- proving the
whole-batch computation is faithful to per-element semantics. Add
tests/test_vectorization.py exercising the deterministic-given-params
transforms (Ghosting, Spike, Blur, BiasField, Flip, Motion, Swap,
Anisotropy), including gated batches. Passes against the current
loop-based implementations and must stay green as each is vectorized.
Replace the per-element torch.flip loop with three whole-batch flips
(one per spatial axis), selecting which elements are flipped via a
broadcast boolean mask. Flips along distinct axes commute, so composing
them matches flipping an element's axes at once. Equivalence gate green.
Replace the per-element loop (one FFT per element) with a single batched
fftn over (B,C,I,J,K), a per-element (B,1,I,J,K) k-space mask, and a
single ifftn. Gated-out elements are restored exactly; input dtype is
preserved. Equivalence gate green.
Replace the per-element FFT loop with one batched fftn/ifftn and scatter
each element's spikes into the batched spectrum. Gated-out elements are
restored exactly; input dtype preserved. Equivalence gate green.
Stack per-element seeded coarse noise (cheap, preserves per-element
reproducibility), then run interpolate, exp and the multiply/divide once
over the whole batch. Zero-std rows restored exactly; dtype preserved.
Equivalence gate green.
Replace the per-element separable-conv loop with one grouped conv3d per
spatial axis over a (1, B*C, I, J, K) view (groups=B*C), each element
using its own kernel padded to the batch-max size. Zero-sigma rows
restored exactly; dtype preserved. Equivalence gate green.
Replace the per-element FFT loop with batched fftn/ifftn over the whole
batch and apply each segment's per-element rigid transform via a batched
resample, mixing k-space segments across the batch at once. Gated-out
elements restored exactly; dtype preserved. Equivalence gate green.
Replace the per-element loop with whole-batch generation: broadcast
per-element per-label mean/std as (B,1,1,1,1) tensors and draw one
batch-shaped random tensor per label. Non-per-instance path unchanged.
Replace the per-element downsample/upsample loop with a fixed loop over
the 3 spatial axes, processing each axis's subset of the batch together
with per-element factors via batched index/weight maps. Gated factor<=1
elements unchanged; dtype preserved. Equivalence gate green.
Replace the per-element _apply_swaps loop with batched indexed swaps over
the whole batch (loop only over the fixed number of swap pairs), each
element using its own patch coordinates. Gated-out (empty locations)
elements unchanged. Equivalence gate green.
Address GPT 5.5 review: the equivalence gate now asserts gated-out
elements are bit-for-bit no-ops (rtol=0, atol=0), not just close. Add a
LabelsToImage test verifying each batch element uses its own per-label
mean (no cross-element contamination), since that transform samples
inside apply_transform and is outside the gate.
Copilot AI review requested due to automatic review settings June 16, 2026 21:15
@github-actions

github-actions Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

📖 Docs Preview

Preview of the documentation for this PR:

🔗 https://smokeshow.helpmanual.io/52046l0f6y2k4k5q6m0y/

Built from 34bad8e

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enforces the project’s “batch transforms should be vectorized” principle by replacing per-element batch loops in per-instance transforms with whole-batch tensor implementations, and adds tests that gate correctness by comparing batched output vs. per-element replay using recorded params.

Changes:

  • Added a vectorization equivalence test gate (assert_vectorized) and a dedicated per-element correctness test for LabelsToImage.
  • Vectorized per-instance batch paths across multiple transforms (e.g., Flip, Ghosting, Spike, BiasField, Blur, Motion, Swap, Anisotropy, LabelsToImage).
  • Hardened “gated-out rows are exact no-ops” behavior checks in tests.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_vectorization.py Adds parametrized equivalence tests ensuring vectorized batched output matches per-element replay.
tests/test_labels_to_image.py Adds a per-element correctness test to catch cross-element contamination in LabelsToImage.
tests/conftest.py Introduces the assert_vectorized fixture used as the correctness gate.
src/torchio/transforms/spatial/flip.py Replaces per-element flip loop with axis-wise whole-batch flips selected via masks.
src/torchio/transforms/spatial/anisotropy.py Vectorizes per-instance anisotropy by batching per-axis processing and indexed resampling.
src/torchio/transforms/intensity/swap.py Adds per-instance batched indexed swapping and type aliases for clarity.
src/torchio/transforms/intensity/spike.py Vectorizes the per-instance path using batched FFTs with per-element spike injection.
src/torchio/transforms/intensity/motion.py Refactors motion simulation to support per-instance batched rigid transforms and segment mixing.
src/torchio/transforms/intensity/labels_to_image.py Vectorizes per-element label-based synthesis via batched masks and broadcast stats.
src/torchio/transforms/intensity/ghosting.py Vectorizes per-instance ghosting using batched FFTs and a per-element k-space mask.
src/torchio/transforms/intensity/blur.py Implements per-instance blur via grouped conv3d with per-element separable kernels.
src/torchio/transforms/intensity/bias_field.py Vectorizes per-instance bias field by stacking coarse fields then doing a batched upsample/exp/apply.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

>>> transform = tio.Spike(num_spikes=3, intensity=2.0)
"""

__name__ = "Spike"

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Generated by a coding agent]


Fixed in 32d6042. Removed the redundant __name__ = "Spike" class attribute (a leftover test-selection hack); type already provides Spike.__name__.

Comment on lines +417 to +420
Args:
tensor: `(B, C, I, J, K)` tensor.
degrees: Euler angles in degrees, with shape `(B, 3)`.
translation: Translation in voxels, with shape `(B, 3)`.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Generated by a coding agent]


Fixed in 32d6042. Aligned the docs with the long-standing behavior: both the previous and the vectorized implementation treat translation as voxels, normalized to affine_grid coordinates via translation / (shape / 2) (a voxel-space approximation), not millimeters. Updated the public Motion.translation docstring accordingly rather than changing the numerical behavior.

…s doc

- Remove the redundant __name__ class attribute from Spike (it was a
  leftover test-selection hack; type already provides __name__).
- Correct the Motion translation docstring: the implementation treats
  translation as voxels (normalized to grid coordinates), not mm, so the
  public docs now match the long-standing behavior.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.

Comment on lines +363 to +367
result = data.float()
num_segments = len(segment_parameters) + 1
spatial_shape = result.shape[-3:]
segment_size = spatial_shape[0] // num_segments

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Generated by a coding agent]


Fixed in 5fd8d68. Added a clear ValueError when num_transforms + 1 exceeds the first spatial axis size (segment_size would be 0 and the k-space segmentation would silently degenerate to replacing the whole spectrum). Added a regression test.

Comment on lines +129 to +149
def _gaussian_smooth(
data: torch.Tensor,
sigmas: list[float] | np.ndarray,
) -> torch.Tensor:
"""Apply separable Gaussian smoothing to a 5D tensor.

Args:
data: `(B, C, I, J, K)` tensor.
sigmas: Per-axis sigma in voxels.
sigmas: Per-axis sigma in voxels, or per-element per-axis sigmas
with shape `(B, 3)`. Zero means skip that axis.

Returns:
Smoothed tensor.
"""
if all(s <= 0 for s in sigmas):
sigmas_array = np.asarray(sigmas, dtype=np.float64)
if sigmas_array.ndim == 1:
sigmas_array = np.broadcast_to(sigmas_array, (data.shape[0], 3)).copy()
if np.all(sigmas_array <= 0):
return data
return _gaussian_smooth_per_element(data, sigmas_array)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Generated by a coding agent]


Fixed in 5fd8d68. _gaussian_smooth now routes batch-shared and uniform-sigma inputs through a single shared kernel set (_gaussian_smooth_shared) instead of a grouped conv with B identical kernels, restoring the cheaper path for per_instance=False. Verified the shared path is numerically identical to the grouped-conv path (max diff 0.0). Genuinely per-element sigmas still use the grouped conv3d.

- Blur: route batch-shared / uniform-sigma cases through a single shared
  kernel set instead of a grouped conv with B identical kernels (the
  result is numerically identical; restores the cheaper path for
  per_instance=False). Genuinely per-element sigmas still use grouped
  conv.
- Motion: raise a clear ValueError when num_transforms + 1 exceeds the
  first spatial axis size (segment_size would be 0 and the segmentation
  would silently degenerate). Added a regression test.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 2 comments.

Comment on lines +206 to +217
def _downsample_sizes(length: int, factors: torch.Tensor) -> torch.Tensor:
"""Return PyTorch-compatible nearest-downsampled sizes.

Args:
length: Original length along the degraded axis.
factors: Per-element downsampling factors.

Returns:
Downsampled sizes matching `round(length / factor)`.
"""
sizes = torch.round(length / factors).clamp_min(1)
return sizes.to(torch.long)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Generated by a coding agent]


This one is actually a non-issue: in current PyTorch torch.round uses round-half-to-even (banker's rounding), the same rule as Python's round. Empirically both map 0.5->0, 1.5->2, 2.5->2, 3.5->4, and _downsample_sizes divides in float64, so length/factor matches the scalar original_shape[axis]/factor. The per-instance and scalar paths therefore choose the same downsample size even on .5 ties (e.g. 9/2 = 4.5 -> 4 in both). I added an equivalence test in 34bad8e with an odd length (9) and factor 2.0 that exercises exactly this tie and confirms the batched path matches the scalar path, so no rounding change is needed.

if not bool(active.any()):
return data.to(data.dtype)

output = data.clone()

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Generated by a coding agent]


Fixed in 34bad8e. _simulate_anisotropy_per_instance now raises a clear ValueError when any ACTIVE (factor > 1) per-element axis is outside {0, 1, 2}, matching the scalar path instead of silently no-opping. Gated-out / inactive elements (factor <= 1) with an unused axis are ignored. Added a regression test.

- Anisotropy: raise a clear ValueError when an active per-element axis is
  outside {0,1,2}, matching the scalar path (which would IndexError)
  rather than silently no-opping. Inactive (factor<=1) elements are
  ignored. Added a regression test.
- Add an equivalence test with an odd length and factor 2.0 (a .5 tie),
  demonstrating the per-instance torch.round path and the scalar Python
  round path agree (both round-half-to-even), so no rounding fix is
  needed.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 14 out of 14 changed files in this pull request and generated no new comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants