From f943d7c848fb84646b6d54b0cb96908063ed5daf Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Sun, 10 May 2026 12:30:07 -0700 Subject: [PATCH 1/2] Treat arg-cache length mismatch as a cache miss in ChunkSizeTuner I encountered an issue with this when doing https://github.com/aqlaboratory/openfold-3/pull/213. That enables `batch > 1` and chunking to happen together. When you run inference with samples both above and below per_sample_token_cutoff, the small inputs use the batched/normal pairformer embedding path and flatten their batch dimensions (i.e. batch and num_samples), but the larger inputs follow the per-sample path, which *doesn't* flatten its batch dimensions and results in higher-rank inputs (https://github.com/aqlaboratory/openfold-3/pull/223 aims to fix this mismatch). `_compare_arg_caches` recurses on tensor shapes (`torch.Size` is a tuple subclass), so when the same `ChunkSizeTuner` instance is invoked with tensors of different ranks across calls, the inner `zip(..., strict=True)` raised. Instead treat any length mismatch as a cache miss so the caller re-tunes instead. The redundant top-level length assert in `tune_chunk_size` is now handled the same way and removed. Also adds dtype element size to the cache key for a tensor argument. We could use the entire dtype, but I think in terms of what matters for chunking, the element size is the key factor. --- openfold3/core/utils/chunk_utils.py | 6 ++- openfold3/tests/utils/test_utils.py | 75 +++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/openfold3/core/utils/chunk_utils.py b/openfold3/core/utils/chunk_utils.py index 4037d519d..bf57fb591 100644 --- a/openfold3/core/utils/chunk_utils.py +++ b/openfold3/core/utils/chunk_utils.py @@ -388,6 +388,9 @@ def test_chunk_size(chunk_size): return candidates[lo] def _compare_arg_caches(self, ac1, ac2): + # When recursing this tests that tensors have the same rank + if len(ac1) != len(ac2): + return False consistent = True for a1, a2 in zip(ac1, ac2, strict=True): assert type(a1) is type(a2) @@ -412,12 +415,11 @@ def tune_chunk_size( max_chunk_size=DEFAULT_MAX_CHUNK_SIZE, ) -> int: def remove_tensors(a): - return a.shape if type(a) is torch.Tensor else a + return (a.shape, a.dtype.itemsize) if type(a) is torch.Tensor else a arg_data = tree_map(remove_tensors, args, object) if self.cached_arg_data is not None: # If args have changed shape/value, we need to re-tune - assert len(self.cached_arg_data) == len(arg_data) consistent = self._compare_arg_caches(self.cached_arg_data, arg_data) else: # Otherwise, we can reuse the precomputed value diff --git a/openfold3/tests/utils/test_utils.py b/openfold3/tests/utils/test_utils.py index ed3907dbf..8814041e5 100644 --- a/openfold3/tests/utils/test_utils.py +++ b/openfold3/tests/utils/test_utils.py @@ -244,3 +244,78 @@ def fn(arg, chunk_size, _max=max_viable): fn, args=(None,), min_chunk_size=4, max_chunk_size=1024 ) self.assertEqual(result, expected) + + def test_chunk_size_tuner_handles_arg_rank_change(self): + tuner = ChunkSizeTuner() + + def fn(t, chunk_size): + if chunk_size > 2 ** t.dim() * t.dtype.itemsize: + raise RuntimeError("Chunk size too large") + return t + + first = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5),), + min_chunk_size=4, + max_chunk_size=256, + ) + second = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5, 6),), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertNotEqual( + first, second, "Chunk size should have been re-tuned for new arg rank" + ) + + def test_chunk_size_tuner_handles_dtype_bytes_change(self): + tuner = ChunkSizeTuner() + + def fn(t, chunk_size): + if chunk_size > 2 ** t.dim() * t.dtype.itemsize: + raise RuntimeError("Chunk size too large") + return t + + first = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5, dtype=torch.float32),), + min_chunk_size=4, + max_chunk_size=256, + ) + second = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5, dtype=torch.bfloat16),), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertNotEqual( + first, second, "Chunk size should have been re-tuned for new dtype bytes" + ) + + def test_chunk_size_tuner_handles_arg_count_change(self): + tuner = ChunkSizeTuner() + + def fn(*args, chunk_size): + if chunk_size > 2 ** len(args): + raise RuntimeError("Chunk size too large") + return args + + first = tuner.tune_chunk_size( + representative_fn=fn, + args=(1, 2, 3, 4, 5), + min_chunk_size=4, + max_chunk_size=256, + ) + second = tuner.tune_chunk_size( + representative_fn=fn, + args=(1, 2, 3, 4, 5, 6), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertNotEqual( + first, second, "Chunk size should have been re-tuned for new arg count" + ) From 61b4ec09f7f3e3ac3acc3acae6cc3c2b6342fa71 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Mon, 1 Jun 2026 10:59:03 -0700 Subject: [PATCH 2/2] Add chunk size tuner positive caching test --- openfold3/tests/utils/test_utils.py | 36 +++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/openfold3/tests/utils/test_utils.py b/openfold3/tests/utils/test_utils.py index 8814041e5..5dd658dff 100644 --- a/openfold3/tests/utils/test_utils.py +++ b/openfold3/tests/utils/test_utils.py @@ -197,6 +197,42 @@ def test_chunk_slice_dict(self): self.assertTrue(torch.all(chunked == chunked_flattened)) + def test_chunk_size_tuner_caches(self): + tuner = ChunkSizeTuner() + + def fn(t, chunk_size): + if chunk_size > 2 ** t.dim() * t.dtype.itemsize: + raise RuntimeError("Chunk size too large") + return t + + spy_fn = unittest.mock.Mock(side_effect=fn) + + first = tuner.tune_chunk_size( + representative_fn=spy_fn, + args=(torch.randn(2, 3, 4, 5),), + min_chunk_size=4, + max_chunk_size=256, + ) + + first_call_count = spy_fn.call_count + second = tuner.tune_chunk_size( + representative_fn=spy_fn, + args=(torch.randn(2, 3, 4, 5),), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertEqual( + first, + second, + "Chunk size should have been cached for identical arg shapes and dtypes", + ) + self.assertEqual( + first_call_count, + spy_fn.call_count, + "Representative function should not have been called again for identical arg shapes and dtypes", + ) + def test_chunk_size_tuner_does_not_retest_candidates(self): # Based on previous bug: the binary search forgot which candidates it # had already proven non-viable and re-tested them.