Skip to content

Commit

Permalink
improve documentation and function naming
Browse files Browse the repository at this point in the history
Co-authored-by: Gregor Olenik <gregor.olenik@kit.edu>
  • Loading branch information
upsj and Gregor Olenik committed Jun 20, 2023
1 parent c04f273 commit 6ab159f
Showing 1 changed file with 76 additions and 43 deletions.
119 changes: 76 additions & 43 deletions benchmark/test/test_framework.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,75 @@ import re
import pathlib
import sys

sourcepath = pathlib.Path("@CMAKE_CURRENT_SOURCE_DIR@")
binpath = pathlib.Path("@PROJECT_BINARY_DIR@")
generate = False
if len(sys.argv) > 2 and sys.argv[2] == "--generate":
generate = True


denumberify_paths = [
"time",
"bandwidth",
"flops",
"components",
"residual_norm",
"rhs_norm",
"max_relative_norm2",
]
empty_string_paths = ["error"]
empty_array_paths = [
"recurrent_residuals",
"true_residuals",
"implicit_residuals",
"iteration_timestamps",
]


def sanitize_json_single(key, value, sanitize_all):
if __name__ == "test_framework":
sourcepath = pathlib.Path("@CMAKE_CURRENT_SOURCE_DIR@")
binpath = pathlib.Path("@PROJECT_BINARY_DIR@")
generate = False
if len(sys.argv) > 2 and sys.argv[2] == "--generate":
generate = True
denumberify_paths = [
"time",
"bandwidth",
"flops",
"components",
"residual_norm",
"rhs_norm",
"max_relative_norm2",
]
empty_string_paths = ["error"]
empty_array_paths = [
"recurrent_residuals",
"true_residuals",
"implicit_residuals",
"iteration_timestamps",
]


def sanitize_json_key_value(key: str, value, sanitize_all: bool):
"""Applies sanitation to a single key-value pair.
Strings with a key in empty_string_paths will be emptied
Numbers with a key in denumberify_paths will be set to 1.0
"""
if key in empty_string_paths and isinstance(value, str):
return ""
if key in denumberify_paths and isinstance(value, float):
return 1.0
if key in denumberify_paths and isinstance(value, typing.Dict):
if key in denumberify_paths and isinstance(value, dict):
return sanitize_json(value, True)
if key in empty_array_paths and isinstance(value, typing.List):
if key in empty_array_paths and isinstance(value, list):
return []
return sanitize_json(value, sanitize_all)


def sanitize_json(parsed_input, sanitize_all=False):
if isinstance(parsed_input, typing.Dict):
"""Removes non-deterministic parts of a parsed JSON input.
If sanitize_all is set to True, all nested float values will be set to 0.
Otherwise, the sanitation"""
if isinstance(parsed_input, dict):
return {
key: sanitize_json_single(key, value, sanitize_all)
key: sanitize_json_key_value(key, value, sanitize_all)
for key, value in parsed_input.items()
}
elif isinstance(parsed_input, typing.List):
elif isinstance(parsed_input, list):
return [sanitize_json(e, sanitize_all) for e in parsed_input]
elif sanitize_all and isinstance(parsed_input, float):
return 1.0
else:
return parsed_input


def sanitize_text(lines):
def sanitize_json_in_text(lines: list[str]) -> list[str]:
"""Sanitizes all occurrences of JSON content inside text input.
Takes a list of text lines and detects any pretty-printed JSON output inside
(recognized by a single [, {, } or ] in an otherwise empty line).
The JSON output will be parsed and sanitized through sanitize_json(...)
and pretty-printed to replace the original JSON input.
The function returns the resulting output"""

json_begins = [i for i, l in enumerate(lines) if l in ["[", "{"]]
json_ends = [i + 1 for i, l in enumerate(lines) if l in ["]", "}"]]
json_pairs = list(zip(json_begins, json_ends))
Expand Down Expand Up @@ -86,12 +103,20 @@ def sanitize_text(lines):


def determinize_text(
input,
denumberify_paths=[],
remove_paths=[],
ignore_patterns=[],
replace_patterns=[],
):
input: str,
ignore_patterns: list[str],
replace_patterns: list[(str, str)],
) -> list[str]:
"""Sanitizes the given input string.
Every input line matching an entry from ignore_patterns will be removed.
Every line matching the first string in an entry from replace_patterns
will be replaced by the second string.
Finally, the text will be passed to sanitize_json_in_text, which removes
nondeterministic parts from JSON objects/arrays in the input,
if it can be parsed correctly.
The output is guaranteed to end with an empty line.
"""
lines = input.splitlines()
output_lines = []
patterns = [re.compile(pattern) for pattern in ignore_patterns]
Expand All @@ -108,12 +133,12 @@ def determinize_text(
if output_lines[-1] != "":
output_lines.append("")
try:
return sanitize_text(output_lines)
return sanitize_json_in_text(output_lines)
except json.decoder.JSONDecodeError:
return output_lines


def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_flags=[]):
def compare_output_impl(args: list[str], expected_stdout: str, expected_stderr: str, stdin: str, launcher_flags: list[str]):
args = [sys.argv[1]] + args
expected_stdout = str(sourcepath / "reference" / expected_stdout)
expected_stderr = str(sourcepath / "reference" / expected_stderr)
Expand All @@ -139,7 +164,9 @@ def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_fl
open(expected_stdout, "w").write(
"\n".join(
determinize_text(
result.stdout.decode(), replace_patterns=typename_patterns
result.stdout.decode(),
ignore_patterns=[],
replace_patterns=typename_patterns,
)
)
)
Expand All @@ -155,15 +182,17 @@ def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_fl
print("GENERATED")
return
result_stdout_processed = determinize_text(
result.stdout.decode(), replace_patterns=typename_patterns
result.stdout.decode(), ignore_patterns=[], replace_patterns=typename_patterns
)
result_stderr_processed = determinize_text(
result.stderr.decode(),
ignore_patterns=version_patterns,
replace_patterns=typename_patterns,
)
expected_stdout_processed = determinize_text(
open(expected_stdout).read(), replace_patterns=typename_patterns
open(expected_stdout).read(),
ignore_patterns=[],
replace_patterns=typename_patterns,
)
expected_stderr_processed = determinize_text(
open(expected_stderr).read(),
Expand Down Expand Up @@ -192,6 +221,10 @@ def compare_output(args, expected_stdout, expected_stderr, stdin="", launcher_fl
print("PASS")


def compare_output(args: list[str], expected_stdout: str, expected_stderr: str, stdin: str = ""):
compare_output_impl(args, expected_stdout=expected_stdout, expected_stderr=expected_stderr, stdin=stdin, launcher_flags=[])


def compare_output_distributed(
args, expected_stdout, expected_stderr, num_procs, stdin=""
):
Expand Down

0 comments on commit 6ab159f

Please sign in to comment.