Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bug in local field projection for far_field_approx=True #2048

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Autograd support for local field projections using `FieldProjectionKSpaceMonitor`.

### Fixed
- Regression in local field projection leading to incorrect results for `far_field_approx=True`.

## [2.7.6] - 2024-10-30

### Added
Expand Down
16 changes: 15 additions & 1 deletion tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def objective(args):


@pytest.mark.parametrize("far_field_approx", [True, False])
@pytest.mark.parametrize("projection_type", ["angular", "cartesian"])
@pytest.mark.parametrize("projection_type", ["angular", "cartesian", "kspace"])
@pytest.mark.parametrize("sim_2d", [True, False])
class TestFieldProjection:
@staticmethod
Expand Down Expand Up @@ -1047,6 +1047,20 @@ def setup(far_field_approx, projection_type, sim_2d):
far_field_approx=far_field_approx,
name="far_field",
)
elif projection_type == "kspace":
ux = np.linspace(-0.7, 0.7, 2)
uy = np.linspace(-0.7, 0.7, 3)
monitor_far = td.FieldProjectionKSpaceMonitor(
center=monitor.center,
size=monitor.size,
freqs=monitor.freqs,
ux=ux,
uy=uy,
proj_axis=1,
proj_distance=r_proj,
far_field_approx=far_field_approx,
name="far_field",
)

sim = SIM_BASE.updated_copy(monitors=[monitor])

Expand Down
20 changes: 12 additions & 8 deletions tidy3d/components/field_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,11 @@ def contract(currents):

order = [idx_u, idx_v, idx_w]
zeros = np.zeros(jm[0].shape)
J = anp.array([*jm[:2], zeros])[order]
M = anp.array([*jm[2:], zeros])[order]

# for each index (0, 1, 2), if it’s in the first two elements of order,
# select the corresponding jm element for J or the offset element (+2) for M
J = anp.array([jm[order.index(i)] if i in order[:2] else zeros for i in range(3)])
M = anp.array([jm[order.index(i) + 2] if i in order[:2] else zeros for i in range(3)])

cos_theta_cos_phi = cos_theta[:, None] * cos_phi[None, :]
cos_theta_sin_phi = cos_theta[:, None] * sin_phi[None, :]
Expand Down Expand Up @@ -756,7 +759,7 @@ def _project_fields_kspace(

# compute projected fields for the dataset associated with each monitor
field_names = ("Er", "Etheta", "Ephi", "Hr", "Htheta", "Hphi")
fields = [np.zeros((len(ux), len(uy), 1, len(freqs)), dtype=complex) for _ in field_names]
fields = np.zeros((len(field_names), len(ux), len(uy), 1, len(freqs)), dtype=complex)

medium = monitor.medium if monitor.medium else self.medium
k = AbstractFieldProjectionData.wavenumber(medium=medium, frequency=freqs)
Expand Down Expand Up @@ -790,16 +793,17 @@ def _project_fields_kspace(
currents=currents,
medium=medium,
)
for field, _field in zip(fields, _fields):
field = add_at(field, [i, j, 0, idx_f], _field * phase[idx_f])

where = (slice(None), i, j, 0, idx_f)
_fields = anp.reshape(_fields, fields[where].shape)
fields = add_at(fields, where, _fields * phase[idx_f])
else:
_x, _y, _z = monitor.sph_2_car(monitor.proj_distance, theta, phi)
_fields = self._fields_for_surface_exact(
x=_x, y=_y, z=_z, surface=surface, currents=currents, medium=medium
)
for field, _field in zip(fields, _fields):
field = add_at(field, [i, j, 0], _field)
where = (slice(None), i, j, 0)
_fields = anp.reshape(_fields, fields[where].shape)
fields = add_at(fields, where, _fields)

coords = {
"ux": np.array(monitor.ux),
Expand Down
Loading