Skip to content

Commit

Permalink
feat(lib): add opt. y axis size in Interactable.grid (#82)
Browse files Browse the repository at this point in the history
* feat(lib): add opt. second axis size in `Interactable.grid`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(lib): add `Interactable.sample` method (#80)

* feat(lib): add `Interactable.sample` method

* fix(docs): typo

* feat(lib): add opt. second axis size in `Interactable.grid`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chore(docs): add changelog entry

* fix(lib): type hint

* fix: up

* fix: tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jeertmans and pre-commit-ci[bot] authored Sep 6, 2024
1 parent 10e0d65 commit 6b836c3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

+ Added `Interactable.sample` method to randomly sample a point on an object.
[#80](https://github.com/jeertmans/DiffeRT2d/pull/80)
+ Added optional y-axis size in `Interactable.grid`.
[#82](https://github.com/jeertmans/DiffeRT2d/pull/82)

(unreleased-fixed)=
### Fixed
Expand Down
16 changes: 12 additions & 4 deletions differt2d/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,24 @@ def bounding_box(self) -> Float[Array, "2 2"]:
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def grid(
self, n: int = 50
) -> tuple[Float[Array, "{n} {n}"], Float[Array, "{n} {n}"]]:
self,
m: int = 50,
n: Optional[int] = None,
) -> tuple[Float[Array, "n_or_m {m}"], Float[Array, "n_or_m {m}"]]:
"""
Returns a (mesh) grid that overlays the current object.
:param n: The number of sample along one axis.
:param m: The number of sample along x dimension.
:param n: The number of sample along y dimension,
defaults to ``m`` is left unspecified.
:return: A tuple of (X, Y) coordinates.
"""
bounding_box = self.bounding_box()
x = jnp.linspace(bounding_box[0, 0], bounding_box[1, 0], n)

if n is None:
n = m

x = jnp.linspace(bounding_box[0, 0], bounding_box[1, 0], m)
y = jnp.linspace(bounding_box[0, 1], bounding_box[1, 1], n)

X, Y = jnp.meshgrid(x, y)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ def test_grid(self):
assert float(Y.min()) == 0.0
assert float(Y.max()) == 2.0

X, Y = wall.grid(25, 50)

assert X.shape == (50, 25)
assert Y.shape == (50, 25)
assert float(X.min()) == 0.0
assert float(X.max()) == 1.0
assert float(Y.min()) == 0.0
assert float(Y.max()) == 2.0

def test_center(self):
wall = Wall(xys=jnp.array([[0.0, 1.0], [1.0, 2.0]]))

Expand Down

0 comments on commit 6b836c3

Please sign in to comment.