diff --git a/src/psyclone/domain/lfric/transformations/lfric_loop_fuse_trans.py b/src/psyclone/domain/lfric/transformations/lfric_loop_fuse_trans.py index 488e64f2ee..98db3a8d96 100644 --- a/src/psyclone/domain/lfric/transformations/lfric_loop_fuse_trans.py +++ b/src/psyclone/domain/lfric/transformations/lfric_loop_fuse_trans.py @@ -168,7 +168,6 @@ def validate(self, node1: LFRicLoop, node2: LFRicLoop, # 2.2) If 'same_space' is true check that both function spaces are # the same or that at least one of the nodes is on ANY_SPACE. The # former case is convenient when loop fusion is applied generically. - if same_space: if node1_fs_name == node2_fs_name: pass @@ -182,12 +181,15 @@ def validate(self, node1: LFRicLoop, node2: LFRicLoop, # 2.3.1) Check whether one or more of the function spaces # is ANY_SPACE without the 'same_space' flag if node_on_any_space: - raise TransformationError( - f"Error in {self.name} transformation: One or more of the " - f"iteration spaces is unknown ('ANY_SPACE') so loop fusion" - f" might be invalid. If you know the spaces are the same " - f"then please set the 'same_space' optional argument to " - f"'True'.") + # If the nodes are on ANY_SPACE, but those are the same + # space, we can fuse. + if node1_fs_name != node2_fs_name: + raise TransformationError( + f"Error in {self.name} transformation: One or more of " + f"the iteration spaces is unknown ('ANY_SPACE') so " + f"loop fusion might be invalid. If you know the " + f"spaces are the same then please set the " + f"'same_space' optional argument to 'True'.") # 2.3.2) Check whether specific function spaces are the # same. If they are not, the loop fusion is still possible # but only when both function spaces are discontinuous diff --git a/src/psyclone/tests/domain/lfric/transformations/lfric_transformations_test.py b/src/psyclone/tests/domain/lfric/transformations/lfric_transformations_test.py index 1619d1eac7..e13ec177e8 100644 --- a/src/psyclone/tests/domain/lfric/transformations/lfric_transformations_test.py +++ b/src/psyclone/tests/domain/lfric/transformations/lfric_transformations_test.py @@ -3204,18 +3204,53 @@ def test_multi_builtins_fuse_error(): "reduction") in str(excinfo.value) -def test_loop_fuse_error(dist_mem): - '''Test that we raise an exception in loop fusion if one or more of - the loops has an any_space iteration space.''' - _, invoke = get_invoke("15.14.2_multiple_set_kernels.f90", - TEST_API, idx=0, dist_mem=dist_mem) +def test_loop_fuse_any_space(tmpdir, dist_mem): + '''Test that we correctly fuse two or more of loops that are on + any_space iteration space.''' + psy, invoke = get_invoke("15.14.2_multiple_set_kernels.f90", + TEST_API, idx=0, dist_mem=dist_mem) schedule = invoke.schedule ftrans = LFRicLoopFuseTrans() - with pytest.raises(TransformationError) as excinfo: - ftrans.apply(schedule.children[0], schedule.children[1]) - assert ("One or more of the iteration spaces is unknown " - "('ANY_SPACE') so loop fusion might be " - "invalid") in str(excinfo.value) + + # Fuses the first two loops + ftrans.apply(schedule.children[0], schedule.children[1]) + code = str(psy.gen) + assert ( + "do df = loop0_start, loop0_stop, 1\n" + " ! Built-in: setval_c (set a real-valued field to a real scalar " + "value)\n" + " f1_data(df) = fred\n" + "\n" + " ! Built-in: setval_c (set a real-valued field to a real scalar " + "value)\n" + " f2_data(df) = 3.0_r_def\n" + " enddo\n") in code + assert ( + "do df = loop1_start, loop1_stop, 1\n" + " ! Built-in: setval_c (set a real-valued field to a real scalar " + "value)\n" + " f3_data(df) = ginger\n" + " enddo\n") in code + + # Fuses the combined loop with the third loop + ftrans.apply(schedule.children[0], schedule.children[1]) + code = str(psy.gen) + assert ( + "do df = loop0_start, loop0_stop, 1\n" + " ! Built-in: setval_c (set a real-valued field to a real scalar " + "value)\n" + " f1_data(df) = fred\n" + "\n" + " ! Built-in: setval_c (set a real-valued field to a real scalar " + "value)\n" + " f2_data(df) = 3.0_r_def\n" + "\n" + " ! Built-in: setval_c (set a real-valued field to a real scalar " + "value)\n" + " f3_data(df) = ginger\n" + " enddo\n") in code + + assert LFRicBuild(tmpdir).code_compiles(psy) # Repeat the reduction tests for the reproducible version