diff --git a/benchmark/test/test_framework.py.in b/benchmark/test/test_framework.py.in index fff93548ad6..5f983d22255 100644 --- a/benchmark/test/test_framework.py.in +++ b/benchmark/test/test_framework.py.in @@ -7,50 +7,59 @@ import re import pathlib import sys -sourcepath = pathlib.Path("@CMAKE_CURRENT_SOURCE_DIR@") -binpath = pathlib.Path("@PROJECT_BINARY_DIR@") -generate = False -if len(sys.argv) > 2 and sys.argv[2] == "--generate": - generate = True - - -denumberify_paths = [ - "time", - "bandwidth", - "flops", - "components", - "residual_norm", - "rhs_norm", - "max_relative_norm2", -] -empty_string_paths = ["error"] -empty_array_paths = [ - "recurrent_residuals", - "true_residuals", - "implicit_residuals", - "iteration_timestamps", -] - - -def sanitize_json_single(key, value, sanitize_all): +if __name__ == "test_framework": + sourcepath = pathlib.Path("@CMAKE_CURRENT_SOURCE_DIR@") + binpath = pathlib.Path("@PROJECT_BINARY_DIR@") + generate = False + if len(sys.argv) > 2 and sys.argv[2] == "--generate": + generate = True + denumberify_paths = [ + "time", + "bandwidth", + "flops", + "components", + "residual_norm", + "rhs_norm", + "max_relative_norm2", + ] + empty_string_paths = ["error"] + empty_array_paths = [ + "recurrent_residuals", + "true_residuals", + "implicit_residuals", + "iteration_timestamps", + ] + + +def sanitize_json_key_value(key: str, value, sanitize_all: bool): + """Applies sanitation to a single key-value pair. + + Strings with a key in empty_string_paths will be emptied + Numbers with a key in denumberify_paths will be set to 1.0 + + """ if key in empty_string_paths and isinstance(value, str): return "" if key in denumberify_paths and isinstance(value, float): return 1.0 - if key in denumberify_paths and isinstance(value, typing.Dict): + if key in denumberify_paths and isinstance(value, dict): return sanitize_json(value, True) - if key in empty_array_paths and isinstance(value, typing.List): + if key in empty_array_paths and isinstance(value, list): return [] return sanitize_json(value, sanitize_all) def sanitize_json(parsed_input, sanitize_all=False): - if isinstance(parsed_input, typing.Dict): + """Removes non-deterministic parts of a parsed JSON input. + + If sanitize_all is set to True, all nested float values will be set to 0. + Otherwise, the sanitation""" + if isinstance(parsed_input, dict): return { - key: sanitize_json_single(key, value, sanitize_all) + key: sanitize_json_key_value(key, value, sanitize_all) for key, value in parsed_input.items() } - elif isinstance(parsed_input, typing.List): + elif isinstance(parsed_input, list): return [sanitize_json(e, sanitize_all) for e in parsed_input] elif sanitize_all and isinstance(parsed_input, float): return 1.0 @@ -58,7 +67,15 @@ def sanitize_json(parsed_input, sanitize_all=False): return parsed_input -def sanitize_text(lines): +def sanitize_json_in_text(lines: list[str]) -> list[str]: + """Sanitizes all occurrences of JSON content inside text input. + + Takes a list of text lines and detects any pretty-printed JSON output inside + (recognized by a single [, {, } or ] in an otherwise empty line). + The JSON output will be parsed and sanitized through sanitize_json(...) + and pretty-printed to replace the original JSON input. + The function returns the resulting output""" + json_begins = [i for i, l in enumerate(lines) if l in ["[", "{"]] json_ends = [i + 1 for i, l in enumerate(lines) if l in ["]", "}"]] json_pairs = list(zip(json_begins, json_ends)) @@ -86,12 +103,20 @@ def sanitize_text(lines): def determinize_text( - input, - denumberify_paths=[], - remove_paths=[], - ignore_patterns=[], - replace_patterns=[], -): + input: str, + ignore_patterns: list[str], + replace_patterns: list[(str, str)], +) -> list[str]: + """Sanitizes the given input string. + + Every input line matching an entry from ignore_patterns will be removed. + Every line matching the first string in an entry from replace_patterns + will be replaced by the second string. + Finally, the text will be passed to sanitize_json_in_text, which removes + nondeterministic parts from JSON objects/arrays in the input, + if it can be parsed correctly. + The output is guaranteed to end with an empty line. + """ lines = input.splitlines() output_lines = [] patterns = [re.compile(pattern) for pattern in ignore_patterns] @@ -108,12 +133,12 @@ def determinize_text( if output_lines[-1] != "": output_lines.append("") try: - return sanitize_text(output_lines) + return sanitize_json_in_text(output_lines) except json.decoder.JSONDecodeError: return output_lines -def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_flags=[]): +def compare_output_impl(args: list[str], expected_stdout: str, expected_stderr: str, stdin: str, launcher_flags: list[str]): args = [sys.argv[1]] + args expected_stdout = str(sourcepath / "reference" / expected_stdout) expected_stderr = str(sourcepath / "reference" / expected_stderr) @@ -139,7 +164,9 @@ def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_fl open(expected_stdout, "w").write( "\n".join( determinize_text( - result.stdout.decode(), replace_patterns=typename_patterns + result.stdout.decode(), + ignore_patterns=[], + replace_patterns=typename_patterns, ) ) ) @@ -155,7 +182,7 @@ def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_fl print("GENERATED") return result_stdout_processed = determinize_text( - result.stdout.decode(), replace_patterns=typename_patterns + result.stdout.decode(), ignore_patterns=[], replace_patterns=typename_patterns ) result_stderr_processed = determinize_text( result.stderr.decode(), @@ -163,7 +190,9 @@ def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_fl replace_patterns=typename_patterns, ) expected_stdout_processed = determinize_text( - open(expected_stdout).read(), replace_patterns=typename_patterns + open(expected_stdout).read(), + ignore_patterns=[], + replace_patterns=typename_patterns, ) expected_stderr_processed = determinize_text( open(expected_stderr).read(), @@ -192,6 +221,10 @@ def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_fl print("PASS") +def compare_output(args: list[str], expected_stdout: str, expected_stderr: str, stdin: str = ""): + compare_output_impl(args, expected_stdout=expected_stdout, expected_stderr=expected_stderr, stdin=stdin, launcher_flags=[]) + + def compare_output_distributed( args, expected_stdout, expected_stderr, num_procs, stdin="" ):