Skip to content

Commit

Permalink
clean up serialized functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kbhargava-jump committed Oct 22, 2024
1 parent 37cadfb commit d905004
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 37 deletions.
44 changes: 10 additions & 34 deletions src/test_suite/multiprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def process_target(
harness_ctx: HarnessCtx, library: ctypes.CDLL, serialized_instruction_context: str
harness_ctx: HarnessCtx, library: ctypes.CDLL, context: ContextType
) -> invoke_pb.InstrEffects | None:
"""
Process an instruction through a provided shared library and return the result.
Expand All @@ -26,6 +26,11 @@ def process_target(
Returns:
- invoke_pb.InstrEffects | None: Result of instruction execution.
"""

serialized_instruction_context = context.SerializeToString(deterministic=True)
if serialized_instruction_context is None:
return None

# Prepare input data and output buffers
in_data = serialized_instruction_context
in_ptr = (ctypes.c_uint8 * len(in_data))(*in_data)
Expand Down Expand Up @@ -143,25 +148,6 @@ def read_context(harness_ctx: HarnessCtx, test_file: Path) -> message.Message |
return context


def read_fixture_serialized(fixture_file: Path) -> str | None:
"""
Same as read_instr, but for InstrFixture protobuf messages.
DOES NOT SUPPORT HUMAN READABLE MESSAGES!!!
Args:
- fixture_file (Path): Path to the instruction fixture message.
Returns:
- str | None: Serialized instruction fixture, or None if reading failed.
"""
fixture = read_fixture(fixture_file)
if fixture is None:
return None
# Serialize instruction fixture to string (pickleable)
return fixture.SerializeToString(deterministic=True)


def read_fixture(fixture_file: Path) -> message.Message | None:
"""
Reads in test files and generates an Fixture Protobuf object for a test case.
Expand Down Expand Up @@ -209,10 +195,12 @@ def decode_single_test_case(test_file: Path) -> int:
if test_file.suffix == ".fix":
fn_entrypoint = extract_metadata(test_file).fn_entrypoint
harness_ctx = ENTRYPOINT_HARNESS_MAP[fn_entrypoint]
serialized_protobuf = read_fixture_serialized(test_file)
fixture = read_fixture(test_file)
serialized_protobuf = fixture.SerializeToString(deterministic=True)
else:
harness_ctx = globals.default_harness_ctx
serialized_protobuf = read_context_serialized(harness_ctx, test_file)
context = read_context(harness_ctx, test_file)
serialized_protobuf = context.SerializeToString(deterministic=True)

# Skip if input is invalid
if serialized_protobuf is None:
Expand Down Expand Up @@ -378,18 +366,6 @@ def initialize_process_output_buffers(randomize_output_buffer=False):
)


def serialize_context(harness_ctx: HarnessCtx, file: Path) -> str | None:
if file.suffix == ".fix":
fixture = harness_ctx.fixture_type()
fixture.ParseFromString(file.open("rb").read())
serialized_instr_context = fixture.input.SerializeToString(deterministic=True)
else:
serialized_instr_context = read_context_serialized(harness_ctx, file)

assert serialized_instr_context is not None, f"Unable to read {file.name}"
return serialized_instr_context


def run_test(test_file: Path) -> tuple[str, int, dict | None]:
"""
Runs a single test from start to finish.
Expand Down
6 changes: 3 additions & 3 deletions src/test_suite/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
process_target,
run_test,
read_context,
serialize_context,
)
import test_suite.globals as globals
from test_suite.util import set_ld_preload_asan
Expand Down Expand Up @@ -98,13 +97,14 @@ def execute(
if file.suffix == ".fix":
fn_entrypoint = extract_metadata(file).fn_entrypoint
harness_ctx = ENTRYPOINT_HARNESS_MAP[fn_entrypoint]
context = read_fixture(file).input
else:
harness_ctx = HARNESS_MAP[default_harness_ctx]
context = read_context(harness_ctx, file)

# Execute and cleanup
context = read_context(harness_ctx, file)
start = time.time()
effects = process_target(harness_ctx, lib, serialize_context(harness_ctx, file))
effects = process_target(harness_ctx, lib, context)
end = time.time()

print(f"Total time taken for {file}: {(end - start) * 1000} ms\n------------")
Expand Down

0 comments on commit d905004

Please sign in to comment.