Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: threshold checker interval string #652

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/augment/brightness.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Node(ThresholdCheckerMixin, AbstractNode):
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None:
super().__init__(config, node_path=__name__, **kwargs)

self.check_bounds("beta", (-100, 100), "within")
self.check_bounds("beta", "[-100, 100]")

def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Adjusts the brightness of an image frame.
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/augment/contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Node(ThresholdCheckerMixin, AbstractNode):
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None:
super().__init__(config, node_path=__name__, **kwargs)

self.check_bounds("alpha", (0, 3), "within")
self.check_bounds("alpha", "[0, 3]")

def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Adjusts the contrast of an image frame.
Expand Down
237 changes: 76 additions & 161 deletions peekingduck/pipeline/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,136 +17,93 @@
import hashlib
import operator
import os
import re
import sys
import zipfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
ongtw marked this conversation as resolved.
Show resolved Hide resolved

import requests
from tqdm import tqdm

BASE_URL = "https://storage.googleapis.com/peekingduck/models"
PEEKINGDUCK_WEIGHTS_SUBDIR = "peekingduck_weights"

Number = Union[float, int]


class ThresholdCheckerMixin:
"""Mixin class providing utility methods for checking validity of config
values, typically thresholds.
"""

def check_bounds(
self,
key: Union[str, List[str]],
value: Union[Number, Tuple[Number, Number]],
method: str,
include: Optional[str] = "both",
) -> None:
"""Checks if the configuration value(s) specified by `key` satisties
interval_pattern = re.compile(
r"^[\[\(]\s*[-+]?(inf|\d*\.?\d+)\s*,\s*[-+]?(inf|\d*\.?\d+)\s*[\]\)]$"
ongtw marked this conversation as resolved.
Show resolved Hide resolved
)

def check_bounds(self, key: Union[str, List[str]], interval: str) -> None:
"""Checks if the configuration value(s) specified by `key` satisfies
the specified bounds.

Args:
key (Union[str, List[str]]): The specified key or list of keys.
value (Union[Number, Tuple[Number, Number]]): Either a single
number to specify the upper or lower bound or a tuple of
numbers to specify both the upper and lower bounds.
method (str): The bounds checking methods, one of
{"above", "below", "both"}. If "above", checks if the
configuration value is above the specified `value`. If "below",
checks if the configuration value is below the specified
`value`. If "both", checks if the configuration value is above
`value[0]` and below `value[1]`.
include (Optional[str]): Indicates if the `value` itself should be
included in the bound, one of {"lower", "upper", "both", None}.
Please see Technotes for details.
interval (str): An mathematical interval representing the range of
valid values. The syntax of the `interval` string is:

<value> = <number> | "-inf" | "+inf"
<left_bracket> = "(" | "["
<right_bracket> = ")" | "]"
<interval> = <left_bracket> <value> "," <value> <right_bracket>

See Technotes for more details.

Raises:
TypeError: `key` type is not in (List[str], str).
TypeError: If `value` is not a tuple of only float/int.
TypeError: If `value` is not a tuple with 2 elements.
TypeError: If `value` is not a float, int, or tuple.
TypeError: If `value` type is not a tuple when `method` is
"within".
TypeError: If `value` type is a tuple when `method` is
"above"/"below".
ValueError: If `method` is not one of {"above", "below", "within"}.
ValueError: If `interval` does not match the specified format.
ValueError: If the lower bound is larger than the upper bound.
ValueError: If the configuration value fails the bounds comparison.

Technotes:
The behavior of `include` depends on the specified `method`. The
table below shows the comparison done for various argument
combinations.

+-----------+---------+-------------------------------------+
| method | include | comparison |
+===========+=========+=====================================+
| | "lower" | config[key] >= value |
+ +---------+-------------------------------------+
| | "upper" | config[key] > value |
+ +---------+-------------------------------------+
| | "both" | config[key] >= value |
+ +---------+-------------------------------------+
| "above" | None | config[key] > value |
+-----------+---------+-------------------------------------+
| | "lower" | config[key] < value |
+ +---------+-------------------------------------+
| | "upper" | config[key] <= value |
+ +---------+-------------------------------------+
| | "both" | config[key] <= value |
+ +---------+-------------------------------------+
| "below" | None | config[key] < value |
+-----------+---------+-------------------------------------+
| | "lower" | value[0] <= config[key] < value[1] |
+ +---------+-------------------------------------+
| | "upper" | value[0] < config[key] <= value[1] |
+ +---------+-------------------------------------+
| | "both" | value[0] <= config[key] <= value[1] |
+ +---------+-------------------------------------+
| "within" | None | value[0] < config[key] < value[1] |
+-----------+---------+-------------------------------------+
The table below shows the comparison done for various interval
expressions.

+---------------------+-------------------------------------+
| interval | comparison |
+=====================+=====================================+
| [lower, +inf] | |
+---------------------+ |
| [lower, +inf) | config[key] >= lower |
+---------------------+-------------------------------------+
| (lower, +inf] | |
+---------------------+ |
| (lower, +inf) | config[key] > lower |
+---------------------+-------------------------------------+
| [-inf, upper] | |
+---------------------+ |
| (-inf, upper] | config[key] <= upper |
+---------------------+-------------------------------------+
| [-inf, upper) | |
+---------------------+ |
| (-inf, upper) | config[key] < upper |
+---------------------+-------------------------------------+
| [lower, upper] | lower <= config[key] <= upper |
+---------------------+-------------------------------------+
| (lower, upper] | lower < config[key] <= upper |
+---------------------+-------------------------------------+
| [lower, upper) | lower <= config[key] < upper |
+---------------------+-------------------------------------+
| (lower, upper) | lower < config[key] < upper |
+---------------------+-------------------------------------+
"""
# available checking methods
methods = {"above", "below", "within"}
# available options of lower/upper bound inclusion
lower_includes = {"lower", "both"}
upper_includes = {"upper", "both"}

if method not in methods:
raise ValueError(f"`method` must be one of {methods}")

if isinstance(value, tuple):
if not all(isinstance(val, (float, int)) for val in value):
raise TypeError(
"When using tuple for `value`, it must be a tuple of float/int"
)
if len(value) != 2:
raise ValueError(
"When using tuple for `value`, it must contain only 2 elements"
)
elif isinstance(value, (float, int)):
pass
else:
raise TypeError(
"`value` must be a float/int or tuple, but you passed a "
f"{type(value).__name__}"
)

if method == "within":
if not isinstance(value, tuple):
raise TypeError("`value` must be a tuple when `method` is 'within'")
self._check_within_bounds(
key, value, (include in lower_includes, include in upper_includes)
)
else:
if isinstance(value, tuple):
raise TypeError(
"`value` must be a float/int when `method` is 'above'/'below'"
)
if method == "above":
self._check_above_value(key, value, include in lower_includes)
elif method == "below":
self._check_below_value(key, value, include in upper_includes)
if self.interval_pattern.match(interval) is None:
raise ValueError("Badly formatted interval")

lower_openness = interval[0]
upper_openness = interval[-1]
lower, upper = [float(value.strip()) for value in interval[1:-1].split(",")]

if lower > upper:
raise ValueError("Lower bound cannot be larger than upper bound")

self._check_within_bounds(key, lower, upper, lower_openness, upper_openness)

def check_valid_choice(
self, key: str, choices: Set[Union[int, float, str]]
Expand All @@ -167,78 +124,36 @@ def check_valid_choice(
if self.config[key] not in choices:
raise ValueError(f"{key} must be one of {choices}")

def _check_above_value(
self, key: Union[str, List[str]], value: Number, inclusive: bool
) -> None:
"""Checks that configuration values specified by `key` is more than
(or equal to) the specified `value`.

Args:
key (Union[str, List[str]]): The specified key or list of keys.
value (Number): The specified value.
inclusive (bool): If `True`, compares `config[key] >= value`. If
`False`, compares `config[key] > value`.

Raises:
TypeError: `key` type is not in (List[str], str).
ValueError: If the configuration value is less than (or equal to)
`value`.
"""
method = operator.ge if inclusive else operator.gt
extra_reason = " or equal to" if inclusive else ""
self._compare(key, value, method, reason=f"more than{extra_reason} {value}")

def _check_below_value(
self, key: Union[str, List[str]], value: Number, inclusive: bool
) -> None:
"""Checks that configuration values specified by `key` is more than
(or equal to) the specified `value`.

Args:
key (Union[str, List[str]]): The specified key or list of keys.
value (Number): The specified value.
inclusive (bool): If `True`, compares `config[key] <= value`. If
`False`, compares `config[key] < value`.

Raises:
TypeError: `key` type is not in (List[str], str).
ValueError: If the configuration value is less than (or equal to)
`value`.
"""
method = operator.le if inclusive else operator.lt
extra_reason = " or equal to" if inclusive else ""
self._compare(key, value, method, reason=f"less than{extra_reason} {value}")

def _check_within_bounds(
def _check_within_bounds( # pylint: disable=too-many-arguments
self,
key: Union[str, List[str]],
bounds: Tuple[Number, Number],
includes: Tuple[bool, bool],
lower: float,
upper: float,
lower_openness: str,
upper_openness: str,
ongtw marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Checks that configuration values specified by `key` is within the
specified bounds between `lower` and `upper`.

Args:
key (Union[str, List[str]]): The specified key or list of keys.
(Union[float, int]): The lower bound.
bounds (Tuple[Number, Number]): The lower and upper bounds.
includes (Tuple[bool, bool]): If `True`, compares `config[key] >= value`.
If `False`, compares `config[key] > value`.
inclusive_upper (bool): If `True`, compares `config[key] <= value`.
If `False`, compares `config[key] < value`.
lower (float): The lower bound.
upper (float): The upper bound.
lower_openness (str): Either a "(" for an open lower bound or a "["
for a closed lower bound.
upper_openness (str): Either a ")" for an open upper bound or a "]"
for a closed upper bound.

Raises:
TypeError: `key` type is not in (List[str], str).
ValueError: If the configuration value is not between `lower` and
`upper`.
"""
method_lower = operator.ge if includes[0] else operator.gt
method_upper = operator.le if includes[1] else operator.lt
reason_lower = "[" if includes[0] else "("
reason_upper = "]" if includes[1] else ")"
reason = f"between {reason_lower}{bounds[0]}, {bounds[1]}{reason_upper}"
self._compare(key, bounds[0], method_lower, reason)
self._compare(key, bounds[1], method_upper, reason)
method_lower = operator.ge if lower_openness == "[" else operator.gt
method_upper = operator.le if upper_openness == "]" else operator.lt
reason = f"between {lower_openness}{lower}, {upper}{upper_openness}"
self._compare(key, lower, method_lower, reason)
self._compare(key, upper, method_upper, reason)

def _compare(
self,
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/model/csrnetv1/csrnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds("width", 0, "above", include=None)
self.check_bounds("width", "(0, +inf]")

model_dir = self.download_weights()
self.predictor = Predictor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.logger = logging.getLogger(__name__)

self.check_valid_choice("model_type", {0, 1, 2, 3, 4})
self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
classes_path = model_dir / self.weights["classes_file"]
Expand Down
6 changes: 2 additions & 4 deletions peekingduck/pipeline/nodes/model/fairmotv1/fairmot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ def __init__(self, config: Dict[str, Any], frame_rate: float) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds(
["K", "min_box_area", "track_buffer"], 0, "above", include=None
)
self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds(["K", "min_box_area", "track_buffer"], "(0, +inf]")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
self.tracker = Tracker(
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/model/hrnetv1/hrnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
self.detector = Detector(
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/model/jdev1/jde_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, config: Dict[str, Any], frame_rate: float) -> None:
self.logger = logging.getLogger(__name__)

self.check_bounds(
["iou_threshold", "nms_threshold", "score_threshold"], (0, 1), "within"
["iou_threshold", "nms_threshold", "score_threshold"], "[0, 1]"
)

model_dir = self.download_weights()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
{"singlepose_lightning", "singlepose_thunder", "multipose_lightning"},
)
self.check_bounds(
["bbox_score_threshold", "keypoint_score_threshold"], (0, 1), "within"
["bbox_score_threshold", "keypoint_score_threshold"], "[0, 1]"
)

model_dir = self.download_weights()
Expand Down
4 changes: 2 additions & 2 deletions peekingduck/pipeline/nodes/model/mtcnnv1/mtcnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds("min_size", 0, "above", include=None)
self.check_bounds("min_size", "(0, +inf]")
self.check_bounds(
["network_thresholds", "scale_factor", "score_threshold"], (0, 1), "within"
["network_thresholds", "scale_factor", "score_threshold"], "[0, 1]"
)

model_dir = self.download_weights()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.logger = logging.getLogger(__name__)

self.check_valid_choice("model_type", {50, 75, 100, "resnet"})
self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
self.predictor = Predictor(
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/model/yolov4/yolo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds(["iou_threshold", "score_threshold"], (0, 1), "within")
self.check_bounds(["iou_threshold", "score_threshold"], "[0, 1]")

model_dir = self.download_weights()
with open(model_dir / self.weights["classes_file"]) as infile:
Expand Down
Loading