Skip to content

Commit

Permalink
some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
gomezzz committed Jul 26, 2022
1 parent deb5d47 commit 8bf5394
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 55 deletions.
169 changes: 121 additions & 48 deletions nidn/fdtd_integration/calculate_transmission_reflection_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,120 @@ def calculate_transmission_reflection_coefficients(
reflection_signals[0] = _eliminate_transient_part(reflection_signals[0], cfg)
true_reflection = _eliminate_transient_part(true_reflection, cfg)

(
reflection_coefficient,
transmission_coefficient,
) = _peak_based_coefficient_computation(
transmission_signals, reflection_signals[0], true_reflection
)

if transmission_coefficient < 0 or transmission_coefficient > 1:
logger.error(
f"The transmission coefficient is outside of the physical range between 0 and 1. The transmission coefficient is {transmission_coefficient}"
)

if reflection_coefficient < 0 or reflection_coefficient > 1:
logger.error(
f"The reflection coefficient is outside of the physical range between 0 and 1. The reflection coefficient is {reflection_coefficient}"
)
if transmission_coefficient + reflection_coefficient > 1:
logger.warning(
f"The sum of the transmission and reflection coefficient is greater than 1, which is physically impossible"
)
return transmission_coefficient, reflection_coefficient


def _mean_square(tensor):
"""Calculates the mean of the squared signal
Args:
tensor (tensor): signal to perform the calculations on
Returns:
torch.float: The mean square value
"""
return torch.sum(torch.square(tensor)) / len(tensor)


def _check_for_all_zero_signal(signals):
if _mean_square(signals[0]) <= 1e-15:
raise ValueError(
"The free-space signal is all zero. Increase the number of FDTD_niter to ensure that the signal reaches the detector."
)


def _FFT_based_coefficient_computation(
transmission_signals, reflection_signal, true_reflection, plot=True
):
"""Calculates the transmission coefficient and reflection coefficient using the FFT method.
Args:
transmission_signals (tuple[array,array]): Transmission signal from a free-space fdtd simulaiton, and transmission signal from a fdtd simulation with an added object
reflection_signal (array): Reflection signal from a free-space fdtd simulation, and reflection signal from a fdtd simulation with an added object
true_reflection (array): The true reflection signal, which is the material reflection signal minus the detector reflection signal.
plot (bool, optional): If True, plots the FFTs. Defaults to False.
Returns:
tuple[float, float]: Transmission coefficient and reflection coefficient
"""
# Calculate the FFT of the signals and amplitude as magnitude
transmission_fft = torch.fft.rfft(transmission_signals[0], norm="forward").abs()
transmission_free_space_fft = torch.fft.rfft(
transmission_signals[1], norm="forward"
).abs()
reflection_fft = torch.fft.rfft(reflection_signal, norm="forward").abs()
true_reflection_fft = torch.fft.rfft(true_reflection, norm="forward").abs()

if plot:
import matplotlib.pyplot as plt

plt.figure(figsize=(7.5, 5.0), dpi=150)
plt.plot(transmission_fft, lw=3, linestyle="None", marker="o")
plt.plot(transmission_free_space_fft, lw=3, linestyle="None", marker="d")
plt.plot(reflection_fft, lw=3, linestyle="None", marker="s")
plt.plot(true_reflection_fft, lw=3, linestyle="None", marker="x")
plt.xscale("log")
plt.ylabel("Magnitude", fontsize=8)
plt.xlabel("Frequency", fontsize=8)
plt.tick_params(axis="both", which="major", labelsize=8)
plt.legend(
["Transmission", "FreeSpaceTransmission", "Reflection", "True Reflection"],
fontsize=8,
)
plt.show()

logger.debug(
f"Transmission FFT min,max,mean,std: "
+ f"{transmission_fft.min():.2e}, {transmission_fft.max():.2e}, {transmission_fft.mean():.2e}, {transmission_fft.std():.2e}"
)
logger.debug(
f"Transmission Free Space FFT min,max,mean,std: "
+ f"{transmission_free_space_fft.min():.2e}, {transmission_free_space_fft.max():.2e}, {transmission_free_space_fft.mean():.2e}, {transmission_free_space_fft.std():.2e}"
)
logger.debug(
f"Reflection FFT min,max,mean,std: "
+ f"{reflection_fft.min():.2e}, {reflection_fft.max():.2e}, {reflection_fft.mean():.2e}, {reflection_fft.std():.2e}"
)
logger.debug(
f"True Reflection FFT min,max,mean,std: "
+ f"{true_reflection_fft.min():.2e}, {true_reflection_fft.max():.2e}, {true_reflection_fft.mean():.2e}, {true_reflection_fft.std():.2e}"
)

# Calculate the transmission coefficient
transmission_coefficient = (
transmission_fft.max() / transmission_free_space_fft.max()
)
reflection_coefficient = reflection_fft.max() / true_reflection_fft.max()

return reflection_coefficient, transmission_coefficient


def _peak_based_coefficient_computation(
transmission_signals, reflection_signal, true_reflection
):
# find peaks for all signals
peaks_transmission_freespace = _torch_find_peaks(transmission_signals[0])
peaks_transmission_material = _torch_find_peaks(transmission_signals[1])
peaks_reflection_freespace = _torch_find_peaks(reflection_signals[0])
peaks_reflection_freespace = _torch_find_peaks(reflection_signal)
peaks_reflection_material = _torch_find_peaks(true_reflection)
transmission_coefficient = torch.tensor(0.0)

if len(peaks_transmission_material) > 1:
mean_squared_transmission_material = _mean_square(
Expand All @@ -44,15 +152,14 @@ def calculate_transmission_reflection_coefficients(
.item()
]
)

else:
mean_squared_transmission_material = (max(transmission_signals[1]) ** 2) / 2
logger.warning(
"There is not enough timesteps for the transmission signal to have the proper lenght/ or no signal is transmited. The signal should at least contain 2 peaks, but {} is found.The FDTD_niter should be increased, to be sure that the resutls are valid.".format(
"There are not enough timesteps for the transmission signal to have the proper length / or no signal is transmited. The signal should at least contain 2 peaks, but {} is found.The FDTD_niter should be increased, to be sure that the resutls are valid.".format(
len(peaks_transmission_material)
)
)
mean_squared_transmission_free_space = 1

if len(peaks_transmission_freespace) > 1:
mean_squared_transmission_free_space = _mean_square(
transmission_signals[0][
Expand All @@ -64,12 +171,6 @@ def calculate_transmission_reflection_coefficients(
else:
mean_squared_transmission_free_space = (max(transmission_signals[0]) ** 2) / 2

transmission_coefficient = (
mean_squared_transmission_material / mean_squared_transmission_free_space
)

reflection_coefficient = torch.tensor(0.0)

if len(peaks_reflection_material) > 1:
mean_squared_reflection_material = _mean_square(
true_reflection[
Expand All @@ -81,59 +182,31 @@ def calculate_transmission_reflection_coefficients(
else:
mean_squared_reflection_material = (max(true_reflection) ** 2) / 2
logger.warning(
"There is not enough timesteps for the reflected signal to have the proper lenght. The signal should at least contain 2 peaks, but {} is found. The FDTD_niter should be increased, to be sure that the resutls are valid.".format(
"There are not enough timesteps for the reflected signal to have the proper length. The signal should at least contain 2 peaks, but {} is found. The FDTD_niter should be increased, to be sure that the resutls are valid.".format(
len(peaks_reflection_material)
)
)
mean_squared_reflection_free_space = 1

if len(peaks_reflection_freespace) > 1:
mean_squared_reflection_free_space = _mean_square(
reflection_signals[0][
reflection_signal[
peaks_reflection_freespace[0]
.item() : peaks_reflection_freespace[-1]
.item()
]
)
else:
mean_squared_reflection_free_space = (max(reflection_signals[0]) ** 2) / 2
mean_squared_reflection_free_space = (max(reflection_signal) ** 2) / 2

transmission_coefficient = (
mean_squared_transmission_material / mean_squared_transmission_free_space
)

reflection_coefficient = (
mean_squared_reflection_material / mean_squared_reflection_free_space
)

if transmission_coefficient < 0 or transmission_coefficient > 1:
logger.error(
f"The transmission coefficient is outside of the physical range between 0 and 1. The transmission coefficient is {transmission_coefficient}"
)

if reflection_coefficient < 0 or reflection_coefficient > 1:
logger.error(
f"The reflection coefficient is outside of the physical range between 0 and 1. The reflection coefficient is {reflection_coefficient}"
)
if transmission_coefficient + reflection_coefficient > 1:
logger.warning(
f"The sum of the transmission and reflection coefficient is greater than 1, which is physically impossible"
)
return transmission_coefficient, reflection_coefficient


def _mean_square(tensor):
"""Calculates the mean of the squared signal
Args:
tensor (tensor): signal to perform the calculations on
Returns:
torch.float: The mean square value
"""
return torch.sum(torch.square(tensor)) / len(tensor)


def _check_for_all_zero_signal(signals):

if _mean_square(signals[0]) <= 1e-15:
raise ValueError(
"The free-space signal is all zero. Increase the number of FDTD_niter to ensure that the signal reaches the detector."
)
return reflection_coefficient, transmission_coefficient


def _eliminate_transient_part(signal, cfg, plot=False):
Expand Down
24 changes: 17 additions & 7 deletions notebooks/FDTD_RCWA_TMM_Comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@
"outputs": [],
"source": [
"cfg = nidn.load_default_cfg()\n",
"materials = [\"titanium_oxide\", \"germanium\", \"silicon_nitride\"]\n",
"cfg.N_freq = 20 # number of frequency points\n",
"thicknesses = [0.38,0.5,2.0] # thicknesses\n",
"materials = [\"titanium_oxide\", \"tantalum_pentoxide\",\"silicon_nitride\"]\n",
"cfg.N_freq = 16 # number of frequency points\n",
"thicknesses = [0.38,0.1,2.0] # thicknesses\n",
"# wavelengths\n",
"lam_l = [np.linspace(0.4, 0.5, cfg.N_freq),\n",
" np.linspace(0.1, 0.2, cfg.N_freq),\n",
" np.linspace(0.2, 0.4, cfg.N_freq),\n",
" np.linspace(2.0, 3.0, cfg.N_freq)]\n",
"\n",
"cfg.N_layers = 1\n",
"cfg.FDTD_niter = 800\n",
"cfg.FDTD_niter = 2000\n",
"cfg.FDTD_min_gridpoints_per_unit_magnitude = 50\n",
"cfg.FDTD_pulse_type = 'continuous'\n",
"cfg.FDTD_source_type = 'line'"
Expand All @@ -45,7 +45,9 @@
"cell_type": "code",
"execution_count": null,
"id": "acf8495d",
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"R_tmm,T_tmm,R_rcwa,T_rcwa,R_fdtd,T_fdtd = {}, {}, {}, {}, {}, {}\n",
Expand Down Expand Up @@ -157,11 +159,19 @@
" plt.ylim([-0.1,1.1])\n",
" plt.savefig(f'../results/{material}_A.png')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "140d2626",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('nidn')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down

0 comments on commit 8bf5394

Please sign in to comment.