diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 880a422160..f2f5b73613 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -94,13 +94,14 @@ repos: - astunparse==1.6.3 - attrs==24.2.0 - black==24.8.0 - - boltons==24.0.0 + - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - cmake==3.30.5 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 + - diskcache==5.6.3 - factory-boy==3.3.1 - frozendict==2.4.6 - gridtools-cpp==2.3.6 diff --git a/constraints.txt b/constraints.txt index e846d4126c..e7acc466cd 100644 --- a/constraints.txt +++ b/constraints.txt @@ -13,10 +13,10 @@ attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema, babel==2.16.0 # via sphinx backcall==0.2.0 # via ipython black==24.8.0 # via gt4py (pyproject.toml) -boltons==24.0.0 # via gt4py (pyproject.toml) +boltons==24.1.0 # via gt4py (pyproject.toml) bracex==2.5.post1 # via wcmatch build==1.2.2.post1 # via pip-tools -bump-my-version==0.28.0 # via -r requirements-dev.in +bump-my-version==0.28.1 # via -r requirements-dev.in cached-property==2.0.1 # via gt4py (pyproject.toml) cachetools==5.5.0 # via tox certifi==2024.8.30 # via requests @@ -40,6 +40,7 @@ decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.9 # via dace +diskcache==5.6.3 # via gt4py (pyproject.toml) distlib==0.3.9 # via virtualenv docutils==0.20.1 # via sphinx, sphinx-rtd-theme exceptiongroup==1.2.2 # via hypothesis, pytest @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via ipykernel, jupyter-client questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx -rich==13.9.3 # via bump-my-version, rich-click, tach +rich==13.9.4 # via bump-my-version, rich-click, tach rich-click==1.8.3 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing ruff==0.7.2 # via -r requirements-dev.in @@ -158,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.12.1 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.1 # via -r requirements-dev.in +tach==0.14.2 # via -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 7fea11bc3d..f63042906c 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -65,6 +65,7 @@ dace==0.16.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 +diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 gridtools-cpp==2.3.6 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index c20883e25e..666aa79107 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -61,6 +61,7 @@ cytoolz==0.12.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 +diskcache==5.6.3 factory-boy==3.3.0 frozendict==2.3 gridtools-cpp==2.3.6 diff --git a/pyproject.toml b/pyproject.toml index 64f08e671e..c9f7b3b50b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ 'cytoolz>=0.12.1', 'deepdiff>=5.6.0', 'devtools>=0.6', + 'diskcache>=5.6.3', 'factory-boy>=3.3.0', 'frozendict>=2.3', 'gridtools-cpp>=2.3.6,==2.*', diff --git a/requirements-dev.txt b/requirements-dev.txt index eb757e0afd..a036307e80 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,10 +13,10 @@ attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypo babel==2.16.0 # via -c constraints.txt, sphinx backcall==0.2.0 # via -c constraints.txt, ipython black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml) -boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml) +boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml) bracex==2.5.post1 # via -c constraints.txt, wcmatch build==1.2.2.post1 # via -c constraints.txt, pip-tools -bump-my-version==0.28.0 # via -c constraints.txt, -r requirements-dev.in +bump-my-version==0.28.1 # via -c constraints.txt, -r requirements-dev.in cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml) cachetools==5.5.0 # via -c constraints.txt, tox certifi==2024.8.30 # via -c constraints.txt, requests @@ -40,6 +40,7 @@ decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) dill==0.3.9 # via -c constraints.txt, dace +diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml) distlib==0.3.9 # via -c constraints.txt, virtualenv docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx -rich==13.9.3 # via -c constraints.txt, bump-my-version, rich-click, tach +rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach rich-click==1.8.3 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in @@ -157,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.1 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.2 # via -c constraints.txt, -r requirements-dev.in tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index f415c95b63..09f53be600 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG: ) -def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]: +def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]: """ Wrap `func_to_past` in a chainable and optionally cached workflow step. diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index bf3bee4b56..834536ff59 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No @add_content_to_fingerprint.register(FieldOperatorDefinition) @add_content_to_fingerprint.register(FoastOperatorDefinition) +@add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) @add_content_to_fingerprint.register(toolchain.CompilableProgram) @add_content_to_fingerprint.register(arguments.CompileTimeArgs) @@ -121,10 +122,14 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo for item in sourcedef: add_content_to_fingerprint(item, hasher) + closure_vars = source_utils.get_closure_vars_from_function(obj) + for item in sorted(closure_vars.items(), key=lambda x: x[0]): + add_content_to_fingerprint(item, hasher) + @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in obj.items(): + for key, value in sorted(obj.items()): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) @@ -148,4 +153,3 @@ def add_foast_located_node_to_fingerprint( ) -> None: add_content_to_fingerprint(obj.location, hasher) add_content_to_fingerprint(str(obj), hasher) - add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b6f543e9d1..f50d8080eb 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -208,7 +208,9 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait): closures: List[StencilClosure] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(BUILTINS) + ] # sorted for serialization stability class Stmt(Node): ... diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index a63801c97e..ef3a4083b9 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -12,6 +12,7 @@ import dataclasses import functools import typing +from collections.abc import MutableMapping from typing import Any, Callable, Generic, Protocol, TypeVar from typing_extensions import Self @@ -253,16 +254,15 @@ class CachedStep( step: Workflow[StartT, EndT] hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] - - _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) + cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict) def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" hash_ = self.hash_function(inp) try: - result = self._cache[hash_] + result = self.cache[hash_] except KeyError: - result = self._cache[hash_] = self.step(inp) + result = self.cache[hash_] = self.step(inp) return result diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 07eec0b64b..66d74d53cc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -213,6 +213,7 @@ def generate_stencil_source( generated_code = GTFNIMCodegen.apply(gtfn_im_ir) else: generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") def __call__( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 2275576081..4a788bf40c 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -8,16 +8,19 @@ import functools import warnings -from typing import Any +from typing import Any, Optional +import diskcache import factory import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators +from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.iterator import transforms +from gt4py.next.common import Connectivity, Dimension +from gt4py.next.iterator import ir as itir, transforms from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -116,6 +119,37 @@ def compilation_hash(otf_closure: stages.CompilableProgram) -> int: ) +def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: + """ + Generates a unique hash string for a stencil source program representing + the program, sorted offset_provider, and column_axis. + """ + program: itir.FencilDefinition | itir.Program = inp.data + offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider + column_axis: Optional[common.Dimension] = inp.args.column_axis + + program_hash = utils.content_hash( + ( + program, + sorted(offset_provider.items(), key=lambda el: el[0]), + column_axis, + ) + ) + + return program_hash + + +class FileCache(diskcache.Cache): + """ + This class extends `diskcache.Cache` to ensure the cache is closed upon deletion, + i.e. it ensures that any resources associated with the cache are properly + released when the instance is garbage collected. + """ + + def __del__(self) -> None: + self.close() + + class GTFNCompileWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFCompileWorkflow @@ -129,10 +163,23 @@ class Params: lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) - translation = factory.SubFactory( - gtfn_module.GTFNTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - ) + cached_translation = factory.Trait( + translation=factory.LazyAttribute( + lambda o: workflow.CachedStep( + o.translation_, + hash_function=fingerprint_compilable_program, + cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), + ) + ), + ) + + translation_ = factory.SubFactory( + gtfn_module.GTFNTranslationStepFactory, + device_type=factory.SelfAttribute("..device_type"), + ) + + translation = factory.LazyAttribute(lambda o: o.translation_) + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source ) @@ -193,7 +240,7 @@ class Params: name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True ) -run_gtfn_cached = GTFNBackendFactory(cached=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7540d52fb3..27f94960dc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -7,9 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import reduce - +from gt4py.next.otf import languages, stages, workflow +from gt4py.next.otf.binding import interface import numpy as np import pytest +import diskcache +from gt4py.eve import SymbolName import gt4py.next as gtx from gt4py.next import ( diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e3e0ee474f..e64bd8a57d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -8,13 +8,25 @@ import numpy as np import pytest +import copy +import diskcache + import gt4py.next as gtx from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import arguments, languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import KDim + +from next_tests.integration_tests.cases import cartesian_case + +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, +) @pytest.fixture @@ -71,3 +83,103 @@ def test_codegen(fencil_example): assert module.entry_point.name == fencil.id assert any(d.name == "gridtools_cpu" for d in module.library_deps) assert module.language is languages.CPP + + +def test_hash_and_diskcache(fencil_example, tmp_path): + fencil, parameters = fencil_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + hash = gtfn.fingerprint_compilable_program(compilable_program) + + with diskcache.Cache(tmp_path) as cache: + cache[hash] = compilable_program + + # check content of cash file + with diskcache.Cache(tmp_path) as reopened_cache: + assert hash in reopened_cache + compilable_program_from_cache = reopened_cache[hash] + assert compilable_program == compilable_program_from_cache + del reopened_cache[hash] # delete data + + # hash creation is deterministic + assert hash == gtfn.fingerprint_compilable_program(compilable_program) + assert hash == gtfn.fingerprint_compilable_program(compilable_program_from_cache) + + # hash is different if program changes + altered_program_id = copy.deepcopy(compilable_program) + altered_program_id.data.id = "example2" + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_id) + + altered_program_offset_provider = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_offset_provider.args, "offset_provider", {"Koff": KDim}) + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_offset_provider) + + altered_program_column_axis = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_column_axis.args, "column_axis", KDim) + assert gtfn.fingerprint_compilable_program( + compilable_program + ) != gtfn.fingerprint_compilable_program(altered_program_column_axis) + + +def test_gtfn_file_cache(fencil_example): + fencil, parameters = fencil_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + cached_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ).executor.step.translation + + bare_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=False + ).executor.step.translation + + cache_key = gtfn.fingerprint_compilable_program(compilable_program) + + # ensure the actual cached step in the backend generates the cache item for the test + if cache_key in (translation_cache := cached_gtfn_translation_step.cache): + del translation_cache[cache_key] + cached_gtfn_translation_step(compilable_program) + assert bare_gtfn_translation_step(compilable_program) == cached_gtfn_translation_step( + compilable_program + ) + + assert cache_key in cached_gtfn_translation_step.cache + assert ( + bare_gtfn_translation_step(compilable_program) + == cached_gtfn_translation_step.cache[cache_key] + ) + + +# TODO(egparedes): we should switch to use the cached backend by default and then remove this test +def test_gtfn_file_cache_whole_workflow(cartesian_case): + if cartesian_case.backend != gtfn.run_gtfn: + pytest.skip("Skipping backend.") + cartesian_case.backend = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ) + + @gtx.field_operator + def testee(a: cases.IJKField) -> cases.IJKField: + field_tuple = (a, a) + field_0 = field_tuple[0] + field_1 = field_tuple[1] + return field_0 + + # first call: this generates the cache file + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) + # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again + object.__setattr__(cartesian_case.backend.executor, "cache", {}) + # second call: the cache file is used + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a)