diff --git a/openfold3/core/utils/chunk_utils.py b/openfold3/core/utils/chunk_utils.py index 4037d519d..bf57fb591 100644 --- a/openfold3/core/utils/chunk_utils.py +++ b/openfold3/core/utils/chunk_utils.py @@ -388,6 +388,9 @@ def test_chunk_size(chunk_size): return candidates[lo] def _compare_arg_caches(self, ac1, ac2): + # When recursing this tests that tensors have the same rank + if len(ac1) != len(ac2): + return False consistent = True for a1, a2 in zip(ac1, ac2, strict=True): assert type(a1) is type(a2) @@ -412,12 +415,11 @@ def tune_chunk_size( max_chunk_size=DEFAULT_MAX_CHUNK_SIZE, ) -> int: def remove_tensors(a): - return a.shape if type(a) is torch.Tensor else a + return (a.shape, a.dtype.itemsize) if type(a) is torch.Tensor else a arg_data = tree_map(remove_tensors, args, object) if self.cached_arg_data is not None: # If args have changed shape/value, we need to re-tune - assert len(self.cached_arg_data) == len(arg_data) consistent = self._compare_arg_caches(self.cached_arg_data, arg_data) else: # Otherwise, we can reuse the precomputed value diff --git a/openfold3/tests/utils/test_utils.py b/openfold3/tests/utils/test_utils.py index ed3907dbf..5dd658dff 100644 --- a/openfold3/tests/utils/test_utils.py +++ b/openfold3/tests/utils/test_utils.py @@ -197,6 +197,42 @@ def test_chunk_slice_dict(self): self.assertTrue(torch.all(chunked == chunked_flattened)) + def test_chunk_size_tuner_caches(self): + tuner = ChunkSizeTuner() + + def fn(t, chunk_size): + if chunk_size > 2 ** t.dim() * t.dtype.itemsize: + raise RuntimeError("Chunk size too large") + return t + + spy_fn = unittest.mock.Mock(side_effect=fn) + + first = tuner.tune_chunk_size( + representative_fn=spy_fn, + args=(torch.randn(2, 3, 4, 5),), + min_chunk_size=4, + max_chunk_size=256, + ) + + first_call_count = spy_fn.call_count + second = tuner.tune_chunk_size( + representative_fn=spy_fn, + args=(torch.randn(2, 3, 4, 5),), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertEqual( + first, + second, + "Chunk size should have been cached for identical arg shapes and dtypes", + ) + self.assertEqual( + first_call_count, + spy_fn.call_count, + "Representative function should not have been called again for identical arg shapes and dtypes", + ) + def test_chunk_size_tuner_does_not_retest_candidates(self): # Based on previous bug: the binary search forgot which candidates it # had already proven non-viable and re-tested them. @@ -244,3 +280,78 @@ def fn(arg, chunk_size, _max=max_viable): fn, args=(None,), min_chunk_size=4, max_chunk_size=1024 ) self.assertEqual(result, expected) + + def test_chunk_size_tuner_handles_arg_rank_change(self): + tuner = ChunkSizeTuner() + + def fn(t, chunk_size): + if chunk_size > 2 ** t.dim() * t.dtype.itemsize: + raise RuntimeError("Chunk size too large") + return t + + first = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5),), + min_chunk_size=4, + max_chunk_size=256, + ) + second = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5, 6),), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertNotEqual( + first, second, "Chunk size should have been re-tuned for new arg rank" + ) + + def test_chunk_size_tuner_handles_dtype_bytes_change(self): + tuner = ChunkSizeTuner() + + def fn(t, chunk_size): + if chunk_size > 2 ** t.dim() * t.dtype.itemsize: + raise RuntimeError("Chunk size too large") + return t + + first = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5, dtype=torch.float32),), + min_chunk_size=4, + max_chunk_size=256, + ) + second = tuner.tune_chunk_size( + representative_fn=fn, + args=(torch.zeros(2, 3, 4, 5, dtype=torch.bfloat16),), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertNotEqual( + first, second, "Chunk size should have been re-tuned for new dtype bytes" + ) + + def test_chunk_size_tuner_handles_arg_count_change(self): + tuner = ChunkSizeTuner() + + def fn(*args, chunk_size): + if chunk_size > 2 ** len(args): + raise RuntimeError("Chunk size too large") + return args + + first = tuner.tune_chunk_size( + representative_fn=fn, + args=(1, 2, 3, 4, 5), + min_chunk_size=4, + max_chunk_size=256, + ) + second = tuner.tune_chunk_size( + representative_fn=fn, + args=(1, 2, 3, 4, 5, 6), + min_chunk_size=4, + max_chunk_size=256, + ) + + self.assertNotEqual( + first, second, "Chunk size should have been re-tuned for new arg count" + )