From 4e4fd3109b020aaa51ae75eb070a84dbb1dff2a2 Mon Sep 17 00:00:00 2001 From: Yannick Augenstein Date: Sun, 18 Aug 2024 15:33:20 +0200 Subject: [PATCH] Add support for dilation in JaxPolySlab --- CHANGELOG.md | 4 ++++ tests/test_plugins/test_adjoint.py | 21 ++++++++++------- tidy3d/components/geometry/polyslab.py | 2 +- tidy3d/plugins/adjoint/components/geometry.py | 23 ++++++++++++------- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40a12afb8d..b4720cb1bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,11 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added value_and_grad function to the autograd plugin, importable via `from tidy3d.plugins.autograd import value_and_grad`. Supports differentiating functions with auxiliary data (`value_and_grad(f, has_aux=True)`). +- Support for `dilation` argument in `JaxPolySlab`. ### Fixed - `DataArray` interpolation failure due to incorrect ordering of coordinates when interpolating with autograd tracers. - Error in `CustomSourceTime` when evaluating at a list of times entirely outside of the range of the envelope definition times. +### Changed +- `PolySlab` now raises error when differentiating and dilation causes damage to the polygon. + ## [2.7.2] - 2024-08-07 ### Added diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index 9c8807f045..f320164f32 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -2034,17 +2034,21 @@ def test_to_gds(tmp_path): [(0, 0), (1, 0), (1, 1), (0, 1), (0, 0.9), (0, 0.11)], # notched rectangle ], ) -@pytest.mark.parametrize("subdivide", [0, 1, 5]) +@pytest.mark.parametrize("subdivide", [0, 1, 3]) @pytest.mark.parametrize("sidewall_angle_deg", [0, 10]) +@pytest.mark.parametrize("dilation", [-0.02, 0.0, 0.02]) class TestJaxComplexPolySlab: slab_bounds = (-0.25, 0.25) EPS = 1e-12 RTOL = 1e-2 @staticmethod - def objfun(vertices, slab_bounds, sidewall_angle): + def objfun(vertices, slab_bounds, sidewall_angle, dilation): p = JaxComplexPolySlab( - vertices=vertices, slab_bounds=slab_bounds, sidewall_angle=sidewall_angle + vertices=vertices, + slab_bounds=slab_bounds, + sidewall_angle=sidewall_angle, + dilation=dilation, ) obj = 0.0 for s in p.sub_polyslabs: @@ -2072,11 +2076,12 @@ def vertices(self, base_vertices, subdivide): def sidewall_angle(self, sidewall_angle_deg): return np.deg2rad(sidewall_angle_deg) - def test_matches_complexpolyslab(self, vertices, sidewall_angle): + def test_matches_complexpolyslab(self, vertices, sidewall_angle, dilation): kwargs = dict( vertices=vertices, sidewall_angle=sidewall_angle, slab_bounds=self.slab_bounds, + dilation=dilation, axis=POLYSLAB_AXIS, ) cp = ComplexPolySlab(**kwargs) @@ -2087,9 +2092,9 @@ def test_matches_complexpolyslab(self, vertices, sidewall_angle): for cps, jcps in zip(cp.sub_polyslabs, jcp.sub_polyslabs): assert_allclose(cps.vertices, jcps.vertices) - def test_vertices_grads(self, vertices, sidewall_angle): + def test_vertices_grads(self, vertices, sidewall_angle, dilation): check_grads( - lambda x: self.objfun(x, self.slab_bounds, sidewall_angle), + lambda x: self.objfun(x, self.slab_bounds, sidewall_angle, dilation), (vertices,), order=1, rtol=self.RTOL, @@ -2097,13 +2102,13 @@ def test_vertices_grads(self, vertices, sidewall_angle): ) @pytest.mark.skip(reason="No VJP implemented yet") - def test_slab_bounds_grads(self, vertices, sidewall_angle): + def test_slab_bounds_grads(self, vertices, sidewall_angle, dilation): check_grads( lambda x: self.objfun(vertices, x, sidewall_angle), (self.slab_bounds,), order=1 ) @pytest.mark.skip(reason="No VJP implemented yet") - def test_sidewall_angle_grads(self, vertices, sidewall_angle): + def test_sidewall_angle_grads(self, vertices, sidewall_angle, dilation): check_grads( lambda x: self.objfun(vertices, self.slab_bounds, x), (sidewall_angle,), order=1 ) diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index 10cf2aa0c1..ef5e4e4bef 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -1319,7 +1319,7 @@ def _heal_polygon(vertices: np.ndarray) -> np.ndarray: return vertices elif isbox(vertices): raise NotImplementedError( - "It looks like the dilation causes damage to the Polygon. " + "It looks like the dilation causes damage to the polygon. " "Automatically healing this is currently not supported when " "differentiating w.r.t. the vertices. Try increasing the spacing " "between vertices or reduce the amount of dilation." diff --git a/tidy3d/plugins/adjoint/components/geometry.py b/tidy3d/plugins/adjoint/components/geometry.py index d4aa0351f1..1878063da5 100644 --- a/tidy3d/plugins/adjoint/components/geometry.py +++ b/tidy3d/plugins/adjoint/components/geometry.py @@ -320,13 +320,6 @@ def no_sidewall(cls, val): ) return val - @pd.validator("dilation", always=True) - def no_dilation(cls, val): - """Don't allow dilation.""" - if not np.isclose(val, 0.0): - raise AdjointError("'JaxPolySlab' does not support dilation.") - return val - def _validate_web_adjoint(self) -> None: """Run validators for this component, only if using ``tda.web.run()``.""" self._limit_number_of_vertices() @@ -526,6 +519,19 @@ def _proper_vertices(vertices: ArrayFloat2D) -> jnp.ndarray: vertices_np = JaxPolySlab.vertices_to_array(vertices) return JaxPolySlab._orient(JaxPolySlab._remove_duplicate_vertices(vertices_np)) + @staticmethod + def _heal_polygon(vertices: jnp.ndarray) -> jnp.ndarray: + """heal a self-intersecting polygon.""" + shapely_poly = PolySlab.make_shapely_polygon(jax.lax.stop_gradient(vertices)) + if shapely_poly.is_valid: + return vertices + + raise NotImplementedError( + "It looks like the dilation causes damage to the polygon. Automatically " + "healing this is currently not supported for JAX PolySlabs. Try increasing " + "the spacing between vertices or reduce the amount of dilation." + ) + @staticmethod def vertices_to_array(vertices_tuple: ArrayFloat2D) -> jnp.ndarray: """Converts a list of tuples (vertices) to a jax array.""" @@ -543,7 +549,8 @@ def reference_polygon(self) -> jnp.ndarray: vertices = JaxPolySlab._proper_vertices(self.vertices_jax) if jnp.isclose(self.dilation, 0): return vertices - raise NotImplementedError("JaxPolySlab does not support dilation!") + offset_vertices = self._shift_vertices(vertices, self.dilation)[0] + return self._heal_polygon(offset_vertices) def edge_contrib( self,