Skip to content

Commit

Permalink
remove restriction on 3D pixelated grids
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Jun 23, 2023
1 parent 5dee7f2 commit 990441f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Adjoint simulations no longer contain unused gradient permittivity monitors, reducing processing time.
- `Batch` prints total estimated cost if `verbose=True`.
- Unified config and authentication.
- Remove restriction that `JaxCustomMedium` must not be a 3D pixelated array.

### Fixed
- Plotting 2D materials in `SimulationData.plot_field` and other circumstances.
Expand Down
3 changes: 2 additions & 1 deletion tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,8 +1187,9 @@ def test_validate_vertices():
poly = JaxPolySlab(vertices=vertices, slab_bounds=(-1, 1))


def test_custom_medium_3D(use_emulated_run):
def _test_custom_medium_3D(use_emulated_run):
"""Ensure custom medium fails if 3D pixelated grid."""
# NOTE: turned off since we relaxed this restriction

jax_box = JaxBox(size=(1, 1, 1), center=(0, 0, 0))

Expand Down
33 changes: 16 additions & 17 deletions tidy3d/plugins/adjoint/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,23 +308,22 @@ def _deprecation_dataset(cls, values):
"""Raise deprecation warning if dataset supplied and convert to dataset."""
return values

@pd.validator("eps_dataset", always=True)
def _is_not_3d(cls, val):
"""Ensure the custom medium pixels contain at least one dimension with only pixel thick."""

for field_dim in "xyz":
field_name = f"eps_{field_dim}{field_dim}"
data_array = val.field_components[field_name]
coord_lens = [len(data_array.coords[key]) for key in "xyz"]
dims_len1 = [val == 1 for val in coord_lens]
if sum(dims_len1) == 0:
raise SetupError(
"For adjoint plugin, the 'JaxCustomMedium' is restricted to a 1D or 2D "
"pixellated grid. It may not contain multiple pixels along all 3 dimensions. "
f"Detected 3D pixelated grid in '{field_name}' component of 'eps_dataset'."
)

return val
# @pd.validator("eps_dataset", always=True)
# def _is_not_3d(cls, val):
# """Ensure the custom medium pixels contain at least one dimension with one pixel thick."""
# for field_dim in "xyz":
# field_name = f"eps_{field_dim}{field_dim}"
# data_array = val.field_components[field_name]
# coord_lens = [len(data_array.coords[key]) for key in "xyz"]
# dims_len1 = [val == 1 for val in coord_lens]
# if sum(dims_len1) == 0:
# raise SetupError(
# "For adjoint plugin, the 'JaxCustomMedium' is restricted to a 1D or 2D "
# "pixellated grid. It may not contain multiple pixels along all 3 dimensions. "
# f"Detected 3D pixelated grid in '{field_name}' component of 'eps_dataset'."
# )

# return val

@pd.validator("eps_dataset", always=True)
def _is_not_too_large(cls, val):
Expand Down

0 comments on commit 990441f

Please sign in to comment.