diff --git a/CHANGELOG.md b/CHANGELOG.md index 08c544238..7ceee11da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Adjoint simulations no longer contain unused gradient permittivity monitors, reducing processing time. - `Batch` prints total estimated cost if `verbose=True`. - Unified config and authentication. +- Remove restriction that `JaxCustomMedium` must not be a 3D pixelated array. ### Fixed - Plotting 2D materials in `SimulationData.plot_field` and other circumstances. diff --git a/tests/test_plugins/test_adjoint.py b/tests/test_plugins/test_adjoint.py index 9535e0daf..4a0bf4390 100644 --- a/tests/test_plugins/test_adjoint.py +++ b/tests/test_plugins/test_adjoint.py @@ -1187,8 +1187,9 @@ def test_validate_vertices(): poly = JaxPolySlab(vertices=vertices, slab_bounds=(-1, 1)) -def test_custom_medium_3D(use_emulated_run): +def _test_custom_medium_3D(use_emulated_run): """Ensure custom medium fails if 3D pixelated grid.""" + # NOTE: turned off since we relaxed this restriction jax_box = JaxBox(size=(1, 1, 1), center=(0, 0, 0)) diff --git a/tidy3d/plugins/adjoint/components/medium.py b/tidy3d/plugins/adjoint/components/medium.py index 1ab99d676..3e7360de4 100644 --- a/tidy3d/plugins/adjoint/components/medium.py +++ b/tidy3d/plugins/adjoint/components/medium.py @@ -308,23 +308,22 @@ def _deprecation_dataset(cls, values): """Raise deprecation warning if dataset supplied and convert to dataset.""" return values - @pd.validator("eps_dataset", always=True) - def _is_not_3d(cls, val): - """Ensure the custom medium pixels contain at least one dimension with only pixel thick.""" - - for field_dim in "xyz": - field_name = f"eps_{field_dim}{field_dim}" - data_array = val.field_components[field_name] - coord_lens = [len(data_array.coords[key]) for key in "xyz"] - dims_len1 = [val == 1 for val in coord_lens] - if sum(dims_len1) == 0: - raise SetupError( - "For adjoint plugin, the 'JaxCustomMedium' is restricted to a 1D or 2D " - "pixellated grid. It may not contain multiple pixels along all 3 dimensions. " - f"Detected 3D pixelated grid in '{field_name}' component of 'eps_dataset'." - ) - - return val + # @pd.validator("eps_dataset", always=True) + # def _is_not_3d(cls, val): + # """Ensure the custom medium pixels contain at least one dimension with one pixel thick.""" + # for field_dim in "xyz": + # field_name = f"eps_{field_dim}{field_dim}" + # data_array = val.field_components[field_name] + # coord_lens = [len(data_array.coords[key]) for key in "xyz"] + # dims_len1 = [val == 1 for val in coord_lens] + # if sum(dims_len1) == 0: + # raise SetupError( + # "For adjoint plugin, the 'JaxCustomMedium' is restricted to a 1D or 2D " + # "pixellated grid. It may not contain multiple pixels along all 3 dimensions. " + # f"Detected 3D pixelated grid in '{field_name}' component of 'eps_dataset'." + # ) + + # return val @pd.validator("eps_dataset", always=True) def _is_not_too_large(cls, val):