-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
388 additions
and
299 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .types import ( | ||
AutogradFieldMap, | ||
AutogradTraced, | ||
TracedCoordinate, | ||
TracedFloat, | ||
TracedSize, | ||
TracedSize1D, | ||
TracedVertices, | ||
) | ||
from .utils import get_static | ||
|
||
__all__ = [ | ||
"TracedFloat", | ||
"TracedSize1D", | ||
"TracedSize", | ||
"TracedCoordinate", | ||
"TracedVertices", | ||
"AutogradTraced", | ||
"AutogradFieldMap", | ||
"get_static", | ||
"integrate_within_bounds", | ||
"DerivativeInfo", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# utilities for autograd derivative passing | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pydantic.v1 as pd | ||
import xarray as xr | ||
|
||
from ..base import Tidy3dBaseModel | ||
from ..data.data_array import ScalarFieldDataArray | ||
from ..types import Bound, tidycomplex | ||
from .types import PathType | ||
from .utils import get_static | ||
|
||
# we do this because importing these creates circular imports | ||
FieldData = dict[str, ScalarFieldDataArray] | ||
PermittivityData = dict[str, ScalarFieldDataArray] | ||
|
||
|
||
class DerivativeInfo(Tidy3dBaseModel): | ||
"""Stores derivative information passed to the ``.compute_derivatives`` methods.""" | ||
|
||
paths: list[PathType] = pd.Field( | ||
..., | ||
title="Paths to Traced Fields", | ||
description="List of paths to the traced fields that need derivatives calculated.", | ||
) | ||
|
||
E_der_map: FieldData = pd.Field( | ||
..., | ||
title="Electric Field Gradient Map", | ||
description='Dataset where the field components ``("Ex", "Ey", "Ez")`` store the ' | ||
"multiplication of the forward and adjoint electric fields. The tangential components " | ||
"of this dataset is used when computing adjoint gradients for shifting boundaries. " | ||
"All components are used when computing volume-based gradients.", | ||
) | ||
|
||
D_der_map: FieldData = pd.Field( | ||
..., | ||
title="Displacement Field Gradient Map", | ||
description='Dataset where the field components ``("Ex", "Ey", "Ez")`` store the ' | ||
"multiplication of the forward and adjoint displacement fields. The normal component " | ||
"of this dataset is used when computing adjoint gradients for shifting boundaries.", | ||
) | ||
|
||
eps_data: PermittivityData = pd.Field( | ||
..., | ||
title="Permittivity Dataset", | ||
description="Dataset of relative permittivity values along all three dimensions. " | ||
"Used for automatically computing permittivity inside or outside of a simple geometry.", | ||
) | ||
|
||
eps_in: tidycomplex = pd.Field( | ||
title="Permittivity Inside", | ||
description="Permittivity inside of the ``Structure``. " | ||
"Typically computed from ``Structure.medium.eps_model``." | ||
"Used when it can not be computed from ``eps_data`` or when ``eps_approx==True``.", | ||
) | ||
|
||
eps_out: tidycomplex = pd.Field( | ||
..., | ||
title="Permittivity Outside", | ||
description="Permittivity outside of the ``Structure``. " | ||
"Typically computed from ``Simulation.medium.eps_model``." | ||
"Used when it can not be computed from ``eps_data`` or when ``eps_approx==True``.", | ||
) | ||
|
||
bounds: Bound = pd.Field( | ||
..., | ||
title="Geometry Bounds", | ||
description="Bounds corresponding to the structure, used in ``Medium`` calculations.", | ||
) | ||
|
||
eps_approx: bool = pd.Field( | ||
False, | ||
title="Use Permittivity Approximation", | ||
description="If ``True``, approximates outside permittivity using ``Simulation.medium``" | ||
"and the inside permittivity using ``Structure.medium``. " | ||
"Only set ``True`` for ``GeometryGroup`` handling where it is difficult to automatically " | ||
"evaluate the inside and outside relative permittivity for each geometry.", | ||
) | ||
|
||
def updated_paths(self, paths: list[PathType]) -> DerivativeInfo: | ||
"""Update this ``DerivativeInfo`` with new set of paths.""" | ||
return self.updated_copy(paths=paths) | ||
|
||
|
||
# TODO: could we move this into a DataArray method? | ||
def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray: | ||
"""integrate a data array within bounds, assumes bounds are [2, N] for N dims.""" | ||
|
||
_arr = arr.copy() | ||
|
||
# order bounds with dimension first (N, 2) | ||
bounds = np.array(bounds).T | ||
|
||
all_coords = {} | ||
|
||
# loop over all dimensions | ||
for dim, (bmin, bmax) in zip(dims, bounds): | ||
bmin = get_static(bmin) | ||
bmax = get_static(bmax) | ||
|
||
coord_values = _arr.coords[dim].values | ||
|
||
# reset all coordinates outside of bounds to the bounds, so that dL = 0 in integral | ||
coord_values[coord_values < bmin] = bmin | ||
coord_values[coord_values > bmax] = bmax | ||
|
||
all_coords[dim] = coord_values | ||
|
||
_arr = _arr.assign_coords(**all_coords) | ||
|
||
# uses trapezoidal rule | ||
# https://docs.xarray.dev/en/stable/generated/xarray.DataArray.integrate.html | ||
return _arr.integrate(coord=dims) | ||
|
||
|
||
__all__ = [ | ||
"integrate_within_bounds", | ||
"DerivativeInfo", | ||
] |
Oops, something went wrong.