Skip to content

Commit

Permalink
Add support for dilation in JaxPolySlab
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Aug 21, 2024
1 parent 9e352f3 commit c576119
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 21 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added value_and_grad function to the autograd plugin, importable via `from tidy3d.plugins.autograd import value_and_grad`. Supports differentiating functions with auxiliary data (`value_and_grad(f, has_aux=True)`).
- Support for `dilation` argument in `JaxPolySlab`.

### Fixed
- `DataArray` interpolation failure due to incorrect ordering of coordinates when interpolating with autograd tracers.
- Error in `CustomSourceTime` when evaluating at a list of times entirely outside of the range of the envelope definition times.

### Changed
- `PolySlab` now raises error when differentiating and dilation causes damage to the polygon.

## [2.7.2] - 2024-08-07

### Added
Expand Down
50 changes: 38 additions & 12 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,17 +2034,21 @@ def test_to_gds(tmp_path):
[(0, 0), (1, 0), (1, 1), (0, 1), (0, 0.9), (0, 0.11)], # notched rectangle
],
)
@pytest.mark.parametrize("subdivide", [0, 1, 5])
@pytest.mark.parametrize("subdivide", [0, 1, 3])
@pytest.mark.parametrize("sidewall_angle_deg", [0, 10])
@pytest.mark.parametrize("dilation", [-0.02, 0.0, 0.02])
class TestJaxComplexPolySlab:
slab_bounds = (-0.25, 0.25)
EPS = 1e-12
RTOL = 1e-2

@staticmethod
def objfun(vertices, slab_bounds, sidewall_angle):
def objfun(vertices, slab_bounds, sidewall_angle, dilation):
p = JaxComplexPolySlab(
vertices=vertices, slab_bounds=slab_bounds, sidewall_angle=sidewall_angle
vertices=vertices,
slab_bounds=slab_bounds,
sidewall_angle=sidewall_angle,
dilation=dilation,
)
obj = 0.0
for s in p.sub_polyslabs:
Expand Down Expand Up @@ -2072,11 +2076,12 @@ def vertices(self, base_vertices, subdivide):
def sidewall_angle(self, sidewall_angle_deg):
return np.deg2rad(sidewall_angle_deg)

def test_matches_complexpolyslab(self, vertices, sidewall_angle):
def test_matches_complexpolyslab(self, vertices, sidewall_angle, dilation):
kwargs = dict(
vertices=vertices,
sidewall_angle=sidewall_angle,
slab_bounds=self.slab_bounds,
dilation=dilation,
axis=POLYSLAB_AXIS,
)
cp = ComplexPolySlab(**kwargs)
Expand All @@ -2087,23 +2092,44 @@ def test_matches_complexpolyslab(self, vertices, sidewall_angle):
for cps, jcps in zip(cp.sub_polyslabs, jcp.sub_polyslabs):
assert_allclose(cps.vertices, jcps.vertices)

def test_vertices_grads(self, vertices, sidewall_angle):
def test_vertices_grads(self, vertices, sidewall_angle, dilation):
check_grads(
lambda x: self.objfun(x, self.slab_bounds, sidewall_angle),
lambda x: self.objfun(x, self.slab_bounds, sidewall_angle, dilation),
(vertices,),
order=1,
rtol=self.RTOL,
eps=self.EPS,
)

@pytest.mark.skip(reason="No VJP implemented yet")
def test_slab_bounds_grads(self, vertices, sidewall_angle):
def test_dilation_grads(self, vertices, sidewall_angle, dilation):
if sidewall_angle != 0:
pytest.xfail("Dilation gradients only work if no sidewall angle.")
check_grads(
lambda x: self.objfun(vertices, self.slab_bounds, sidewall_angle, x),
(dilation,),
order=1,
rtol=self.RTOL,
eps=self.EPS,
)

def test_slab_bounds_grads(self, vertices, sidewall_angle, dilation):
if sidewall_angle != 0:
pytest.xfail("Slab bound gradients only work if no sidewall angle.")
check_grads(
lambda x: self.objfun(vertices, x, sidewall_angle), (self.slab_bounds,), order=1
lambda x: self.objfun(vertices, x, sidewall_angle, dilation),
(self.slab_bounds,),
order=1,
rtol=self.RTOL,
eps=self.EPS,
)

@pytest.mark.skip(reason="No VJP implemented yet")
def test_sidewall_angle_grads(self, vertices, sidewall_angle):
def test_sidewall_angle_grads(self, vertices, sidewall_angle, dilation):
if sidewall_angle != 0:
pytest.xfail("Sidewall gradients only work for small angles.")
check_grads(
lambda x: self.objfun(vertices, self.slab_bounds, x), (sidewall_angle,), order=1
lambda x: self.objfun(vertices, self.slab_bounds, x, dilation),
(sidewall_angle,),
order=1,
rtol=self.RTOL,
eps=self.EPS,
)
2 changes: 1 addition & 1 deletion tidy3d/components/geometry/polyslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ def _heal_polygon(vertices: np.ndarray) -> np.ndarray:
return vertices
elif isbox(vertices):
raise NotImplementedError(
"It looks like the dilation causes damage to the Polygon. "
"The dilation caused damage to the polygon. "
"Automatically healing this is currently not supported when "
"differentiating w.r.t. the vertices. Try increasing the spacing "
"between vertices or reduce the amount of dilation."
Expand Down
31 changes: 23 additions & 8 deletions tidy3d/plugins/adjoint/components/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ class JaxPolySlab(JaxGeometry, PolySlab, JaxObject):
stores_jax_for="sidewall_angle",
)

dilation_jax: JaxFloat = pd.Field(
default=0.0,
title="Dilation (Jax)",
description="Jax-traced float defining the dilation.",
units=MICROMETER,
stores_jax_for="dilation",
)

@pd.validator("sidewall_angle", always=True)
def no_sidewall(cls, val):
"""Warn if sidewall angle present."""
Expand All @@ -320,13 +328,6 @@ def no_sidewall(cls, val):
)
return val

@pd.validator("dilation", always=True)
def no_dilation(cls, val):
"""Don't allow dilation."""
if not np.isclose(val, 0.0):
raise AdjointError("'JaxPolySlab' does not support dilation.")
return val

def _validate_web_adjoint(self) -> None:
"""Run validators for this component, only if using ``tda.web.run()``."""
self._limit_number_of_vertices()
Expand Down Expand Up @@ -526,6 +527,19 @@ def _proper_vertices(vertices: ArrayFloat2D) -> jnp.ndarray:
vertices_np = JaxPolySlab.vertices_to_array(vertices)
return JaxPolySlab._orient(JaxPolySlab._remove_duplicate_vertices(vertices_np))

@staticmethod
def _heal_polygon(vertices: jnp.ndarray) -> jnp.ndarray:
"""heal a self-intersecting polygon."""
shapely_poly = PolySlab.make_shapely_polygon(jax.lax.stop_gradient(vertices))
if shapely_poly.is_valid:
return vertices

raise NotImplementedError(
"The dilation caused damage to the polygon. Automatically healing this is "
"currently not supported for 'JaxPolySlab' objects. Try increasing the spacing "
"between vertices or reduce the amount of dilation."
)

@staticmethod
def vertices_to_array(vertices_tuple: ArrayFloat2D) -> jnp.ndarray:
"""Converts a list of tuples (vertices) to a jax array."""
Expand All @@ -543,7 +557,8 @@ def reference_polygon(self) -> jnp.ndarray:
vertices = JaxPolySlab._proper_vertices(self.vertices_jax)
if jnp.isclose(self.dilation, 0):
return vertices
raise NotImplementedError("JaxPolySlab does not support dilation!")
offset_vertices = self._shift_vertices(vertices, self.dilation)[0]
return self._heal_polygon(offset_vertices)

def edge_contrib(
self,
Expand Down

0 comments on commit c576119

Please sign in to comment.