Skip to content

Commit

Permalink
Add pruning to only include actually modified accounts
Browse files Browse the repository at this point in the history
Refactoring
  • Loading branch information
mjain-jump committed Apr 19, 2024
1 parent afd6e3c commit 7bc1d9b
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 63 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ Each target must contain a `sol_compat_instr_execute_v1` function that takes in
Before running tests, `InstrContext` messages may be converted into Protobuf's text format, with all `bytes` fields base58-encoded (for human readability). Run the following command to do this:

```sh
solana-test-suite decode-protobuf --input-dir <input_dir> --output-dir <output_dir> [--check-results] [verbose]
solana-test-suite decode-protobuf --input-dir <input_dir> --output-dir <output_dir> --num-processes <num_processes>
```

| Argument | Description |
|----------------|-----------------------------------------------------------------------------------------------|
| `--input-dir` | Input directory containing instruction context messages in binary format |
| `--output-dir` | Output directory for encoded, human-readable instruction context messages |
| `--check-results` | Validate binary and human readable messages are identical |
| `--verbose` | Enable verbose output |
| `--num-processes` | Number of processes to use |


Optionally, instruction context messages may also be left in the original Protobuf binary-encoded format.
Expand Down
117 changes: 95 additions & 22 deletions src/test_suite/multiprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from yaml import serialize
from test_suite.constants import OUTPUT_BUFFER_SIZE
import test_suite.invoke_pb2 as pb
from test_suite.codec_utils import encode_input, encode_output, decode_input
from test_suite.validation_utils import is_valid
from test_suite.validation_utils import check_account_unchanged, is_valid
import ctypes
from ctypes import c_uint64, c_int, POINTER
from pathlib import Path
Expand Down Expand Up @@ -56,15 +57,15 @@ def process_instruction(
return output_object


def generate_test_case(test_file: Path) -> tuple[Path, str | None]:
def generate_test_case(test_file: Path) -> tuple[str, str | None]:
"""
Reads in test files and generates a Protobuf object for a test case.
Args:
- test_file (Path): Path to the file containing serialized instruction contexts.
Returns:
- tuple[Path, str | None]: Tuple of file and serialized instruction context, if exists.
- tuple[str, str | None]: Tuple of file stem and serialized instruction context, if exists.
"""
# Try to read in first as binary-encoded Protobuf messages
try:
Expand All @@ -86,18 +87,24 @@ def generate_test_case(test_file: Path) -> tuple[Path, str | None]:

if instruction_context is None:
# Unreadable file, skip it
return test_file, None
return test_file.stem, None

# Discard unknown fields
instruction_context.DiscardUnknownFields()

# Serialize instruction context to string (pickleable)
return test_file, instruction_context.SerializeToString(deterministic=True)
return test_file.stem, instruction_context.SerializeToString(deterministic=True)


def decode_single_test_case(test_file: Path):
def decode_single_test_case(test_file: Path) -> int:
"""
Decode a single test case into a human-readable message
Args:
- test_file (Path): Path to the file containing serialized instruction contexts.
Returns:
- int: 1 if successfully decoded and written, 0 if skipped.
"""
_, serialized_instruction_context = generate_test_case(test_file)

Expand All @@ -118,15 +125,15 @@ def decode_single_test_case(test_file: Path):


def process_single_test_case(
file: Path, serialized_instruction_context: str | None
file_stem: str, serialized_instruction_context: str | None
) -> tuple[str, dict[str, str | None] | None]:
"""
Process a single execution context (file, serialized instruction context) through
all target libraries and returns serialized instruction effects. This
function is called by processes.
Args:
- file (Path): File containing serialized instruction context.
- file_stem (str): Stem of file containing serialized instruction context.
- serialized_instruction_context (str | None): Serialized instruction context.
Returns:
Expand All @@ -135,7 +142,7 @@ def process_single_test_case(
"""
# Mark as skipped if instruction context doesn't exist
if serialized_instruction_context is None:
return file.stem, None
return file_stem, None

# Execute test case on each target library
results = {}
Expand All @@ -150,7 +157,7 @@ def process_single_test_case(
)
results[target] = result

return file.stem, results
return file_stem, results


def merge_results_over_iterations(results: tuple) -> tuple[str, dict]:
Expand Down Expand Up @@ -181,12 +188,72 @@ def merge_results_over_iterations(results: tuple) -> tuple[str, dict]:
return file, merged_results


def check_consistency_in_results(file_stem: Path, results: dict) -> dict[str, bool]:
def prune_execution_result(
file_serialized_instruction_context: tuple[str, dict],
file_serialized_instruction_effects,
) -> tuple[str, dict]:
"""
Prune execution result to only include actually modified accounts.
Args:
- file_serialized_instruction_context (tuple[str, str]): Tuple of file stem and serialized instruction context.
- file_serialized_instruction_effects (tuple[str, dict[str, str | None]]): Tuple of file stem and dictionary of target library names and serialized instruction effects.
Returns:
- tuple[str, dict]: Tuple of file stem and serialized pruned instruction effects for each target.
"""
file_stem, serialized_instruction_context = file_serialized_instruction_context
if serialized_instruction_context is None:
return file_stem, None

instruction_context = pb.InstrContext()
instruction_context.ParseFromString(serialized_instruction_context)

file_stem_2, targets_to_serialized_instruction_effects = (
file_serialized_instruction_effects
)
assert file_stem == file_stem_2, f"{file_stem}, {file_stem_2}"

targets_to_serialized_pruned_instruction_effects = {}
for (
target,
serialized_instruction_effects,
) in targets_to_serialized_instruction_effects.items():
if serialized_instruction_effects is None:
targets_to_serialized_pruned_instruction_effects[target] = None
continue

instruction_effects = pb.InstrEffects()
instruction_effects.ParseFromString(serialized_instruction_effects)

# O(n^2) because not performance sensitive
new_modified_accounts: list[pb.AcctState] = []
for modified_account in instruction_effects.modified_accounts:
account_unchanged = False
for beginning_account_state in instruction_context.accounts:
account_unchanged |= check_account_unchanged(
modified_account, beginning_account_state
)

if not account_unchanged:
new_modified_accounts.append(modified_account)

# Assign new modified accounts
del instruction_effects.modified_accounts[:]
instruction_effects.modified_accounts.extend(new_modified_accounts)
targets_to_serialized_pruned_instruction_effects[target] = (
instruction_effects.SerializeToString(deterministic=True)
)

return file_stem, targets_to_serialized_pruned_instruction_effects


def check_consistency_in_results(file_stem: str, results: dict) -> dict[str, bool]:
"""
Check consistency for all target libraries over all iterations for a test case.
Args:
- file_stem (Path): File stem of the test case.
- file_stem (str): File stem of the test case.
- execution_results (dict): Dictionary of target library names and serialized instruction effects.
Returns:
Expand Down Expand Up @@ -231,42 +298,48 @@ def check_consistency_in_results(file_stem: Path, results: dict) -> dict[str, bo
return results_per_target


def build_test_results(file_stem: Path, results: dict[str, str | None]) -> int:
def build_test_results(file_stem: str, results: dict[str, str | None]) -> int:
"""
Build a single result of single test execution and returns whether the test passed or failed.
Args:
- file_stem (Path): File stem of the test case.
- file_stem (str): File stem of the test case.
- results (dict[str, str | None]): Dictionary of target library names and serialized instruction effects.
Returns:
- int: 1 if passed, -1 if failed, 0 if skipped.
- tuple[str, int, dict | None]: Tuple of:
File stem; 1 if passed, -1 if failed, 0 if skipped
Dictionary of target library
Names and file-dumpable serialized instruction effects.
"""
outputs = {target: "None\n" for target in results}

# If no results or Agave rejects input, mark case as skipped
if results is None:
# Mark as skipped (0)
return 0
return file_stem, 0, None

# Log execution results
protobuf_structures = {}
for target, result in results.items():
# Create a Protobuf struct to compare and output, if applicable
protobuf_struct = None
instruction_effects = None
if result:
# Turn bytes into human readable fields
protobuf_struct = pb.InstrEffects()
protobuf_struct.ParseFromString(result)
encode_output(protobuf_struct)
instruction_effects = pb.InstrEffects()
instruction_effects.ParseFromString(result)
encode_output(instruction_effects)
outputs[target] = text_format.MessageToString(instruction_effects)

protobuf_structures[target] = protobuf_struct
protobuf_structures[target] = instruction_effects

test_case_passed = all(
protobuf_structures[globals.solana_shared_library] == result
for result in protobuf_structures.values()
)

# 1 = passed, -1 = failed
return 1 if test_case_passed else -1
return file_stem, 1 if test_case_passed else -1, outputs


def initialize_process_output_buffers(randomize_output_buffer=False):
Expand Down
86 changes: 48 additions & 38 deletions src/test_suite/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from google.protobuf import text_format
from test_suite.constants import LOG_FILE_SEPARATOR_LENGTH
import test_suite.invoke_pb2 as pb
from test_suite.codec_utils import encode_input, encode_output
from test_suite.codec_utils import decode_input, encode_input, encode_output
from test_suite.multiprocessing_utils import (
check_consistency_in_results,
decode_single_test_case,
Expand All @@ -18,6 +18,7 @@
process_instruction,
process_single_test_case,
build_test_results,
prune_execution_result,
)
import test_suite.globals as globals
from test_suite.debugger import debug_host
Expand Down Expand Up @@ -54,14 +55,27 @@ def execute_single_instruction(
lib.sol_compat_init()

# Execute and cleanup
instruction_effects = process_instruction(lib, instruction_context)
instruction_effects = process_instruction(
lib, instruction_context
).SerializeToString(deterministic=True)

# Prune execution results
_, pruned_instruction_effects = prune_execution_result(
(file.stem, instruction_context),
(file.stem, {shared_library: instruction_effects}),
)
parsed_instruction_effects = pb.InstrEffects()
parsed_instruction_effects.ParseFromString(
pruned_instruction_effects[shared_library]
)

lib.sol_compat_fini()

# Print human-readable output
if instruction_effects:
encode_output(instruction_effects)
if parsed_instruction_effects:
encode_output(parsed_instruction_effects)

print(instruction_effects)
print(parsed_instruction_effects)


@app.command()
Expand Down Expand Up @@ -279,52 +293,54 @@ def run_tests(
) as pool:
execution_results = pool.starmap(process_single_test_case, execution_contexts)

print("Pruning results...")
# Prune modified accounts that were not actually modified
with Pool(processes=num_processes) as pool:
pruned_execution_results = pool.starmap(
prune_execution_result, zip(execution_contexts, execution_results)
)

# Process the test results in parallel
print("Building test results...")
with Pool(processes=num_processes) as pool:
test_case_results = pool.starmap(build_test_results, execution_results)
counts = Counter(test_case_results)
passed = counts[1]
failed = counts[-1]
skipped = counts[0]
test_case_results = pool.starmap(build_test_results, pruned_execution_results)

print("Logging results...")
counter = 0
passed = 0
failed = 0
skipped = 0
target_log_files = {target: None for target in shared_libraries}
for file, result in execution_results:
if result is None:
for file_stem, status, stringified_results in test_case_results:
if stringified_results is None:
skipped += 1
continue

for target, serialized_instruction_effects in result.items():
if counter % log_chunk_size == 0:
for target, string_result in stringified_results.items():
if (passed + failed + skipped) % log_chunk_size == 0:
if target_log_files[target]:
target_log_files[target].close()
target_log_files[target] = open(
globals.output_dir / target.stem / (file + ".txt"), "w"
globals.output_dir / target.stem / (file_stem + ".txt"), "w"
)

target_log_files[target].write(file + ":\n")

if serialized_instruction_effects is None:
target_log_files[target].write(str(None))
else:
instruction_effects = pb.InstrEffects()
instruction_effects.ParseFromString(serialized_instruction_effects)
encode_output(instruction_effects)
target_log_files[target].write(
text_format.MessageToString(instruction_effects)
)
target_log_files[target].write(
"\n" + "-" * LOG_FILE_SEPARATOR_LENGTH + "\n"
file_stem
+ ":\n"
+ string_result
+ "\n"
+ "-" * LOG_FILE_SEPARATOR_LENGTH
+ "\n"
)
counter += 1

for target in shared_libraries:
if target_log_files[target]:
target_log_files[target].close()
if status == 1:
passed += 1
elif status == -1:
failed += 1

print("Cleaning up...")
for target in shared_libraries:
if target_log_files[target]:
target_log_files[target].close()
globals.target_libraries[target].sol_compat_fini()

peak_memory_usage_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
Expand All @@ -348,12 +364,6 @@ def decode_protobuf(
"-o",
help="Output directory for base58-encoded, human-readable instruction context messages",
),
check_decode_results: bool = typer.Option(
False,
"--check-results",
"-c",
help="Validate binary and human readable messages are identical",
),
num_processes: int = typer.Option(
4, "--num-processes", "-p", help="Number of processes to use"
),
Expand Down
Loading

0 comments on commit 7bc1d9b

Please sign in to comment.