Skip to content

Commit

Permalink
support types in the typing module (#37)
Browse files Browse the repository at this point in the history
This will be very useful for all downstream projects to set types like
`list[str]`, `list[list[str]]`, etc.

```py
>>> ca = Argument("key1", List[float])
>>> ca.check({"key1": [1, 2.0, 3]})
pass
>>> ca.check({"key1": [1, 2.0, "3"]})
throw ArgumentTypeError
```

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Oct 24, 2023
1 parent 900cf68 commit d97bd20
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 14 deletions.
44 changes: 31 additions & 13 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import re
from copy import deepcopy
from enum import Enum
from numbers import Real
from textwrap import indent
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_origin

import typeguard

INDENT = " " # doc is indented by four spaces
RAW_ANCHOR = False # whether to use raw html anchors or RST ones
Expand Down Expand Up @@ -176,7 +176,7 @@ def __eq__(self, other: "Argument") -> bool:
)

def __repr__(self) -> str:
return f"<Argument {self.name}: {' | '.join(dd.__name__ for dd in self.dtype)}>"
return f"<Argument {self.name}: {' | '.join(self._get_type_name(dd) for dd in self.dtype)}>"

def __getitem__(self, key: str) -> "Argument":
key = key.lstrip("/")
Expand Down Expand Up @@ -205,10 +205,17 @@ def I(self):
return Argument("_", dict, [self])

def _reorg_dtype(self):
if isinstance(self.dtype, type) or self.dtype is None:
if (
isinstance(self.dtype, type)
or isinstance(get_origin(self.dtype), type)
or self.dtype is None
):
self.dtype = [self.dtype]
# remove duplicate
self.dtype = {dt if type(dt) is type else type(dt) for dt in self.dtype}
self.dtype = {
dt if type(dt) is type or type(get_origin(dt)) is type else type(dt)
for dt in self.dtype
}
# check conner cases
if self.sub_fields or self.sub_variants:
self.dtype.add(list if self.repeat else dict)
Expand Down Expand Up @@ -414,16 +421,19 @@ def _check_exist(self, argdict: dict, path=None):
)

def _check_data(self, value: Any, path=None):
if not (
isinstance(value, self.dtype)
or (float in self.dtype and isinstance(value, Real))
):
try:
typeguard.check_type(
value,
self.dtype,
collection_check_strategy=typeguard.CollectionCheckStrategy.ALL_ITEMS,
)
except typeguard.TypeCheckError as e:
raise ArgumentTypeError(
path,
f"key `{self.name}` gets wrong value type, "
f"requires <{'|'.join(dd.__name__ for dd in self.dtype)}> "
f"but gets <{type(value).__name__}>",
)
f"requires <{'|'.join(self._get_type_name(dd) for dd in self.dtype)}> "
f"but " + str(e),
) from e
if self.extra_check is not None and not self.extra_check(value):
raise ArgumentValueError(
path,
Expand Down Expand Up @@ -586,7 +596,9 @@ def gen_doc(self, path: Optional[List[str]] = None, **kwargs) -> str:
return "\n".join(filter(None, doc_list))

def gen_doc_head(self, path: Optional[List[str]] = None, **kwargs) -> str:
typesig = "| type: " + " | ".join([f"``{dt.__name__}``" for dt in self.dtype])
typesig = "| type: " + " | ".join(
[f"``{self._get_type_name(dt)}``" for dt in self.dtype]
)
if self.optional:
typesig += ", optional"
if self.default == "":
Expand Down Expand Up @@ -632,6 +644,10 @@ def gen_doc_body(self, path: Optional[List[str]] = None, **kwargs) -> str:
body = "\n".join(body_list)
return body

def _get_type_name(self, dd) -> str:
"""Get type name for doc/message generation."""
return str(dd) if isinstance(get_origin(dd), type) else dd.__name__


class Variant:
"""Define multiple choices of possible argument sets.
Expand Down Expand Up @@ -993,6 +1009,8 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]:
"choice_alias": obj.choice_alias,
"doc": obj.doc,
}
elif isinstance(get_origin(obj), type):
return get_origin(obj).__name__
elif isinstance(obj, type):
return obj.__name__
return json.JSONEncoder.default(self, obj)
2 changes: 1 addition & 1 deletion dargs/sphinx.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,5 @@ def _test_arguments() -> List[Argument]:
return [
Argument(name="test1", dtype=int, doc="Argument 1"),
Argument(name="test2", dtype=[float, None], doc="Argument 2"),
Argument(name="test3", dtype=list, doc="Argument 3"),
Argument(name="test3", dtype=List[str], doc="Argument 3"),
]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ classifiers = [
"License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
]
dependencies = [
"typeguard>=3",
]
requires-python = ">=3.7"
readme = "README.md"
Expand Down
6 changes: 6 additions & 0 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from .context import dargs
import unittest
from dargs import Argument, Variant
Expand Down Expand Up @@ -27,6 +28,11 @@ def test_name_type(self):
# special handel of int and float
ca = Argument("key1", float)
ca.check({"key1": 1})
# list[int]
ca = Argument("key1", List[float])
ca.check({"key1": [1, 2.0, 3]})
with self.assertRaises(ArgumentTypeError):
ca.check({"key1": [1, 2.0, "3"]})
# optional case
ca = Argument("key1", int, optional=True)
ca.check({})
Expand Down
6 changes: 6 additions & 0 deletions tests/test_docgen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .context import dargs
import unittest
import json
from typing import List
from dargs import Argument, Variant, ArgumentEncoder


Expand All @@ -22,6 +23,11 @@ def test_sub_fields(self):
[Argument("subsubsub1", int, doc="subsubsub doc." * 5)],
doc="subsub doc." * 5,
),
Argument(
"list_of_float",
List[float],
doc="Check if List[float] works.",
),
],
doc="sub doc." * 5,
),
Expand Down

0 comments on commit d97bd20

Please sign in to comment.