Skip to content

Commit

Permalink
Rework autograd handling in DataArray
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Oct 24, 2024
1 parent 9a33406 commit 70bd113
Show file tree
Hide file tree
Showing 16 changed files with 646 additions and 196 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Minor gradient direction and normalization fixes for polyslab, field monitors, and diffraction monitors in autograd.
- Resolved an issue where temporary files for adjoint simulations were not being deleted properly.
- Autograd functions can now be called directly on `DataArray` (e.g., `np.sum(data_array)`) in objective functions.

### Changed
- Improved autograd tracer handling in `DataArray`, resulting in significant speedups for differentiation involving large monitors.

### Fixed
- Resolve several edge cases where autograd boxes were incorrectly converted to numpy arrays.


## [2.7.5] - 2024-10-16

Expand Down
57 changes: 54 additions & 3 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
import autograd.numpy as anp
import matplotlib.pylab as plt
import numpy as np
import numpy.testing as npt
import pytest
import tidy3d as td
import tidy3d.web as web
import xarray as xr
from autograd.test_util import check_grads
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
from tidy3d.components.autograd.utils import is_tidy_box
from tidy3d.components.data.data_array import DataArray
from tidy3d.web import run, run_async
from tidy3d.web.api.autograd.utils import FieldMap

from ..utils import SIM_FULL, AssertLogLevel, run_emulated
from ..utils import SIM_FULL, AssertLogLevel, run_emulated, tracer_arr

""" Test configuration """

Expand Down Expand Up @@ -990,7 +994,7 @@ def objective(args):
data = run(sim, task_name="autograd_test", verbose=False)

if objtype == "flux":
return anp.sum(data[monitor.name].flux.values)
return data[monitor.name].flux.item()
elif objtype == "intensity":
return anp.sum(data.get_intensity(monitor.name).values)

Expand Down Expand Up @@ -1494,6 +1498,53 @@ def objective(params):
)
data = run(sim, task_name="extra_field")
amp = data["mode"].amps.sel(direction="+", f=FREQ0 * 0.9, mode_index=0).values
return abs(anp.squeeze(amp.tolist())) ** 2
return abs(amp.item()) ** 2

g = ag.grad(objective)(params0)


class TestTidyArrayBox:
def test_is_tidy_box(self):
da = DataArray(tracer_arr, dims=map(str, range(tracer_arr.ndim)))
assert is_tidy_box(da.data)

def test_real(self):
npt.assert_allclose(tracer_arr.real._value, tracer_arr._value.real)

def test_imag(self):
npt.assert_allclose(tracer_arr.imag._value, tracer_arr._value.imag)

def test_conj(self):
npt.assert_allclose(tracer_arr.conj()._value, tracer_arr._value.conj())

def test_item(self):
assert tracer_arr.item() == tracer_arr._value.item()


class TestDataArrayGrads:
@pytest.mark.parametrize("attr", ["real", "imag", "conj"])
def test_custom_methods_grads(self, attr):
"""Test grads of TidyArrayBox methods implemented in autograd/boxes.py"""

def objective(x, attr):
da = DataArray(x, dims=map(str, range(x.ndim)))
attr_value = getattr(da, attr)
val = attr_value() if callable(attr_value) else attr_value
return val.item()

x = np.array([1.0])
check_grads(objective, modes=["fwd", "rev"], order=2)(x, attr)

def test_multiply_at_grads(self, rng):
"""Test grads of DataArray.multiply_at method"""

def objective(a, b):
coords = {str(i): np.arange(a.shape[i]) for i in range(a.ndim)}
da = DataArray(a, coords=coords, dims=map(str, range(a.ndim)))
da_mult = da.multiply_at(b, "0", [0, 1]) ** 2
return np.sum(da_mult).item()

a = rng.uniform(-1, 1, (3, 3))
b = 1.0
check_grads(lambda x: objective(x, b), modes=["fwd", "rev"], order=1)(a)
check_grads(lambda x: objective(a, x), modes=["fwd", "rev"], order=1)(b)
2 changes: 1 addition & 1 deletion tests/test_components/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def f(x):
structure = td.Structure.from_permittivity_array(
geometry=box, eps_data=eps_data, name="test"
)
return anp.sum(structure.medium.permittivity.attrs["AUTOGRAD"])
return anp.sum(structure.medium.permittivity).item()

grad = ag.grad(f)(1.0)
assert not np.isclose(grad, 0.0)
15 changes: 15 additions & 0 deletions tests/test_data/test_data_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pytest
import tidy3d as td
import xarray.testing as xrt
from tidy3d.exceptions import DataError

np.random.seed(4)
Expand Down Expand Up @@ -404,3 +405,17 @@ def test_uniform_check():
coords=dict(x=[0, 1], y=[1, 2], z=[2, 3]),
)
assert not arr.is_uniform


@pytest.mark.parametrize("method", ["nearest", "linear"])
@pytest.mark.parametrize("scalar_index", [True, False])
def test_interp(method, scalar_index):
data = make_scalar_field_data_array("Ex")

f = 1.5e14
if not scalar_index:
f = [f]

xr_interp = data.interp(f=f)
ag_interp = data._ag_interp(f=f)
xrt.assert_allclose(xr_interp, ag_interp)
2 changes: 1 addition & 1 deletion tests/test_plugins/autograd/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def test_interpn_val(self, rng, dim, method):
points, values, xi = self.generate_points_values_xi(rng, dim)
xi_grid = np.meshgrid(*xi, indexing="ij")

result_custom = interpn(points, values, xi, method=method)
result_custom = interpn(points, values, tuple(xi_grid), method=method)
result_scipy = scipy.interpolate.interpn(points, values, tuple(xi_grid), method=method)
npt.assert_allclose(result_custom, result_scipy)

Expand Down
11 changes: 10 additions & 1 deletion tidy3d/components/autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import autograd.numpy as anp
from autograd.extend import VJPNode, register_notrace

from .boxes import TidyArrayBox
from .functions import interpn
from .types import (
AutogradFieldMap,
Expand All @@ -8,9 +12,12 @@
TracedSize1D,
TracedVertices,
)
from .utils import get_static
from .utils import get_static, is_tidy_box, split_list

register_notrace(VJPNode, anp.full_like)

__all__ = [
"TidyArrayBox",
"TracedFloat",
"TracedSize1D",
"TracedSize",
Expand All @@ -20,4 +27,6 @@
"AutogradFieldMap",
"get_static",
"interpn",
"split_list",
"is_tidy_box",
]
147 changes: 147 additions & 0 deletions tidy3d/components/autograd/boxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Adds some functionality to the autograd arraybox
# NOTE: this is not a subclass of ArrayBox since that would break autograd's internal checks

import importlib
from typing import Any, Callable, Dict, List, Tuple

import autograd.numpy as anp
from autograd.numpy.numpy_boxes import ArrayBox

TidyArrayBox = ArrayBox # NOT a subclass

_autograd_module_cache = {} # cache for imported autograd modules


@classmethod
def from_arraybox(cls, box: ArrayBox) -> TidyArrayBox:
"""Construct a TidyArrayBox from an ArrayBox."""
return cls(box._value, box._trace, box._node)


def __array_function__(
self: Any,
func: Callable,
types: List[Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""
Handle the dispatch of NumPy functions to autograd's numpy implementation.
Parameters
----------
self : Any
The instance of the class.
func : Callable
The NumPy function being called.
types : List[Any]
The types of the arguments that implement __array_function__.
args : Tuple[Any, ...]
The positional arguments to the function.
kwargs : Dict[str, Any]
The keyword arguments to the function.
Returns
-------
Any
The result of the function call, or NotImplemented.
Raises
------
NotImplementedError
If the function is not implemented in autograd.numpy.
See Also
--------
https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_function__
"""
if not all(t in TidyArrayBox.type_mappings for t in types):
return NotImplemented

module_name = func.__module__

if module_name.startswith("numpy"):
anp_module_name = "autograd." + module_name
else:
return NotImplemented

# Use the cached module if available
anp_module = _autograd_module_cache.get(anp_module_name)
if anp_module is None:
try:
anp_module = importlib.import_module(anp_module_name)
_autograd_module_cache[anp_module_name] = anp_module
except ImportError:
return NotImplemented

f = getattr(anp_module, func.__name__, None)
if f is None:
return NotImplemented

if f.__name__ == "nanmean": # somehow xarray always dispatches to nanmean
f = anp.mean
kwargs.pop("dtype", None) # autograd mean vjp doesn't support dtype

return f(*args, **kwargs)


def __array_ufunc__(
self: Any,
ufunc: Callable,
method: str,
*inputs: Any,
**kwargs: Dict[str, Any],
) -> Any:
"""
Handle the dispatch of NumPy ufuncs to autograd's numpy implementation.
Parameters
----------
self : Any
The instance of the class.
ufunc : Callable
The universal function being called.
method : str
The method of the ufunc being called.
inputs : Any
The input arguments to the ufunc.
kwargs : Dict[str, Any]
The keyword arguments to the ufunc.
Returns
-------
Any
The result of the ufunc call, or NotImplemented.
See Also
--------
https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__
"""
if method != "__call__":
return NotImplemented

ufunc_name = ufunc.__name__

anp_ufunc = getattr(anp, ufunc_name, None)
if anp_ufunc is not None:
return anp_ufunc(*inputs, **kwargs)

return NotImplemented


def item(self):
if self.size != 1:
raise ValueError("Can only convert an array of size 1 to a scalar")
return anp.ravel(self)[0]


TidyArrayBox._tidy = True
TidyArrayBox.from_arraybox = from_arraybox
TidyArrayBox.__array_namespace__ = lambda self, *, api_version=None: anp
TidyArrayBox.__array_ufunc__ = __array_ufunc__
TidyArrayBox.__array_function__ = __array_function__
TidyArrayBox.__repr__ = str
TidyArrayBox.real = property(anp.real)
TidyArrayBox.imag = property(anp.imag)
TidyArrayBox.conj = anp.conj
TidyArrayBox.item = item
12 changes: 4 additions & 8 deletions tidy3d/components/autograd/derivative_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,27 +188,23 @@ def project_in_basis(
def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray:
"""integrate a data array within bounds, assumes bounds are [2, N] for N dims."""

_arr = arr.copy()

# order bounds with dimension first (N, 2)
bounds = np.array(bounds).T

bounds = np.asarray(bounds).T
all_coords = {}

# loop over all dimensions
for dim, (bmin, bmax) in zip(dims, bounds):
bmin = get_static(bmin)
bmax = get_static(bmax)

coord_values = _arr.coords[dim].values
coord_values = np.copy(arr.coords[dim].data)

# reset all coordinates outside of bounds to the bounds, so that dL = 0 in integral
coord_values[coord_values < bmin] = bmin
coord_values[coord_values > bmax] = bmax
np.clip(coord_values, bmin, bmax, out=coord_values)

all_coords[dim] = coord_values

_arr = _arr.assign_coords(**all_coords)
_arr = arr.assign_coords(**all_coords)

# uses trapezoidal rule
# https://docs.xarray.dev/en/stable/generated/xarray.DataArray.integrate.html
Expand Down
5 changes: 2 additions & 3 deletions tidy3d/components/autograd/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def interpn(
raise ValueError(f"Unsupported interpolation method: {method}")

itrp = RegularGridInterpolator(points, values, method=method)
grid = anp.meshgrid(*xi, indexing="ij")

# Prepare the grid for interpolation
# This step reshapes the grid, checks for NaNs and out-of-bounds values
Expand All @@ -142,12 +141,12 @@ def interpn(
# - number of dimensions
# - boolean array indicating NaN positions
# - (discarded) boolean array for out-of-bounds values
grid, shape, ndim, nans, _ = itrp._prepare_xi(tuple(grid))
xi, shape, ndim, nans, _ = itrp._prepare_xi(xi)

# Find the indices of the grid cells containing the interpolation points
# and calculate the normalized distances (ranging from 0 at lower grid point to 1
# at upper grid point) within these cells
indices, norm_distances = itrp._find_indices(grid.T)
indices, norm_distances = itrp._find_indices(xi.T)

result = interp_fn(indices, norm_distances, values)
nans = anp.reshape(nans, (-1,) + (1,) * (result.ndim - 1))
Expand Down
6 changes: 6 additions & 0 deletions tidy3d/components/autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ def split_list(x: list[typing.Any], index: int) -> (list[typing.Any], list[typin
return x[:index], x[index:]


def is_tidy_box(x: typing.Any) -> bool:
"""Check if a value is a tidy box."""
return getattr(x, "_tidy", False)


__all__ = [
"get_static",
"split_list",
"is_tidy_box",
]
Loading

0 comments on commit 70bd113

Please sign in to comment.