Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: cover when approx is None #31

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions differt2d/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def logical_or(x: Array, y: Array, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.maximum(x, y) if approx else jnp.logical_or(x, y)

Expand All @@ -277,7 +277,7 @@ def logical_and(x: Array, y: Array, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.minimum(x, y) if approx else jnp.logical_and(x, y)

Expand All @@ -295,7 +295,7 @@ def logical_not(x: Array, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.subtract(1.0, x) if approx else jnp.logical_not(x)

Expand All @@ -321,7 +321,7 @@ def greater(
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y)

Expand All @@ -344,7 +344,7 @@ def greater_equal(
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return (
activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y)
Expand All @@ -367,7 +367,7 @@ def less(x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any) -> Ar
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y)

Expand All @@ -393,7 +393,7 @@ def less_equal(
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y)

Expand All @@ -416,7 +416,7 @@ def logical_all(
:param approx: Whether approximation is enabled or not.
:return: Output array.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
arr = jnp.asarray(x)
return jnp.min(arr, axis=axis) if approx else jnp.all(arr, axis=axis)
Expand All @@ -440,7 +440,7 @@ def logical_any(
:param approx: Whether approximation is enabled or not.
:return: Output array.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
arr = jnp.asarray(x)
return jnp.max(arr, axis=axis) if approx else jnp.any(arr, axis=axis)
Expand All @@ -460,7 +460,7 @@ def is_true(x: Array, tol: float = 0.5, approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: True array if the value is considered to be true.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.greater(x, 1.0 - tol) if approx else jnp.asarray(x)

Expand All @@ -479,7 +479,7 @@ def is_false(x: Array, tol: float = 0.5, approx: Optional[bool] = None) -> Array
:param approx: Whether approximation is enabled or not.
:return: True if the value is considered to be false.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.less(x, tol) if approx else jnp.logical_not(x)

Expand All @@ -494,7 +494,7 @@ def true_value(approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: A value that evaluates to true.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.array(1.0) if approx else jnp.array(True)

Expand All @@ -509,6 +509,6 @@ def false_value(approx: Optional[bool] = None) -> Array:
:param approx: Whether approximation is enabled or not.
:return: A value that evaluates to false.
"""
if approx is None: # pragma: no cover
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.array(0.0) if approx else jnp.array(False)
22 changes: 19 additions & 3 deletions differt2d/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,12 @@ def random_uniform_scene(
return cls(emitters=emitters, receivers=receivers, objects=walls)

@classmethod
def basic_scene(cls, *, tx_coords: Array = jnp.array([0.1, 0.1]), rx_coords: Array = jnp.array([0.302, 0.2147])) -> "Scene":
def basic_scene(
cls,
*,
tx_coords: Array = jnp.array([0.1, 0.1]),
rx_coords: Array = jnp.array([0.302, 0.2147]),
) -> "Scene":
"""
Instantiates a basic scene with a main room, and a second inner room in the
lower left corner, with a small entrance.
Expand Down Expand Up @@ -425,7 +430,12 @@ def basic_scene(cls, *, tx_coords: Array = jnp.array([0.1, 0.1]), rx_coords: Arr
return cls(emitters=dict(tx=tx), receivers=dict(rx=rx), objects=walls)

@classmethod
def square_scene(cls, *, tx_coords: Array = jnp.array([0.2, 0.2]), rx_coords: Array = jnp.array([0.5, 0.6])) -> "Scene":
def square_scene(
cls,
*,
tx_coords: Array = jnp.array([0.2, 0.2]),
rx_coords: Array = jnp.array([0.5, 0.6]),
) -> "Scene":
"""
Instantiates a square scene with one main room.
Expand Down Expand Up @@ -469,7 +479,13 @@ def square_scene(cls, *, tx_coords: Array = jnp.array([0.2, 0.2]), rx_coords: Ar
return Scene(emitters=dict(tx=tx), receivers=dict(rx=rx), objects=walls)

@classmethod
def square_scene_with_wall(cls, ratio: float = 0.1, *, tx_coords: Array = jnp.array([0.2, 0.5]), rx_coords: Array = jnp.array([0.8, 0.5])) -> "Scene":
def square_scene_with_wall(
cls,
ratio: float = 0.1,
*,
tx_coords: Array = jnp.array([0.2, 0.5]),
rx_coords: Array = jnp.array([0.8, 0.5]),
) -> "Scene":
"""
Instantiates a square scene with one main room, and vertical wall in the middle.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ If you are intersted in contributing to this tool, please checkout the
ax = plt.gca()
scene = Scene.basic_scene()
X, Y = scene.grid(n=300)
Z = scene.accumulate_on_receivers_grid_over_paths(
X,
Y,
Expand Down
8 changes: 4 additions & 4 deletions docs/source/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pretty useful for optimization problems.
from differt2d.geometry import Wall
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
wall = Wall(points=jnp.array([[0.8, 0.2], [0.8, 0.8]]))
scene.add_objects([wall])
Plotting utils
Expand Down Expand Up @@ -96,7 +96,7 @@ The easiest way to trace all paths from every emitter to every receiver is to us
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
scene.add_objects([wall])
scene.plot(ax)
for _, _, path, _ in scene.all_valid_paths():
path.plot(ax, zorder=-1) # -1 to draw below the scene objects
Expand Down Expand Up @@ -125,7 +125,7 @@ like the received power, on a grid and plot it:
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
scene.add_objects([wall])
scene.plot(ax)
X, Y = scene.grid(n=300)
Z = scene.accumulate_on_receivers_grid_over_paths(
X,
Expand Down Expand Up @@ -157,7 +157,7 @@ if we were to simulate a higher order of interacion, e.g.:
wall = Wall(points=jnp.array([[.8, .2], [.8, .8]]))
scene.add_objects([wall])
scene.plot(ax)
X, Y = scene.grid(n=300)
Z = scene.accumulate_on_receivers_grid_over_paths(
X,
Expand Down
25 changes: 13 additions & 12 deletions tests/test_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
true_value,
)

approx = pytest.mark.parametrize(("approx",), [(True,), (False,)])
approx = pytest.mark.parametrize(("approx",), [(True,), (False,), (None,)])
alpha = pytest.mark.parametrize(
("alpha",), [(1e-3,), (1e-2,), (1e-1,), (1e-0,), (1e1,)]
)
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_invalid_activation(function):
@approx
def test_logical_or(xy, approx):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.maximum(x, y)
else:
expected = jnp.logical_or(x, y)
Expand All @@ -246,7 +246,7 @@ def test_logical_or(xy, approx):
@approx
def test_logical_and(xy, approx):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.minimum(x, y)
else:
expected = jnp.logical_and(x, y)
Expand All @@ -258,7 +258,7 @@ def test_logical_and(xy, approx):

@approx
def test_logical_not(x, approx):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.subtract(1.0, x)
else:
expected = jnp.logical_not(x)
Expand All @@ -270,7 +270,7 @@ def test_logical_not(x, approx):

@approx
def test_logical_all(x, approx):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
x = jnp.array([0.8, 0.2, 0.3])
expected = jnp.min(x)
else:
Expand All @@ -284,7 +284,7 @@ def test_logical_all(x, approx):

@approx
def test_logical_any(x, approx):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
x = jnp.array([0.8, 0.2, 0.3])
expected = jnp.max(x)
else:
Expand All @@ -301,7 +301,7 @@ def test_logical_any(x, approx):
@function
def test_greater(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(x - y, alpha=alpha, function=function)
else:
expected = jnp.greater(x, y)
Expand All @@ -316,7 +316,7 @@ def test_greater(xy, approx, alpha, function):
@function
def test_greater_equal(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(x - y, alpha=alpha, function=function)
else:
expected = jnp.greater_equal(x, y)
Expand All @@ -331,7 +331,7 @@ def test_greater_equal(xy, approx, alpha, function):
@function
def test_less(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(y - x, alpha=alpha, function=function)
else:
expected = jnp.less(x, y)
Expand All @@ -346,7 +346,7 @@ def test_less(xy, approx, alpha, function):
@function
def test_less_equal(xy, approx, alpha, function):
x, y = xy
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = activation(y - x, alpha=alpha, function=function)
else:
expected = jnp.less_equal(x, y)
Expand All @@ -359,7 +359,7 @@ def test_less_equal(xy, approx, alpha, function):
@approx
@tol
def test_is_true(x, approx, tol):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.greater(x, 1.0 - tol)
else:
x = jnp.greater(x, 1.0 - tol)
Expand All @@ -373,7 +373,7 @@ def test_is_true(x, approx, tol):
@approx
@tol
def test_is_false(x, approx, tol):
if approx:
if approx or (approx is None and jax.config.jax_enable_approx):
expected = jnp.less(x, tol)
else:
x = jnp.greater(x, tol)
Expand All @@ -394,5 +394,6 @@ def test_true_value(approx, tol):
@approx
@tol
def test_false_value(approx, tol):
print("approx", approx, jax.config.jax_enable_approx)
x = false_value(approx=approx)
assert is_false(x, tol=tol, approx=approx)