Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def validate(self, node1: LFRicLoop, node2: LFRicLoop,
# 2.2) If 'same_space' is true check that both function spaces are
# the same or that at least one of the nodes is on ANY_SPACE. The
# former case is convenient when loop fusion is applied generically.

if same_space:
if node1_fs_name == node2_fs_name:
pass
Expand All @@ -182,12 +181,15 @@ def validate(self, node1: LFRicLoop, node2: LFRicLoop,
# 2.3.1) Check whether one or more of the function spaces
# is ANY_SPACE without the 'same_space' flag
if node_on_any_space:
raise TransformationError(
f"Error in {self.name} transformation: One or more of the "
f"iteration spaces is unknown ('ANY_SPACE') so loop fusion"
f" might be invalid. If you know the spaces are the same "
f"then please set the 'same_space' optional argument to "
f"'True'.")
# If the nodes are on ANY_SPACE, but those are the same
# space, we can fuse.
if node1_fs_name != node2_fs_name:
raise TransformationError(
f"Error in {self.name} transformation: One or more of "
f"the iteration spaces is unknown ('ANY_SPACE') so "
f"loop fusion might be invalid. If you know the "
f"spaces are the same then please set the "
f"'same_space' optional argument to 'True'.")
# 2.3.2) Check whether specific function spaces are the
# same. If they are not, the loop fusion is still possible
# but only when both function spaces are discontinuous
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3204,18 +3204,53 @@ def test_multi_builtins_fuse_error():
"reduction") in str(excinfo.value)


def test_loop_fuse_error(dist_mem):
'''Test that we raise an exception in loop fusion if one or more of
the loops has an any_space iteration space.'''
_, invoke = get_invoke("15.14.2_multiple_set_kernels.f90",
TEST_API, idx=0, dist_mem=dist_mem)
def test_loop_fuse_any_space(tmpdir, dist_mem):
'''Test that we correctly fuse two or more of loops that are on
any_space iteration space.'''
psy, invoke = get_invoke("15.14.2_multiple_set_kernels.f90",
TEST_API, idx=0, dist_mem=dist_mem)
schedule = invoke.schedule
ftrans = LFRicLoopFuseTrans()
with pytest.raises(TransformationError) as excinfo:
ftrans.apply(schedule.children[0], schedule.children[1])
assert ("One or more of the iteration spaces is unknown "
"('ANY_SPACE') so loop fusion might be "
"invalid") in str(excinfo.value)

# Fuses the first two loops
ftrans.apply(schedule.children[0], schedule.children[1])
code = str(psy.gen)
assert (
"do df = loop0_start, loop0_stop, 1\n"
" ! Built-in: setval_c (set a real-valued field to a real scalar "
"value)\n"
" f1_data(df) = fred\n"
"\n"
" ! Built-in: setval_c (set a real-valued field to a real scalar "
"value)\n"
" f2_data(df) = 3.0_r_def\n"
" enddo\n") in code
assert (
"do df = loop1_start, loop1_stop, 1\n"
" ! Built-in: setval_c (set a real-valued field to a real scalar "
"value)\n"
" f3_data(df) = ginger\n"
" enddo\n") in code

# Fuses the combined loop with the third loop
ftrans.apply(schedule.children[0], schedule.children[1])
code = str(psy.gen)
assert (
"do df = loop0_start, loop0_stop, 1\n"
" ! Built-in: setval_c (set a real-valued field to a real scalar "
"value)\n"
" f1_data(df) = fred\n"
"\n"
" ! Built-in: setval_c (set a real-valued field to a real scalar "
"value)\n"
" f2_data(df) = 3.0_r_def\n"
"\n"
" ! Built-in: setval_c (set a real-valued field to a real scalar "
"value)\n"
" f3_data(df) = ginger\n"
" enddo\n") in code

assert LFRicBuild(tmpdir).code_compiles(psy)


# Repeat the reduction tests for the reproducible version
Expand Down
Loading