diff --git a/tools/ctexplain/BUILD b/tools/ctexplain/BUILD index 02f4c73582b1b7..345a16974e3a00 100644 --- a/tools/ctexplain/BUILD +++ b/tools/ctexplain/BUILD @@ -6,11 +6,32 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +filegroup( + name = "srcs", + srcs = glob(["*"]), +) + py_binary( name = "ctexplain", srcs = ["ctexplain.py"], python_version = "PY3", - deps = [":bazel_api"], + deps = [ + ":analyses", + ":base", + ":bazel_api", + ":lib", + "//third_party/py/abseil" + ], +) + +py_library( + name = "lib", + srcs = ["lib.py"], + srcs_version = "PY3ONLY", + deps = [ + ":base", + ":bazel_api", + ], ) py_library( @@ -20,6 +41,38 @@ py_library( deps = [":base"], ) +py_library( + name = "analyses", + srcs = ["analyses/summary.py"], + srcs_version = "PY3ONLY", + deps = [":base"], +) + +py_library( + name = "base", + srcs = [ + "types.py", + "util.py", + ], + srcs_version = "PY3ONLY", + deps = [ + "//third_party/py/dataclasses", # Backport for Python < 3.7. + "//third_party/py/frozendict", + ], +) + +py_test( + name = "lib_test", + size = "small", + srcs = ["lib_test.py"], + python_version = "PY3", + deps = [ + ":bazel_api", + ":lib", + "//src/test/py/bazel:test_base", + ], +) + py_test( name = "bazel_api_test", size = "small", @@ -31,15 +84,15 @@ py_test( ], ) -py_library( - name = "base", - srcs = [ - "types.py", - ], - srcs_version = "PY3ONLY", +py_test( + name = "analyses_test", + size = "small", + srcs = ["analyses/summary_test.py"], + main = "analyses/summary_test.py", # TODO: generalize this. + python_version = "PY3", deps = [ - "//third_party/py/dataclasses", # Backport for Python < 3.7. - "//third_party/py/frozendict", + ":analyses", + ":base", ], ) @@ -53,8 +106,3 @@ py_test( "//third_party/py/frozendict", ], ) - -filegroup( - name = "srcs", - srcs = glob(["*"]), -) diff --git a/tools/ctexplain/analyses/summary.py b/tools/ctexplain/analyses/summary.py new file mode 100644 index 00000000000000..11d2228c6ea13a --- /dev/null +++ b/tools/ctexplain/analyses/summary.py @@ -0,0 +1,71 @@ +# Lint as: python3 +# Copyright 2020 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Analysis that summarizes basic graph info.""" +from typing import Tuple +# Do not edit this line. Copybara replaces it with PY2 migration helper. +from dataclasses import dataclass + +from tools.ctexplain.types import ConfiguredTarget +import tools.ctexplain.util as util + + +@dataclass(frozen=True) +class _Summary(): + """Analysis result.""" + # Number of configurations in the build's configured target graph. + configurations: int + # Number of unique target labels. + targets: int + # Number of configured targets. + configured_targets: int + # Number of targets that produce multiple configured targets. This is more + # subtle than computing configured_targets - targets. For example, if + # targets=2 and configured_targets=4, that could mean both targets are + # configured twice. Or it could mean the first target is configured 3 times. + repeated_targets: int + + +def analyze(cts: Tuple[ConfiguredTarget, ...]) -> _Summary: + """Runs the analysis on a build's configured targets.""" + configurations = set() + targets = set() + label_count = {} + for ct in cts: + configurations.add(ct.config_hash) + targets.add(ct.label) + label_count[ct.label] = label_count.setdefault(ct.label, 0) + 1 + configured_targets = len(cts) + repeated_targets = sum([1 for count in label_count.values() if count > 1]) + + return _Summary(len(configurations), len(targets), configured_targets, + repeated_targets) + + +def report(result: _Summary) -> None: + """Reports analysis results to the user. + + We intentionally make this its own function to make it easy to support other + output formats (like machine-readable) if we ever want to do that. + + Args: + result: the analysis result + """ + ct_surplus = util.percent_diff(result.targets, result.configured_targets) + print(f""" +Configurations: {result.configurations} +Targets: {result.targets} +Configured targets: {result.configured_targets} ({ct_surplus} vs. targets) +Targets with multiple configs: {result.repeated_targets} +""") diff --git a/tools/ctexplain/analyses/summary_test.py b/tools/ctexplain/analyses/summary_test.py new file mode 100644 index 00000000000000..f9cc19186e86c4 --- /dev/null +++ b/tools/ctexplain/analyses/summary_test.py @@ -0,0 +1,44 @@ +# Lint as: python3 +# Copyright 2020 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for summary.py.""" +import unittest +# Do not edit this line. Copybara replaces it with PY2 migration helper. +from frozendict import frozendict + +import tools.ctexplain.analyses.summary as summary +from tools.ctexplain.types import Configuration +from tools.ctexplain.types import ConfiguredTarget +from tools.ctexplain.types import NullConfiguration + + +class SummaryTest(unittest.TestCase): + + def testAnalysis(self): + config1 = Configuration(None, frozendict({'a': frozendict({'b': 'c'})})) + config2 = Configuration(None, frozendict({'d': frozendict({'e': 'f'})})) + + ct1 = ConfiguredTarget('//foo', config1, 'hash1', None) + ct2 = ConfiguredTarget('//foo', config2, 'hash2', None) + ct3 = ConfiguredTarget('//foo', NullConfiguration(), 'null', None) + ct4 = ConfiguredTarget('//bar', config1, 'hash1', None) + + res = summary.analyze((ct1, ct2, ct3, ct4)) + self.assertEqual(3, res.configurations) + self.assertEqual(2, res.targets) + self.assertEqual(4, res.configured_targets) + self.assertEqual(1, res.repeated_targets) + +if __name__ == '__main__': + unittest.main() diff --git a/tools/ctexplain/bazel_api.py b/tools/ctexplain/bazel_api.py index 4fd776c2a786d2..176c0bd2bfc9c8 100644 --- a/tools/ctexplain/bazel_api.py +++ b/tools/ctexplain/bazel_api.py @@ -47,7 +47,7 @@ def run_bazel_in_client(args: List[str]) -> Tuple[int, List[str], List[str]]: cwd=os.getcwd(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, - check=True) + check=False) return (result.returncode, result.stdout.decode("utf-8").split(os.linesep), result.stderr) @@ -73,17 +73,23 @@ def cquery(self, stderr contains the query's stderr (regardless of success value), and cts is the configured targets found by the query if successful, empty otherwise. + + ct order preserves cquery's output order. This is topologically sorted + with duplicates removed. So no unique configured target appears twice and + if A depends on B, A appears before B. """ base_args = ["cquery", "--show_config_fragments=transitive"] (returncode, stdout, stderr) = self.run_bazel(base_args + args) if returncode != 0: return (False, stderr, ()) - cts = set() + cts = [] for line in stdout: + if not line.strip(): + continue ctinfo = _parse_cquery_result_line(line) if ctinfo is not None: - cts.add(ctinfo) + cts.append(ctinfo) return (True, stderr, tuple(cts)) @@ -97,7 +103,7 @@ def get_config(self, config_hash: str) -> Configuration: The matching configuration or None if no match is found. Raises: - ValueError on any parsing problems. + ValueError: On any parsing problems. """ if config_hash == "HOST": return HostConfiguration() @@ -109,11 +115,13 @@ def get_config(self, config_hash: str) -> Configuration: if returncode != 0: raise ValueError("Could not get config: " + stderr) config_json = json.loads(os.linesep.join(stdout)) - fragments = [ - fragment["name"].split(".")[-1] for fragment in config_json["fragments"] - ] + fragments = frozendict({ + _base_name(entry["name"]): tuple( + _base_name(clazz) for clazz in entry["fragmentOptions"]) + for entry in config_json["fragments"] + }) options = frozendict({ - entry["name"].split(".")[-1]: frozendict(entry["options"]) + _base_name(entry["name"]): frozendict(entry["options"]) for entry in config_json["fragmentOptions"] }) return Configuration(fragments, options) @@ -156,3 +164,20 @@ def _parse_cquery_result_line(line: str) -> ConfiguredTarget: config=None, # Not yet available: we'll need `bazel config` to get this. config_hash=config_hash, transitive_fragments=fragments) + + +def _base_name(full_name: str) -> str: + """Strips a fully qualified Java class name to the file scope. + + Examples: + - "A.B.OuterClass" -> "OuterClass" + - "A.B.OuterClass$InnerClass" -> "OuterClass$InnerClass" + + Args: + full_name: Fully qualified class name. + + Returns: + Stripped name. + """ + return full_name.split(".")[-1] + diff --git a/tools/ctexplain/bazel_api_test.py b/tools/ctexplain/bazel_api_test.py index 3587251f873b5a..d142564ce96ca9 100644 --- a/tools/ctexplain/bazel_api_test.py +++ b/tools/ctexplain/bazel_api_test.py @@ -77,7 +77,7 @@ def testGetTargetConfig(self): config = self._bazel_api.get_config(cts[0].config_hash) expected_fragments = ['PlatformConfiguration', 'JavaConfiguration'] for exp in expected_fragments: - self.assertIn(exp, config.fragments) + self.assertIn(exp, config.fragments.keys()) core_options = config.options['CoreOptions'] self.assertIsNotNone(core_options) self.assertIn(('stamp', 'false'), core_options.items()) @@ -111,6 +111,16 @@ def testGetNullConfig(self): self.assertEqual(len(config.fragments), 0) self.assertEqual(len(config.options), 0) + def testConfigFragmentsMap(self): + self.ScratchFile('testapp/BUILD', [ + 'filegroup(name = "fg", srcs = ["a.file"])', + ]) + cts = self._bazel_api.cquery(['//testapp:fg'])[2] + fragments_map = self._bazel_api.get_config(cts[0].config_hash).fragments + self.assertIn('PlatformOptions', fragments_map['PlatformConfiguration']) + self.assertIn( + 'ShellConfiguration$Options', fragments_map['ShellConfiguration']) + def testConfigWithDefines(self): self.ScratchFile('testapp/BUILD', [ 'filegroup(name = "fg", srcs = ["a.file"])', diff --git a/tools/ctexplain/ctexplain.py b/tools/ctexplain/ctexplain.py index b66d78d0e3dc25..d28712ebe94474 100644 --- a/tools/ctexplain/ctexplain.py +++ b/tools/ctexplain/ctexplain.py @@ -12,14 +12,148 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""ctexplain main entry point. +"""ctexplain: how does configuration affect build graphs? -Currently a stump. +This is a swiss army knife tool that tries to explain why build graphs are the +size they are and how build flags, configuration transitions, and dependency +structures affect that. + +This can help developers use flags and transitions with minimal memory and +maximum build speed. + +Usage: + + $ ctexplain [--analysis=...] -b " [build flags]" + +Example: + + $ ctexplain -b "//mypkg:mybinary --define MY_FEATURE=1" + +Relevant terms in https://docs.bazel.build/versions/master/glossary.html: + "target", "configuration", "analysis phase", "configured target", + "configuration trimming", "transition" + +TODO(gregce): link to proper documentation for full details. """ -from tools.ctexplain.bazel_api import BazelApi +from typing import Callable +from typing import Tuple + +# Do not edit this line. Copybara replaces it with PY2 migration helper. +from absl import app +from absl import flags +from dataclasses import dataclass + +import tools.ctexplain.analyses.summary as summary +import tools.ctexplain.bazel_api as bazel_api +import tools.ctexplain.lib as lib +from tools.ctexplain.types import ConfiguredTarget +import tools.ctexplain.util as util + +FLAGS = flags.FLAGS + + +@dataclass(frozen=True) +class Analysis(): + """Supported analysis type.""" + # The value in --analysis= that triggers this analysis. + key: str + # The function that invokes this analysis. + exec: Callable[[Tuple[ConfiguredTarget, ...]], None] + # User-friendly analysis description. + description: str + +available_analyses = [ + Analysis( + "summary", + lambda x: summary.report(summary.analyze(x)), + "summarizes build graph size and how trimming could help" + ), + Analysis( + "culprits", + lambda x: print("this analysis not yet implemented"), + "shows which flags unnecessarily fork configured targets. These\n" + + "are conceptually mergeable." + ), + Analysis( + "forked_targets", + lambda x: print("this analysis not yet implemented"), + "ranks targets by how many configured targets they\n" + + "create. These may be legitimate forks (because they behave " + + "differently with\n different flags) or identical clones that are " + + "conceptually mergeable." + ), + Analysis( + "cloned_targets", + lambda x: print("this analysis not yet implemented"), + "ranks targets by how many behavior-identical configured\n targets " + + "they produce. These are conceptually mergeable." + ) +] + +# Available analyses, keyed by --analysis= triggers. +analyses = {analysis.key: analysis for analysis in available_analyses} + + +# Command-line flag registration: + + +def _render_analysis_help_text() -> str: + """Pretty-prints help text for available analyses.""" + return "\n".join(f'- "{name}": {analysis.description}' + for name, analysis in analyses.items()) + +flags.DEFINE_list("analysis", ["summary"], f""" +Analyses to run. May be any comma-separated combination of + +{_render_analysis_help_text()} +""") + +flags.register_validator( + "analysis", + lambda flag_value: all(name in analyses for name in flag_value), + message=f'available analyses: {", ".join(analyses.keys())}') + +flags.DEFINE_multi_string( + "build", [], + """command-line invocation of the build to analyze. For example: +"//foo --define a=b". If listed multiple times, this is a "multi-build +analysis" that measures how much distinct builds can share subgraphs""", + short_name="b") + + +# Core program logic: + + +def _get_build_flags(cmdline: str) -> Tuple[Tuple[str, ...], Tuple[str, ...]]: + """Parses a build invocation command line. + + Args: + cmdline: raw build invocation string. For example: "//foo --cpu=x86" + + Returns: + Tuple of ((target labels to build), (build flags)) + """ + cmdlist = cmdline.split() + labels = [arg for arg in cmdlist if arg.startswith("//")] + build_flags = [arg for arg in cmdlist if not arg.startswith("//")] + return (tuple(labels), tuple(build_flags)) + + +def main(argv): + del argv # Satisfy py linter's "unused" warning. + if not FLAGS.build: + exit("ctexplain: build efficiency measurement tool. Add --help " + + "for usage.") + elif len(FLAGS.build) > 1: + exit("TODO(gregce): support multi-build shareability analysis") + + (labels, build_flags) = _get_build_flags(FLAGS.build[0]) + build_desc = ",".join(labels) + with util.ProgressStep(f"Collecting configured targets for {build_desc}"): + cts = lib.analyze_build(bazel_api.BazelApi(), labels, build_flags) + for analysis in FLAGS.analysis: + analyses[analysis].exec(cts) -bazel_api = BazelApi() -# TODO(gregce): move all logic to a _lib library so we can easily include -# end-to-end testing. We'll only handle flag parsing here, which we pass -# into the main invoker as standard Python args. +if __name__ == "__main__": + app.run(main) diff --git a/tools/ctexplain/lib.py b/tools/ctexplain/lib.py new file mode 100644 index 00000000000000..60afc3903c076a --- /dev/null +++ b/tools/ctexplain/lib.py @@ -0,0 +1,60 @@ +# Lint as: python3 +# Copyright 2020 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""General-purpose business logic.""" +from typing import Tuple + +import tools.ctexplain.bazel_api as bazel_api +from tools.ctexplain.types import ConfiguredTarget + + +def analyze_build(bazel: bazel_api.BazelApi, labels: Tuple[str, ...], + build_flags: Tuple[str, ...]) -> Tuple[ConfiguredTarget, ...]: + """Gets a build invocation's configured targets. + + Args: + bazel: API for invoking Bazel. + labels: The targets to build. + build_flags: The build flags to use. + + Returns: + Configured targets representing the build. + + Raises: + RuntimeError: On any invocation errors. + """ + cquery_args = [f'deps({",".join(labels)})'] + cquery_args.extend(build_flags) + (success, stderr, cts) = bazel.cquery(cquery_args) + if not success: + raise RuntimeError("invocation failed: " + "\n".join(stderr)) + + # We have to do separate calls to "bazel config" to get the actual configs + # from their hashes. + hashes_to_configs = {} + cts_with_configs = [] + for ct in cts: + # Don't use dict.setdefault because that unconditionally calls get_config + # as one of its parameters and that's an expensive operation to waste. + if ct.config_hash not in hashes_to_configs: + hashes_to_configs[ct.config_hash] = bazel.get_config(ct.config_hash) + config = hashes_to_configs[ct.config_hash] + cts_with_configs.append( + ConfiguredTarget( + ct.label, + config, + ct.config_hash, + ct.transitive_fragments)) + + return tuple(cts_with_configs) diff --git a/tools/ctexplain/lib_test.py b/tools/ctexplain/lib_test.py new file mode 100644 index 00000000000000..660a603b541061 --- /dev/null +++ b/tools/ctexplain/lib_test.py @@ -0,0 +1,90 @@ +# Lint as: python3 +# Copyright 2020 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for lib.py.""" +import unittest +from src.test.py.bazel import test_base +import tools.ctexplain.bazel_api as bazel_api +import tools.ctexplain.lib as lib +from tools.ctexplain.types import Configuration +from tools.ctexplain.types import HostConfiguration +from tools.ctexplain.types import NullConfiguration + + +class LibTest(test_base.TestBase): + + _bazel: bazel_api.BazelApi = None + + def setUp(self): + test_base.TestBase.setUp(self) + self._bazel = bazel_api.BazelApi(self.RunBazel) + self.ScratchFile('WORKSPACE') + self.CreateWorkspaceWithDefaultRepos('repo/WORKSPACE') + + def tearDown(self): + test_base.TestBase.tearDown(self) + + def testAnalyzeBuild(self): + self.ScratchFile('testapp/defs.bzl', [ + 'def _impl(ctx):', + ' pass', + 'rule_with_host_dep = rule(', + ' implementation = _impl,', + ' attrs = { "host_deps": attr.label_list(cfg = "host") })', + ]) + self.ScratchFile('testapp/BUILD', [ + 'load("//testapp:defs.bzl", "rule_with_host_dep")', + 'rule_with_host_dep(name = "a", host_deps = [":h"])', + 'filegroup(name = "h", srcs = ["h.src"])' + ]) + cts = lib.analyze_build(self._bazel, ('//testapp:a',), ()) + # Remove boilerplate deps to focus on targets declared here. + cts = [ct for ct in cts if ct.label.startswith('//testapp')] + + self.assertListEqual( + [ct.label for ct in cts], + ['//testapp:a', '//testapp:h', '//testapp:h.src']) + # Don't use assertIsInstance because we don't want to match subclasses. + self.assertEqual(Configuration, type(cts[0].config)) + self.assertEqual('HOST', cts[1].config_hash) + self.assertIsInstance(cts[1].config, HostConfiguration) + self.assertEqual('null', cts[2].config_hash) + self.assertIsInstance(cts[2].config, NullConfiguration) + + def testAnalyzeBuildNoRepeats(self): + self.ScratchFile('testapp/defs.bzl', [ + 'def _impl(ctx):', + ' pass', + 'rule_with_host_dep = rule(', + ' implementation = _impl,', + ' attrs = { "host_deps": attr.label_list(cfg = "host") })', + ]) + self.ScratchFile('testapp/BUILD', [ + 'load("//testapp:defs.bzl", "rule_with_host_dep")', + 'rule_with_host_dep(name = "a", host_deps = [":h", ":other"])', + 'rule_with_host_dep(name = "other")', + 'filegroup(name = "h", srcs = ["h.src", ":other"])' + ]) + cts = lib.analyze_build(self._bazel, ('//testapp:a',), ()) + # Remove boilerplate deps to focus on targets declared here. + cts = [ct for ct in cts if ct.label.startswith('//testapp')] + + # Even though the build references //testapp:other twice, it only appears + # once. + self.assertListEqual( + [ct.label for ct in cts], + ['//testapp:a', '//testapp:h', '//testapp:other', '//testapp:h.src']) + +if __name__ == '__main__': + unittest.main() diff --git a/tools/ctexplain/types.py b/tools/ctexplain/types.py index 2ac480f835052c..c47ab9fb4b995c 100644 --- a/tools/ctexplain/types.py +++ b/tools/ctexplain/types.py @@ -26,16 +26,20 @@ @dataclass(frozen=True) class Configuration(): """Stores a build configuration as a collection of fragments and options.""" - # BuildConfiguration.Fragments in this configuration, as base names without - # packages. For example: ["PlatformConfiguration", ...]. - fragments: Tuple[str, ...] + # Mapping of each BuildConfiguration.Fragment in this configuration to the + # FragmentOptions it requires. + # + # All names are qualified up to the base file name, without package prefixes. + # For example, foo.bar.BazConfiguration appears as "BazConfiguration". + # foo.bar.BazConfiguration$Options appears as "BazeConfiguration$Options". + fragments: Mapping[str, Tuple[str, ...]] # Mapping of FragmentOptions to option key/value pairs. For example: # {"CoreOptions": {"action_env": "[]", "cpu": "x86", ...}, ...}. # # Option values are stored as strings of whatever "bazel config" outputs. # # Note that Fragment and FragmentOptions aren't the same thing. - options: [Mapping[str, Mapping[str, str]]] + options: Mapping[str, Mapping[str, str]] @dataclass(frozen=True) @@ -49,7 +53,7 @@ class ConfiguredTarget(): config_hash: str # Fragments required by this configured target and its transitive # dependencies. Stored as base names without packages. For example: - # "PlatformOptions". + # "PlatformOptions" or "FooConfiguration$Options". transitive_fragments: Tuple[str, ...] diff --git a/tools/ctexplain/util.py b/tools/ctexplain/util.py new file mode 100644 index 00000000000000..ed02c3c45b049e --- /dev/null +++ b/tools/ctexplain/util.py @@ -0,0 +1,54 @@ +# Lint as: python3 +# Copyright 2020 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generic utilities.""" +import time + + +class ProgressStep: + """A simple context manager that prints a progress message. + + Forked from a similar project by brandjon@google.com. + """ + + def __init__(self, msg, show_done=True): + self.msg = msg + self.show_done = show_done + + def __enter__(self): + print(self.msg + "...", flush=True, end="") + self.start_time = time.perf_counter() + + def __exit__(self, exc_type, exc_value, traceback): + if self.show_done: + elapsed_sec = time.perf_counter() - self.start_time + if elapsed_sec < 0.1: + time_str = f"{elapsed_sec * 1000:.0f} ms" + else: + time_str = f"{elapsed_sec:.2f} s" + print(f" done in {time_str}.", flush=True) + + +def percent_diff(val1, val2): + """Returns what percentage a change val2 is from val1. + + For example, if val=10 and val2=15, returns '+50%'. + + Args: + val1: Base number. + val2: Number to compare against the base number. + """ + diff = (val2 - val1) / val1 * 100 + return f"+{diff:.1f}%" if diff >= 0 else f"{diff:.1f}%" + \ No newline at end of file