From d905004c33f1016490fefff5b3a901ebba408ba9 Mon Sep 17 00:00:00 2001 From: Kunal Bhargava Date: Tue, 22 Oct 2024 21:38:22 +0000 Subject: [PATCH] clean up serialized functions --- src/test_suite/multiprocessing_utils.py | 44 ++++++------------------- src/test_suite/test_suite.py | 6 ++-- 2 files changed, 13 insertions(+), 37 deletions(-) diff --git a/src/test_suite/multiprocessing_utils.py b/src/test_suite/multiprocessing_utils.py index 96cfed3..e462726 100644 --- a/src/test_suite/multiprocessing_utils.py +++ b/src/test_suite/multiprocessing_utils.py @@ -14,7 +14,7 @@ def process_target( - harness_ctx: HarnessCtx, library: ctypes.CDLL, serialized_instruction_context: str + harness_ctx: HarnessCtx, library: ctypes.CDLL, context: ContextType ) -> invoke_pb.InstrEffects | None: """ Process an instruction through a provided shared library and return the result. @@ -26,6 +26,11 @@ def process_target( Returns: - invoke_pb.InstrEffects | None: Result of instruction execution. """ + + serialized_instruction_context = context.SerializeToString(deterministic=True) + if serialized_instruction_context is None: + return None + # Prepare input data and output buffers in_data = serialized_instruction_context in_ptr = (ctypes.c_uint8 * len(in_data))(*in_data) @@ -143,25 +148,6 @@ def read_context(harness_ctx: HarnessCtx, test_file: Path) -> message.Message | return context -def read_fixture_serialized(fixture_file: Path) -> str | None: - """ - Same as read_instr, but for InstrFixture protobuf messages. - - DOES NOT SUPPORT HUMAN READABLE MESSAGES!!! - - Args: - - fixture_file (Path): Path to the instruction fixture message. - - Returns: - - str | None: Serialized instruction fixture, or None if reading failed. - """ - fixture = read_fixture(fixture_file) - if fixture is None: - return None - # Serialize instruction fixture to string (pickleable) - return fixture.SerializeToString(deterministic=True) - - def read_fixture(fixture_file: Path) -> message.Message | None: """ Reads in test files and generates an Fixture Protobuf object for a test case. @@ -209,10 +195,12 @@ def decode_single_test_case(test_file: Path) -> int: if test_file.suffix == ".fix": fn_entrypoint = extract_metadata(test_file).fn_entrypoint harness_ctx = ENTRYPOINT_HARNESS_MAP[fn_entrypoint] - serialized_protobuf = read_fixture_serialized(test_file) + fixture = read_fixture(test_file) + serialized_protobuf = fixture.SerializeToString(deterministic=True) else: harness_ctx = globals.default_harness_ctx - serialized_protobuf = read_context_serialized(harness_ctx, test_file) + context = read_context(harness_ctx, test_file) + serialized_protobuf = context.SerializeToString(deterministic=True) # Skip if input is invalid if serialized_protobuf is None: @@ -378,18 +366,6 @@ def initialize_process_output_buffers(randomize_output_buffer=False): ) -def serialize_context(harness_ctx: HarnessCtx, file: Path) -> str | None: - if file.suffix == ".fix": - fixture = harness_ctx.fixture_type() - fixture.ParseFromString(file.open("rb").read()) - serialized_instr_context = fixture.input.SerializeToString(deterministic=True) - else: - serialized_instr_context = read_context_serialized(harness_ctx, file) - - assert serialized_instr_context is not None, f"Unable to read {file.name}" - return serialized_instr_context - - def run_test(test_file: Path) -> tuple[str, int, dict | None]: """ Runs a single test from start to finish. diff --git a/src/test_suite/test_suite.py b/src/test_suite/test_suite.py index 856835b..0cb3722 100644 --- a/src/test_suite/test_suite.py +++ b/src/test_suite/test_suite.py @@ -22,7 +22,6 @@ process_target, run_test, read_context, - serialize_context, ) import test_suite.globals as globals from test_suite.util import set_ld_preload_asan @@ -98,13 +97,14 @@ def execute( if file.suffix == ".fix": fn_entrypoint = extract_metadata(file).fn_entrypoint harness_ctx = ENTRYPOINT_HARNESS_MAP[fn_entrypoint] + context = read_fixture(file).input else: harness_ctx = HARNESS_MAP[default_harness_ctx] + context = read_context(harness_ctx, file) # Execute and cleanup - context = read_context(harness_ctx, file) start = time.time() - effects = process_target(harness_ctx, lib, serialize_context(harness_ctx, file)) + effects = process_target(harness_ctx, lib, context) end = time.time() print(f"Total time taken for {file}: {(end - start) * 1000} ms\n------------")