Skip to content

Commit

Permalink
Fix bug that caused custom schema functions not to be used (#11)
Browse files Browse the repository at this point in the history
* Move `merge_dicts` and add doctest and types

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add more tests for the `customize` module

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Actually use the `schema_conflict_handlers` dict

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Reorganize test files

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix test issue (caused by doctest!)

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Nov 6, 2024
1 parent 83812e5 commit 62ea0bd
Show file tree
Hide file tree
Showing 15 changed files with 519 additions and 183 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ cython_debug/
.schemas
outputs
tests/structured_app/.schemas
.vscode
!tests/**/.vscode
4 changes: 2 additions & 2 deletions hydra_auto_schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from .auto_schema import (
add_schemas_to_all_hydra_configs,
)
from .customize import custom_enum_schemas, special_handlers
from .customize import custom_enum_schemas, custom_hydra_zen_builds_args
from .filewatcher import AutoSchemaEventHandler

__all__ = [
"add_schemas_to_all_hydra_configs",
# "AutoSchemaPlugin",
# "register_auto_schema_plugin",
"AutoSchemaEventHandler",
"special_handlers",
"custom_hydra_zen_builds_args",
"custom_enum_schemas",
]
110 changes: 33 additions & 77 deletions hydra_auto_schema/auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import sys
import typing
import warnings
from collections.abc import Callable, MutableMapping
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, TypeVar
Expand All @@ -48,14 +47,18 @@
from pydantic_core import core_schema
from tqdm.rich import tqdm_rich

from hydra_auto_schema.customize import custom_enum_schemas, special_handlers
from hydra_auto_schema.customize import (
custom_enum_schemas,
custom_hydra_zen_builds_args,
schema_conflict_handlers,
)
from hydra_auto_schema.hydra_schema import (
HYDRA_CONFIG_SCHEMA,
ObjectSchema,
PropertySchema,
Schema,
)
from hydra_auto_schema.utils import pretty_path
from hydra_auto_schema.utils import merge_dicts, pretty_path

logger = get_logger(__name__)

Expand Down Expand Up @@ -459,8 +462,11 @@ def _create_schema_for_config(
assert "properties" in nested_value_schema

if is_top_level:
schema = _merge_dicts(
schema, nested_value_schema, conflict_handler=_overwrite
schema = merge_dicts(
schema,
nested_value_schema,
conflict_handler=_overwrite,
conflict_handlers=schema_conflict_handlers,
)
continue

Expand All @@ -481,10 +487,11 @@ def _create_schema_for_config(
assert isinstance(last_key, str)
where_to_set["properties"][last_key] = nested_value_schema # type: ignore
else:
where_to_set["properties"] = _merge_dicts( # type: ignore
where_to_set["properties"] = merge_dicts( # type: ignore
where_to_set["properties"],
{last_key: nested_value_schema}, # type: ignore
conflict_handler=_overwrite,
conflict_handlers=schema_conflict_handlers,
)

return schema
Expand Down Expand Up @@ -546,7 +553,7 @@ def _update_schema_from_defaults(
# f"Properties of {default=}: {list(schema_of_default['properties'].keys())}"
# ) # type: ignore

schema = _merge_dicts( # type: ignore
schema = merge_dicts( # type: ignore
schema_of_default, # type: ignore
schema, # type: ignore
conflict_handler=_overwrite,
Expand All @@ -557,6 +564,7 @@ def _update_schema_from_defaults(
# "title": _overwrite,
# "description": _overwrite,
# },
conflict_handlers=schema_conflict_handlers,
)
# todo: deal with this one here.
if schema.get("additionalProperties") is False:
Expand Down Expand Up @@ -657,63 +665,6 @@ def _keep_previous(val_a: Any, val_b: Any) -> Any:
return val_a


conflict_handlers: dict[str, Callable[[Any, Any], Any]] = {}

_K = TypeVar("_K")
_V = TypeVar("_V")
_NestedDict = MutableMapping[_K, _V | "_NestedDict[_K, _V]"]

_D1 = TypeVar("_D1", bound=_NestedDict)
_D2 = TypeVar("_D2", bound=_NestedDict)


def _merge_dicts(
a: _D1,
b: _D2,
path: list[str] = [],
conflict_handlers: dict[str, Callable[[Any, Any], Any]] = conflict_handlers,
conflict_handler: Callable[[Any, Any], Any] | None = None,
) -> _D1 | _D2:
"""Merge two nested dictionaries.
>>> x = dict(b=1, c=dict(d=2, e=3))
>>> y = dict(d=3, c=dict(z=2, f=4))
>>> _merge_dicts(x, y)
{'b': 1, 'c': {'d': 2, 'e': 3, 'z': 2, 'f': 4}, 'd': 3}
>>> x
{'b': 1, 'c': {'d': 2, 'e': 3}}
>>> y
{'d': 3, 'c': {'z': 2, 'f': 4}}
"""
out = copy.deepcopy(a)
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
out[key] = _merge_dicts(
a[key],
b[key],
path + [str(key)],
conflict_handlers={
k.removeprefix(f"{key}."): v
for k, v in conflict_handlers.items()
},
conflict_handler=conflict_handler,
)
elif a[key] != b[key]:
if specific_conflict_handler := conflict_handlers.get(key):
out[key] = specific_conflict_handler(a[key], b[key]) # type: ignore
elif conflict_handler:
out[key] = conflict_handler(a[key], b[key]) # type: ignore

# if any(key.split(".")[-1] == handler_name for for prefix in ["_", "description", "title"]):
# out[key] = b[key]
else:
raise Exception("Conflict at " + ".".join(path + [str(key)]))
else:
out[key] = copy.deepcopy(b[key]) # type: ignore
return out


def _has_package_global_line(config_file: Path) -> int | None:
"""Returns whether the config file contains a `@package _global_` directive of hydra.
Expand Down Expand Up @@ -941,19 +892,24 @@ def _add_schema_header(config_file: Path, schema_path: Path) -> None:


def _get_dataclass_from_target(target: Any, config: dict | DictConfig) -> type:
if inspect.isclass(target) and target in special_handlers:
special_kwargs = special_handlers[target]
kwargs = _merge_dicts(
dict(
populate_full_signature=True,
hydra_recursive=False,
hydra_convert="all",
zen_dataclass={"cls_name": target.__qualname__},
),
special_kwargs,
)
# Generate the dataclass dynamically with hydra-zen.
return hydra_zen.builds(target, **kwargs)
for target_type, special_kwargs in custom_hydra_zen_builds_args.items():
if target_type is target or (
inspect.isclass(target)
and inspect.isclass(target_type)
and issubclass(target, target_type)
):
kwargs = merge_dicts(
dict(
populate_full_signature=True,
hydra_recursive=False,
hydra_convert="all",
zen_dataclass={"cls_name": target.__qualname__},
),
special_kwargs,
conflict_handler=_overwrite,
)
# Generate the dataclass dynamically with hydra-zen.
return hydra_zen.builds(target, **kwargs)
if dataclasses.is_dataclass(target):
# The target is a dataclass, so the schema is just the schema of the dataclass.
assert inspect.isclass(target)
Expand Down
85 changes: 0 additions & 85 deletions hydra_auto_schema/auto_schema_test.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,3 @@
import json
import os
from pathlib import Path

import pytest
import yaml
from hydra.core.config_store import ConfigStore
from pytest_regressions.file_regression import FileRegressionFixture

from .auto_schema import (
_add_schema_header,
_create_schema_for_config,
add_schemas_to_all_hydra_configs,
)

REPO_ROOTDIR = Path.cwd()
IN_GITHUB_CI = "GITHUB_ACTIONS" in os.environ
config_dir = Path(__file__).parent.parent / "tests" / "configs"


@pytest.fixture
def original_datadir():
return config_dir


class Foo:
def __init__(self, bar: str):
"""Description of the `Foo` class.
Expand All @@ -44,63 +19,3 @@ def __init__(self, bar: str, baz: int):
# no docstring here.
super().__init__(bar=bar)
self.baz = baz


test_files = list(config_dir.rglob("*.yaml"))


@pytest.mark.parametrize(
"config_file",
[
pytest.param(
p,
marks=pytest.mark.xfail(
IN_GITHUB_CI,
reason="TODO: Does not work on the Github CI for some reason!",
),
)
if "structured" in p.name
else p
for p in test_files
],
ids=[f.name for f in test_files],
)
def test_make_schema(config_file: Path, file_regression: FileRegressionFixture):
"""Test that creates a schema for a config file and saves it next to it.
(in the test folder).
"""
schema_file = config_file.with_suffix(".json")

config = yaml.load(config_file.read_text(), yaml.FullLoader)
if config is None:
config = {}
schema = _create_schema_for_config(
config=config,
config_file=config_file,
configs_dir=config_dir,
repo_root=REPO_ROOTDIR,
config_store=ConfigStore.instance(),
)
_add_schema_header(config_file, schema_path=schema_file)

file_regression.check(
json.dumps(schema, indent=2) + "\n", fullpath=schema_file, extension=".json"
)


def test_warns_when_no_config_files_found(tmp_path: Path):
with pytest.warns(RuntimeWarning, match="No config files were found"):
add_schemas_to_all_hydra_configs(
repo_root=tmp_path, configs_dir=tmp_path, schemas_dir=tmp_path
)


def test_raises_when_no_config_files_found_and_stop_on_error(tmp_path: Path):
with pytest.raises(RuntimeError, match="No config files were found"):
add_schemas_to_all_hydra_configs(
repo_root=tmp_path,
configs_dir=tmp_path,
schemas_dir=tmp_path,
stop_on_error=True,
)
28 changes: 23 additions & 5 deletions hydra_auto_schema/customize.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,49 @@
"""Global variables that can be used to customize how schemas are generated."""
import dataclasses
from collections.abc import Callable
import enum
from typing import Any, Callable

from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


special_handlers: dict[type | Callable, dict] = {
custom_hydra_zen_builds_args: dict[type | Callable, dict] = {
# flax.linen.Module: {"zen_exclude": ["parent"]},
# lightning.pytorch.callbacks.RichProgressBar: {"zen_exclude": ["theme"]},
}
"""Keyword arguments that should be passed to `hydra_zen.builds` for a given class or callable.
These arguments overwrite the default values.
"""

custom_enum_schemas: dict[type[enum.Enum], Callable] = {}
"""Dict of functions to be used by pydantic to generate schemas for enum classes.
TODO: This a bit too specific. We could probably use our `GenerateJsonSchema`
subclass to enable more general customizations, following this guide:
https://docs.pydantic.dev/2.9/concepts/json_schema/#customizing-the-json-schema-generation-process
"""

schema_conflict_handlers: dict[str, Callable[[Any, Any], Any]] = {}
"""Functions to be used by the `merge_dicts` function to resolve conflicts between schemas.
See the docstring of `merge_dicts` for more info.
"""

# Conditionally add some common fixes?
# TODO: Should we really import those here? This might make the import pretty slow, no? If not,
# when / how should we import things here?

try:
from flax.linen import Module # type: ignore

special_handlers[Module] = {"zen_exclude": ["parent"]}
custom_hydra_zen_builds_args[Module] = {"zen_exclude": ["parent"]}
except ImportError:
pass

try:
from lightning.pytorch.callbacks import RichProgressBar # type: ignore

special_handlers[RichProgressBar] = {"zen_exclude": ["theme"]}
custom_hydra_zen_builds_args[RichProgressBar] = {"zen_exclude": ["theme"]}
except ImportError:
pass

Expand Down
Loading

0 comments on commit 62ea0bd

Please sign in to comment.