Add stimulus temporal batching to project_stimulus_batched#59
Conversation
…poral batching Add a nullable `stimulus_batch_size` parameter that splits the stimulus into overlapping temporal batches to reduce VRAM usage for long stimuli. The overlap (padding) is automatically computed from the temporal filter width (`filter_temporal_width - 1` frames on each side) to avoid edge artefacts. Default is None (all frames at once) to preserve existing behavior.
There was a problem hiding this comment.
Pull request overview
This PR extends the existing batched motion-energy projection path to optionally batch over stimulus frames (time) as well as over filters, aiming to reduce GPU/VRAM usage for long stimuli while preserving numerical equivalence with the unbatched projection.
Changes:
- Add
stimulus_batch_sizetoMotionEnergyPyramid.project_stimulus_batched(...)andcore.project_stimulus_batched(...)to process the stimulus in overlapping temporal chunks. - Introduce
_compute_temporal_pad(filters)to determine overlap/padding needed to avoid temporal edge artifacts when chunking. - Add test coverage asserting temporal batching matches unbatched results (NumPy backend).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
moten/core.py |
Implements temporal chunking in project_stimulus_batched and adds _compute_temporal_pad to determine overlap. |
moten/pyramids.py |
Exposes stimulus_batch_size through the public MotionEnergyPyramid.project_stimulus_batched API. |
moten/tests/test_batched.py |
Adds NumPy tests validating equivalence for several stimulus_batch_size values and combinations with batch_size. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| The temporal convolution shifts frames by up to | ||
| ``filter_temporal_width - 1`` positions. To avoid edge artifacts | ||
| when processing temporal batches, each batch must be padded by this | ||
| many frames on each side. |
There was a problem hiding this comment.
in practice, the filter width is negligible compared to the stimulus duration and this imposes little overhead.
| # The delay shifting uses indices from -(tdxc-1) to T-tdxc where | ||
| # tdxc = ceil(T/2). The maximum absolute shift is T-1. | ||
| return max_filter_temporal_width - 1 if max_filter_temporal_width > 0 else 0 |
There was a problem hiding this comment.
in practice, the filter width is negligible compared to the stimulus duration and this imposes little overhead.
| # Build gabor filter banks for this batch | ||
| sg_sin, sg_cos, tg_sin, tg_cos = mk_3d_gabor_batched(vhsize, | ||
| batch_filters) | ||
| # sg_sin: (B, npixels), tg_sin: (B, T) |
There was a problem hiding this comment.
Fixed by swapping the loop order so filter batches are outermost and temporal chunks are innermost. The gabor filter banks and masks are now computed once per filter batch and reused across all temporal stimulus chunks.
| @pytest.mark.parametrize("stim_batch", [5, 10, 15, 25, 50]) | ||
| def test_various_stimulus_batch_sizes(self, stim_batch): | ||
| """Different stimulus_batch_size values match unbatched result.""" | ||
| set_backend("numpy") | ||
| stimulus = make_test_stimulus(nimages=50) | ||
| pyramid = moten.pyramids.MotionEnergyPyramid(**SMALL_PYRAMID_KWARGS) | ||
|
|
||
| ref = pyramid.project_stimulus_batched(stimulus) | ||
| result = pyramid.project_stimulus_batched( | ||
| stimulus, stimulus_batch_size=stim_batch) | ||
|
|
||
| np.testing.assert_allclose( | ||
| result, ref, atol=1e-5, rtol=1e-5, | ||
| err_msg=f"stimulus_batch_size={stim_batch} mismatch") | ||
|
|
There was a problem hiding this comment.
Added a torch-backed parametrized test test_various_stimulus_batch_sizes_torch guarded by has_torch() in commit e8968be.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
…g gabor filters per temporal chunk
- Correct the comment/docstring in _compute_temporal_pad: the maximum temporal shift is T // 2, not T - 1; the T - 1 pad is kept as a deliberately conservative choice (per review discussion). - Guard against a zero step in the temporal chunk loop when the stimulus is empty, and use explicit reshape sizes so an empty stimulus returns a (0, nfilters) response. - Reword misleading comment in test_stimulus_batch_size_one (padding provides full temporal context). - Add tests for odd/even filter_temporal_width parities and for the empty-stimulus case. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
|
LGTM |
The current batching only works over filters and doesn't help if the stimulus is long and your GPU is small. This PR adds an optional argument to the batched stimulus projection method to break the stimulus up into batches.
Changes
stimulus_batch_sizeargument toproject_stimulus_batchedto specify frame batch size.None(default) processes all frames at once._compute_temporal_pad(filters)to compute filter frame batch padding. Because the motion-energy filters are temporal, naively splitting the stimulus tensor will introduce edge artifacts. This method computes the necessary padding size given the filter temporal parameters.