Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to gather and analyse some model metadata #376

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions alt_e2eshark/e2e_testing/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def update_opset_version_and_overwrite(self):
og_model, self.opset_version
)
onnx.save(model, self.model)

def get_metadata(self):
model_size = os.path.getsize(self.model)
freq = get_op_frequency(self.model)
metadata = {"model_size" : model_size, "op_frequency" : freq}
return metadata



# TODO: extend TestModel to a union, or make TestModel a base class when supporting other frontends
TestModel = OnnxModelInfo
Expand Down Expand Up @@ -161,6 +169,7 @@ def benchmark(self, artifact: CompiledOutput, input: TestTensors, repetitions: i
"""returns a float representing inference time in ms"""
pass


class Test(NamedTuple):
"""Used to store the name and TestInfo constructor for a registered test"""

Expand Down
2 changes: 1 addition & 1 deletion alt_e2eshark/e2e_testing/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def scan_dir_del_not_logs(dir):
for root, dirs, files in os.walk(dir):
for name in files:
curr_file = os.path.join(root, name)
if not name.endswith(".log") and name != "benchmark.json":
if not name.endswith(".log") and not name.endswith(".json"):
removed_files.append(curr_file)
for file in removed_files:
os.remove(file)
Expand Down
13 changes: 12 additions & 1 deletion alt_e2eshark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def main(args):
stages,
args.load_inputs,
int(args.cleanup),
args.get_metadata,
)

if args.report:
Expand All @@ -142,7 +143,7 @@ def main(args):


def run_tests(
test_list: List[Test], config: TestConfig, parent_log_dir: str, no_artifacts: bool, verbose: bool, stages: List[str], load_inputs: bool, cleanup: int,
test_list: List[Test], config: TestConfig, parent_log_dir: str, no_artifacts: bool, verbose: bool, stages: List[str], load_inputs: bool, cleanup: int, get_metadata=bool,
) -> Dict[str, Dict]:
"""runs tests in test_list based on config. Returns a dictionary containing the test statuses."""
# TODO: multi-process
Expand Down Expand Up @@ -190,6 +191,10 @@ def run_tests(
# TODO: Figure out how to factor this out of run.py
if not os.path.exists(inst.model):
inst.construct_model()
if get_metadata:
metadata = inst.get_metadata()
metadata_file = Path(log_dir) / "metadata.json"
save_dict(metadata, metadata_file)

artifact_save_to = None if no_artifacts else log_dir
# generate mlir from the instance using the config
Expand Down Expand Up @@ -449,6 +454,12 @@ def _get_argparse():
default="report.md",
help="output filename for the report summary.",
)
parser.add_argument(
"--get-metadata",
action="store_true",
default=False,
help="save some model metadata to log_dir/metadata.json"
)
# parser.add_argument(
# "-d",
# "--todtype",
Expand Down
100 changes: 100 additions & 0 deletions alt_e2eshark/utils/find_duplicate_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from pathlib import Path
import argparse
from typing import Union, Dict, Any, Optional
import json
import io

ROOT = Path(__file__).parents[1]


class HashableDict(dict):
"""a hashable dictionary, used to invert a dictionary with dictionary values"""

def __hash__(self):
return hash(tuple(sorted(self.items())))


def load_json_dict(filepath: Union[str, Path]) -> Dict[str, Any]:
with open(filepath) as contents:
loaded_dict = json.load(contents)
return loaded_dict


def save_to_json(jsonable_object, name_json: Optional[str] = None):
"""Saves an object to a json file with the given name, or prints result."""
dict_str = json.dumps(
jsonable_object,
indent=4,
sort_keys=True,
separators=(",", ": "),
ensure_ascii=False,
)
if not name_json:
print(dict_str)
return
path_json = ROOT / f"{name_json.stem}.json"
with io.open(path_json, "w", encoding="utf8") as outfile:
outfile.write(dict_str)


def get_groupings(metadata_dicts: Dict[str, Dict]) -> Dict:
"""gets a multi-valued inverse of metatdata_dicts"""
groupings = dict()
for key, value in metadata_dicts.items():
value["op_frequency"] = HashableDict(value["op_frequency"])
hashable = HashableDict(value)
if hashable in groupings.keys():
groupings[hashable].append(key)
else:
groupings[hashable] = [key]
return groupings


def main(args):
run_dir = ROOT / args.rundirectory
metadata_dicts = dict()
for x in run_dir.glob("*/*.json"):
if x.name == "metadata.json":
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
test_name = x.parent.name
metadata_dicts[test_name] = load_json_dict(x)

groupings = get_groupings(metadata_dicts)
found_redundancies = []
for key, value in groupings.items():
if len(value) > 1:
found_redundancies.append(
value if args.simplified else {"models": value, "shared_metadata": key}
)
save_to_json(found_redundancies, args.output)


def _get_argparse():
msg = "After running run.py with the flag --get-metadata, use this tool to find duplicate models."
parser = argparse.ArgumentParser(
prog="find_duplicate_models.py", description=msg, epilog=""
)

parser.add_argument(
"-r",
"--rundirectory",
default="test-run",
help="The directory containing run.py results",
)
parser.add_argument(
"-o",
"--output",
help="specify an output json file",
)
parser.add_argument(
"-s",
"--simplified",
action="store_true",
default=False,
help="pass this arg to only print redundant model lists, without the corresponding metadata.",
)
return parser


if __name__ == "__main__":
parser = _get_argparse()
main(parser.parse_args())