Skip to content

Commit

Permalink
Merge pull request #31 from jeertmans/cover-none
Browse files Browse the repository at this point in the history
tests: cover when approx is None
  • Loading branch information
jeertmans authored Oct 2, 2023
2 parents f13baf7 + 0b50cbf commit 4bdbf15
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 33 deletions.
26 changes: 13 additions & 13 deletions differt2d/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def logical_or(x: Array, y: Array, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.maximum(x, y) if approx else jnp.logical_or(x, y)

Expand All @@ -277,7 +277,7 @@ def logical_and(x: Array, y: Array, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.minimum(x, y) if approx else jnp.logical_and(x, y)

Expand All @@ -295,7 +295,7 @@ def logical_not(x: Array, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.subtract(1.0, x) if approx else jnp.logical_not(x)

Expand All @@ -321,7 +321,7 @@ def greater(
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y)

Expand All @@ -344,7 +344,7 @@ def greater_equal(
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return (
activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y)
Expand All @@ -367,7 +367,7 @@ def less(x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any) -> Ar
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y)

Expand All @@ -393,7 +393,7 @@ def less_equal(
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y)

Expand All @@ -416,7 +416,7 @@ def logical_all(
:param approx: Whether approximation is enabled or not.
:return: Output array.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
arr = jnp.asarray(x)
return jnp.min(arr, axis=axis) if approx else jnp.all(arr, axis=axis)
Expand All @@ -440,7 +440,7 @@ def logical_any(
:param approx: Whether approximation is enabled or not.
:return: Output array.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
arr = jnp.asarray(x)
return jnp.max(arr, axis=axis) if approx else jnp.any(arr, axis=axis)
Expand All @@ -460,7 +460,7 @@ def is_true(x: Array, tol: float = 0.5, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: True array if the value is considered to be true.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.greater(x, 1.0 - tol) if approx else jnp.asarray(x)

Expand All @@ -479,7 +479,7 @@ def is_false(x: Array, tol: float = 0.5, approx: Optional[bool] = None) -> Array
:param approx: Whether approximation is enabled or not.
:return: True if the value is considered to be false.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.less(x, tol) if approx else jnp.logical_not(x)

Expand All @@ -494,7 +494,7 @@ def true_value(approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: A value that evaluates to true.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.array(1.0) if approx else jnp.array(True)

Expand All @@ -509,6 +509,6 @@ def false_value(approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: A value that evaluates to false.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.array(0.0) if approx else jnp.array(False)
22 changes: 19 additions & 3 deletions differt2d/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,12 @@ def random_uniform_scene(
return cls(emitters=emitters, receivers=receivers, objects=walls)

@classmethod
def basic_scene(cls, *, tx_coords: Array = jnp.array([0.1, 0.1]), rx_coords: Array = jnp.array([0.302, 0.2147])) -> "Scene":
def basic_scene(
cls,
*,
tx_coords: Array = jnp.array([0.1, 0.1]),
rx_coords: Array = jnp.array([0.302, 0.2147]),
) -> "Scene":
"""
Instantiates a basic scene with a main room, and a second inner room in the
lower left corner, with a small entrance.
Expand Down Expand Up @@ -425,7 +430,12 @@ def basic_scene(cls, *, tx_coords: Array = jnp.array([0.1, 0.1]), rx_coords: Arr
return cls(emitters=dict(tx=tx), receivers=dict(rx=rx), objects=walls)

@classmethod
def square_scene(cls, *, tx_coords: Array = jnp.array([0.2, 0.2]), rx_coords: Array = jnp.array([0.5, 0.6])) -> "Scene":
def square_scene(
cls,
*,
tx_coords: Array = jnp.array([0.2, 0.2]),
rx_coords: Array = jnp.array([0.5, 0.6]),
) -> "Scene":
"""
Instantiates a square scene with one main room.
Expand Down Expand Up @@ -469,7 +479,13 @@ def square_scene(cls, *, tx_coords: Array = jnp.array([0.2, 0.2]), rx_coords: Ar
return Scene(emitters=dict(tx=tx), receivers=dict(rx=rx), objects=walls)

@classmethod
def square_scene_with_wall(cls, ratio: float = 0.1, *, tx_coords: Array = jnp.array([0.2, 0.5]), rx_coords: Array = jnp.array([0.8, 0.5])) -> "Scene":
def square_scene_with_wall(
cls,
ratio: float = 0.1,
*,
tx_coords: Array = jnp.array([0.2, 0.5]),
rx_coords: Array = jnp.array([0.8, 0.5]),
) -> "Scene":
"""
Instantiates a square scene with one main room, and vertical wall in the middle.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ If you are intersted in contributing to this tool, please checkout the
ax = plt.gca()
scene = Scene.basic_scene()
X, Y = scene.grid(n=300)
Z = scene.accumulate_on_receivers_grid_over_paths(
X,
Y,
Expand Down
8 changes: 4 additions & 4 deletions docs/source/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pretty useful for optimization problems.
from differt2d.geometry import Wall
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
wall = Wall(points=jnp.array([[0.8, 0.2], [0.8, 0.8]]))
scene.add_objects([wall])
Plotting utils
Expand Down Expand Up @@ -96,7 +96,7 @@ The easiest way to trace all paths from every emitter to every receiver is to us
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
scene.add_objects([wall])
scene.plot(ax)
for _, _, path, _ in scene.all_valid_paths():
path.plot(ax, zorder=-1) # -1 to draw below the scene objects
Expand Down Expand Up @@ -125,7 +125,7 @@ like the received power, on a grid and plot it:
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
scene.add_objects([wall])
scene.plot(ax)
X, Y = scene.grid(n=300)
Z = scene.accumulate_on_receivers_grid_over_paths(
X,
Expand Down Expand Up @@ -157,7 +157,7 @@ if we were to simulate a higher order of interacion, e.g.:
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
scene.add_objects([wall])
scene.plot(ax)
X, Y = scene.grid(n=300)
Z = scene.accumulate_on_receivers_grid_over_paths(
X,
Expand Down
25 changes: 13 additions & 12 deletions tests/test_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
true_value,
)

approx = pytest.mark.parametrize(("approx",), [(True,), (False,)])
approx = pytest.mark.parametrize(("approx",), [(True,), (False,), (None,)])
alpha = pytest.mark.parametrize(
("alpha",), [(1e-3,), (1e-2,), (1e-1,), (1e-0,), (1e1,)]
)
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_invalid_activation(function):
@approx
def test_logical_or(xy, approx):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.maximum(x, y)
else:
expected = jnp.logical_or(x, y)
Expand All @@ -246,7 +246,7 @@ def test_logical_or(xy, approx):
@approx
def test_logical_and(xy, approx):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.minimum(x, y)
else:
expected = jnp.logical_and(x, y)
Expand All @@ -258,7 +258,7 @@ def test_logical_and(xy, approx):

@approx
def test_logical_not(x, approx):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.subtract(1.0, x)
else:
expected = jnp.logical_not(x)
Expand All @@ -270,7 +270,7 @@ def test_logical_not(x, approx):

@approx
def test_logical_all(x, approx):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
x = jnp.array([0.8, 0.2, 0.3])
expected = jnp.min(x)
else:
Expand All @@ -284,7 +284,7 @@ def test_logical_all(x, approx):

@approx
def test_logical_any(x, approx):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
x = jnp.array([0.8, 0.2, 0.3])
expected = jnp.max(x)
else:
Expand All @@ -301,7 +301,7 @@ def test_logical_any(x, approx):
@function
def test_greater(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(x - y, alpha=alpha, function=function)
else:
expected = jnp.greater(x, y)
Expand All @@ -316,7 +316,7 @@ def test_greater(xy, approx, alpha, function):
@function
def test_greater_equal(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(x - y, alpha=alpha, function=function)
else:
expected = jnp.greater_equal(x, y)
Expand All @@ -331,7 +331,7 @@ def test_greater_equal(xy, approx, alpha, function):
@function
def test_less(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(y - x, alpha=alpha, function=function)
else:
expected = jnp.less(x, y)
Expand All @@ -346,7 +346,7 @@ def test_less(xy, approx, alpha, function):
@function
def test_less_equal(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(y - x, alpha=alpha, function=function)
else:
expected = jnp.less_equal(x, y)
Expand All @@ -359,7 +359,7 @@ def test_less_equal(xy, approx, alpha, function):
@approx
@tol
def test_is_true(x, approx, tol):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.greater(x, 1.0 - tol)
else:
x = jnp.greater(x, 1.0 - tol)
Expand All @@ -373,7 +373,7 @@ def test_is_true(x, approx, tol):
@approx
@tol
def test_is_false(x, approx, tol):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.less(x, tol)
else:
x = jnp.greater(x, tol)
Expand All @@ -394,5 +394,6 @@ def test_true_value(approx, tol):
@approx
@tol
def test_false_value(approx, tol):
print("approx", approx, jax.config.jax_enable_approx)
x = false_value(approx=approx)
assert is_false(x, tol=tol, approx=approx)

0 comments on commit 4bdbf15

Please sign in to comment.