Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow transformed groups to be flattened #2050

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Autograd support for local field projections using `FieldProjectionKSpaceMonitor`.
- Function `components.geometry.utils.flatten_groups` now also flattens transformed groups when requested.

### Fixed
- Regression in local field projection leading to incorrect results for `far_field_approx=True`.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_components/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,33 @@ def test_flattening():
for g in flat
)

t0 = np.array([[2, 0, 0, 0], [3, 2, 0, 0], [1, 0, 2, 0], [0, 0, 0, 1.0]])
g0 = td.Sphere(radius=1)
t1 = np.array([[2, 0, 5, 0], [0, 1, 0, 0], [-1, 0, 1, 0], [0, 0, 0, 1.0]])
g1 = td.Box(size=(1, 2, 3))
flat = list(
flatten_groups(
td.Transformed(
transform=t0,
geometry=td.ClipOperation(
operation="union",
geometry_a=g0,
geometry_b=td.Transformed(transform=t1, geometry=g1),
),
),
flatten_transformed=True,
)
)
assert len(flat) == 2

assert isinstance(flat[0], td.Transformed)
assert flat[0].geometry == g0
assert np.allclose(flat[0].transform, t0)

assert isinstance(flat[1], td.Transformed)
assert flat[1].geometry == g1
assert np.allclose(flat[1].transform, t0 @ t1)


def test_geometry_traversal():
geometries = list(traverse_geometries(td.Box(size=(1, 1, 1))))
Expand Down
33 changes: 29 additions & 4 deletions tidy3d/components/geometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from math import isclose
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import numpy as np

Expand All @@ -24,17 +24,25 @@
]


def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = False) -> GeometryType:
def flatten_groups(
*geometries: GeometryType,
flatten_nonunion_type: bool = False,
flatten_transformed: bool = False,
transform: Optional[MatrixReal4x4] = None,
) -> GeometryType:
"""Iterates over all geometries, flattening groups and unions.

Parameters
----------
*geometries : GeometryType
Geometries to flatten.

flatten_nonunion_type : bool = False
If ``False``, only flatten geometry unions (and ``GeometryGroup``). If ``True``, flatten
all clip operations.
flatten_transformed : bool = False
If ``True``, ``Transformed`` groups are flattened into individual transformed geometries.
transform : Optional[MatrixReal4x4]
Accumulated transform from parents. Only used when ``flatten_transformed`` is ``True``.

Yields
------
Expand All @@ -44,7 +52,10 @@ def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = Fals
for geometry in geometries:
if isinstance(geometry, base.GeometryGroup):
yield from flatten_groups(
*geometry.geometries, flatten_nonunion_type=flatten_nonunion_type
*geometry.geometries,
flatten_nonunion_type=flatten_nonunion_type,
flatten_transformed=flatten_transformed,
transform=transform,
)
elif isinstance(geometry, base.ClipOperation) and (
flatten_nonunion_type or geometry.operation == "union"
Expand All @@ -53,7 +64,21 @@ def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = Fals
geometry.geometry_a,
geometry.geometry_b,
flatten_nonunion_type=flatten_nonunion_type,
flatten_transformed=flatten_transformed,
transform=transform,
)
elif flatten_transformed and isinstance(geometry, base.Transformed):
new_transform = geometry.transform
if transform is not None:
new_transform = np.matmul(transform, new_transform)
yield from flatten_groups(
geometry.geometry,
flatten_nonunion_type=flatten_nonunion_type,
flatten_transformed=flatten_transformed,
transform=new_transform,
)
elif flatten_transformed and transform is not None:
yield base.Transformed(geometry=geometry, transform=transform)
else:
yield geometry

Expand Down
2 changes: 1 addition & 1 deletion tidy3d/components/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _validate_num_geometries(cls, val):
return val

for i, structure in enumerate(val):
for geometry in flatten_groups(structure.geometry):
for geometry in flatten_groups(structure.geometry, flatten_transformed=True):
count = sum(
1
for g in traverse_geometries(geometry)
Expand Down
Loading