Skip to content

Commit

Permalink
feat[next]: Add memory and disk-based caching to more workflow steps (#…
Browse files Browse the repository at this point in the history
…1690)

Add memory and disk-based caching to other workflow steps and,
therefore, removing unnecessary overhead of Program calls and
significantly improving time to first computed value.

Changes:

- setting `cached = True` for `func_to_past_factory`
- wrapping the GTFN code generation into a `CachedStep` (using
`Diskcache`) which is activated when setting
`otf_workflow__cached_translation=True`, similar as in
[PR#1474](#1474) (without
CachedStep)
- Fixing hash function of `ProgramDefinition`
- New tests for added functionality

This leads to a runtime decrease of about 25% for PMAP-G in the
advect-uniform testcase (5 hours) after caches are populated.

TODOs: 
- [x] improving hash functions of `fingerprint_stage`

---------

Co-authored-by: Till Ehrengruber <till.ehrengruber@cscs.ch>
Co-authored-by: Enrique G. Paredes <enriqueg@cscs.ch>
  • Loading branch information
3 people authored Nov 7, 2024
1 parent 6873a0e commit 1b9eb5c
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 25 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ repos:
- astunparse==1.6.3
- attrs==24.2.0
- black==24.8.0
- boltons==24.0.0
- boltons==24.1.0
- cached-property==2.0.1
- click==8.1.7
- cmake==3.30.5
- cytoolz==1.0.0
- deepdiff==8.0.1
- devtools==0.12.2
- diskcache==5.6.3
- factory-boy==3.3.1
- frozendict==2.4.6
- gridtools-cpp==2.3.6
Expand Down
9 changes: 5 additions & 4 deletions constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema,
babel==2.16.0 # via sphinx
backcall==0.2.0 # via ipython
black==24.8.0 # via gt4py (pyproject.toml)
boltons==24.0.0 # via gt4py (pyproject.toml)
boltons==24.1.0 # via gt4py (pyproject.toml)
bracex==2.5.post1 # via wcmatch
build==1.2.2.post1 # via pip-tools
bump-my-version==0.28.0 # via -r requirements-dev.in
bump-my-version==0.28.1 # via -r requirements-dev.in
cached-property==2.0.1 # via gt4py (pyproject.toml)
cachetools==5.5.0 # via tox
certifi==2024.8.30 # via requests
Expand All @@ -40,6 +40,7 @@ decorator==5.1.1 # via ipython
deepdiff==8.0.1 # via gt4py (pyproject.toml)
devtools==0.12.2 # via gt4py (pyproject.toml)
dill==0.3.9 # via dace
diskcache==5.6.3 # via gt4py (pyproject.toml)
distlib==0.3.9 # via virtualenv
docutils==0.20.1 # via sphinx, sphinx-rtd-theme
exceptiongroup==1.2.2 # via hypothesis, pytest
Expand Down Expand Up @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via ipykernel, jupyter-client
questionary==2.0.1 # via bump-my-version
referencing==0.35.1 # via jsonschema, jsonschema-specifications
requests==2.32.3 # via sphinx
rich==13.9.3 # via bump-my-version, rich-click, tach
rich==13.9.4 # via bump-my-version, rich-click, tach
rich-click==1.8.3 # via bump-my-version
rpds-py==0.20.1 # via jsonschema, referencing
ruff==0.7.2 # via -r requirements-dev.in
Expand All @@ -158,7 +159,7 @@ stack-data==0.6.3 # via ipython
stdlib-list==0.10.0 # via tach
sympy==1.12.1 # via dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via gt4py (pyproject.toml)
tach==0.14.1 # via -r requirements-dev.in
tach==0.14.2 # via -r requirements-dev.in
tomli==2.0.2 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via tach
tomlkit==0.13.2 # via bump-my-version
Expand Down
1 change: 1 addition & 0 deletions min-extra-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ dace==0.16.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
frozendict==2.3
gridtools-cpp==2.3.6
Expand Down
1 change: 1 addition & 0 deletions min-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ cytoolz==0.12.1
darglint==1.6
deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
frozendict==2.3
gridtools-cpp==2.3.6
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
'cytoolz>=0.12.1',
'deepdiff>=5.6.0',
'devtools>=0.6',
'diskcache>=5.6.3',
'factory-boy>=3.3.0',
'frozendict>=2.3',
'gridtools-cpp>=2.3.6,==2.*',
Expand Down
9 changes: 5 additions & 4 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypo
babel==2.16.0 # via -c constraints.txt, sphinx
backcall==0.2.0 # via -c constraints.txt, ipython
black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml)
boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml)
boltons==24.1.0 # via -c constraints.txt, gt4py (pyproject.toml)
bracex==2.5.post1 # via -c constraints.txt, wcmatch
build==1.2.2.post1 # via -c constraints.txt, pip-tools
bump-my-version==0.28.0 # via -c constraints.txt, -r requirements-dev.in
bump-my-version==0.28.1 # via -c constraints.txt, -r requirements-dev.in
cached-property==2.0.1 # via -c constraints.txt, gt4py (pyproject.toml)
cachetools==5.5.0 # via -c constraints.txt, tox
certifi==2024.8.30 # via -c constraints.txt, requests
Expand All @@ -40,6 +40,7 @@ decorator==5.1.1 # via -c constraints.txt, ipython
deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml)
devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml)
dill==0.3.9 # via -c constraints.txt, dace
diskcache==5.6.3 # via -c constraints.txt, gt4py (pyproject.toml)
distlib==0.3.9 # via -c constraints.txt, virtualenv
docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme
exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest
Expand Down Expand Up @@ -135,7 +136,7 @@ pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client
questionary==2.0.1 # via -c constraints.txt, bump-my-version
referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications
requests==2.32.3 # via -c constraints.txt, sphinx
rich==13.9.3 # via -c constraints.txt, bump-my-version, rich-click, tach
rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach
rich-click==1.8.3 # via -c constraints.txt, bump-my-version
rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing
ruff==0.7.2 # via -c constraints.txt, -r requirements-dev.in
Expand All @@ -157,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython
stdlib-list==0.10.0 # via -c constraints.txt, tach
sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml)
tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml)
tach==0.14.1 # via -c constraints.txt, -r requirements-dev.in
tach==0.14.2 # via -c constraints.txt, -r requirements-dev.in
tomli==2.0.2 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via -c constraints.txt, tach
tomlkit==0.13.2 # via -c constraints.txt, bump-my-version
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG:
)


def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]:
def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]:
"""
Wrap `func_to_past` in a chainable and optionally cached workflow step.
Expand Down
8 changes: 6 additions & 2 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No

@add_content_to_fingerprint.register(FieldOperatorDefinition)
@add_content_to_fingerprint.register(FoastOperatorDefinition)
@add_content_to_fingerprint.register(ProgramDefinition)
@add_content_to_fingerprint.register(PastProgramDefinition)
@add_content_to_fingerprint.register(toolchain.CompilableProgram)
@add_content_to_fingerprint.register(arguments.CompileTimeArgs)
Expand All @@ -121,10 +122,14 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo
for item in sourcedef:
add_content_to_fingerprint(item, hasher)

closure_vars = source_utils.get_closure_vars_from_function(obj)
for item in sorted(closure_vars.items(), key=lambda x: x[0]):
add_content_to_fingerprint(item, hasher)


@add_content_to_fingerprint.register
def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None:
for key, value in obj.items():
for key, value in sorted(obj.items()):
add_content_to_fingerprint(key, hasher)
add_content_to_fingerprint(value, hasher)

Expand All @@ -148,4 +153,3 @@ def add_foast_located_node_to_fingerprint(
) -> None:
add_content_to_fingerprint(obj.location, hasher)
add_content_to_fingerprint(str(obj), hasher)
add_content_to_fingerprint(str(obj), hasher)
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait):
closures: List[StencilClosure]
implicit_domain: bool = False

_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS]
_NODE_SYMBOLS_: ClassVar[List[Sym]] = [
Sym(id=name) for name in sorted(BUILTINS)
] # sorted for serialization stability


class Stmt(Node): ...
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/otf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dataclasses
import functools
import typing
from collections.abc import MutableMapping
from typing import Any, Callable, Generic, Protocol, TypeVar

from typing_extensions import Self
Expand Down Expand Up @@ -253,16 +254,15 @@ class CachedStep(

step: Workflow[StartT, EndT]
hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment]

_cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict)
cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict)

def __call__(self, inp: StartT) -> EndT:
"""Run the step only if the input is not cached, else return from cache."""
hash_ = self.hash_function(inp)
try:
result = self._cache[hash_]
result = self.cache[hash_]
except KeyError:
result = self._cache[hash_] = self.step(inp)
result = self.cache[hash_] = self.step(inp)
return result


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def generate_stencil_source(
generated_code = GTFNIMCodegen.apply(gtfn_im_ir)
else:
generated_code = GTFNCodegen.apply(gtfn_ir)

return codegen.format_source("cpp", generated_code, style="LLVM")

def __call__(
Expand Down
61 changes: 54 additions & 7 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

import functools
import warnings
from typing import Any
from typing import Any, Optional

import diskcache
import factory
import numpy.typing as npt

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
from gt4py.eve import utils
from gt4py.eve.utils import content_hash
from gt4py.next import backend, common, config
from gt4py.next.iterator import transforms
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.iterator import ir as itir, transforms
from gt4py.next.otf import arguments, recipes, stages, workflow
from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import compiler
Expand Down Expand Up @@ -116,6 +119,37 @@ def compilation_hash(otf_closure: stages.CompilableProgram) -> int:
)


def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str:
"""
Generates a unique hash string for a stencil source program representing
the program, sorted offset_provider, and column_axis.
"""
program: itir.FencilDefinition | itir.Program = inp.data
offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider
column_axis: Optional[common.Dimension] = inp.args.column_axis

program_hash = utils.content_hash(
(
program,
sorted(offset_provider.items(), key=lambda el: el[0]),
column_axis,
)
)

return program_hash


class FileCache(diskcache.Cache):
"""
This class extends `diskcache.Cache` to ensure the cache is closed upon deletion,
i.e. it ensures that any resources associated with the cache are properly
released when the instance is garbage collected.
"""

def __del__(self) -> None:
self.close()


class GTFNCompileWorkflowFactory(factory.Factory):
class Meta:
model = recipes.OTFCompileWorkflow
Expand All @@ -129,10 +163,23 @@ class Params:
lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type)
)

translation = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)
cached_translation = factory.Trait(
translation=factory.LazyAttribute(
lambda o: workflow.CachedStep(
o.translation_,
hash_function=fingerprint_compilable_program,
cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")),
)
),
)

translation_ = factory.SubFactory(
gtfn_module.GTFNTranslationStepFactory,
device_type=factory.SelfAttribute("..device_type"),
)

translation = factory.LazyAttribute(lambda o: o.translation_)

bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = (
nanobind.bind_source
)
Expand Down Expand Up @@ -193,7 +240,7 @@ class Params:
name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True
)

run_gtfn_cached = GTFNBackendFactory(cached=True)
run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True)

run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
# SPDX-License-Identifier: BSD-3-Clause

from functools import reduce

from gt4py.next.otf import languages, stages, workflow
from gt4py.next.otf.binding import interface
import numpy as np
import pytest
import diskcache
from gt4py.eve import SymbolName

import gt4py.next as gtx
from gt4py.next import (
Expand Down
Loading

0 comments on commit 1b9eb5c

Please sign in to comment.