From 0b50cbf9b3b3bc23e5d73d6b58b3aab2fa230f78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 2 Oct 2023 09:43:14 +0200 Subject: [PATCH] tests: cover when approx is None --- differt2d/logic.py | 26 +++++++++++++------------- differt2d/scene.py | 22 +++++++++++++++++++--- docs/source/index.md | 2 +- docs/source/quickstart.md | 8 ++++---- tests/test_logic.py | 25 +++++++++++++------------ 5 files changed, 50 insertions(+), 33 deletions(-) diff --git a/differt2d/logic.py b/differt2d/logic.py index 8f629f1..d1d1df1 100644 --- a/differt2d/logic.py +++ b/differt2d/logic.py @@ -259,7 +259,7 @@ def logical_or(x: Array, y: Array, approx: Optional[bool] = None) -> Array: :param approx: Whether approximation is enabled or not. :return: Output array, with element-wise comparison. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return jnp.maximum(x, y) if approx else jnp.logical_or(x, y) @@ -277,7 +277,7 @@ def logical_and(x: Array, y: Array, approx: Optional[bool] = None) -> Array: :param approx: Whether approximation is enabled or not. :return: Output array, with element-wise comparison. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return jnp.minimum(x, y) if approx else jnp.logical_and(x, y) @@ -295,7 +295,7 @@ def logical_not(x: Array, approx: Optional[bool] = None) -> Array: :param approx: Whether approximation is enabled or not. :return: Output array, with element-wise comparison. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return jnp.subtract(1.0, x) if approx else jnp.logical_not(x) @@ -321,7 +321,7 @@ def greater( Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y) @@ -344,7 +344,7 @@ def greater_equal( Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return ( activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y) @@ -367,7 +367,7 @@ def less(x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any) -> Ar Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y) @@ -393,7 +393,7 @@ def less_equal( Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y) @@ -416,7 +416,7 @@ def logical_all( :param approx: Whether approximation is enabled or not. :return: Output array. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] arr = jnp.asarray(x) return jnp.min(arr, axis=axis) if approx else jnp.all(arr, axis=axis) @@ -440,7 +440,7 @@ def logical_any( :param approx: Whether approximation is enabled or not. :return: Output array. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] arr = jnp.asarray(x) return jnp.max(arr, axis=axis) if approx else jnp.any(arr, axis=axis) @@ -460,7 +460,7 @@ def is_true(x: Array, tol: float = 0.5, approx: Optional[bool] = None) -> Array: :param approx: Whether approximation is enabled or not. :return: True array if the value is considered to be true. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return jnp.greater(x, 1.0 - tol) if approx else jnp.asarray(x) @@ -479,7 +479,7 @@ def is_false(x: Array, tol: float = 0.5, approx: Optional[bool] = None) -> Array :param approx: Whether approximation is enabled or not. :return: True if the value is considered to be false. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return jnp.less(x, tol) if approx else jnp.logical_not(x) @@ -494,7 +494,7 @@ def true_value(approx: Optional[bool] = None) -> Array: :param approx: Whether approximation is enabled or not. :return: A value that evaluates to true. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return jnp.array(1.0) if approx else jnp.array(True) @@ -509,6 +509,6 @@ def false_value(approx: Optional[bool] = None) -> Array: :param approx: Whether approximation is enabled or not. :return: A value that evaluates to false. """ - if approx is None: # pragma: no cover + if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] return jnp.array(0.0) if approx else jnp.array(False) diff --git a/differt2d/scene.py b/differt2d/scene.py index 5ec4ea6..d3ae792 100644 --- a/differt2d/scene.py +++ b/differt2d/scene.py @@ -375,7 +375,12 @@ def random_uniform_scene( return cls(emitters=emitters, receivers=receivers, objects=walls) @classmethod - def basic_scene(cls, *, tx_coords: Array = jnp.array([0.1, 0.1]), rx_coords: Array = jnp.array([0.302, 0.2147])) -> "Scene": + def basic_scene( + cls, + *, + tx_coords: Array = jnp.array([0.1, 0.1]), + rx_coords: Array = jnp.array([0.302, 0.2147]), + ) -> "Scene": """ Instantiates a basic scene with a main room, and a second inner room in the lower left corner, with a small entrance. @@ -425,7 +430,12 @@ def basic_scene(cls, *, tx_coords: Array = jnp.array([0.1, 0.1]), rx_coords: Arr return cls(emitters=dict(tx=tx), receivers=dict(rx=rx), objects=walls) @classmethod - def square_scene(cls, *, tx_coords: Array = jnp.array([0.2, 0.2]), rx_coords: Array = jnp.array([0.5, 0.6])) -> "Scene": + def square_scene( + cls, + *, + tx_coords: Array = jnp.array([0.2, 0.2]), + rx_coords: Array = jnp.array([0.5, 0.6]), + ) -> "Scene": """ Instantiates a square scene with one main room. @@ -469,7 +479,13 @@ def square_scene(cls, *, tx_coords: Array = jnp.array([0.2, 0.2]), rx_coords: Ar return Scene(emitters=dict(tx=tx), receivers=dict(rx=rx), objects=walls) @classmethod - def square_scene_with_wall(cls, ratio: float = 0.1, *, tx_coords: Array = jnp.array([0.2, 0.5]), rx_coords: Array = jnp.array([0.8, 0.5])) -> "Scene": + def square_scene_with_wall( + cls, + ratio: float = 0.1, + *, + tx_coords: Array = jnp.array([0.2, 0.5]), + rx_coords: Array = jnp.array([0.8, 0.5]), + ) -> "Scene": """ Instantiates a square scene with one main room, and vertical wall in the middle. diff --git a/docs/source/index.md b/docs/source/index.md index 636134f..f11b5fa 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -41,7 +41,7 @@ If you are intersted in contributing to this tool, please checkout the ax = plt.gca() scene = Scene.basic_scene() X, Y = scene.grid(n=300) - + Z = scene.accumulate_on_receivers_grid_over_paths( X, Y, diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index 9d6fe00..7ecc500 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -45,7 +45,7 @@ pretty useful for optimization problems. from differt2d.geometry import Wall - wall = Wall(points=jnp.array([[.8, .2], [.8, .8]])) + wall = Wall(points=jnp.array([[0.8, 0.2], [0.8, 0.8]])) scene.add_objects([wall]) Plotting utils @@ -96,7 +96,7 @@ The easiest way to trace all paths from every emitter to every receiver is to us wall = Wall(points=jnp.array([[.8, .2], [.8, .8]])) scene.add_objects([wall]) scene.plot(ax) - + for _, _, path, _ in scene.all_valid_paths(): path.plot(ax, zorder=-1) # -1 to draw below the scene objects @@ -125,7 +125,7 @@ like the received power, on a grid and plot it: wall = Wall(points=jnp.array([[.8, .2], [.8, .8]])) scene.add_objects([wall]) scene.plot(ax) - + X, Y = scene.grid(n=300) Z = scene.accumulate_on_receivers_grid_over_paths( X, @@ -157,7 +157,7 @@ if we were to simulate a higher order of interacion, e.g.: wall = Wall(points=jnp.array([[.8, .2], [.8, .8]])) scene.add_objects([wall]) scene.plot(ax) - + X, Y = scene.grid(n=300) Z = scene.accumulate_on_receivers_grid_over_paths( X, diff --git a/tests/test_logic.py b/tests/test_logic.py index 14e679b..b71e0b1 100644 --- a/tests/test_logic.py +++ b/tests/test_logic.py @@ -23,7 +23,7 @@ true_value, ) -approx = pytest.mark.parametrize(("approx",), [(True,), (False,)]) +approx = pytest.mark.parametrize(("approx",), [(True,), (False,), (None,)]) alpha = pytest.mark.parametrize( ("alpha",), [(1e-3,), (1e-2,), (1e-1,), (1e-0,), (1e1,)] ) @@ -233,7 +233,7 @@ def test_invalid_activation(function): @approx def test_logical_or(xy, approx): x, y = xy - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = jnp.maximum(x, y) else: expected = jnp.logical_or(x, y) @@ -246,7 +246,7 @@ def test_logical_or(xy, approx): @approx def test_logical_and(xy, approx): x, y = xy - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = jnp.minimum(x, y) else: expected = jnp.logical_and(x, y) @@ -258,7 +258,7 @@ def test_logical_and(xy, approx): @approx def test_logical_not(x, approx): - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = jnp.subtract(1.0, x) else: expected = jnp.logical_not(x) @@ -270,7 +270,7 @@ def test_logical_not(x, approx): @approx def test_logical_all(x, approx): - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): x = jnp.array([0.8, 0.2, 0.3]) expected = jnp.min(x) else: @@ -284,7 +284,7 @@ def test_logical_all(x, approx): @approx def test_logical_any(x, approx): - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): x = jnp.array([0.8, 0.2, 0.3]) expected = jnp.max(x) else: @@ -301,7 +301,7 @@ def test_logical_any(x, approx): @function def test_greater(xy, approx, alpha, function): x, y = xy - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = activation(x - y, alpha=alpha, function=function) else: expected = jnp.greater(x, y) @@ -316,7 +316,7 @@ def test_greater(xy, approx, alpha, function): @function def test_greater_equal(xy, approx, alpha, function): x, y = xy - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = activation(x - y, alpha=alpha, function=function) else: expected = jnp.greater_equal(x, y) @@ -331,7 +331,7 @@ def test_greater_equal(xy, approx, alpha, function): @function def test_less(xy, approx, alpha, function): x, y = xy - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = activation(y - x, alpha=alpha, function=function) else: expected = jnp.less(x, y) @@ -346,7 +346,7 @@ def test_less(xy, approx, alpha, function): @function def test_less_equal(xy, approx, alpha, function): x, y = xy - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = activation(y - x, alpha=alpha, function=function) else: expected = jnp.less_equal(x, y) @@ -359,7 +359,7 @@ def test_less_equal(xy, approx, alpha, function): @approx @tol def test_is_true(x, approx, tol): - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = jnp.greater(x, 1.0 - tol) else: x = jnp.greater(x, 1.0 - tol) @@ -373,7 +373,7 @@ def test_is_true(x, approx, tol): @approx @tol def test_is_false(x, approx, tol): - if approx: + if approx or (approx is None and jax.config.jax_enable_approx): expected = jnp.less(x, tol) else: x = jnp.greater(x, tol) @@ -394,5 +394,6 @@ def test_true_value(approx, tol): @approx @tol def test_false_value(approx, tol): + print("approx", approx, jax.config.jax_enable_approx) x = false_value(approx=approx) assert is_false(x, tol=tol, approx=approx)