Skip to content

Commit

Permalink
Improve inheritance of profile flags with linkers and wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
hiker committed Feb 15, 2025
1 parent 21ba96c commit fee307c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 3 deletions.
7 changes: 6 additions & 1 deletion source/fab/tools/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def profile_flags(self) -> ProfileFlags:
''':returns; the ProfileFlags for this compiler.'''
return self._profile_flags

def get_profile_flags(self, profile: str) -> List[str]:
''':returns; the ProfileFlags for the given profile.
:param profile: the profile to use.'''
return self._profile_flags[profile]

@property
def mpi(self) -> bool:
''':returns: whether this compiler supports MPI or not.'''
Expand Down Expand Up @@ -132,7 +137,7 @@ def compile_file(self, input_file: Path,
f"instead.")
params += add_flags

params.extend(self.profile_flags[config.profile])
params.extend(self.get_profile_flags(config.profile))

params.extend([input_file.name,
self._output_flag, str(output_file)])
Expand Down
6 changes: 6 additions & 0 deletions source/fab/tools/compiler_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def has_syntax_only(self) -> bool:
raise RuntimeError(f"Compiler '{self._compiler.name}' has "
f"no has_syntax_only.")

def get_profile_flags(self, profile: str) -> List[str]:
''':returns; the ProfileFlags for the given profile, combined
from the wrapped compiler and this wrapper.
:param profile: the profile to use.'''
return self._compiler.get_profile_flags(profile) + self._profile_flags[profile]

def set_module_output_path(self, path: Path):
'''Sets the output path for modules.
Expand Down
12 changes: 11 additions & 1 deletion source/fab/tools/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ def output_flag(self) -> str:
'''
return self._compiler.output_flag

def get_profile_flags(self, profile: str) -> List[str]:
''':returns; the ProfileFlags for the given profile, combined
from the wrapped compiler and this wrapper.
:param profile: the profile to use.'''
if self._linker:
flags = self._linker.get_profile_flags(profile)[:]
else:
flags = []
return flags + self._compiler.get_profile_flags(profile)

def get_lib_flags(self, lib: str) -> List[str]:
'''Gets the standard flags for a standard library
Expand Down Expand Up @@ -201,7 +211,7 @@ def link(self, input_files: List[Path], output_file: Path,
# TODO: For now we pick up both flags from ProfileFlags and the
# standard ones from Tool.
params.extend(self._compiler.flags)
params.extend(self._compiler.profile_flags[config.profile])
params.extend(self._compiler.get_profile_flags(config.profile))

if config.openmp:
params.append(self._compiler.openmp_flag)
Expand Down
25 changes: 24 additions & 1 deletion tests/unit_tests/tools/test_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import pytest

from fab.tools import (Category, Linker, ToolRepository)
from fab.tools import Category, CompilerWrapper, Linker, ToolRepository


def test_linker(mock_c_compiler, mock_fortran_compiler):
Expand Down Expand Up @@ -369,3 +369,26 @@ def test_linker_inheriting():
with pytest.raises(RuntimeError) as err:
linker_mpif90.get_lib_flags("does_not_exist")
assert "Unknown library name: 'does_not_exist'" in str(err.value)


def test_linker_profile_flags_inheriting(mock_c_compiler):
'''Test nested compiler and nested linker with inherited profiling flags.
'''
mock_c_compiler_wrapper = CompilerWrapper(name="mock_c_compiler_wrapper",
compiler=mock_c_compiler,
exec_name="exec_name")
linker = Linker(mock_c_compiler_wrapper)
linker_wrapper = Linker(mock_c_compiler_wrapper, linker=linker)
count = 0
for pf in [mock_c_compiler._profile_flags,
mock_c_compiler_wrapper._profile_flags]:
pf.define_profile("base")
pf.define_profile("derived", "base")
pf.add_flags("base", f"-f{count}")
pf.add_flags("derived", f"-f{count+1}")
count += 2

# One set f1-f4 from the compiler wrapper, one from the wrapped linker
assert (linker_wrapper.get_profile_flags("derived") ==
["-f0", "-f1", "-f2", "-f3", "-f0", "-f1", "-f2", "-f3"])

0 comments on commit fee307c

Please sign in to comment.