diff --git a/examples/nemo/scripts/insert_loop_parallelism.py_ b/examples/nemo/scripts/insert_loop_parallelism.py_ new file mode 100755 index 0000000000..035e6e0714 --- /dev/null +++ b/examples/nemo/scripts/insert_loop_parallelism.py_ @@ -0,0 +1,345 @@ +#!/usr/bin/env python +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2021-2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Authors: S. Siso, STFC Daresbury Lab + +''' PSyclone transformation script showing the introduction of OpenMP for GPU +directives into Nemo code. ''' + +import os +import sys +from utils import ( + add_profiling, inline_calls, insert_explicit_loop_parallelism, + normalise_loops, NEMO_MODULES_TO_IMPORT) +from psyclone.psyir.nodes import Routine, Loop +from psyclone.psyir.transformations import ( + OMPTargetTrans, OMPDeclareTargetTrans) +from psyclone.transformations import ( + OMPLoopTrans, TransformationError) +from psyclone.transformations import ( + ACCParallelTrans, ACCLoopTrans, ACCRoutineTrans) + + +# This environment variable informs if this is targeting NEMOv4 +NEMOV4 = os.environ.get('NEMOV4', False) + +# This environment variable informs which parallelisation directives to use +# It supports acc_offloading, omp_offloading and omp_threading +# They can be combined, e.g PARALLEL_DIRECTIVES='omp_offloading+omp_threading', +# or use none to just apply the serial transformations +PARALLEL_DIRECTIVES = os.environ.get('PARALLEL_DIRECTIVES', '') + +# By default, allow optimisations that may change the results, e.g. reductions, +# offloading intrinsics without math_uniform, ... +REPRODUCIBLE = os.environ.get('REPRODUCIBLE', False) + +# This environment variable informs if profiling hooks have to be inserted. +PROFILING_ENABLED = os.environ.get('ENABLE_PROFILING', False) + +# By default, we don't do module inlining as it's still under development. +INLINING_ENABLED = os.environ.get('ENABLE_INLINING', False) + +# This environment variable informs if we're enabling asynchronous +# parallelism. +ASYNC_PARALLEL = os.environ.get('ASYNC_PARALLEL', False) + +# Whether to chase the imported modules to improve symbol information (it can +# also be a list of module filenames to limit the chasing to only specific +# modules). This has to be used in combination with '-I' command flag in order +# to point to the module location directory. We also strongly recommend using +# the '--enable-cache' flag to reduce the performance overhead. +RESOLVE_IMPORTS = NEMO_MODULES_TO_IMPORT + +# List of all files that psyclone will skip processing +FILES_TO_SKIP = [] + +# There files are skipped because transforming them degrade the performance +SKIP_FOR_PERFORMANCE = [ + "iom.f90", + "iom_nf90.f90", + "iom_def.f90", + "timing.f90", + "histcom.f90", +] + +# These files change the results from the baseline when psyclone adds +# parallelisation directives +PARALLELISATION_ISSUES = [] + +# These files change the results from the baseline when psyclone adds +# offloading directives +OFFLOADING_ISSUES = [] + +if not NEMOV4: + FILES_TO_SKIP.extend([ + # Fail in nvfortran when enabling seaice + "icefrm.f90", # Has unsupported implicit symbol declaration + ]) + + SKIP_FOR_PERFORMANCE.extend([ + "lbclnk.f90", + ]) + + PARALLELISATION_ISSUES.extend([ + "ldfc1d_c2d.f90", + "tramle.f90", + "traqsr.f90", + ]) + + OFFLOADING_ISSUES.extend([ + # Produces different output results + "zdftke.f90", + # The following issues only affect BENCH (because ice is enabled?) + # Runtime Error: Illegal address during kernel execution + "trcrad.f90", + # nvhpc > 24.11 - Signal 11 issues + "icerst.f90", # When enabling ice* parallelisation + "trcbbl.f90", + "trabbc.f90", + "bdyice.f90", + "sedfunc.f90", + "stpmlf.f90", + "trddyn.f90", + "trczdf.f90", + "trcice_pisces.f90", + "dtatsd.f90", + "trcatf.f90", + "stp2d.f90", + ]) + + # if "acc_offloading" in PARALLEL_DIRECTIVES: + # OFFLOADING_ISSUES.extend([ + # # Fail in OpenACC ORCA2_ICE_PISCES + # "dynzdf.f90", + # "trabbl.f90", + # "trazdf.f90", + # "zdfsh2.f90", + # ]) + +ASYNC_ISSUES = [ + # Runtime Error: (CUDA_ERROR_LAUNCH_FAILED): Launch failed + # (often invalid pointer dereference) in get_cstrgsurf + "sbcclo.f90", + "trcldf.f90", + # Runtime Error: Illegal address during kernel execution with + # asynchronicity. + "zdfiwm.f90", + "zdfsh2.f90", + # Diverging results with asynchronicity + "traadv_fct.f90", + "bdy_oce.f90", +] + + +def select_transformations(): + ''' + Use the PARALLEL_DIRECTIVES global to select what specific transformations + to apply to insert the desired directives. + ''' + process_directives = PARALLEL_DIRECTIVES + + if 'omp_offloading' in process_directives: + offload_region_trans = OMPTargetTrans() + mark_for_gpu_trans = OMPDeclareTargetTrans() + if NEMOV4: + # TODO #2895: Explore why loop/teams loop diverge for NEMOv4 + gpu_loop_trans = OMPLoopTrans(omp_schedule="none") + gpu_loop_trans.omp_directive = "loop" + else: + gpu_loop_trans = OMPLoopTrans(omp_schedule="none") + gpu_loop_trans.omp_directive = "teamsloop" + process_directives = process_directives.replace('omp_offloading', '') + elif 'acc_offloading' in process_directives: + offload_region_trans = ACCParallelTrans(default_present=False) + mark_for_gpu_trans = ACCRoutineTrans() + gpu_loop_trans = ACCLoopTrans() + process_directives = process_directives.replace('acc_offloading', '') + else: + offload_region_trans = None + mark_for_gpu_trans = None + gpu_loop_trans = None + + if 'omp_threading' in process_directives: + cpu_loop_trans = OMPLoopTrans(omp_schedule="static") + cpu_loop_trans.omp_directive = "paralleldo" + process_directives = process_directives.replace('omp_threading', '') + else: + cpu_loop_trans = None + + process_directives = process_directives.replace('+', '') + if process_directives != '': + sys.exit(f"Unknown PARALLEL_DIRECTIVES: {process_directives}") + + return (offload_region_trans, mark_for_gpu_trans, + gpu_loop_trans, cpu_loop_trans) + + +def filter_files_by_name(name: str) -> bool: + ''' + :returns: whether to transform a file with the given name. Contrary to + FILES_TO_SKIP, this will still run the files through psyclone. + ''' + # The two options below are useful for file-by-file exhaustive tests. + # If the environment has ONLY_FILE defined, only process that one file and + # known-good files that need a "declare target" inside. + only_file = os.environ.get('ONLY_FILE', False) + if only_file: + files_to_do = [only_file] + if "offloading" in PARALLEL_DIRECTIVES: + files_to_do.extend( + ["lib_fortran.f90", "solfrac_mod.f90", "sbc_phy.f90"]) + if name not in files_to_do: + return True + # If the environment has ALL_BUT_FILE defined, process all files but + # the one named file. + all_but_file = os.environ.get('ALL_BUT_FILE', False) + if all_but_file and name == all_but_file: + return True + + # These work but are skipped to improve performance, they could be in the + # FILES_TO_SKIP global parameter, but in this script, for testing purposes, + # we exclude them here so the PSyclone frontend and backend are still + # tested and it also allows to insert profiling hooks later on. + if name in SKIP_FOR_PERFORMANCE: + return True + + # Parallelising ICE or ICB currently causes a noticeable slowdown + # On nemo_main it can be just: if name.startswith("icethd"): + if not NEMOV4 and name.startswith("ice"): + return True + if name.startswith("icb"): + return True + + # This file fails for gcc NEMOv5 BENCH + if not NEMOV4 and name == "icedyn_rhg_evp.f90": + return True + + return False + + +def trans(psyir): + ''' Normalise and add directives to all possible loops, including the + implicit ones. + + :param psyir: the PSyIR of the provided file. + :type psyir: :py:class:`psyclone.psyir.nodes.FileContainer` + + ''' + if filter_files_by_name(psyir.name): + return + + (offload_region_trans, mark_for_gpu_trans, gpu_loop_trans, + cpu_loop_trans) = select_transformations() + + disable_profiling_for = [] + enable_async = ASYNC_PARALLEL and psyir.name not in ASYNC_ISSUES + privatise_arrays = not (NEMOV4 or "acc" in PARALLEL_DIRECTIVES) + + for subroutine in psyir.walk(Routine): + + # Skip initialisation and diagnostic subroutines + if (subroutine.name.endswith('_alloc') or + subroutine.name.endswith('_init') or + subroutine.name.startswith('init_') or + subroutine.name.startswith('Agrif') or + subroutine.name.startswith('dia_') or + subroutine.name == 'dom_msk' or + subroutine.name == 'dom_zgr' or + subroutine.name == 'dom_ngb'): + continue + + normalise_loops( + subroutine, + hoist_local_arrays=False, + convert_array_notation=True, + # See issue #3022 + loopify_array_intrinsics=psyir.name != "getincom.f90", + convert_range_loops=True, + increase_array_ranks=not NEMOV4, + hoist_expressions=True + ) + + # Perform module-inlining of called routines. + if INLINING_ENABLED: + inline_calls(subroutine) + + # These are functions that are called from inside parallel regions, + # annotate them with 'omp declare target' + if ( + mark_for_gpu_trans and + (subroutine.name.lower().startswith("sign_") + or subroutine.name.lower() == "solfrac" + or (psyir.name == "sbc_phy.f90" and not subroutine.walk(Loop))) + ): + try: + mark_for_gpu_trans.apply(subroutine) + print(f"Marked {subroutine.name} as GPU-enabled") + except TransformationError as err: + print(err) + # We continue parallelising inside the routine, but this could + # change if the parallelisation directives added below are not + # nestable, in that case we could add a 'continue' here + disable_profiling_for.append(subroutine.name) + + elif (psyir.name not in PARALLELISATION_ISSUES + OFFLOADING_ISSUES + and gpu_loop_trans): + print( + f"Adding offload directives to subroutine: {subroutine.name}") + insert_explicit_loop_parallelism( + subroutine, + region_directive_trans=offload_region_trans, + loop_directive_trans=gpu_loop_trans, + collapse=True, + privatise_arrays=privatise_arrays, + enable_reductions=not REPRODUCIBLE, + uniform_intrinsics_only=REPRODUCIBLE, + asynchronous_parallelism=enable_async, + ) + elif psyir.name not in PARALLELISATION_ISSUES and cpu_loop_trans: + # These have issues offloading, but we can still do threading + print(f"Adding OpenMP threading to subroutine: {subroutine.name}") + insert_explicit_loop_parallelism( + subroutine, + loop_directive_trans=cpu_loop_trans, + collapse=False, + privatise_arrays=privatise_arrays, + enable_reductions=not REPRODUCIBLE, + asynchronous_parallelism=enable_async, + ) + + # Iterate again and add profiling hooks when needed + for subroutine in psyir.walk(Routine): + if PROFILING_ENABLED and subroutine.name not in disable_profiling_for: + print(f"Adding profiling hooks to subroutine: {subroutine.name}") + add_profiling(subroutine.children) diff --git a/external/fparser b/external/fparser index d93d18de65..e557461d10 160000 --- a/external/fparser +++ b/external/fparser @@ -1 +1 @@ -Subproject commit d93d18de6526a09b5bb301d55d3ba6f52a55e6b9 +Subproject commit e557461d105a2bec3379a2986e580adee58d974b diff --git a/src/psyclone/psyir/backend/fortran.py b/src/psyclone/psyir/backend/fortran.py index ce4cadf042..2b7cb6e1ff 100644 --- a/src/psyclone/psyir/backend/fortran.py +++ b/src/psyclone/psyir/backend/fortran.py @@ -53,11 +53,11 @@ Literal, Member, Node, OMPDependClause, OMPReductionClause, Operation, Range, Routine, Schedule, UnaryOperation, UnknownDirective, IfBlock) from psyclone.psyir.symbols import ( - ArgumentInterface, ArrayType, ContainerSymbol, DataSymbol, DataType, - DataTypeSymbol, GenericInterfaceSymbol, IntrinsicSymbol, - PreprocessorInterface, RoutineSymbol, ScalarType, StructureType, Symbol, - SymbolTable, UnresolvedInterface, UnresolvedType, UnsupportedFortranType, - UnsupportedType, TypedSymbol) + ArgumentInterface, ArrayType, CommonBlockSymbol, ContainerSymbol, + DataSymbol, DataType, DataTypeSymbol, GenericInterfaceSymbol, + IntrinsicSymbol, PreprocessorInterface, RoutineSymbol, ScalarType, + StructureType, Symbol, SymbolTable, UnresolvedInterface, UnresolvedType, + UnsupportedFortranType, UnsupportedType, TypedSymbol) # Mapping from PSyIR types to Fortran data types. Simply reverse the @@ -513,6 +513,26 @@ def gen_use(self, symbol, symbol_table): f"{renames}\n") return f"{self._nindent}use{intrinsic_str}{symbol.name}\n" + def gen_commonblock(self, sym): + '''Create and return the Fortran COMMON statement for a + :py:class:`~psyclone.psyir.symbols.CommonBlockSymbol`. + + The variables are listed in the order they appear in + ``sym.variables``. Array dimensions are *not* repeated here (they + appear in the variable declarations generated by + :py:meth:`gen_vardecl`), which is valid Fortran. + + :param sym: the CommonBlockSymbol to generate a declaration for. + :type sym: :py:class:`psyclone.psyir.symbols.CommonBlockSymbol` + + :returns: the Fortran COMMON statement as a string. + :rtype: str + + ''' + var_names = ", ".join(var.name for var in sym.variables) + block_id = f"/{sym.name}/" if sym.name else "//" + return f"{self._nindent}COMMON {block_id} {var_names}\n" + def gen_vardecl(self, symbol: Union[DataSymbol, Member], include_visibility: bool = False) -> str: @@ -964,7 +984,7 @@ def gen_decls(self, for sym in all_symbols[:]: # Everything that is a container or imported (because it should # already be done by the gen_use() method before) - if isinstance(sym, ContainerSymbol): + if isinstance(sym, (ContainerSymbol, CommonBlockSymbol)): all_symbols.remove(sym) continue if sym.is_import: @@ -1071,6 +1091,10 @@ def gen_decls(self, declarations += self.gen_vardecl( symbol, include_visibility=is_module_scope) + # 6: COMMON block declarations + for cb_sym in symbol_table.common_block_symbols: + declarations += self.gen_commonblock(cb_sym) + return declarations def filecontainer_node(self, node): diff --git a/src/psyclone/psyir/frontend/fparser2.py b/src/psyclone/psyir/frontend/fparser2.py index 17891534f8..d95b73e138 100644 --- a/src/psyclone/psyir/frontend/fparser2.py +++ b/src/psyclone/psyir/frontend/fparser2.py @@ -68,9 +68,9 @@ from psyclone.psyir.nodes.array_mixin import ArrayMixin from psyclone.psyir.symbols import ( ArgumentInterface, ArrayType, AutomaticInterface, ScalarType, - CommonBlockInterface, ContainerSymbol, DataSymbol, DataTypeSymbol, - DefaultModuleInterface, GenericInterfaceSymbol, ImportInterface, - NoType, RoutineSymbol, StaticInterface, + CommonBlockInterface, CommonBlockSymbol, ContainerSymbol, DataSymbol, + DataTypeSymbol, DefaultModuleInterface, GenericInterfaceSymbol, + ImportInterface, NoType, RoutineSymbol, StaticInterface, StructureType, Symbol, SymbolError, UnknownInterface, UnresolvedInterface, UnresolvedType, UnsupportedFortranType, UnsupportedType, SymbolTable) @@ -2873,12 +2873,48 @@ def _process_data_statements(nodes, psyir_parent): datatype=UnresolvedType()) sym.interface = StaticInterface() + @staticmethod + def _get_common_block_groups(node): + '''Return the COMMON-block groups in a ``Common_Stmt`` node as a list + of ``(block_name, variable_names)`` pairs. + + The *block_name* is the name of the COMMON block as a string, or + ``None`` for the blank common. The *variable_names* list contains + the bare variable names (without any array-dimension suffixes) as + strings, preserving the original case from the source. + + Example: ``COMMON /a/ x, y(3) /b/ z`` returns + ``[('a', ['x', 'y']), ('b', ['z'])]``. + + :param node: a fparser2 Common_Stmt node. + :type node: :py:class:`fparser.two.Fortran2003.Common_Stmt` + + :returns: ordered list of ``(block_name, variable_names)`` pairs. + :rtype: list[tuple[str | None, list[str]]] + ''' + groups = [] + for name, lst in node.items[0]: + block_name = str(name) if name is not None else None + var_names = [] + for item in lst.items: + # item is either a Name/StringBase (bare variable) or a + # Common_Block_Object/CallBase (variable with an explicit + # shape spec). For the latter, items[0] holds the name. + if isinstance(item, utils.CallBase): + var_names.append(str(item.items[0])) + else: + var_names.append(str(item)) + groups.append((block_name, var_names)) + return groups + @staticmethod def _process_common_blocks(nodes, psyir_parent): - ''' Process the fparser2 common block declaration statements. This is - done after the other declarations and it will keep the statement - as a UnsupportedFortranType and update the referenced symbols to a - CommonBlockInterface. + ''' Process the fparser2 common block declaration statements. + For each COMMON block found, a + :py:class:`~psyclone.psyir.symbols.CommonBlockSymbol` is created (or + reused if already present) and each variable listed in the block has + its interface updated to + :py:class:`~psyclone.psyir.symbols.CommonBlockInterface`. :param nodes: fparser2 AST nodes containing declaration statements. :type nodes: List[:py:class:`fparser.two.utils.Base`] @@ -2896,33 +2932,29 @@ def _process_common_blocks(nodes, psyir_parent): ''' for node in nodes: if isinstance(node, Fortran2003.Common_Stmt): - # Place the declaration statement into a UnsupportedFortranType - # (for now we just want to reproduce it). The name of the - # commonblock is not in the same namespace as the variable - # symbols names (and there may be multiple of them in a - # single statement). So we use an internal symbol name. - psyir_parent.symbol_table.new_symbol( - root_name="_PSYCLONE_INTERNAL_COMMONBLOCK", - symbol_type=DataSymbol, - datatype=UnsupportedFortranType(str(node))) - - # Get the names of the symbols accessed with the commonblock, - # they are already defined in the symbol table but they must - # now have a common-block interface. + table = psyir_parent.symbol_table try: - # Loop over every COMMON block defined in this Common_Stmt - for cb_object in node.children[0]: - for symbol_name in cb_object[1].items: - sym = psyir_parent.symbol_table.lookup( - str(symbol_name)) + for block_name, var_names in \ + Fparser2Reader._get_common_block_groups(node): + actual_name = (block_name + if block_name is not None else "") + # Find or create the CommonBlockSymbol. + cb_sym = table.lookup(actual_name, otherwise=None) + if cb_sym is None: + cb_sym = CommonBlockSymbol(actual_name) + table.add(cb_sym) + # Update each variable's interface. + for var_name in var_names: + sym = table.lookup(var_name) if sym.initial_value: # This is C506 of the F2008 standard. raise NotImplementedError( f"Symbol '{sym.name}' has an initial value" - f" ({sym.initial_value.debug_string()}) " - f"but appears in a common block. This is " - f"not valid Fortran.") - sym.interface = CommonBlockInterface() + f" ({sym.initial_value.debug_string()}" + f") but appears in a common block. This is" + f" not valid Fortran.") + sym.interface = CommonBlockInterface(cb_sym) + cb_sym.add_variable(sym) except KeyError as error: raise NotImplementedError( f"The symbol interface of a common block variable " diff --git a/src/psyclone/psyir/symbols/__init__.py b/src/psyclone/psyir/symbols/__init__.py index 4073219adc..183b71ec15 100644 --- a/src/psyclone/psyir/symbols/__init__.py +++ b/src/psyclone/psyir/symbols/__init__.py @@ -40,6 +40,7 @@ from psyclone.psyir.symbols.datasymbol import DataSymbol from psyclone.psyir.symbols.containersymbol import ContainerSymbol +from psyclone.psyir.symbols.commonblocksymbol import CommonBlockSymbol from psyclone.psyir.symbols.data_type_symbol import DataTypeSymbol from psyclone.psyir.symbols.generic_interface_symbol import ( GenericInterfaceSymbol) @@ -61,6 +62,7 @@ 'ArrayType', 'AutomaticInterface', 'CommonBlockInterface', + 'CommonBlockSymbol', 'ContainerSymbol', 'DataSymbol', 'DataType', diff --git a/src/psyclone/psyir/symbols/commonblocksymbol.py b/src/psyclone/psyir/symbols/commonblocksymbol.py new file mode 100644 index 0000000000..7ab0e1e775 --- /dev/null +++ b/src/psyclone/psyir/symbols/commonblocksymbol.py @@ -0,0 +1,149 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2017-2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Authors R. W. Ford, A. R. Porter, S. Siso and N. Nobre, STFC Daresbury Lab +# I. Kavcic, Met Office +# J. Henrichs, Bureau of Meteorology +# ----------------------------------------------------------------------------- + +'''This module contains the CommonBlockSymbol class.''' + +from psyclone.psyir.symbols.symbol import Symbol + + +class CommonBlockSymbol(Symbol): + '''Symbol representing a Fortran COMMON block. + + Named COMMON blocks use the block name as the symbol name (e.g. + ``CommonBlockSymbol("myblock")``). The *blank* common (``COMMON x, y`` + or ``COMMON // x, y``) is represented by using the empty string as the + name: ``CommonBlockSymbol("")``. + + The ordered list of :py:class:`~psyclone.psyir.symbols.DataSymbol`\\ s + that belong to the block is maintained in insertion order, which mirrors + the order of variables in the original COMMON statement. This order is + significant because it determines the shared memory layout. + + :param str name: name of the COMMON block, or ``""`` for blank common. + :param kwargs: additional keyword arguments forwarded to \ + :py:class:`psyclone.psyir.symbols.Symbol`. + :type kwargs: unwrapped dict + + :raises TypeError: if *name* is not a ``str``. + + ''' + + def __init__(self, name, **kwargs): + # Allow the empty string (blank COMMON) but still validate type. + if not isinstance(name, str): + raise TypeError( + f"CommonBlockSymbol 'name' attribute should be of type " + f"'str' but '{type(name).__name__}' found.") + self._variables: list = [] + # Call Symbol.__init__ directly to bypass the non-empty name check + # that Symbol enforces for ordinary symbols (Symbol itself accepts + # any str including ""). + super().__init__(name, **kwargs) + + @property + def variables(self) -> list: + '''Return the ordered list of DataSymbols that belong to this block. + + :returns: the ordered variable list. + :rtype: list[:py:class:`psyclone.psyir.symbols.DataSymbol`] + ''' + return list(self._variables) + + def add_variable(self, symbol) -> None: + '''Append *symbol* to the end of the ordered variable list. + + :param symbol: the DataSymbol to add. + :type symbol: :py:class:`psyclone.psyir.symbols.DataSymbol` + + :raises TypeError: if *symbol* is not a \ + :py:class:`psyclone.psyir.symbols.DataSymbol`. + :raises ValueError: if *symbol* is already in the variable list. + + ''' + # pylint: disable=import-outside-toplevel + from psyclone.psyir.symbols.datasymbol import DataSymbol + if not isinstance(symbol, DataSymbol): + raise TypeError( + f"CommonBlockSymbol.add_variable: expected a DataSymbol " + f"but got '{type(symbol).__name__}'.") + if symbol in self._variables: + raise ValueError( + f"CommonBlockSymbol '{self.name}': variable " + f"'{symbol.name}' is already in the variable list.") + self._variables.append(symbol) + + def replace_variable(self, old_sym, new_sym) -> None: + '''Replace *old_sym* with *new_sym* in the ordered variable list. + Used by :py:meth:`SymbolTable.replace_symbols_using` when a symbol + object is replaced (e.g. on deep-copy). + + :param old_sym: the symbol to replace. + :param new_sym: the replacement symbol. + + :raises ValueError: if *old_sym* is not in the variable list. + + ''' + try: + idx = self._variables.index(old_sym) + except ValueError as exc: + raise ValueError( + f"CommonBlockSymbol '{self.name}': variable " + f"'{old_sym.name}' is not in the variable list.") from exc + self._variables[idx] = new_sym + + def copy(self): + '''Create and return a copy of this symbol. + + The ``_variables`` list is intentionally left empty in the copy; it + will be repopulated when the copied table's + :py:meth:`~psyclone.psyir.symbols.SymbolTable.replace_symbols_using` + is called. + + :returns: a copy of this CommonBlockSymbol with an empty variable list. + :rtype: :py:class:`psyclone.psyir.symbols.CommonBlockSymbol` + + ''' + new_sym = type(self)(self.name, visibility=self.visibility, + interface=self.interface.copy()) + new_sym.preceding_comment = self.preceding_comment + new_sym.inline_comment = self.inline_comment + return new_sym + + def __str__(self): + block_label = f"/{self.name}/" if self.name else "//" + return f"{block_label}: CommonBlockSymbol" diff --git a/src/psyclone/psyir/symbols/interfaces.py b/src/psyclone/psyir/symbols/interfaces.py index 833bbdd712..fca57bde06 100644 --- a/src/psyclone/psyir/symbols/interfaces.py +++ b/src/psyclone/psyir/symbols/interfaces.py @@ -96,11 +96,69 @@ def __str__(self): class CommonBlockInterface(SymbolInterface): - ''' A symbol declared in the local scope but acts as a global that - can be accessed by any scope referencing the same CommonBlock name.''' + '''Describes a symbol that belongs to a Fortran COMMON block. + + :param common_block_symbol: the :py:class:`CommonBlockSymbol` that \ + represents the COMMON block this variable is part of. + :type common_block_symbol: \ + :py:class:`psyclone.psyir.symbols.CommonBlockSymbol` + + :raises TypeError: if *common_block_symbol* is not a \ + :py:class:`~psyclone.psyir.symbols.CommonBlockSymbol`. + + ''' + + def __init__(self, common_block_symbol): + super().__init__() + # Use setter for validation. + self.common_block_symbol = common_block_symbol + + @property + def common_block_symbol(self): + ''' + :returns: the CommonBlockSymbol for the COMMON block that owns this \ + variable. + :rtype: :py:class:`psyclone.psyir.symbols.CommonBlockSymbol` + ''' + return self._common_block_symbol + + @common_block_symbol.setter + def common_block_symbol(self, value): + ''' + :param value: the CommonBlockSymbol for this interface. + :type value: :py:class:`psyclone.psyir.symbols.CommonBlockSymbol` + + :raises TypeError: if *value* is not a \ + :py:class:`~psyclone.psyir.symbols.CommonBlockSymbol`. + + ''' + # pylint: disable=import-outside-toplevel + from psyclone.psyir.symbols.commonblocksymbol import CommonBlockSymbol + if not isinstance(value, CommonBlockSymbol): + raise TypeError( + f"CommonBlockInterface common_block_symbol parameter must be " + f"of type CommonBlockSymbol, but found " + f"'{type(value).__name__}'.") + self._common_block_symbol = value def __str__(self): - return "CommonBlock" + block_label = (f"/{self.common_block_symbol.name}/" + if self.common_block_symbol.name else "//") + return f"CommonBlock({block_label})" + + def __eq__(self, other): + if not super().__eq__(other): + return False + return (self.common_block_symbol.name.lower() == + other.common_block_symbol.name.lower()) + + def copy(self): + ''' + :returns: a copy of this interface (sharing the same \ + :py:class:`~psyclone.psyir.symbols.CommonBlockSymbol` instance). + :rtype: :py:class:`psyclone.psyir.symbols.CommonBlockInterface` + ''' + return self.__class__(self.common_block_symbol) class UnresolvedInterface(SymbolInterface): diff --git a/src/psyclone/psyir/symbols/symbol.py b/src/psyclone/psyir/symbols/symbol.py index a80c74dae8..1e2f125516 100644 --- a/src/psyclone/psyir/symbols/symbol.py +++ b/src/psyclone/psyir/symbols/symbol.py @@ -555,21 +555,39 @@ def replace_symbols_using(self, table_or_symbol): :py:class:`psyclone.psyir.symbols.Symbol` ''' - if not isinstance(self.interface, ImportInterface): - return - name = self.interface.container_symbol.name - orig_name = self.interface.orig_name - if isinstance(table_or_symbol, Symbol): - if name.lower() == table_or_symbol.name.lower(): - self.interface = ImportInterface(table_or_symbol, - orig_name=orig_name) - else: - try: - new_container = table_or_symbol.lookup(name) - self.interface = ImportInterface(new_container, - orig_name=orig_name) - except KeyError: - pass + if isinstance(self.interface, ImportInterface): + name = self.interface.container_symbol.name + orig_name = self.interface.orig_name + if isinstance(table_or_symbol, Symbol): + if name.lower() == table_or_symbol.name.lower(): + self.interface = ImportInterface(table_or_symbol, + orig_name=orig_name) + else: + try: + new_container = table_or_symbol.lookup(name) + self.interface = ImportInterface(new_container, + orig_name=orig_name) + except KeyError: + pass + + elif isinstance(self.interface, CommonBlockInterface): + # pylint: disable=import-outside-toplevel + from psyclone.psyir.symbols.commonblocksymbol import ( + CommonBlockSymbol) + name = self.interface.common_block_symbol.name + if isinstance(table_or_symbol, Symbol): + if (isinstance(table_or_symbol, CommonBlockSymbol) and + name.lower() == table_or_symbol.name.lower()): + self.interface = CommonBlockInterface(table_or_symbol) + else: + # table_or_symbol is a SymbolTable — look up by name. + # Blank-common symbols have name ""; use find_or_create so + # we tolerate tables that have not yet added the block sym. + try: + new_cb = table_or_symbol.lookup_commonblock(name) + self.interface = CommonBlockInterface(new_cb) + except KeyError: + pass def get_all_accessed_symbols(self) -> set["Symbol"]: ''' diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index 82dbc6b8da..36caa9a9e7 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -53,8 +53,9 @@ from psyclone.configuration import Config from psyclone.errors import InternalError from psyclone.psyir.symbols import ( - DataSymbol, ContainerSymbol, DataTypeSymbol, - ImportInterface, RoutineSymbol, Symbol, SymbolError, UnresolvedInterface) + DataSymbol, ContainerSymbol, CommonBlockSymbol, + DataTypeSymbol, ImportInterface, RoutineSymbol, Symbol, SymbolError, + UnresolvedInterface) from psyclone.psyir.symbols.intrinsic_symbol import IntrinsicSymbol from psyclone.psyir.symbols.typed_symbol import TypedSymbol @@ -299,11 +300,15 @@ def deep_copy(self, attached_to: "ScopingNode" = None) -> SymbolTable: new_st._node = attached_to # Make a copy of each symbol in the symbol table ensuring we do any - # ContainerSymbols first as there may be imports from them. + # ContainerSymbols first as there may be imports from them, then + # CommonBlockSymbols so that CommonBlockInterface references can be + # updated before the variables are added. for symbol in self.containersymbols: new_st.add(symbol.copy()) + for symbol in self.common_block_symbols: + new_st.add(symbol.copy()) for symbol in self.symbols: - if isinstance(symbol, ContainerSymbol): + if isinstance(symbol, (ContainerSymbol, CommonBlockSymbol)): continue new_sym = symbol.copy() if new_sym.is_import: @@ -332,10 +337,17 @@ def deep_copy(self, attached_to: "ScopingNode" = None) -> SymbolTable: pass # Update any references to Symbols within Symbols (initial values, - # precision etc.) + # precision etc.). This also re-points any CommonBlockInterface + # references to the copied CommonBlockSymbol instances. for symbol in new_st.symbols: symbol.replace_symbols_using(new_st) + # Reconstruct the ordered variable list for each CommonBlockSymbol. + for old_cb_sym in self.common_block_symbols: + new_cb_sym = new_st.lookup_commonblock(old_cb_sym.name) + for old_var in old_cb_sym.variables: + new_cb_sym.add_variable(new_st.lookup(old_var.name)) + # Set the default visibility new_st._default_visibility = self.default_visibility @@ -608,6 +620,8 @@ def add(self, new_symbol: Symbol, tag: Optional[str] = None): :raises InternalError: if the new_symbol argument is not a symbol. :raises KeyError: if the symbol name is already in use. :raises KeyError: if a tag is supplied and it is already in use. + :raises KeyError: if the symbol is a COMMON-block marker and an + identical declaration is already present under another marker name. :raises SymbolError: if the supplied symbol has an ImportInterface that refers to a ContainerSymbol that is not in scope. @@ -704,6 +718,15 @@ def check_for_clashes(self, other_table, symbols_to_skip=()): isinstance(other_sym, IntrinsicSymbol)): continue + # If both symbols are CommonBlockSymbols (same block name) or + # both variables in a COMMON block, they represent the same + # shared COMMON-block data. They cannot (and do not need to) + # be renamed, so treat this as a benign clash. + if ((isinstance(this_sym, CommonBlockSymbol) and + isinstance(other_sym, CommonBlockSymbol)) or + (this_sym.is_commonblock and other_sym.is_commonblock)): + continue + if other_sym.is_import and this_sym.is_import: # Both symbols are imported. That's fine as long as they have # the same import interface (are imported from the same @@ -945,11 +968,14 @@ def _add_symbols_from_table(self, other_table, symbols_to_skip=()): already been updated to refer to a Container in this table. ''' + for old_sym in other_table.symbols: - if old_sym in symbols_to_skip or isinstance(old_sym, - ContainerSymbol): - # We've dealt with Container symbols in _add_container_symbols. + if old_sym in symbols_to_skip or isinstance( + old_sym, (ContainerSymbol, CommonBlockSymbol)): + # ContainerSymbols are handled by _add_container_symbols; + # CommonBlockSymbols are handled by + # _add_commonblock_symbols_from_table. continue try: @@ -959,11 +985,68 @@ def _add_symbols_from_table(self, other_table, symbols_to_skip=()): # We have a clash with a symbol in this table. self._handle_symbol_clash(old_sym, other_table) + def get_common_block_groups(self) -> dict: + '''Return a dict mapping lower-cased COMMON-block name to the + lower-cased list of variable names for every COMMON block in this + table. + + The blank COMMON is mapped with the key ``""``. + + :returns: mapping of COMMON-block name to ordered list of variable + names. + :rtype: dict[str, list[str]] + ''' + return {s.name.lower(): [v.name.lower() for v in s.variables] + for s in self.common_block_symbols} + + def _handle_symbol_clash_common_block(self, old_sym: Symbol) -> bool: + ''' + Handles a name clash for COMMON-block related symbols. Called from + :py:meth:`_handle_symbol_clash` as soon as a COMMON-block symbol is + detected. Returns ``True`` if the clash has been fully resolved + (nothing more to do) or ``False`` if the generic rename-and-add path + should be followed instead. + + Two kinds of COMMON-block symbol are handled: + + * :py:class:`~psyclone.psyir.symbols.CommonBlockSymbol` (block + objects themselves): the clash has already been approved by + ``check_for_clashes``; the block is already declared — skip. + * Variables with a + :py:class:`~psyclone.psyir.symbols.CommonBlockInterface`: + the clash has already been approved by ``check_for_clashes``; + the variable is already in scope — skip. + + :param old_sym: the Symbol being added to self. + + :returns: ``True`` if the clash is resolved; ``False`` if the + generic rename-and-add path should be followed. + + ''' + try: + self_sym = self.lookup(old_sym.name) + except KeyError: + self_sym = None + + if self_sym is None: + return True + + if (isinstance(old_sym, CommonBlockSymbol) and + isinstance(self_sym, CommonBlockSymbol)): + # Same-named COMMON block already declared — skip. + return True + + if old_sym.is_commonblock and self_sym.is_commonblock: + # check_for_clashes has already approved this variable clash. + return True + + return False + def _handle_symbol_clash(self, old_sym, other_table): ''' - Adds the supplied Symbol to the current table in the presence - of a name clash. `check_for_clashes` MUST have been called - prior to this method in order to check for any unresolvable cases. + Adds the supplied Symbol to the current table in the presence of a + name clash. ``check_for_clashes`` MUST have been called prior to this + method in order to check for any unresolvable cases. :param old_sym: the Symbol to be added to self. :type old_sym: :py:class:`psyclone.psyir.symbols.Symbol` @@ -976,6 +1059,10 @@ def _handle_symbol_clash(self, old_sym, other_table): check_for_clashes()). ''' + if old_sym.is_commonblock or isinstance(old_sym, CommonBlockSymbol): + if self._handle_symbol_clash_common_block(old_sym): + return + self_sym = self.lookup(old_sym.name) if old_sym.is_import: # The clashing symbol is imported from a Container and the table @@ -1055,8 +1142,15 @@ def merge(self, other_table, symbols_to_skip=()): # Deal with any Container symbols first. self._add_container_symbols_from_table(other_table) + # Deal with any CommonBlockSymbols next. + for cb_sym in other_table.common_block_symbols: + outer_sym = self.lookup(cb_sym.name, otherwise=None) + if not outer_sym: + self.add(cb_sym) + # Copy each Symbol from the supplied table into this one, excluding - # ContainerSymbols and any listed in `symbols_to_skip`. + # ContainerSymbols, CommonBlockSymbols, and any listed in + # `symbols_to_skip`. self._add_symbols_from_table(other_table, symbols_to_skip=symbols_to_skip) @@ -1221,6 +1315,28 @@ def lookup_with_tag(self, tag, scope_limit=None): raise KeyError(f"Could not find the tag '{tag}' in the Symbol " f"Table.") from err + def lookup_commonblock(self, name: str): + '''Look up a :py:class:`~psyclone.psyir.symbols.CommonBlockSymbol` + by COMMON-block name. Use ``""`` for the blank COMMON. + + :param str name: the COMMON-block name (empty string for blank COMMON). + + :returns: the CommonBlockSymbol with the given name. + :rtype: :py:class:`psyclone.psyir.symbols.CommonBlockSymbol` + + :raises KeyError: if no CommonBlockSymbol with the given name exists \ + in scope. + :raises KeyError: if a symbol with that name exists but is not a \ + CommonBlockSymbol. + + ''' + sym = self.lookup(name) + if not isinstance(sym, CommonBlockSymbol): + raise KeyError( + f"'{name}' is present in the Symbol Table but is not a " + f"CommonBlockSymbol.") + return sym + def __contains__(self, key): '''Check if the given key is part of the Symbol Table. @@ -1658,6 +1774,15 @@ def containersymbols(self): return [sym for sym in self.symbols if isinstance(sym, ContainerSymbol)] + @property + def common_block_symbols(self): + ''' + :returns: a list of the CommonBlockSymbols present in the Symbol Table. + :rtype: List[:py:class:`psyclone.psyir.symbols.CommonBlockSymbol`] + ''' + return [sym for sym in self.symbols if isinstance(sym, + CommonBlockSymbol)] + @property def datatypesymbols(self): ''' diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 904d689067..3a1981acc2 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -38,7 +38,7 @@ ''' -from typing import Dict, List, Optional +from typing import Dict, Optional from psyclone.core import SymbolicMaths from psyclone.errors import LazyString, InternalError @@ -55,6 +55,7 @@ DataSymbol, StructureType, SymbolError, + SymbolTable, UnresolvedType, UnsupportedType, UnsupportedFortranType, @@ -147,6 +148,7 @@ def apply(self, use_first_callee_and_no_arg_check: bool = False, permit_codeblocks: bool = False, permit_unsupported_type_args: bool = False, + parameter_cloning: bool = True, **kwargs ): ''' @@ -163,6 +165,13 @@ def apply(self, if the target routine contains a CodeBlock. :param permit_unsupported_type_args: If `True` then the target routine is permitted to have arguments of UnsupportedType. + :param parameter_cloning: if `True` (the default), constant + (PARAMETER) symbols from the routine being inlined are always + copied into the call-site symbol table, potentially being renamed + to avoid clashes. If `False`, a constant from the routine is + skipped when an identical constant (same name, same type, and same + value) already exists at the call site, so no duplicate is + created. :raises InternalError: if the merge of the symbol tables fails. In theory this should never happen because validate() should @@ -219,6 +228,20 @@ def apply(self, # just delete the if statement. self._optional_arg_eliminate_ifblock_if_const_condition(routine) + # If parameter_cloning is disabled, identify duplicate constant + # (PARAMETER) symbols and redirect their references *before* the + # routine body is extracted, so that the extracted statements already + # carry references to the call-site symbols. + extra_skip: list[DataSymbol] = [] + if not parameter_cloning: + extra_skip = self._redirect_duplicate_parameters( + table, routine) + + # Redirect references to COMMON-block variables that are aliased + # (same block position, different name) to the caller's symbol, + # and exclude the now-unreferenced callee symbols from the merge. + extra_skip += self._redirect_common_block_aliases(table, routine) + # Construct lists of the nodes that will be inserted and all of the # References that they contain. new_stmts = [] @@ -231,7 +254,8 @@ def apply(self, # call site. This preserves any references to them. try: table.merge(routine_table, - symbols_to_skip=routine_table.argument_list[:]) + symbols_to_skip=routine_table.argument_list[:] + + extra_skip) except SymbolError as err: raise InternalError( f"Error copying routine symbols to call site. This should " @@ -329,9 +353,160 @@ def apply(self, idx += 1 parent.addchild(child, idx) + def _redirect_duplicate_parameters( + self, + table, + routine: Routine, + ) -> list[DataSymbol]: + ''' + Identifies constant (PARAMETER) symbols in ``routine_table`` that + are identical to constants already present in ``table`` (same name, + same type, and same initial value). For each such symbol, every + :py:class:`~psyclone.psyir.nodes.Reference` to it inside ``routine`` + and inside the datatypes / initial-value expressions of other symbols + in ``routine_table`` is redirected to point to the corresponding + symbol in ``table``. + + Only constants whose initial value is represented as a PSyIR node + (i.e. ``initial_value is not None``) are considered; constants of + ``UnsupportedFortranType`` with an embedded value string are left + unchanged. + + A constant is only considered a duplicate when every routine-local + symbol referenced inside its initial-value expression is itself a + confirmed duplicate. This prevents false positives for expressions + like ``negflag = .NOT. flag`` when ``flag`` has different values in + the caller and the callee (the names would match but the semantics + would differ). + + :param table: the call-site symbol table. + :param routine: the (copy of the) routine being inlined. + + :returns: the list of symbols that are duplicates of + call-site constants and should be excluded from the subsequent + table merge. + + ''' + routine_table: SymbolTable = routine.symbol_table + # The names of all local data symbols in the routine table (used to + # identify references that point to routine-local constants). + routine_local_names = { + s.name.lower() for s in routine_table.datasymbols + if not s.is_automatic + } + + # First pass: collect all constants from the routine whose name, + # datatype, and initial-value tree match a constant in the call-site + # table. The structural comparison uses __eq__, which compares + # Reference nodes by symbol name. This is correct for leaf constants + # (Literals) and is refined for dependent constants in the second + # pass below. + candidates: dict = {} + for rsym in routine_table.datasymbols: + if not rsym.is_constant or rsym.initial_value is None: + # Skip constants whose value is not represented as a PSyIR + # node (e.g. UnsupportedFortranType with embedded value). + continue + tsym = table.lookup(rsym.name, otherwise=None) + if not isinstance(tsym, DataSymbol): + continue + if not tsym.is_constant or tsym.initial_value is None: + continue + if rsym.datatype != tsym.datatype: + continue + if rsym.initial_value != tsym.initial_value: + continue + candidates[rsym.name.lower()] = rsym + + # Second pass: iteratively remove candidates whose initial-value + # expression references a routine-local symbol that is NOT itself + # a confirmed duplicate. Without this step, an expression like + # ``negflag = .NOT. flag`` would compare as equal by name even when + # ``flag`` has different values in the two routines. + changed = True + while changed: + changed = False + to_remove = [ + name for name, rsym in candidates.items() + if any( + dep.name.lower() in routine_local_names + and dep.name.lower() not in candidates + for dep in rsym.initial_value.get_all_accessed_symbols() + ) + ] + for name in to_remove: + del candidates[name] + if to_remove: + changed = True + + duplicates: list[DataSymbol] = list(candidates.values()) + + # Redirect all references from duplicate symbols in the routine to + # their call-site counterparts. + for rsym in duplicates: + tsym = table.lookup(rsym.name) + # Update all References in the routine body. + routine.replace_symbols_using(tsym) + # Update any references to rsym embedded in the datatypes or + # initial-value expressions of other symbols in routine_table. + for sym in routine_table.symbols: + if sym is rsym: + continue + sym.replace_symbols_using(tsym) + + return duplicates + + def _redirect_common_block_aliases( + self, + table: SymbolTable, + routine: Routine, + ) -> list[DataSymbol]: + '''Redirect references to COMMON-block variables in *routine* that are + aliased to differently-named variables in the caller *table* (same + block, same position). + + For each such pair the caller's symbol is substituted for the + callee's symbol in every :py:class:`~psyclone.psyir.nodes.Reference` + inside *routine*. The callee symbols that have been redirected are + returned so they can be excluded from the subsequent symbol-table + merge (they no longer have any live references). + + The types of each aliased pair must already have been verified to be + compatible by :py:meth:`validate`. + + :param table: the call-site symbol table. + :param routine: the (copy of the) routine being inlined. + + :returns: callee symbols whose references have been redirected and + that should therefore be skipped during the table merge. + ''' + routine_table = routine.symbol_table + caller_blocks = table.get_common_block_groups() + callee_blocks = routine_table.get_common_block_groups() + + symbols_to_skip = [] + for block_name, callee_vars in callee_blocks.items(): + if block_name not in caller_blocks: + continue + caller_vars = caller_blocks[block_name] + for caller_var_name, callee_var_name in zip( + caller_vars, callee_vars): + if caller_var_name.lower() == callee_var_name.lower(): + continue + # Replace all References to the callee's alias with the + # corresponding caller's symbol. + caller_sym = table.lookup(caller_var_name) + callee_sym = routine_table.lookup(callee_var_name) + for ref in routine.walk(Reference): + if ref.symbol is callee_sym: + ref.symbol = caller_sym + symbols_to_skip.append(callee_sym) + + return symbols_to_skip + def _optional_arg_resolve_present_intrinsics(self, routine_node: Routine, - arg_match_list: List = []): + arg_match_list: list = []): """Replace PRESENT(some_argument) intrinsics in routine with constant booleans depending on whether `some_argument` has been provided (`True`) or not (`False`). @@ -437,7 +612,7 @@ def _replace_formal_args_in_expr( self, expression: Node, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], routine_node: Routine, use_first_callee_and_no_arg_check: bool = False, ) -> Reference: @@ -515,7 +690,7 @@ def _replace_formal_args_in_expr( def _create_inlined_idx( self, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], local_idx: DataNode, decln_start: DataNode, actual_start: DataNode, @@ -615,10 +790,10 @@ def _update_actual_indices( actual_arg: ArrayMixin, local_ref: Reference, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], routine_node: Routine, use_first_callee_and_no_arg_check: bool = False, - ) -> List[Node]: + ) -> list[Node]: ''' Create a new list of indices for the supplied actual argument (ArrayMixin) by replacing any Ranges with the appropriate expressions @@ -731,7 +906,7 @@ def _generate_formal_arg_replacement( actual_arg: Reference, ref: Reference, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], routine_node: Routine, use_first_callee_and_no_arg_check: bool = False, ) -> Reference: @@ -964,6 +1139,9 @@ def validate( does not match that of the corresponding actual argument. :raises TransformationError: if one of the declarations in the routine depends on an argument that is written to prior to the call. + :raises TransformationError: if a COMMON block is declared in both + the caller and the routine being inlined with different variable + names and incompatible types. :raises InternalError: if an unhandled Node type is returned by Reference.previous_accesses(). @@ -1118,6 +1296,33 @@ def validate( f"{err.value}") from err routine_table = routine.symbol_table + # Check that COMMON blocks shared between the caller and the callee + # that use *different* variable names are still type-compatible. + # Different names at the same block position mean the two variables + # are memory aliases; that is acceptable as long as their types + # match (the actual reference-redirection happens in apply()). + caller_blocks = parent_routine.symbol_table.get_common_block_groups() + callee_blocks = routine_table.get_common_block_groups() + for block_name, callee_vars in callee_blocks.items(): + if block_name not in caller_blocks: + continue + caller_vars = caller_blocks[block_name] + for caller_var_name, callee_var_name in zip( + caller_vars, callee_vars): + if caller_var_name.lower() == callee_var_name.lower(): + continue + # Different names – check that the types are compatible. + caller_sym = parent_routine.symbol_table.lookup( + caller_var_name) + callee_sym = routine_table.lookup(callee_var_name) + if caller_sym.datatype != callee_sym.datatype: + raise TransformationError( + f"Cannot inline '{routine.name}' because COMMON " + f"block '/{block_name}/' maps '{caller_var_name}' " + f"(type '{caller_sym.datatype}') in the caller to " + f"'{callee_var_name}' (type '{callee_sym.datatype}')" + f" in the routine being inlined - the types are " + f"incompatible.") # Create a list of routine arguments that is actually used routine_arg_list = [ routine_table.argument_list[i] for i in arg_match_list @@ -1127,7 +1332,6 @@ def validate( routine_arg_list, node.arguments ): self._validate_inline_of_call_and_routine_argument_pairs( - call_node=node, call_arg=actual_arg, routine_node=routine, routine_arg=routine_arg @@ -1184,7 +1388,6 @@ def validate( def _validate_inline_of_call_and_routine_argument_pairs( self, - call_node: Call, call_arg: DataNode, routine_node: Routine, routine_arg: DataSymbol diff --git a/src/psyclone/tests/psyir/backend/fortran_common_block_test.py b/src/psyclone/tests/psyir/backend/fortran_common_block_test.py index 9a1b8fd084..a1ddb71fcf 100644 --- a/src/psyclone/tests/psyir/backend/fortran_common_block_test.py +++ b/src/psyclone/tests/psyir/backend/fortran_common_block_test.py @@ -61,6 +61,8 @@ def test_fw_common_blocks(fortran_reader, fortran_writer, tmpdir): routine = psyir.walk(Routine)[0] assert routine.symbol_table.lookup("a").is_commonblock # Sanity check + assert routine.symbol_table.lookup("d").is_commonblock # Sanity check + assert routine.symbol_table.lookup("e").is_commonblock # Sanity check code = fortran_writer(routine) assert code == ( @@ -71,8 +73,8 @@ def test_fw_common_blocks(fortran_reader, fortran_writer, tmpdir): " real :: d\n" " real :: e\n" " real :: f\n" - " COMMON /name1/ a, b\n" - " COMMON /name1/ c /name2/ d\n" + " COMMON /name1/ a, b, c\n" + " COMMON /name2/ d\n" " COMMON // e, f\n\n\n" "end subroutine sub\n") assert Compile(tmpdir).string_compiles(fortran_writer(psyir)) diff --git a/src/psyclone/tests/psyir/frontend/fparser2_common_block_test.py b/src/psyclone/tests/psyir/frontend/fparser2_common_block_test.py index 227dd6b999..0b2174aca1 100644 --- a/src/psyclone/tests/psyir/frontend/fparser2_common_block_test.py +++ b/src/psyclone/tests/psyir/frontend/fparser2_common_block_test.py @@ -38,11 +38,11 @@ import pytest from fparser.common.readfortran import FortranStringReader -from fparser.two.Fortran2003 import Specification_Part +from fparser.two.Fortran2003 import Common_Stmt, Specification_Part from psyclone.psyir.frontend.fparser2 import Fparser2Reader from psyclone.psyir.nodes import Routine from psyclone.psyir.symbols import ( - CommonBlockInterface, ScalarType, UnsupportedFortranType) + CommonBlockInterface, CommonBlockSymbol, ScalarType) @pytest.mark.usefixtures("f2008_parser") @@ -62,10 +62,10 @@ def test_named_common_block(): fparser2spec = Specification_Part(reader) processor.process_declarations(routine, fparser2spec.content, []) - # There is a name1 commonblock symbol - commonblock = symtab.lookup("_PSYCLONE_INTERNAL_COMMONBLOCK") - assert isinstance(commonblock.datatype, UnsupportedFortranType) - assert commonblock.datatype.declaration == "COMMON /name1/ a, b, c" + # There is a name1 CommonBlockSymbol + cb_sym = symtab.lookup_commonblock("name1") + assert isinstance(cb_sym, CommonBlockSymbol) + assert [v.name for v in cb_sym.variables] == ["a", "b", "c"] # The variables have been updated to a common block interface assert isinstance(symtab.lookup("a").interface, CommonBlockInterface) @@ -73,7 +73,7 @@ def test_named_common_block(): assert isinstance(symtab.lookup("c").interface, CommonBlockInterface) # The same common block can also bring other variables in a separate - # statement + # statement; the existing CommonBlockSymbol is reused. reader = FortranStringReader(''' real :: d, e real(kind=wp) :: f @@ -81,10 +81,8 @@ def test_named_common_block(): fparser2spec = Specification_Part(reader) processor.process_declarations(routine, fparser2spec.content, []) - # This is stored in a separate symbol, but the declaration has the right - # text - commonblock_2 = symtab.lookup("_PSYCLONE_INTERNAL_COMMONBLOCK_1") - assert commonblock_2.datatype.declaration == "COMMON /name1/ d, e, f" + # The same CommonBlockSymbol now also includes d, e, f + assert [v.name for v in cb_sym.variables] == ["a", "b", "c", "d", "e", "f"] assert isinstance(symtab.lookup("d").interface, CommonBlockInterface) assert isinstance(symtab.lookup("e").interface, CommonBlockInterface) fsym = symtab.lookup("f") @@ -108,10 +106,11 @@ def test_unnamed_commonblock(): fparser2spec = Specification_Part(reader) processor.process_declarations(routine, fparser2spec.content, []) - # There is an UnsupportedFortranType symbol containing the commonblock - commonblock = symtab.lookup("_PSYCLONE_INTERNAL_COMMONBLOCK") - assert isinstance(commonblock.datatype, UnsupportedFortranType) - assert commonblock.datatype.declaration == "COMMON // a, b, c" + # There is a blank-name CommonBlockSymbol for the blank common block + cb_sym = symtab.lookup_commonblock("") + assert isinstance(cb_sym, CommonBlockSymbol) + assert cb_sym.name == "" + assert [v.name for v in cb_sym.variables] == ["a", "b", "c"] # The variables have been updated to a common block interface assert isinstance(symtab.lookup("a").interface, CommonBlockInterface) @@ -137,13 +136,15 @@ def test_multiple_commonblocks_in_statement(): fparser2spec = Specification_Part(reader) processor.process_declarations(routine, fparser2spec.content, []) - # There is a UnsupportedFortranType symbol containing each the commonblock - commonblock = symtab.lookup("_PSYCLONE_INTERNAL_COMMONBLOCK") - assert isinstance(commonblock.datatype, UnsupportedFortranType) - assert commonblock.datatype.declaration == "COMMON /name1/ a, b /name2/ c" - commonblock = symtab.lookup("_PSYCLONE_INTERNAL_COMMONBLOCK_1") - assert isinstance(commonblock.datatype, UnsupportedFortranType) - assert commonblock.datatype.declaration == "COMMON /name2/ d" + # There is a CommonBlockSymbol for each block name + cb1 = symtab.lookup_commonblock("name1") + assert isinstance(cb1, CommonBlockSymbol) + assert [v.name for v in cb1.variables] == ["a", "b"] + + cb2 = symtab.lookup_commonblock("name2") + assert isinstance(cb2, CommonBlockSymbol) + # name2 was extended by the second COMMON statement + assert [v.name for v in cb2.variables] == ["c", "d"] # The variables have been updated to a common block interface assert isinstance(symtab.lookup("a").interface, CommonBlockInterface) @@ -155,7 +156,8 @@ def test_multiple_commonblocks_in_statement(): @pytest.mark.usefixtures("f2008_parser") def test_named_commonblock_with_posterior_declaration(): ''' Test that commonblocks with symbols that are declared after the - commonblock statement are handled correctly.''' + commonblock statement are handled correctly (process_declarations processes + COMMON blocks in a second pass, after all declarations).''' # Create a dummy test routine routine = Routine.create("test_routine") @@ -169,10 +171,10 @@ def test_named_commonblock_with_posterior_declaration(): fparser2spec = Specification_Part(reader) processor.process_declarations(routine, fparser2spec.content, []) - # There is an UnsupportedFortranType symbol containing the commonblock - commonblock = symtab.lookup("_PSYCLONE_INTERNAL_COMMONBLOCK") - assert isinstance(commonblock.datatype, UnsupportedFortranType) - assert commonblock.datatype.declaration == "COMMON /name1/ a, b" + # There is a CommonBlockSymbol for name1 + cb_sym = symtab.lookup_commonblock("name1") + assert isinstance(cb_sym, CommonBlockSymbol) + assert [v.name for v in cb_sym.variables] == ["a", "b"] # The variables have been updated to a common block interface assert isinstance(symtab.lookup("a").interface, CommonBlockInterface) @@ -202,24 +204,28 @@ def test_undeclared_symbol(): @pytest.mark.usefixtures("f2008_parser") def test_commonblock_with_explicit_array_shape_symbol(): - ''' Test that commonblocks with an explicit-shape-spec-list - produce NotImplementedError.''' + ''' Test that commonblocks with an explicit-shape-spec-list in the COMMON + statement are handled correctly: the array dimension is stripped by + get_block_groups() and the bare variable name is looked up. ''' # Create a dummy test routine routine = Routine.create("test_routine") + symtab = routine.symbol_table processor = Fparser2Reader() - # This is also valid Fortran, but currently not supported + # The Fortran standard allows specifying array bounds in the COMMON + # statement itself; get_block_groups() strips the "(10, 4)" suffix. reader = FortranStringReader(''' integer :: a common /name1/ a (10, 4)''') - fparser2spec = Specification_Part(reader) - with pytest.raises(NotImplementedError) as err: - processor.process_declarations(routine, fparser2spec.content, []) - assert ("The symbol interface of a common block variable could not be " - "updated because of \"Could not find 'a(10, 4)' in the Symbol " - "Table.\"." in str(err.value)) + processor.process_declarations(routine, fparser2spec.content, []) + + # The variable 'a' is found and its interface is set correctly. + cb_sym = symtab.lookup_commonblock("name1") + assert isinstance(cb_sym, CommonBlockSymbol) + assert [v.name for v in cb_sym.variables] == ["a"] + assert isinstance(symtab.lookup("a").interface, CommonBlockInterface) @pytest.mark.usefixtures("f2008_parser") @@ -240,3 +246,35 @@ def test_commonblock_with_explicit_init_symbol(): processor.process_declarations(routine, fparser2spec.content, []) assert ("Symbol 'a' has an initial value (10) but appears in a common " "block." in str(err.value)) + + +@pytest.mark.usefixtures("f2008_parser") +def test_get_common_block_groups(): + '''Tests for Fparser2Reader._get_common_block_groups().''' + + # Blank common (no block name) with a single variable. + obj = Common_Stmt("common a") + assert Fparser2Reader._get_common_block_groups(obj) == [(None, ['a'])] + + # Blank common with an explicit // and multiple variables. + obj = Common_Stmt("common // a, b") + assert Fparser2Reader._get_common_block_groups(obj) == [(None, ['a', 'b'])] + + # Named common block. + obj = Common_Stmt("common /myblock/ x, y") + assert Fparser2Reader._get_common_block_groups(obj) == [ + ('myblock', ['x', 'y'])] + + # Array variables: the dimension spec is stripped, only the bare name + # is returned. + obj = Common_Stmt("common /name/ a, b(4,5)") + assert Fparser2Reader._get_common_block_groups(obj) == [ + ('name', ['a', 'b'])] + + # Multiple block groups in a single statement. + obj = Common_Stmt("common /name/ a, b(4,5) // c, /ljuks/ g(2)") + assert Fparser2Reader._get_common_block_groups(obj) == [ + ('name', ['a', 'b']), + (None, ['c']), + ('ljuks', ['g']), + ] diff --git a/src/psyclone/tests/psyir/symbols/commonblocksymbol_test.py b/src/psyclone/tests/psyir/symbols/commonblocksymbol_test.py new file mode 100644 index 0000000000..5ffe1f0e46 --- /dev/null +++ b/src/psyclone/tests/psyir/symbols/commonblocksymbol_test.py @@ -0,0 +1,211 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2017-2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Authors R. W. Ford, A. R. Porter, S. Siso and N. Nobre, STFC Daresbury Lab +# I. Kavcic, Met Office +# J. Henrichs, Bureau of Meteorology +# ----------------------------------------------------------------------------- + +'''Tests for the CommonBlockSymbol class.''' + +import pytest + +from psyclone.psyir.symbols import ( + CommonBlockSymbol, DataSymbol, ScalarType) +from psyclone.psyir.symbols.interfaces import CommonBlockInterface + +REAL_TYPE = ScalarType.real_type() + + +# --------------------------------------------------------------------------- +# __init__ +# --------------------------------------------------------------------------- + +def test_commonblocksymbol_init_named(): + '''A CommonBlockSymbol can be created with a non-empty name.''' + sym = CommonBlockSymbol("myblock") + assert sym.name == "myblock" + assert sym.variables == [] + + +def test_commonblocksymbol_init_blank(): + '''The blank common block is represented by the empty string.''' + sym = CommonBlockSymbol("") + assert sym.name == "" + assert sym.variables == [] + + +def test_commonblocksymbol_init_type_error(): + '''A non-str name raises TypeError.''' + with pytest.raises(TypeError) as err: + CommonBlockSymbol(42) + assert "CommonBlockSymbol 'name' attribute should be of type 'str'" in \ + str(err.value) + assert "'int' found" in str(err.value) + + +# --------------------------------------------------------------------------- +# variables property +# --------------------------------------------------------------------------- + +def test_commonblocksymbol_variables_returns_copy(): + '''variables returns a *copy* of the internal list, so mutating it does + not affect the symbol.''' + cb = CommonBlockSymbol("blk") + var = DataSymbol("x", REAL_TYPE) + cb.add_variable(var) + lst = cb.variables + lst.clear() + # Original must be unchanged. + assert cb.variables == [var] + + +# --------------------------------------------------------------------------- +# add_variable +# --------------------------------------------------------------------------- + +def test_commonblocksymbol_add_variable(): + '''add_variable appends DataSymbols in insertion order.''' + cb = CommonBlockSymbol("blk") + x = DataSymbol("x", REAL_TYPE) + y = DataSymbol("y", REAL_TYPE) + cb.add_variable(x) + cb.add_variable(y) + assert cb.variables == [x, y] + + +def test_commonblocksymbol_add_variable_type_error(): + '''add_variable rejects non-DataSymbol values.''' + cb = CommonBlockSymbol("blk") + with pytest.raises(TypeError) as err: + cb.add_variable("not_a_datasymbol") + assert "expected a DataSymbol" in str(err.value) + assert "'str'" in str(err.value) + + +def test_commonblocksymbol_add_variable_duplicate_error(): + '''add_variable raises ValueError when the same DataSymbol is added + twice.''' + cb = CommonBlockSymbol("blk") + var = DataSymbol("x", REAL_TYPE) + cb.add_variable(var) + with pytest.raises(ValueError) as err: + cb.add_variable(var) + assert "already in the variable list" in str(err.value) + + +# --------------------------------------------------------------------------- +# replace_variable +# --------------------------------------------------------------------------- + +def test_commonblocksymbol_replace_variable(): + '''replace_variable substitutes the old symbol at the same position.''' + cb = CommonBlockSymbol("blk") + x = DataSymbol("x", REAL_TYPE) + y = DataSymbol("y", REAL_TYPE) + x2 = DataSymbol("x", REAL_TYPE) + cb.add_variable(x) + cb.add_variable(y) + cb.replace_variable(x, x2) + assert cb.variables == [x2, y] + + +def test_commonblocksymbol_replace_variable_not_found(): + '''replace_variable raises ValueError when old_sym is not present.''' + cb = CommonBlockSymbol("blk") + x = DataSymbol("x", REAL_TYPE) + other = DataSymbol("z", REAL_TYPE) + cb.add_variable(x) + with pytest.raises(ValueError) as err: + cb.replace_variable(other, DataSymbol("z2", REAL_TYPE)) + assert "not in the variable list" in str(err.value) + + +# --------------------------------------------------------------------------- +# copy +# --------------------------------------------------------------------------- + +def test_commonblocksymbol_copy(): + '''copy produces a new CommonBlockSymbol with the same name but an + empty variable list (to be repopulated by replace_symbols_using).''' + cb = CommonBlockSymbol("blk") + var = DataSymbol("x", REAL_TYPE) + cb.add_variable(var) + + copy = cb.copy() + assert isinstance(copy, CommonBlockSymbol) + assert copy is not cb + assert copy.name == cb.name + # The copy starts with an empty variable list. + assert copy.variables == [] + + +def test_commonblocksymbol_copy_blank(): + '''copy also works for the blank common block (empty name).''' + cb = CommonBlockSymbol("") + copy = cb.copy() + assert copy.name == "" + assert copy.variables == [] + + +# --------------------------------------------------------------------------- +# __str__ +# --------------------------------------------------------------------------- + +def test_commonblocksymbol_str_named(): + '''__str__ wraps named blocks in /.../.''' + cb = CommonBlockSymbol("myblock") + assert str(cb) == "/myblock/: CommonBlockSymbol" + + +def test_commonblocksymbol_str_blank(): + '''__str__ uses // for the blank common block.''' + cb = CommonBlockSymbol("") + assert str(cb) == "//: CommonBlockSymbol" + + +# --------------------------------------------------------------------------- +# Integration: CommonBlockInterface linkage +# --------------------------------------------------------------------------- + +def test_commonblocksymbol_with_commonblockinterface(): + '''DataSymbols can reference a CommonBlockSymbol via CommonBlockInterface, + and the back-link is consistent.''' + cb = CommonBlockSymbol("myblock") + x = DataSymbol("x", REAL_TYPE) + x.interface = CommonBlockInterface(cb) + cb.add_variable(x) + + assert x.interface.common_block_symbol is cb + assert x in cb.variables + assert str(x.interface) == "CommonBlock(/myblock/)" diff --git a/src/psyclone/tests/psyir/symbols/interfaces_test.py b/src/psyclone/tests/psyir/symbols/interfaces_test.py index d0a2f316e8..edfa1f6d6b 100644 --- a/src/psyclone/tests/psyir/symbols/interfaces_test.py +++ b/src/psyclone/tests/psyir/symbols/interfaces_test.py @@ -45,7 +45,7 @@ AutomaticInterface, ArgumentInterface, CommonBlockInterface, DefaultModuleInterface, ImportInterface, PreprocessorInterface, StaticInterface, SymbolInterface, UnknownInterface, UnresolvedInterface) -from psyclone.psyir.symbols import ContainerSymbol +from psyclone.psyir.symbols import CommonBlockSymbol, ContainerSymbol def test_symbolinterface(): @@ -91,12 +91,36 @@ def test_static_interface(): def test_commonblockinterface(): - '''Test we can create an CommonBlockInterface instance and check its - __str__ value + '''Test we can create a CommonBlockInterface instance and check its + __str__ value, equality, and copy. ''' - interface = CommonBlockInterface() - assert str(interface) == "CommonBlock" + cb_sym = CommonBlockSymbol("myblock") + interface = CommonBlockInterface(cb_sym) + assert str(interface) == "CommonBlock(/myblock/)" + assert interface.common_block_symbol is cb_sym + + # Test __eq__ + cb_sym2 = CommonBlockSymbol("myblock") + iface2 = CommonBlockInterface(cb_sym2) + assert interface == iface2 + + cb_sym3 = CommonBlockSymbol("other") + iface3 = CommonBlockInterface(cb_sym3) + assert interface != iface3 + + # Test copy + iface_copy = interface.copy() + assert iface_copy.common_block_symbol is cb_sym + + # Blank common + blank = CommonBlockSymbol("") + blank_iface = CommonBlockInterface(blank) + assert str(blank_iface) == "CommonBlock(//)" + + # TypeError on bad argument + with pytest.raises(TypeError, match="CommonBlockSymbol"): + CommonBlockInterface("not_a_symbol") def test_unresolvedinterface(): diff --git a/src/psyclone/tests/psyir/symbols/symbol_table_test.py b/src/psyclone/tests/psyir/symbols/symbol_table_test.py index 6fc11d2c85..3ece5ab3b1 100644 --- a/src/psyclone/tests/psyir/symbols/symbol_table_test.py +++ b/src/psyclone/tests/psyir/symbols/symbol_table_test.py @@ -1337,6 +1337,125 @@ def test_handle_symbol_clash_imported_symbols(): "of the same name imported from 'Ridcully'" in str(err.value)) +def test_handle_symbol_clash_commonblock_same_declaration(): + '''Test that _handle_symbol_clash() ignores a duplicate CommonBlockSymbol + (same block name already present in the table).''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + + cb1 = symbols.CommonBlockSymbol("keep_me") + table1.add(cb1) + cb2 = symbols.CommonBlockSymbol("keep_me") + table2.add(cb2) + + # _handle_symbol_clash should silently ignore the duplicate block symbol. + table1._handle_symbol_clash(cb2, table2) + + cb_syms = [s for s in table1.symbols + if isinstance(s, symbols.CommonBlockSymbol)] + assert len(cb_syms) == 1 + assert cb_syms[0].name == "keep_me" + + +def test_add_symbols_from_table_commonblock_same_name(): + '''Test that _add_symbols_from_table() skips CommonBlockSymbols (they + are handled by merge() via a separate explicit step).''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + + table1.add(symbols.CommonBlockSymbol("ocean")) + table2.add(symbols.CommonBlockSymbol("ocean")) + + # _add_symbols_from_table skips CommonBlockSymbols entirely. + table1._add_symbols_from_table(table2) + + cb_syms = [s for s in table1.symbols + if isinstance(s, symbols.CommonBlockSymbol)] + assert len(cb_syms) == 1 + + +def test_add_symbols_from_table_commonblock_distinct_blocks(): + '''Test that _add_symbols_from_table() skips CommonBlockSymbols (they + are handled by merge() via a separate explicit step), even when the + two tables have CommonBlockSymbols with different names.''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + + table1.add(symbols.CommonBlockSymbol("first")) + table2.add(symbols.CommonBlockSymbol("second")) + + # _add_symbols_from_table skips CommonBlockSymbols. + table1._add_symbols_from_table(table2) + + # "second" was NOT added by _add_symbols_from_table. + assert table1.lookup("second", otherwise=None) is None + + +def test_get_common_block_groups(): + '''Tests for SymbolTable.get_common_block_groups().''' + stype = symbols.ScalarType( + symbols.ScalarType.Intrinsic.REAL, symbols.ScalarType.Precision.SINGLE) + + table = symbols.SymbolTable() + + # An empty table returns an empty dict. + assert table.get_common_block_groups() == {} + + # Single named block with two variables. + cb = symbols.CommonBlockSymbol("myblock") + table.add(cb) + x = symbols.DataSymbol("x", stype) + x.interface = symbols.CommonBlockInterface(cb) + table.add(x) + cb.add_variable(x) + y = symbols.DataSymbol("y", stype) + y.interface = symbols.CommonBlockInterface(cb) + table.add(y) + cb.add_variable(y) + result = table.get_common_block_groups() + assert result == {"myblock": ["x", "y"]} + + # Blank COMMON is mapped to the empty-string key. + table2 = symbols.SymbolTable() + blank_cb = symbols.CommonBlockSymbol("") + table2.add(blank_cb) + a = symbols.DataSymbol("a", stype) + a.interface = symbols.CommonBlockInterface(blank_cb) + table2.add(a) + blank_cb.add_variable(a) + b = symbols.DataSymbol("b", stype) + b.interface = symbols.CommonBlockInterface(blank_cb) + table2.add(b) + blank_cb.add_variable(b) + result2 = table2.get_common_block_groups() + assert result2 == {"": ["a", "b"]} + + # Non-COMMON symbols are ignored. + table5 = symbols.SymbolTable() + table5.add(symbols.DataSymbol( + "z", symbols.UnsupportedFortranType("real :: z"))) + assert table5.get_common_block_groups() == {} + + +def test_handle_symbol_clash_unsupported_fortran_non_commonblock_name(): + '''Test that a clash between UnsupportedFortranType symbols with names + unrelated to common-block markers takes the standard rename-and-add path. + ''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + table1.add(symbols.DataSymbol( + "clash", symbols.UnsupportedFortranType("type(t1) :: clash"))) + table2.add(symbols.DataSymbol( + "clash", symbols.UnsupportedFortranType("type(t2) :: clash"))) + + old_sym = table2.lookup("clash") + table1._handle_symbol_clash(old_sym, table2) + + assert old_sym.name != "clash" + assert any(sym.datatype.declaration == "type(t2) :: clash" + for sym in table1.symbols) + + def test_swap_symbol_properties(): ''' Test the symboltable swap_properties method ''' # pylint: disable=too-many-statements @@ -2917,8 +3036,10 @@ def test_rename_symbol_errors(): "and as such may be named in a Call." in str(err.value)) # Cannot rename a common block symbol + cb_sym = symbols.CommonBlockSymbol("myblock") + table.add(cb_sym) asym = symbols.DataSymbol("a", symbols.ScalarType.integer_type(), - interface=symbols.CommonBlockInterface()) + interface=symbols.CommonBlockInterface(cb_sym)) table.add(asym) with pytest.raises(symbols.SymbolError) as err: table.rename_symbol(asym, "b") diff --git a/src/psyclone/tests/psyir/symbols/symbol_test.py b/src/psyclone/tests/psyir/symbols/symbol_test.py index 6e42a08e12..644dfbd3f1 100644 --- a/src/psyclone/tests/psyir/symbols/symbol_test.py +++ b/src/psyclone/tests/psyir/symbols/symbol_test.py @@ -52,7 +52,7 @@ from psyclone.errors import InternalError from psyclone.psyir.nodes import Container, Literal, KernelSchedule, Reference from psyclone.psyir.symbols import ( - ArgumentInterface, ContainerSymbol, + ArgumentInterface, ContainerSymbol, CommonBlockSymbol, DataSymbol, ImportInterface, DefaultModuleInterface, StaticInterface, ScalarType, AutomaticInterface, CommonBlockInterface, NoType, RoutineSymbol, Symbol, SymbolError, UnknownInterface, @@ -163,7 +163,7 @@ def test_symbol_interface_setter_and_is_properties(): assert not symbol.is_commonblock assert not symbol.is_unknown_interface - symbol.interface = CommonBlockInterface() + symbol.interface = CommonBlockInterface(CommonBlockSymbol("myblock")) assert not symbol.is_automatic assert not symbol.is_import assert not symbol.is_argument diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index f305374efc..6371b677cf 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -2223,12 +2223,43 @@ def test_validate_array_reshape(fortran_reader): sub_s = psyir.walk(Routine)[1] with pytest.raises(TransformationError) as err: inline_trans._validate_inline_of_call_and_routine_argument_pairs( - call, call.arguments[0], + call.arguments[0], sub_s, sub_s.symbol_table.lookup("x")) assert ("actual argument 'a(:,:)' has rank 2 but the corresponding formal " "argument, 'x', has rank 1" in str(err.value)) +def test_validate_unknown_type_array_arg(fortran_reader): + '''Test that _validate_inline_of_call_and_routine_argument_pairs rejects + an attempt to inline a call when the actual argument has an unknown type + but the corresponding formal argument is an array.''' + code = """\ +module test_mod +contains +subroutine main + use some_mod, only: mystery + call sub(mystery) +end subroutine +subroutine sub(x) + real, dimension(10), intent(inout) :: x + x(:) = 0.0 +end subroutine +end module +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + sub = psyir.walk(Routine)[1] + inline_trans = InlineTrans() + with pytest.raises(TransformationError) as err: + inline_trans._validate_inline_of_call_and_routine_argument_pairs( + call.arguments[0], sub, sub.symbol_table.lookup("x")) + assert ( + "Routine 'sub' cannot be inlined because the type of the actual " + "argument 'mystery' corresponding to an array formal argument " + "('x') is unknown." in str(err.value) + ) + + def test_validate_array_arg_expression(fortran_reader): ''' Check that validate rejects a call if an argument corresponding to @@ -2875,3 +2906,702 @@ def test_apply_array_access_check_unresolved_override_option( inline_trans.apply( call, use_first_callee_and_no_arg_check=True) # TODO check results + + +def test_apply_common_block_no_duplicate( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining two routines that share a COMMON block does not + produce duplicate COMMON declarations (which would cause a Fortran compile + error "Symbol X is already in a COMMON block").''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + call sub1() + call sub2() + end subroutine caller + subroutine sub1() + real :: volume, lmmpi + COMMON /blk/ volume, lmmpi + volume = 1.0 + end subroutine sub1 + subroutine sub2() + real :: volume, lmmpi + COMMON /blk/ volume, lmmpi + lmmpi = 2.0 + end subroutine sub2 +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # Exactly one COMMON declaration must appear. + assert result.count("COMMON /blk/") == 1 + # Both variables must still be present. + assert "volume" in result + assert "lmmpi" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_no_duplicate_three_routines( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining three routines that all share the same COMMON block + does not produce duplicate COMMON declarations. This mirrors the real-world + case of inlining zetabc_tile, u2dbc_tile and v2dbc_tile (each of which + includes the same set of COMMON-block headers) into step2D_FB_tile.''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + call sub1() + call sub2() + call sub3() + end subroutine caller + subroutine sub1() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + zeta = 1.0 + end subroutine sub1 + subroutine sub2() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + ubar = 2.0 + end subroutine sub2 + subroutine sub3() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + vbar = 3.0 + end subroutine sub3 +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # Each COMMON block must appear exactly once. + assert result.count("COMMON /ocean_zeta/") == 1 + assert result.count("COMMON /ocean_ubar/") == 1 + assert result.count("COMMON /ocean_vbar/") == 1 + # All three variables must still be present. + assert "zeta" in result + assert "ubar" in result + assert "vbar" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_caller_has_extra_block( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining a routine whose only COMMON block is already present + in the caller does not produce a duplicate COMMON declaration, even when + the caller also has an *additional* COMMON block that the inlined routine + does not declare. This is a regression test derived from the real-world + test.f file: the presence of the extra /comm_setup_mpi1/ block in the + caller was enough to confuse the earlier deduplication logic and caused + "Symbol 'zeta' at (1) is already in a COMMON block".''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + integer :: lmmpi + COMMON /comm_setup_mpi1/ lmmpi + real :: zeta + COMMON /ocean_zeta/ zeta + call subfoo() + end subroutine caller + subroutine subfoo() + real :: zeta + COMMON /ocean_zeta/ zeta + zeta = zeta + 1.0 + end subroutine subfoo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # /ocean_zeta/ must appear exactly once – not duplicated. + assert result.count("COMMON /ocean_zeta/") == 1 + # The extra block from the caller must be preserved. + assert result.count("COMMON /comm_setup_mpi1/") == 1 + assert "zeta" in result + assert "lmmpi" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_accept_different_names( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining is accepted when the same COMMON block is declared + in both the caller and the callee with different variable names but the + same type. The callee's variable ('height') is an alias of the caller's + variable ('depth') at the same block position, so all references to + 'height' inside the inlined body must be replaced by 'depth'. + ''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + COMMON /ocean/ depth + real :: depth + integer :: b + call subfoo(b) + end subroutine caller + subroutine subfoo(a) + COMMON /ocean/ height + real :: height + integer :: a + + a = height + end subroutine subfoo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(psyir) + # 'height' must have been replaced by 'depth' (the caller's alias). + assert """\ + subroutine caller() + real :: depth + integer :: b + COMMON /ocean/ depth + + b = depth + + end subroutine caller +""" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_reject_due_to_different_types( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining is rejected when the same COMMON block is declared + in both the caller and the callee with different variable names and + different types. + ''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + COMMON /ocean/ depth + real(kind=4) :: depth + integer :: b + call subfoo(b) + end subroutine caller + subroutine subfoo(a) + COMMON /ocean/ height + real(kind=8) :: height + integer :: a + + a = 3 + end subroutine subfoo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + with pytest.raises(TransformationError) as err: + # Should raise a TransformationError because + # the types of the common-block variables differ. + trans.apply(calls[0]) + assert ("Cannot inline 'subfoo' because COMMON block '/ocean/' maps" + " 'depth' (type 'Scalar]>') in the caller to 'height'" + " (type 'Scalar]>') in the routine being inlined - the types" + " are incompatible.") in str(err.value) + + +def test_apply_parameter_cloning_default( + fortran_reader, fortran_writer, tmp_path): + '''Test that the default behaviour (parameter_cloning=True) clones a + constant from the inlined routine into the call-site table, even when + an identical constant already exists there, potentially renaming it.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 123.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo) + + result = fortran_writer(bar) + # With cloning enabled the inlined constant must appear at least once; + # it may be renamed to avoid the clash. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_identical(fortran_reader, + fortran_writer, tmp_path): + '''Test that parameter_cloning=False suppresses the duplicate when the + call-site already has an identical constant (same name, type, value). + This is the main use-case from the user request.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 123.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval should be declared exactly once (no duplicate parameter). + assert result.count("parameter :: constval") == 1 + # The inlined assignment should still use constval correctly. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_different_value( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False does NOT suppress a parameter when + the values differ between the call site and the inlined routine.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 42.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # Both constant declarations must survive since they have different values. + assert result.count("constval") >= 2 + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_no_match_in_caller( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False still adds a constant that does not + exist at the call site.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval from foo must be added to bar because bar didn't have it. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_used_in_array_dim( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False correctly handles a constant that + is used as an array-dimension bound inside the inlined routine.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: n = 5 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: n = 5 + real, dimension(n) :: tmp + integer :: a + tmp(1) = real(a) + a = int(tmp(1)) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # n should appear only once as a parameter declaration. + assert result.count(", parameter ::", result.lower().find("n =")) <= 1 \ + or result.count("n = 5") == 1 + # The inlined array tmp should still be present and use n. + assert "tmp" in result + assert "n" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_multiple_params( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False with multiple constants, some matching + and some not.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: shared = 10 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: shared = 10 + integer, parameter :: local_only = 99 + integer :: a + a = shared + local_only + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # shared must be declared exactly once (no duplicate parameter). + assert result.count("parameter :: shared") == 1 + # local_only is unique to foo, so it must be added to bar. + assert "local_only" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_complex_rhs_identical( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False with constants whose value is a complex + PSyIR expression (BinaryOperation) that is identical in the caller and the + routine. The duplicate should be suppressed.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # Neither base_val nor constval should be duplicated. + assert result.count("parameter :: constval") == 1 + assert result.count("parameter :: base_val") == 1 + # The inlined body should still reference constval. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_complex_rhs_different( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False with constants that have identical names + but different complex RHS expressions. Both declarations must be kept.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 100 + base_val + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval has different values in bar and foo, so both must appear. + assert result.count("parameter :: constval") >= 2 or ( + "constval" in result and "constval_1" in result) + # base_val is identical and should be deduplicated. + assert result.count("parameter :: base_val") == 1 + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_unary_op_different_base( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False where .NOT. parameters share a name but + their base parameter differs. The derived constant must NOT be deduplicated + because the structural match is only nominal (the base has different + values), and using the caller's copy would produce wrong semantics.''' + code = """\ +module test_mod +contains + subroutine bar(b) + logical, parameter :: flag = .true. + logical, parameter :: negflag = .not. flag + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + logical, parameter :: flag = .false. + logical, parameter :: negflag = .not. flag + integer :: a + if (negflag) a = 42 + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # flag has different values so both must appear (foo's renamed). + assert result.count("parameter :: flag") >= 2 or "flag_1" in result + # negflag depends on flag which differs, so foo's negflag must also + # appear (renamed), and the inlined if must use foo's (renamed) negflag. + assert "negflag_1" in result + assert "if (negflag_1)" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_type_clash( + fortran_reader, fortran_writer): + '''Test with parameter types that don't match.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: wp = kind(1.0d0) + real(kind=wp), parameter :: pi = 3.14592 + real :: tmp + integer :: b + + tmp = wp + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: wp = kind(1.0) + integer :: a + + a = 42 * wp + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + assert """\ +subroutine bar(b) + integer, parameter :: wp = KIND(1.0d0) + real(kind=wp), parameter :: pi = 3.14592 + integer, parameter :: wp_1 = KIND(1.0) + integer :: b + real :: tmp + + tmp = wp + b = 42 * wp_1 + +end subroutine bar""" in result + + +def test_apply_parameter_cloning_false_caller_has_non_constant( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False does NOT suppress a routine constant + when the call-site has a symbol with the same name that is not a constant + (i.e. tsym.is_constant is False). This exercises the + ``if not tsym.is_constant or tsym.initial_value is None`` branch in + _redirect_duplicate_parameters.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer :: constval + integer :: b + constval = 7 + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: constval = 10 + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # bar's constval is a variable; foo's is a parameter. They are not + # duplicates, so foo's parameter constant must appear (possibly renamed). + assert """\ +subroutine bar(b) + integer, parameter :: constval_1 = 10 + integer :: b + integer :: constval + + constval = 7 + b = constval_1 + +end subroutine bar""" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_different_datatype( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False does NOT suppress a routine constant + when the call-site has a constant with the same name but a different + datatype. This exercises the ``if rsym.datatype != tsym.datatype`` + branch in _redirect_duplicate_parameters.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: constval = 10 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 10.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # bar has integer constval=10, foo has real constval=10.0. Different + # types so the routine's parameter must be added (renamed) rather than + # deduplicated. + assert """\ +subroutine bar(b) + integer, parameter :: constval = 10 + real, parameter :: constval_1 = 10.0 + integer :: b + + b = INT(constval_1) + +end subroutine bar""" in result + assert Compile(tmp_path).string_compiles(result)