From 344f3edeb564e07afdf357b6e85935fb5be743e9 Mon Sep 17 00:00:00 2001 From: Tom Pointon Date: Tue, 3 Sep 2024 19:03:49 +0000 Subject: [PATCH] Format feature set nicely when dumping --- src/test_suite/multiprocessing_utils.py | 44 ++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/test_suite/multiprocessing_utils.py b/src/test_suite/multiprocessing_utils.py index acc34df..4b6161c 100644 --- a/src/test_suite/multiprocessing_utils.py +++ b/src/test_suite/multiprocessing_utils.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from test_suite.constants import OUTPUT_BUFFER_SIZE +from test_suite.context_pb2 import FeatureSet import test_suite.invoke_pb2 as invoke_pb import ctypes from ctypes import c_uint64, c_int, POINTER, Structure @@ -131,6 +132,43 @@ def read_fixture(fixture_file: Path) -> str | None: return instruction_fixture.SerializeToString(deterministic=True) +def format_feature_set(message, indent=0, as_one_line=False): + lines = [] + lines.append(" " * indent + f"FeatureSet: {{") + for feature in message.features: + lines.append(" " * (indent + 2) + hex(feature)) + lines.append(" " * indent + "}") + return "\n".join(lines) + + +def recursive_formatter(message, indent=0, as_one_line=False): + """Recursively format the message and its sub-fields""" + lines = [] + for field, value in message.ListFields(): + if field.label == field.LABEL_REPEATED: + lines.append(" " * indent + f"{field.name}: [") + for item in value: + if field.type == field.TYPE_MESSAGE: + lines.append(recursive_formatter(item, indent + 2, as_one_line)) + else: + lines.append(" " * (indent + 2) + str(item)) + lines.append(" " * indent + "]") + elif field.type == field.TYPE_MESSAGE: + # If the field is a message, check if it needs custom formatting + lines.append(" " * indent + f"{field.name}: {{") + if isinstance(value, FeatureSet): + lines.append(format_feature_set(value, indent + 2, as_one_line)) + else: + # Recursively format other sub-messages + lines.append(recursive_formatter(value, indent + 2, as_one_line)) + lines.append(" " * indent + "}") + else: + # Use the default formatting for non-message fields + lines.append(" " * indent + f"{field.name}: {value}") + + return "\n".join(lines) + + def decode_single_test_case(test_file: Path) -> int: """ Decode a single test case into a human-readable message @@ -154,7 +192,11 @@ def decode_single_test_case(test_file: Path) -> int: with open(globals.output_dir / (test_file.stem + ".txt"), "w") as f: f.write( - text_format.MessageToString(instruction_context, print_unknown_fields=False) + text_format.MessageToString( + instruction_context, + print_unknown_fields=False, + message_formatter=recursive_formatter, + ) ) return 1