Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Sep 3, 2024
1 parent 31a2a1b commit 52b3b7f
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 42 deletions.
2 changes: 1 addition & 1 deletion farfield_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def objective(x, *, is_2d=False, local=False, emulated=False, seed=None):
monitor_far = [m for m in sim.monitors if "far_field" in m.name][0]
projected_fields = projector.project_fields(monitor_far)

return projected_fields.power.values.item()
return np.sum(projected_fields.power.values)


def main():
Expand Down
32 changes: 15 additions & 17 deletions gradcheck_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
thickness_sub = 100 * nm

# side length of entire metalens (um)
side_length = 12
side_length = 7

sim_buffer_xy = 2 * wavelength
sim_buffer_xy = 1 * wavelength

# Number of unit cells in each x and y direction (NxN grid)
N = int(side_length / S)
Expand All @@ -57,7 +57,8 @@
Si = td.Medium(permittivity=n_Si**2)

# define symmetry
symmetry = (-1, 1, 0)
# symmetry = (-1, 1, 0)
symmetry = (0, 0, 0)

# using the wavelength in microns, one can use td.C_0 (um/s) to get frequency in Hz
# wavelength_meters = wavelength * meters
Expand Down Expand Up @@ -127,21 +128,20 @@ def make_structures(params, apply_symmetry: bool = True):
if apply_symmetry and symmetry[1] != 0 and y0 < -S / 2:
continue

geometry = td.Box(center=(x0, y0, center_z), size=(size, size, H))
# geometry = td.Cylinder(center=(x0, y0, center_z), length=H, radius=size / 2)
# geometry = td.Box(center=(x0, y0, center_z), size=(size, size, H))
geometry = td.Cylinder(center=(x0, y0, center_z), length=H, radius=size / 2)

geometries.append(geometry)
geo_group = td.GeometryGroup(geometries=geometries)
medium = td.Medium(permittivity=n_Si**2)

return [td.Structure(medium=medium, geometry=geo_group)]
# return [td.Structure(medium=medium, geometry=geo) for geo in geometries]


structures = make_structures(params0)

# steps per unit cell along x and y
grids_per_unit_length = 10
grids_per_unit_length = 16

# uniform mesh in x and y
grid_x = td.UniformGrid(dl=S / grids_per_unit_length)
Expand Down Expand Up @@ -234,33 +234,31 @@ def measure_focal_intensity(sim_data: td.SimulationData) -> float:
)

far_fields = n2f.project_fields(monitor_far)
# return anp.sum(anp.abs(far_fields.Etheta.values))
return far_fields.power.values.item()


def J(params) -> float:
"""Objective function, returns intensity at focal point as a function of params."""
sim = make_sim(params)
# sim.plot(z=0)
# sim.plot_grid(x=0)
# plt.show()
# exit()
# sim = make_sim(np.zeros(x_centers.shape))
sim_data = run_adj(sim, task_name="metalens_invdes_dbg", verbose=True)
# sim_data = make_sim_data(params, fp="simulation_data.hdf5")
# sim_data = td.SimulationData.from_file("simulation_data.hdf5")
return measure_focal_intensity(sim_data)


# val = J(params0)
val = J(params0)

# # params0 = np.random.uniform(-10, 10, 2 * 452)
# # params0 = np.array([1.0, 1.0])
# # print(J(params0))
dJ = ag.value_and_grad(J)
val, grad = dJ(params0)
print("val: ", val)
print("grad: ", grad)
print("|grad|: ", np.linalg.norm(grad))
# dJ = ag.value_and_grad(J)
# val, grad = dJ(params0)
# print("val: ", val)
# print("grad: ", grad)
# print("|grad|: ", np.linalg.norm(grad))
# exit()

# check_grads(J, modes=["rev"], order=1)(params0)
Expand Down Expand Up @@ -293,7 +291,7 @@ def J_normalized(params):

for i in range(num_steps):
# compute gradient and current objective function value
value, gradient = dJ_normalized(params)
value, gradient = dJ_normalized(np.copy(params))

# outputs
print(f"step = {i + 1}")
Expand Down
196 changes: 196 additions & 0 deletions grating_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#!/usr/bin/env -S poetry run python
# ruff: noqa: F401

import autograd.numpy as np
import jax
import matplotlib.pyplot as plt
import optax
from autograd import value_and_grad

import tidy3d as td
from tidy3d.web import run

jax.config.update("jax_enable_x64", True)


def make_sim(
widths,
gaps,
*,
wavelength=1.28,
d_si: float = 0.161,
d_etch: float = 0.106,
n_si: float = 3.507,
n_sio2: float = 1.45,
buffer_left: float = 6.0,
buffer_right: float = 3.0,
buffer_top: float = 1.0,
buffer_bot: float = 1.0,
monitor_buffer: float = 0.1,
mode_monitor_size: float = 2.0,
resolution: int = 20,
shutoff: float = 1e-6,
run_time: float = 1e-12,
sim_size: tuple[float, float, float] = (10, 0, 1.5),
sim_center: tuple[float, float, float] = (3, 0, 0),
):
fcen = td.C_0 / wavelength
si = td.Medium(permittivity=n_si**2, name="Si")
sio2 = td.Medium(permittivity=n_sio2**2, name="SiO2")
etch_cz = d_si / 2 - d_etch / 2

grid_spec = td.GridSpec.auto(
wavelength=wavelength,
min_steps_per_wvl=resolution,
)
boundary_spec = td.BoundarySpec(
x=td.Boundary.pml(),
y=td.Boundary.periodic(),
z=td.Boundary.pml(),
)

waveguide = td.Structure(geometry=td.Box(size=(td.inf, td.inf, d_si)), medium=si)

cx = widths[0] / 2
etch = [
td.Structure(
geometry=td.Box(center=(widths[0] / 2, 0, etch_cz), size=(widths[0], td.inf, d_etch)),
medium=sio2,
)
]
for w, g in zip(widths[1:], gaps, strict=True):
cx = cx + g + w / 2
etch.append(
td.Structure(
geometry=td.Box(center=(cx, 0, etch_cz), size=(w, td.inf, d_etch)), medium=sio2
)
)

near_monitor = td.FieldMonitor(
center=(0, 0, d_si / 2 + monitor_buffer),
size=(td.inf, td.inf, 0),
freqs=(fcen,),
name="near_fields",
colocate=False,
)

far_monitor = td.FieldProjectionAngleMonitor(
center=near_monitor.center,
size=near_monitor.size,
freqs=near_monitor.freqs,
normal_dir="+",
theta=np.linspace(-np.pi / 2, np.pi / 2, 180),
phi=(0.0,),
far_field_approx=True,
name="far_field",
)

mode_source = td.ModeSource(
center=(sim_center[0] - sim_size[0] / 2 + monitor_buffer, 0, 0),
size=(0, td.inf, td.inf),
mode_spec=td.ModeSpec(num_modes=1, filter_pol="te", target_neff=n_si),
source_time=td.GaussianPulse(freq0=fcen, fwidth=fcen / 10),
direction="+",
)

return td.Simulation(
center=sim_center,
size=sim_size,
structures=(waveguide, *etch),
sources=(mode_source,),
monitors=(near_monitor, far_monitor),
grid_spec=grid_spec,
boundary_spec=boundary_spec,
medium=sio2,
shutoff=shutoff,
run_time=run_time,
)


def objective(x):
s = x.size
widths = x[: s // 2 + 1]
gaps = x[widths.size :]

sim = make_sim(widths, gaps)
sim_data = run(sim, task_name="gc_dbg", verbose=False)

near_monitor = sim.monitors[0]

projector = td.FieldProjector.from_near_field_monitors(
sim_data=sim_data,
near_monitors=(near_monitor,),
normal_dirs=("+",),
)

monitor_far = td.FieldProjectionAngleMonitor(
center=near_monitor.center,
size=near_monitor.size,
freqs=near_monitor.freqs,
normal_dir="+",
theta=(np.deg2rad(-10),),
phi=(0.0,),
far_field_approx=True,
name="far_field",
)

projected_fields = projector.project_fields(monitor_far)
power = projected_fields.power.values.ravel()

return np.sum(power) / power.size


def main():
num_steps = 20

widths = np.full(12, 0.3128)
gaps = np.full(widths.size - 1, 0.4068)
x0 = np.concatenate([widths, gaps])

# sim = make_sim(widths, gaps)
# sim.plot(y=0)
# plt.show()
# exit()

vg_fun = value_and_grad(objective)

hist = []
x_opt = np.copy(x0)
lr_schedule = optax.linear_schedule(init_value=1e-2, end_value=1e-4, transition_steps=num_steps)
opt = optax.chain(
optax.adamw(lr_schedule),
optax.scale(-1),
)
opt_state = opt.init(x_opt)

for ii in range(num_steps):
value, gradient = vg_fun(x_opt)

print(f"step = {ii + 1}")
print(f"\tJ = {value:.4e}")
print(f"\tgrad_norm = {np.linalg.norm(gradient):.4e}")

updates, opt_state = opt.update(gradient, opt_state, x_opt)
x_opt = np.array(optax.apply_updates(x_opt, updates))
x_opt = np.clip(x_opt, 0.1, 1.0)

hist.append(value)

fig, ax = plt.subplots(1, 1, tight_layout=True)
ax.plot(hist)
ax.set_xlabel("iterations")
ax.set_ylabel("objective")
plt.show()

fig, ax = plt.subplots(2, 1)
for axi, x in zip(ax, (x0, x_opt)):
s = x_opt.size
widths = x[: s // 2 + 1]
gaps = x[widths.size :]
sim = make_sim(widths, gaps)
sim.plot(y=0, ax=axi)
plt.show()


if __name__ == "__main__":
main()
Loading

0 comments on commit 52b3b7f

Please sign in to comment.