Skip to content

Commit

Permalink
Simplify interp validator
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Dec 3, 2024
1 parent fdfc30d commit 6fe6bd8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 35 deletions.
36 changes: 36 additions & 0 deletions tests/test_components/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,39 @@ def test_warn_diffraction_monitor_intersection(log_capture):
)
sim.updated_copy(structures=[box])
assert_log_level(log_capture, "WARNING")


@pytest.mark.parametrize(
"custom_class, data_key",
[
(CustomMedium, "permittivity"),
(td.CustomFieldSource, "field_dataset"),
(td.CustomCurrentSource, "current_dataset"),
],
)
def test_custom_medium_duplicate_coords(custom_class, data_key):
"""Test that creating components with duplicate coordinates raises validation error."""
coords = {
"x": np.array([0.0, 1.0, 1.0, 2.0]), # Duplicate at x=1.0
"y": np.array([0.0, 1.0]),
"z": np.array([0.0, 1.0]),
}

if custom_class != CustomMedium:
coords["f"] = np.array([2e14])

shape = tuple(len(c) for c in coords.values())
data = np.random.random(shape) + 1
spatial_data = td.SpatialDataArray(data, coords=coords)

if custom_class == CustomMedium:
with pytest.raises(pydantic.ValidationError, match="duplicate coordinates"):
_ = custom_class(permittivity=spatial_data)
else:
field_components = {
f"{field}{component}": spatial_data.copy() for field in "EH" for component in "xyz"
}
field_dataset = td.FieldDataset(**field_components)

with pytest.raises(pydantic.ValidationError, match="duplicate coordinates"):
_ = custom_class(size=SIZE, source_time=ST, **{data_key: field_dataset})
43 changes: 8 additions & 35 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import dask
import h5py
import numpy as np
import pandas
import xarray as xr
from autograd.tracer import isbox
from xarray.core import alignment, missing
Expand Down Expand Up @@ -119,7 +118,7 @@ def assign_data_attrs(cls, val):
return val

def _interp_validator(self, field_name: str = None) -> None:
"""Make sure we can interp()/sel() the data.
"""Ensure the data can be interpolated or selected by checking for duplicate coordinates.
NOTE
----
Expand All @@ -129,39 +128,13 @@ def _interp_validator(self, field_name: str = None) -> None:
if field_name is None:
field_name = "DataArray"

dims = self.coords.dims

for dim in dims:
# in case we encounter some /0 or /NaN we'll ignore the warnings here
with np.errstate(divide="ignore", invalid="ignore"):
# check that we can interpolate
try:
x0 = np.array(self.coords[dim][0])
self.interp({dim: x0}, method="linear")
self.interp({dim: x0}, method="nearest")
# self.interp_like(self.isel({self.dim: 0}))
except pandas.errors.InvalidIndexError as e:
raise DataError(
f"'{field_name}.interp()' fails to interpolate along {dim} which is used by the solver. "
"This may be caused, for instance, by duplicated data "
f"in this dimension (you can verify this by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' "
f"and interpolate with the new '{field_name}'). "
"Plase make sure data can be interpolated."
) from e
# in case it can interpolate, try also to sel
try:
x0 = np.array(self.coords[dim][0])
self.sel({dim: x0}, method="nearest")
except pandas.errors.InvalidIndexError as e:
raise DataError(
f"'{field_name}.sel()' fails to select along {dim} which is used by the solver. "
"This may be caused, for instance, by duplicated data "
f"in this dimension (you can verify this by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' "
f"and run 'sel()' with the new '{field_name}'). "
"Plase make sure 'sel()' can be used on the 'DataArray'."
) from e
for dim, coord in self.coords.items():
if coord.to_index().duplicated().any():
raise DataError(
f"Field '{field_name}' contains duplicate coordinates in dimension '{dim}'. "
"Duplicates can be removed by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'."
)

@classmethod
def assign_coord_attrs(cls, val):
Expand Down

0 comments on commit 6fe6bd8

Please sign in to comment.