diff --git a/tmmax/VERSION b/tmmax/VERSION index 4e379d2..3eefcb9 100644 --- a/tmmax/VERSION +++ b/tmmax/VERSION @@ -1 +1 @@ -0.0.2 +1.0.0 diff --git a/tmmax/angle.py b/tmmax/angle.py index 54bd578..5645b47 100644 --- a/tmmax/angle.py +++ b/tmmax/angle.py @@ -1,162 +1,140 @@ +import jax.numpy as jnp +from jax import vmap import jax -jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX -import jax.numpy as jnp # Import JAX's version of NumPy for differentiable computations -import sys # Import sys for EPSILON -from typing import Union # Type hints for function signatures +import numpy as np +import pickle +import sys +from jax import Array +from jax.typing import ArrayLike # Define EPSILON as the smallest representable positive number such that 1.0 + EPSILON != 1.0 EPSILON = sys.float_info.epsilon -def is_propagating_wave(n: Union[float, jnp.ndarray], angle_of_incidence: Union[float, jnp.ndarray], polarization: bool) -> Union[float, jnp.ndarray]: - """ - Determines whether a wave is forward-propagating through a multilayer thin film stack based on - the refractive index, angle of incidence, and polarization. - - - Args: - n (Union[float, jnp.ndarray]): Complex refractive index of the medium, which can be an array or scalar. - angle_of_incidence (Union[float, jnp.ndarray]): Angle of incidence of the incoming wave in radians. - polarization (bool): Polarization of the wave: - - False for s-polarization (perpendicular to the plane of incidence). - - True for p-polarization (parallel to the plane of incidence). - - - Returns: - Union[float, jnp.ndarray]: - - A positive value indicates forward propagation for both s and p polarizations. - - A negative or zero value implies backward propagation or evanescent waves (non-propagating). - - - The function evaluates whether the wave, given its angle, refractive index, and polarization, is a - forward-propagating wave (i.e., traveling from the front to the back of the stack). This is crucial - when calculating Snell's law in multilayer structures to ensure light is correctly entering or - exiting the stack. - - - The check considers both real and complex values of the refractive index and angle, ensuring that the - light propagates within the correct angle range for physical interpretation. - """ - - # Multiply the refractive index (n) by the cosine of the angle of incidence - n_cos_theta = n * jnp.cos(angle_of_incidence) # Compute n*cos(theta) for angle propagation - - def define_is_forward_if_bigger_than_eps(_): - """Handle cases where the imaginary part of the refractive index is significant, i.e., - evanescent waves or lossy media.""" - is_forward_s = jnp.imag(n_cos_theta) # Check if the wave decays exponentially in evanescent media - is_forward_p = is_forward_s # Both s- and p-polarizations have the same condition for evanescent decay - return is_forward_s, is_forward_p # Return the imaginary part for determining forward propagation - - def define_is_forward_if_smaller_than_eps(_): - """Handle cases where the real part of the refractive index dominates, - indicating propagating waves.""" - is_forward_s = jnp.real(n_cos_theta) # For s-polarization, check if Re[n * cos(theta)] > 0 - - # For p-polarization, consider n * cos(theta*) where theta* is the complex conjugate of the angle - n_cos_theta_star = n * jnp.cos(jnp.conj(angle_of_incidence)) # Calculate n * cos(conjugate(theta)) - is_forward_p = jnp.real(n_cos_theta_star) # For p-polarization, check if Re[n * cos(theta*)] > 0 - - return is_forward_s, is_forward_p # Return real parts to determine forward propagation - - # Check whether the wave is evanescent or lossy by examining the imaginary part of n * cos(theta) - condition = jnp.abs(jnp.imag(n_cos_theta)) > EPSILON * 1e2 # Set a threshold for significant imaginary part - # Use conditional logic to handle different wave types based on whether the imaginary part is large - is_forward_s, is_forward_p = jax.lax.cond( - condition, - define_is_forward_if_bigger_than_eps, # Handle evanescent/lossy cases - define_is_forward_if_smaller_than_eps, # Handle normal propagating waves - None - ) - - # Return the result based on the polarization type - if polarization is False: - # For s-polarization, return whether the wave is forward-propagating - return jnp.array([is_forward_s]) # s-polarization output as a single-element array - elif polarization is True: - # For p-polarization, return whether the wave is forward-propagating - return jnp.array([is_forward_p]) # p-polarization output as a single-element array - - -def _compute_layer_angles_single_wl_angle_point(nk_list: jnp.ndarray, - angle_of_incidence: Union[float, jnp.ndarray], - wavelength: Union[float, jnp.ndarray], - polarization: bool) -> jnp.ndarray: - """ - Computes the angle of incidence for light in each layer of a multilayer thin film using Snell's law - for a single wavelength and angle of incidence. - - Args: - nk_list (jnp.ndarray): A JAX array containing the refractive index (n) and extinction coefficient (k) values - for each layer, typically as a function of wavelength. - - angle_of_incidence (Union[float, jnp.ndarray]): The angle of incidence (in radians) relative to the normal of - the first layer. Can be a float for single-angle computation - or an array for batch computation. - - wavelength (Union[float, jnp.ndarray]): The wavelength or an array of wavelengths (ndarray) for which the computation - is performed. - - polarization (bool): Determines the polarization state of light: - - False: s-polarization (perpendicular to the plane of incidence). - - True: p-polarization (parallel to the plane of incidence). - - Returns: - jnp.ndarray: A JAX array containing the calculated angles of incidence for each layer in radians. - If `angle_of_incidence` is a float, it returns a 1D array where each element represents - the angle in a specific layer. If `angle_of_incidence` is a 1D array, the return is 2D, - with each row representing the angles in all layers for a specific initial angle. - - Detailed Description: - This function applies Snell's law to calculate the angle of incidence in each layer of a multilayer - thin film structure. Snell's law relates the angles and refractive indices across different media via: - - sin(theta_i) = (n_0 * sin(theta_0)) / n_i - - where: - - theta_i is the angle of incidence in the i-th layer, - - n_0 is the refractive index of the first layer, - - theta_0 is the initial angle of incidence, - - n_i is the refractive index of the i-th layer. - - The function also handles situations where the light may not propagate forward (e.g., due to total internal - reflection or evanescent waves) by flipping angles when needed. - """ - +def is_forward_if_bigger_than_eps_s_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # Calculate n * cos(theta) to evaluate propagation direction for s-polarization + n_cos_theta = jnp.multiply(n, jnp.cos(theta)) + # For evanescent or lossy mediums, forward is determined by decay + is_forward_s = jnp.invert(jnp.signbit(jnp.imag(n_cos_theta))) + return is_forward_s + +def is_forward_if_smaller_than_eps_s_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # Calculate n * cos(theta) to evaluate propagation direction for s-polarization + n_cos_theta = jnp.multiply(n, jnp.cos(theta)) + # For s-polarization: Re[n cos(theta)] > 0 + is_forward_s = jnp.invert(jnp.signbit(jnp.real(n_cos_theta))) + return is_forward_s + +def is_forward_if_bigger_than_eps_p_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # Calculate n * cos(theta) to evaluate propagation direction for s-polarization + n_cos_theta = jnp.multiply(n, jnp.cos(theta)) + # For evanescent or lossy mediums, forward is determined by decay + is_forward_p = jnp.invert(jnp.signbit(jnp.imag(n_cos_theta))) + # The decay condition applies to both polarizations (s and p, so we return s and p as s and s) equally + return is_forward_p + +def is_forward_if_smaller_than_eps_p_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # For p-polarization: Re[n cos(theta*)] > 0 + n_cos_theta_star = jnp.multiply(n, jnp.cos(jnp.conj(theta))) + is_forward_p = jnp.invert(jnp.signbit(jnp.real(n_cos_theta_star))) + return is_forward_p + +def is_forward_if_bigger_than_eps_u_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # Calculate n * cos(theta) to evaluate propagation direction for s-polarization + n_cos_theta = jnp.multiply(n, jnp.cos(theta)) + # For evanescent or lossy mediums, forward is determined by decay + is_forward_s = jnp.invert(jnp.signbit(jnp.imag(n_cos_theta))) + # The decay condition applies to both polarizations (s and p, so we return s and p as s and s) equally + return is_forward_s, is_forward_s + +def is_forward_if_smaller_than_eps_u_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # Calculate n * cos(theta) to evaluate propagation direction for s-polarization + n_cos_theta = jnp.multiply(n, jnp.cos(theta)) + # For s-polarization: Re[n cos(theta)] > 0 + is_forward_s = jnp.invert(jnp.signbit(jnp.real(n_cos_theta))) + # For p-polarization: Re[n cos(theta*)] > 0 + n_cos_theta_star = jnp.multiply(n, jnp.cos(jnp.conj(theta))) + is_forward_p = jnp.invert(jnp.signbit(jnp.real(n_cos_theta_star))) + return is_forward_s, is_forward_p + +def is_propagating_wave_s_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # Handle the evanescent and lossy cases by checking the imaginary part + condition = jnp.squeeze(jnp.greater(jnp.abs(jnp.imag(jnp.multiply(n, jnp.cos(theta)))), jnp.multiply(jnp.array([EPSILON]), jnp.array([1e3])))) + # Return based on polarization argument + is_forward_s = jax.lax.cond(condition, is_forward_if_bigger_than_eps_s_pol, is_forward_if_smaller_than_eps_s_pol, n, theta) + # s-polarization case + return is_forward_s + +def is_propagating_wave_p_pol(n: ArrayLike, theta: ArrayLike) -> Array: + # Handle the evanescent and lossy cases by checking the imaginary part + condition = jnp.squeeze(jnp.greater(jnp.abs(jnp.imag(jnp.multiply(n, jnp.cos(theta)))), jnp.multiply(jnp.array([EPSILON]), jnp.array([1e3])))) + # Return based on polarization argument + is_forward_p = jax.lax.cond(condition, is_forward_if_bigger_than_eps_p_pol, is_forward_if_smaller_than_eps_p_pol, n, theta) + # p-polarization case + return is_forward_p + +def update_theta_arr_incoming(theta_array): + return theta_array.at[0].set(jnp.pi - theta_array.at[0].get()) + +def update_theta_arr_outgoing(theta_array): + return theta_array.at[-1].set(jnp.pi - theta_array.at[-1].get()) + +def return_unchanged_theta(theta_array): + return theta_array + +def compute_layer_angles_s_pol(angle_of_incidence: ArrayLike, + nk_list: ArrayLike) -> Array: + #print("angle_of_incidence : ", jnp.asarray(angle_of_incidence)) + #print("nk_list : ", jnp.asarray(nk_list)) # Calculate the sine of the angles in the first layer using Snell's law - # Here, we are computing sin(theta) for each layer using the ratio of the refractive index of the first layer - # to the refractive index of the current layer (Snell's Law). - sin_theta = jnp.sin(angle_of_incidence) * nk_list[0] / nk_list # Ratio ensures correct angle for each layer + sin_theta = jnp.true_divide(jnp.multiply(jnp.sin(angle_of_incidence), nk_list.at[0].get()), nk_list) + # Compute the angle (theta) in each layer using the arcsin function + # jnp.arcsin is preferred for compatibility with complex values if needed + theta_array = jnp.arcsin(sin_theta) + #print("first theta : ", jnp.asarray(theta_array)) + # If the angle is not forward-facing, we subtract it from pi to flip the orientation. + incoming_props = is_propagating_wave_s_pol(nk_list.at[0].get(), theta_array.at[0].get()) + outgoing_props = is_propagating_wave_s_pol(nk_list.at[-1].get(), theta_array.at[-1].get()) + + # Handle the evanescent and lossy cases by checking the imaginary part + condition_incoming = jnp.array_equal(False, jnp.array([True], dtype=bool)) + condition_outgoing = jnp.array_equal(False, jnp.array([True], dtype=bool)) + + theta_array = jax.lax.cond(condition_incoming, update_theta_arr_incoming, return_unchanged_theta, operand=theta_array) + #print("second theta : ", jnp.asarray(theta_array)) + theta_array = jax.lax.cond(condition_outgoing, update_theta_arr_outgoing, return_unchanged_theta, operand=theta_array) + #print("third theta : ", jnp.asarray(theta_array)) + # Return a 1D theta array for each layer + return theta_array + +def compute_layer_angles_p_pol(angle_of_incidence: ArrayLike, + nk_list: ArrayLike) -> Array: + + # Calculate the sine of the angles in the first layer using Snell's law + sin_theta = jnp.true_divide(jnp.multiply(jnp.sin(angle_of_incidence), nk_list.at[0].get()), nk_list) # Compute the angle (theta) in each layer using the arcsin function - # jnp.arcsin is used here to calculate the inverse sine (arcsine) and is compatible with complex values if needed. - theta_array = jnp.arcsin(sin_theta) # Converts sin(theta) values back to theta (angle in radians) - - # Check if the wave is forward propagating or not by calculating its properties for the first and last layer. - # is_propagating_wave returns a boolean array where True means the wave is propagating and False means evanescent. - is_incoming_props = is_propagating_wave(nk_list[0], theta_array[0], polarization) # First layer propagation check - is_outgoing_props = is_propagating_wave(nk_list[-1], theta_array[-1], polarization) # Last layer propagation check - - # If the wave is evanescent (non-propagating), update the angle by flipping it (subtracting from pi). - def update_theta_arr_incoming(_): - return theta_array.at[0].set(jnp.pi - theta_array[0]) # Flips the angle in the first layer if needed - - # Similarly for the outgoing wave in the last layer. - def update_theta_arr_outgoing(_): - return theta_array.at[-1].set(jnp.pi - theta_array[-1]) # Flips the angle in the last layer if needed - - # If the wave is propagating normally, return the theta_array unchanged. - def return_unchanged_theta(_): - return theta_array # No angle flip if propagation is normal - - # Handle the evanescent and lossy cases by checking the incoming wave's properties. - # If any wave in the first layer is non-propagating, the angle gets flipped. - condition_incoming = jnp.any(is_incoming_props <= 0) # Check if the incoming wave has an evanescent component - condition_outgoing = jnp.any(is_outgoing_props <= 0) # Check if the outgoing wave has an evanescent component - - # Conditionally update the theta_array based on whether the incoming wave is evanescent or not. - # jax.lax.cond is used here to conditionally perform updates based on the given condition. - theta_array = jax.lax.cond(condition_incoming, update_theta_arr_incoming, return_unchanged_theta, operand=None) # Conditionally flip the angle for incoming wave - theta_array = jax.lax.cond(condition_outgoing, update_theta_arr_outgoing, return_unchanged_theta, operand=None) # Conditionally flip the angle for outgoing wave - - # Return the final angles of incidence (theta_array) for each layer, reflecting any necessary flips. - return theta_array # Final output: angles of incidence in each layer after applying Snell's law + # jnp.arcsin is preferred for compatibility with complex values if needed + theta_array = jnp.arcsin(sin_theta) + # If the angle is not forward-facing, we subtract it from pi to flip the orientation. + incoming_props = is_propagating_wave_p_pol(nk_list.at[0].get(), theta_array.at[0].get()) + outgoing_props = is_propagating_wave_p_pol(nk_list.at[-1].get(), theta_array.at[-1].get()) + + # Handle the evanescent and lossy cases by checking the imaginary part + condition_incoming = jnp.array_equal(incoming_props, jnp.array([True], dtype=bool)) + condition_outgoing = jnp.array_equal(outgoing_props, jnp.array([True], dtype=bool)) + + theta_array = jax.lax.cond(condition_incoming, update_theta_arr_incoming, return_unchanged_theta, operand=theta_array) + theta_array = jax.lax.cond(condition_outgoing, update_theta_arr_outgoing, return_unchanged_theta, operand=theta_array) + + # Return a 1D theta array for each layer + return theta_array + +#@jit +def compute_layer_angles(angle_of_incidence: ArrayLike, + nk_list: ArrayLike, + polarization: ArrayLike) -> Array: + + return jnp.select(condlist=[jnp.array_equal(polarization, jnp.array([0], dtype = jnp.int16)), + jnp.array_equal(polarization, jnp.array([1], dtype = jnp.int16))], + choicelist=[compute_layer_angles_s_pol(angle_of_incidence, nk_list), + compute_layer_angles_p_pol(angle_of_incidence, nk_list)]) \ No newline at end of file diff --git a/tmmax/cascaded_matmul.py b/tmmax/cascaded_matmul.py index 1d503aa..d135090 100644 --- a/tmmax/cascaded_matmul.py +++ b/tmmax/cascaded_matmul.py @@ -1,47 +1,37 @@ import jax -jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX +#jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX import jax.numpy as jnp # Import JAX's version of NumPy for differentiable computations +def matmul(carry, phase_r_t): -def _matmul(carry, phase_t_r): - """ - Multiplies two complex matrices in a sequence. - Args: - carry (jax.numpy.ndarray): The accumulated product of the matrices so far. - This is expected to be a 2x2 complex matrix. - phase_t_r (jax.numpy.ndarray): A 3-element array where: - - phase_t_r[0] represents the phase shift delta (a scalar). - - phase_t_r[1] represents the transmission coefficient t or T (a scalar). - - phase_t_r[2] represents the reflection coefficient r or R (a scalar). + transfer_matrix_00 = jnp.exp(jnp.multiply(jnp.array([-1j], dtype = jnp.complex64), phase_r_t.at[0].get())) + transfer_matrix_11 = jnp.exp(jnp.multiply(jnp.array([1j], dtype = jnp.complex64), phase_r_t.at[0].get())) + transfer_matrix_01 = jnp.multiply(phase_r_t.at[1].get(), transfer_matrix_00) + transfer_matrix_10 = jnp.multiply(phase_r_t.at[1].get(), transfer_matrix_11) - Returns: - jax.numpy.ndarray: The updated product after multiplying the carry matrix with the current matrix. - This is also a 2x2 complex matrix. - None: A placeholder required by jax.lax.scan for compatibility. - """ - # Create the diagonal phase matrix based on phase_t_r[0] - # This matrix introduces a phase shift based on the delta value - phase_matrix = jnp.array([[jnp.exp(-1j * phase_t_r[0]), 0], # Matrix with phase shift for the first entry - [0, jnp.exp(1j * phase_t_r[0])]]) # Matrix with phase shift for the second entry + transfer_matrix = jnp.multiply(jnp.true_divide(1, phase_r_t.at[2].get()), jnp.array([[transfer_matrix_00, transfer_matrix_01], + [transfer_matrix_10, transfer_matrix_11]])) - # Create the matrix based on phase_t_r[1] and phase_t_r[2] - # This matrix incorporates the transmission and reflection coefficients - transmission_reflection_matrix = jnp.array([[1, phase_t_r[1]], # Top row with transmission coefficient - [phase_t_r[1], 1]]) # Bottom row with transmission coefficient + result = jnp.matmul(carry, transfer_matrix) - # Compute the current matrix by multiplying the phase_matrix with the transmission_reflection_matrix - # The multiplication is scaled by 1/phase_t_r[2] to account for the reflection coefficient - mat = jnp.array(1 / phase_t_r[2]) * jnp.dot(phase_matrix, transmission_reflection_matrix) + - # Multiply the accumulated carry matrix with the current matrix - # This updates the product with the new matrix - result = jnp.dot(carry, mat) + return jnp.squeeze(result), jnp.squeeze(result) # Return the updated matrix and None as a placeholder for jax.lax.scan - return result, None # Return the updated matrix and None as a placeholder for jax.lax.scan +#@jax.jit +def cascaded_matrix_multiplication(phases: ArrayLike, rts: ArrayLike) -> Array: + """ + Calculates the angles of incidence across layers for a set of refractive indices (nk_list_2d) + and an initial angle of incidence (initial_theta) using vectorization. -def _cascaded_matrix_multiplication(phases_ts_rs: jnp.ndarray) -> jnp.ndarray: + Returns: + jnp.ndarray: A 3D JAX array where the [i, j, :] entry represents the angles of incidence + for the j-th initial angle at the i-th wavelength. The size of the third dimension + corresponds to the number of layers. + """ + phase_rt_stack = jnp.concat([jnp.expand_dims(phases, 1), rts], axis=1) """ Performs cascaded matrix multiplication on a sequence of complex matrices using scan. @@ -53,13 +43,13 @@ def _cascaded_matrix_multiplication(phases_ts_rs: jnp.ndarray) -> jnp.ndarray: jax.numpy.ndarray: The final result of multiplying all the matrices together in sequence. This result is a single 2x2 complex matrix representing the accumulated product of all input matrices. """ - initial_value = jnp.eye(2, dtype=jnp.complex128) - # Initialize with the identity matrix of size 2x2. # The identity matrix acts as the multiplicative identity, + initial_value = jnp.eye(2, dtype=jnp.complex64) + # Initialize with the identity matrix of size 2x2. # The identity matrix acts as the multiplicative identity, # ensuring that the multiplication starts correctly. - # jax.lax.scan applies a function across the sequence of matrices. + # jax.lax.scan applies a function across the sequence of matrices. #Here, _matmul is the function applied, starting with the identity matrix. # `result` will hold the final matrix after processing all input matrices. - result, _ = jax.lax.scan(_matmul, initial_value, phases_ts_rs) # Scan function accumulates results of _matmul over the matrices. - - return result # Return the final accumulated matrix product. # The result is the product of all input matrices in the given sequence. + result, _ = jax.lax.scan(matmul, initial_value, phase_rt_stack) # Scan function accumulates results of _matmul over the matrices. + + return result \ No newline at end of file diff --git a/tmmax/data.py b/tmmax/data.py index bd329d1..71dbaed 100644 --- a/tmmax/data.py +++ b/tmmax/data.py @@ -1,38 +1,37 @@ -from functools import lru_cache # Importing lru_cache to cache the function results +from functools import lru_cache # Importing lru_cache to cache the function results and importing partial decorator import jax -jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX import jax.numpy as jnp # Import JAX's version of NumPy for differentiable computations -from jax import jit, device_put # Import JAX functions for JIT compilation +from jax import jit, device_put, grad, vmap # Import JAX functions for JIT compilation, gradient and vmap +import matplotlib.pyplot as plt # Import matplotlib for plotting import numpy as np # Importing numpy lib for savetxt function for saving arrays to csv files import os # Importing os to handle file paths import pandas as pd # Importing pandas to handle CSV data -from typing import Union, Callable # Type hints for function signatures +from typing import Union, List, Tuple, Optional, Callable # Type hints for function signatures import warnings # Importing the warnings module to handle warnings in the code from . import nk_data_dir -@lru_cache(maxsize=32) -def load_nk_data(material_name: str = '') -> Union[jnp.ndarray, None]: +def load_nk_data_csv(material_name: str = '') -> Union[jnp.ndarray, None]: """ Load the refractive index (n) and extinction coefficient (k) data for a given material: (n + 1j * k). - This function fetches wavelength-dependent refractive index (n) and extinction coefficient (k) - data for a specified material. The data is read from a CSV file located in the 'nk_data/' directory. - The CSV file should be named after the material, e.g., 'Si.csv', and include three columns: wavelength (in micrometers), - refractive index (n), and extinction coefficient (k). These parameters are crucial for optical simulations, + This function fetches wavelength-dependent refractive index (n) and extinction coefficient (k) + data for a specified material. The data is read from a CSV file located in the 'nk_data/' directory. + The CSV file should be named after the material, e.g., 'Si.csv', and include three columns: wavelength (in micrometers), + refractive index (n), and extinction coefficient (k). These parameters are crucial for optical simulations, allowing the user to work with materials' optical properties over a range of wavelengths. Args: - material_name (str): The name of the material for which the data is to be loaded. - This must not be an empty string, and the corresponding CSV file + material_name (str): The name of the material for which the data is to be loaded. + This must not be an empty string, and the corresponding CSV file must exist in the 'nk_data/' directory. Returns: - jnp.ndarray: A 2D array containing the wavelength (first column), + jnp.ndarray: A 2D array containing the wavelength (first column), refractive index (n) (second column), and extinction coefficient (k) (third column). Each row corresponds to a different wavelength. - - None: If the function fails due to any raised exception or if the CSV file is empty, + + None: If the function fails due to any raised exception or if the CSV file is empty, it will return None. Raises: @@ -41,69 +40,123 @@ def load_nk_data(material_name: str = '') -> Union[jnp.ndarray, None]: IOError: If there's an issue reading or parsing the file. """ # Check that the material name is not an empty string - if not material_name: + if not material_name: raise ValueError("Material name cannot be an empty string.") # Raise an error if no material is provided # Construct the file path and check if the file exists - file_path = os.path.join(nk_data_dir, f'{material_name}.csv') # Create the full path to the file - if not os.path.exists(file_path): + file_path = os.path.join(nk_data_dir, f'csv/{material_name}.csv') # Create the full path to the file + if not os.path.exists(file_path): # Raise an error if the file for the material does not exist raise FileNotFoundError(f"No data found for material '{material_name}' in 'nk_data/' folder (library database).") - + # Load the data from the CSV file try: # Load the CSV data as a JAX array (important for using JAX's functionality, like automatic differentiation) - data = jnp.asarray(pd.read_csv(file_path, skiprows=1, header=None).values) + data = jnp.asarray(pd.read_csv(file_path, skiprows=1, header=None).values) except Exception as e: # If an error occurs during file reading or conversion, raise an IOError raise IOError(f"An error occurred while loading data for '{material_name}': {e}") - + # Check if the file is empty or doesn't contain valid data - if data.size == 0: + if data.size == 0: # Raise an error if the data array is empty or incorrectly formatted raise ValueError(f"The file for material '{material_name}' is empty or not in the expected format.") - + + return data # Return the loaded data as a JAX array + + +def load_nk_data_numpy(material_name: str = '') -> Union[jnp.ndarray, None]: + """ + Load the refractive index (n) and extinction coefficient (k) data for a given material: (n + 1j * k). + + This function fetches wavelength-dependent refractive index (n) and extinction coefficient (k) + data for a specified material. The data is read from a CSV file located in the 'nk_data/' directory. + The CSV file should be named after the material, e.g., 'Si.csv', and include three columns: wavelength (in micrometers), + refractive index (n), and extinction coefficient (k). These parameters are crucial for optical simulations, + allowing the user to work with materials' optical properties over a range of wavelengths. + + Args: + material_name (str): The name of the material for which the data is to be loaded. + This must not be an empty string, and the corresponding CSV file + must exist in the 'nk_data/' directory. + + Returns: + jnp.ndarray: A 2D array containing the wavelength (first column), + refractive index (n) (second column), and extinction coefficient (k) (third column). + Each row corresponds to a different wavelength. + + None: If the function fails due to any raised exception or if the CSV file is empty, + it will return None. + + Raises: + ValueError: If the material name is an empty string. + FileNotFoundError: If the file for the given material does not exist in the 'nk_data/' folder. + IOError: If there's an issue reading or parsing the file. + """ + # Check that the material name is not an empty string + if not material_name: + raise ValueError("Material name cannot be an empty string.") # Raise an error if no material is provided + + # Construct the file path and check if the file exists + file_path = os.path.join(nk_data_dir, f'numpy/{material_name}.npy') # Create the full path to the file + if not os.path.exists(file_path): + # Raise an error if the file for the material does not exist + raise FileNotFoundError(f"No data found for material '{material_name}' in 'nk_data/numpy/' folder (library database).") + + # Load the data from the CSV file + try: + # Load the CSV data as a JAX array (important for using JAX's functionality, like automatic differentiation) + data = jnp.load(file_path) + + except Exception as e: + # If an error occurs during file reading or conversion, raise an IOError + raise IOError(f"An error occurred while loading data for '{material_name}': {e}") + + # Check if the file is empty or doesn't contain valid data + if data.size == 0: + # Raise an error if the data array is empty or incorrectly formatted + raise ValueError(f"The file for material '{material_name}' is empty or not in the expected format.") + return data # Return the loaded data as a JAX array def interpolate_1d(x: jnp.ndarray, y: jnp.ndarray) -> Callable[[float], float]: """ Creates a 1D linear interpolation function based on the provided x and y arrays. - + This function returns a callable that performs linear interpolation on the input data points (x, y). - Given an x value, it finds the corresponding y value by assuming a straight line between two closest points - in the x array and using the equation of the line. - + Given an x value, it finds the corresponding y value by assuming a straight line between two closest points + in the x array and using the equation of the line. + Args: x (jnp.ndarray): Array of x values (independent variable). It must be sorted in ascending order. y (jnp.ndarray): Array of y values (dependent variable). It should have the same length as the x array. - + Returns: Callable[[float], float]: A function that, when provided with a single float x value, returns the corresponding interpolated float y value based on the linear interpolation. """ - - @jit # Just-In-Time compilation using JAX, speeds up the execution by compiling the function once. + def interpolate(x_val: float) -> float: # Find the index where x_val would fit in x to maintain the sorted order idx = jnp.searchsorted(x, x_val, side='right') - 1 # Ensure idx is within valid bounds (0 to len(x)-2) to avoid out-of-bounds errors idx = jnp.clip(idx, 0, x.shape[0] - 2) - + # Retrieve the two nearest x values, x_i and x_{i+1}, that surround x_val x_i, x_ip1 = x[idx], x[idx + 1] # Retrieve the corresponding y values, y_i and y_{i+1}, at those x positions y_i, y_ip1 = y[idx], y[idx + 1] - + # Calculate the slope of the line between (x_i, y_i) and (x_{i+1}, y_{i+1}) slope = (y_ip1 - y_i) / (x_ip1 - x_i) - + # Interpolate the y value using the slope formula: y = y_i + slope * (x_val - x_i) return y_i + slope * (x_val - x_i) return interpolate # Return the interpolation function to be used later - +@lru_cache(maxsize=32) def interpolate_nk(material_name: str) -> Callable[[float], complex]: """ Load the nk data for a given material and return a callable function that computes @@ -113,17 +166,17 @@ def interpolate_nk(material_name: str) -> Callable[[float], complex]: material_name (str): Name of the material to load the nk data for. Returns: - Callable[[float], complex]: A function that takes a wavelength (in meters) and + Callable[[float], complex]: A function that takes a wavelength (in meters) and returns the complex refractive index. """ - nk_data = load_nk_data(material_name) # Load the nk data for the specified material - wavelength, refractive_index, extinction_coefficient = nk_data.T # Transpose to get columns as variables + nk_data = load_nk_data_numpy(material_name) # Load the nk data for the specified material + wavelength, refractive_index, extinction_coefficient = nk_data[0,:], nk_data[1,:], nk_data[2,:] # Transpose to get columns as variables # Interpolate refractive index and extinction coefficient compute_refractive_index = interpolate_1d(wavelength * 1e-6, refractive_index) # Convert wavelength to meters for interpolation compute_extinction_coefficient = interpolate_1d(wavelength * 1e-6, extinction_coefficient) # Convert wavelength to meters for interpolation - @jit # Just-in-time compile the function to optimize performance + def compute_nk(wavelength: float) -> complex: """ Compute the complex refractive index for a given wavelength. @@ -132,7 +185,7 @@ def compute_nk(wavelength: float) -> complex: wavelength (float): Wavelength in meters. Returns: - complex: The complex refractive index, n + i*k, where n is the refractive index + complex: The complex refractive index, n + i*k, where n is the refractive index and k is the extinction coefficient. """ n = compute_refractive_index(wavelength) # Get the refractive index at the given wavelength @@ -159,7 +212,7 @@ def add_material_to_nk_database(wavelength_arr, refractive_index_arr, extinction TypeError: If any of the input arrays are not of type jax.numpy.ndarray. ValueError: If the input arrays have different lengths or if the material name is empty. """ - + # Validate input types # Check if all input arrays are of type jax.numpy.ndarray if not all(isinstance(arr, jnp.ndarray) for arr in [wavelength_arr, refractive_index_arr, extinction_coeff_arr]): @@ -194,11 +247,155 @@ def add_material_to_nk_database(wavelength_arr, refractive_index_arr, extinction # Construct the file path # Create a file path for saving the data based on the material name path = os.path.join(nk_data_dir, f'{material_name}.csv') - + # Save the file with a header # Convert the jax.numpy array to a numpy array for file saving and write to CSV np.savetxt(path, np.asarray(data), delimiter=',', header='wavelength_in_um,n,k', comments='') - + # Provide feedback on file creation # Inform the user whether the file was created or recreated successfully - print(f"'{os.path.basename(path)}' {'recreated' if os.path.exists(path) else 'created'} successfully.") \ No newline at end of file + print(f"'{os.path.basename(path)}' {'recreated' if os.path.exists(path) else 'created'} successfully.") + + +def visualize_material_properties(material_name = '', logX = False, logY = False, eV = False, savefig = False, save_path = None): + # Load the data from the .csv file + data = np.array(load_nk_data_csv(material_name)) + # Unpack the columns: wavelength, refractive index, extinction coefficient + wavelength, refractive_index, extinction_coeff = data.T # wavelength is in um + # Custom chart specs + if eV: + eV_arr = 1239.8/(wavelength*1e3) # E(eV) = 1239.8 / wavelength (nm) + # Creating plot for refractive_index + fig, ax1 = plt.subplots(figsize=(10, 6)) + color_n = 'navy' + ax1.set_ylabel('Refractive Index (n)', color=color_n, fontsize=14, fontweight='bold') + if not eV: + ax1.set_xlabel('Wavelength (um)', fontsize=14, fontweight='bold') + ax1.plot(wavelength, refractive_index, color=color_n, linewidth=2, label='Refractive Index (n)') + else: + ax1.set_xlabel('Photon energy (eV)', fontsize=14, fontweight='bold') + ax1.plot(eV_arr, refractive_index, color=color_n, linewidth=2, label='Refractive Index (n)') + ax1.tick_params(axis='y', labelcolor=color_n, labelsize=12) + ax1.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7) + # Creating a second y-axis for the extinction coefficient (k) + ax2 = ax1.twinx() + color_k = 'crimson' + ax2.set_ylabel('Extinction Coefficient (k)', color=color_k, fontsize=14, fontweight='bold') + if not eV: + ax2.plot(wavelength, extinction_coeff, color=color_k, linewidth=2, linestyle='-', label='Extinction Coefficient (k)') + else: + ax2.plot(eV_arr, extinction_coeff, color=color_k, linewidth=2, linestyle='-', label='Extinction Coefficient (k)') + ax2.tick_params(axis='y', labelcolor=color_k, labelsize=12) + if logX: + # Set the x-axis to logarithmic scale + plt.xscale('log') + if logY: + # Set the y-axis to logarithmic scale + plt.yscale('log') + # Adding title + plt.title(f'Refractive Index (n) and Extinction Coefficient (k) vs Wavelength for {material_name}', fontsize=16, fontweight='bold', pad=20) + fig.tight_layout() + # Save the figure as a PNG if savefig True + if savefig: + # Check that save_path is not an empty string or None + if not save_path: + raise ValueError("save_path cannot be an empty string or None") + # Ensure the save directory exists + os.makedirs(save_path, exist_ok=True) + # Construct the full save path with filename + full_save_path = os.path.join(save_path, f'{material_name}_nk_plot.png') + # Save the figure + plt.savefig(full_save_path, dpi=300) + print(f"Figure saved successfully at: {full_save_path}") + plt.show() + +def common_wavelength_band(material_list: List[str]) -> Tuple[float, float]: + """ + Compute the common wavelength band across a list of materials based on their n-k data. + + Args: + ---------- + material_list : Optional[List[str]] + A list of material names for which the common wavelength band is to be calculated. + + Returns: + ------- + Optional[Tuple[float, float]] + A tuple containing the minimum and maximum wavelength of the common band. + Returns None if no common wavelength band exists. + + Raises: + ------ + ValueError: + If the material_list is empty or None. + """ + if not material_list: + raise ValueError("Material list cannot be empty or None.") + + # Initialize wavelength bounds + min_wavelength = -jnp.inf + max_wavelength = jnp.inf + + # Iterate through each material's wavelength range + for material_name in material_list: + wavelength_arr = load_nk_data_csv(material_name)[:, 0] + material_min, material_max = jnp.min(wavelength_arr), jnp.max(wavelength_arr) + + # Update the min_wavelength and max_wavelength to find the common range + min_wavelength = jnp.maximum(min_wavelength, material_min) + max_wavelength = jnp.minimum(max_wavelength, material_max) + + # Early exit if no common range is possible + if min_wavelength > max_wavelength: + return None + + return min_wavelength, max_wavelength + + +def calculate_chromatic_dispersion(material_name: str) -> jnp.ndarray: + """ + Calculate the chromatic dispersion, which is the derivative of the refractive index + with respect to wavelength. + + Args: + material_name (str): Name of the material. + + Returns: + jnp.ndarray: Array containing the chromatic dispersion (d n / d wavelength). + """ + # Fetch the nk data for the material + nk_data = load_nk_data_csv(material_name) + + # Unpack the columns: wavelength, refractive index, extinction coefficient + wavelength, refractive_index, _ = nk_data.T # nk_data.T transposes the matrix to easily unpack columns + + # Define a function to compute the refractive index as a function of wavelength + def n_func(wl: jnp.ndarray) -> jnp.ndarray: + return jnp.interp(wl, wavelength, refractive_index) + + # Compute the derivative of the refractive index function with respect to wavelength + dn_dw = vmap(grad(n_func))(wavelength) + + return dn_dw + +def get_max_absorption_wavelength(material_name: str) -> float: + """ + Calculate the wavelength at which the absorption coefficient is maximized. + + Args: + material_name (str): Name of the material. + + Returns: + float: Wavelength (in μm) corresponding to the maximum absorption coefficient. + """ + # Fetch the nk data for the material + data = load_nk_data_csv(material_name) + # Unpack the columns: wavelength, refractive index (not used), extinction coefficient + wavelength, _, k = data.T # data.T transposes the matrix to easily unpack columns + # Calculate the absorption coefficient: α(λ) = 4 * π * k / λ + absorption_coefficient = 4 * jnp.pi * k / wavelength + # Identify the index of the maximum absorption coefficient + max_absorption_index = jnp.argmax(absorption_coefficient) + + # Return the wavelength corresponding to the maximum absorption + return float(wavelength[max_absorption_index]) \ No newline at end of file diff --git a/tmmax/fresnel.py b/tmmax/fresnel.py index de6bd83..15519e3 100644 --- a/tmmax/fresnel.py +++ b/tmmax/fresnel.py @@ -1,24 +1,26 @@ -from typing import Union, Tuple import jax.numpy as jnp +from jax import Array +from jax import jit +from jax.typing import ArrayLike -def _fresnel_s(_first_layer_n: Union[float, jnp.ndarray], - _second_layer_n: Union[float, jnp.ndarray], - _first_layer_theta: Union[float, jnp.ndarray], - _second_layer_theta: Union[float, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: +def fresnel_s(first_layer_n: ArrayLike, + second_layer_n: ArrayLike, + first_layer_theta: ArrayLike, + second_layer_theta: ArrayLike) -> Array: """ - This function calculates the Fresnel reflection (r_s) and transmission (t_s) coefficients - for s-polarized light (electric field perpendicular to the plane of incidence) at the interface - between two materials. The inputs are the refractive indices and the angles of incidence and + This function calculates the Fresnel reflection (r_s) and transmission (t_s) coefficients + for s-polarized light (electric field perpendicular to the plane of incidence) at the interface + between two materials. The inputs are the refractive indices and the angles of incidence and refraction for the two layers. Args: - _first_layer_n (Union[float, jnp.ndarray]): Refractive index of the first layer (incident medium). + _first_layer_n (Union[float, jnp.ndarray]): Refractive index of the first layer (incident medium). Can be a float or an array if computing for multiple incident angles/materials. - _second_layer_n (Union[float, jnp.ndarray]): Refractive index of the second layer (transmitted medium). + _second_layer_n (Union[float, jnp.ndarray]): Refractive index of the second layer (transmitted medium). Similar to the first argument, this can also be a float or an array. - _first_layer_theta (Union[float, jnp.ndarray]): Angle of incidence in the first layer (in radians). + _first_layer_theta (Union[float, jnp.ndarray]): Angle of incidence in the first layer (in radians). Can be a float or an array. - _second_layer_theta (Union[float, jnp.ndarray]): Angle of refraction in the second layer (in radians). + _second_layer_theta (Union[float, jnp.ndarray]): Angle of refraction in the second layer (in radians). Can be a float or an array. Returns: @@ -26,29 +28,32 @@ def _fresnel_s(_first_layer_n: Union[float, jnp.ndarray], - r_s: The Fresnel reflection coefficient for s-polarized light. - t_s: The Fresnel transmission coefficient for s-polarized light. """ - + cos_first_theta = jnp.cos(first_layer_theta) + cos_second_theta = jnp.cos(second_layer_theta) + first_ncostheta = jnp.multiply(first_layer_n, cos_first_theta) + second_ncostheta = jnp.multiply(second_layer_n, cos_second_theta) + add_ncosthetas = jnp.add(first_ncostheta, second_ncostheta) # Calculate the reflection coefficient for s-polarized light using Fresnel's equations. # The formula: r_s = (n1 * cos(theta1) - n2 * cos(theta2)) / (n1 * cos(theta1) + n2 * cos(theta2)) # This measures how much of the light is reflected at the interface. - r_s = ((_first_layer_n * jnp.cos(_first_layer_theta) - _second_layer_n * jnp.cos(_second_layer_theta)) / - (_first_layer_n * jnp.cos(_first_layer_theta) + _second_layer_n * jnp.cos(_second_layer_theta))) - + + r_s = jnp.true_divide(jnp.subtract(first_ncostheta, second_ncostheta), add_ncosthetas) + # Calculate the transmission coefficient for s-polarized light using Fresnel's equations. # The formula: t_s = 2 * n1 * cos(theta1) / (n1 * cos(theta1) + n2 * cos(theta2)) # This measures how much of the light is transmitted through the interface. - t_s = (2 * _first_layer_n * jnp.cos(_first_layer_theta) / - (_first_layer_n * jnp.cos(_first_layer_theta) + _second_layer_n * jnp.cos(_second_layer_theta))) - + t_s = jnp.true_divide(jnp.multiply(2,first_ncostheta),add_ncosthetas) + # Return the reflection and transmission coefficients as a JAX array - return jnp.array([r_s, t_s]) + return jnp.stack([r_s, t_s]) -def _fresnel_p(_first_layer_n: Union[float, jnp.ndarray], - _second_layer_n: Union[float, jnp.ndarray], - _first_layer_theta: Union[float, jnp.ndarray], - _second_layer_theta: Union[float, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: +def fresnel_p(first_layer_n: ArrayLike, + second_layer_n: ArrayLike, + first_layer_theta: ArrayLike, + second_layer_theta: ArrayLike) -> Array: """ - This function calculates the Fresnel reflection (r_p) and transmission (t_p) coefficients + This function calculates the Fresnel reflection (r_p) and transmission (t_p) coefficients for p-polarized light at the interface between two different media. It uses the refractive indices of the two media (_first_layer_n and _second_layer_n) and the incident and transmitted angles (_first_layer_theta and _second_layer_theta) to compute these values. @@ -64,19 +69,22 @@ def _fresnel_p(_first_layer_n: Union[float, jnp.ndarray], - r_p: The reflection coefficient for p-polarized light. - t_p: The transmission coefficient for p-polarized light. """ - + cos_first_theta = jnp.cos(first_layer_theta) + cos_second_theta = jnp.cos(second_layer_theta) + second_n_first_costheta = jnp.multiply(second_layer_n, cos_first_theta) + first_n_second_costheta = jnp.multiply(first_layer_n, cos_second_theta) + add_ncosthetas = jnp.add(second_n_first_costheta, first_n_second_costheta) # Calculate the reflection coefficient for p-polarized light (r_p) - # This equation is based on the Fresnel equations for p-polarization, where + # This equation is based on the Fresnel equations for p-polarization, where # r_p is the ratio of the reflected and incident electric field amplitudes for p-polarized light. - r_p = ((_second_layer_n * jnp.cos(_first_layer_theta) - _first_layer_n * jnp.cos(_second_layer_theta)) / - (_second_layer_n * jnp.cos(_first_layer_theta) + _first_layer_n * jnp.cos(_second_layer_theta))) + r_p = jnp.true_divide(jnp.subtract(second_n_first_costheta, first_n_second_costheta), add_ncosthetas) + # Calculate the transmission coefficient for p-polarized light (t_p) # This equation is also derived from the Fresnel equations for p-polarization. # t_p represents the ratio of the transmitted and incident electric field amplitudes. - t_p = (2 * _first_layer_n * jnp.cos(_first_layer_theta) / - (_second_layer_n * jnp.cos(_first_layer_theta) + _first_layer_n * jnp.cos(_second_layer_theta))) + t_p = jnp.true_divide(jnp.multiply(2,jnp.multiply(first_layer_n, cos_first_theta)),add_ncosthetas) # Return the reflection and transmission coefficients as a tuple of jnp arrays # Both r_p and t_p are essential for understanding how light interacts with different layers. - return jnp.array([r_p, t_p]) \ No newline at end of file + return jnp.stack([r_p, t_p]) \ No newline at end of file diff --git a/tmmax/nk_data/ALON.csv b/tmmax/nk_data/csv/ALON.csv similarity index 100% rename from tmmax/nk_data/ALON.csv rename to tmmax/nk_data/csv/ALON.csv diff --git a/tmmax/nk_data/Air.csv b/tmmax/nk_data/csv/Air.csv similarity index 100% rename from tmmax/nk_data/Air.csv rename to tmmax/nk_data/csv/Air.csv diff --git a/tmmax/nk_data/Al2O3.csv b/tmmax/nk_data/csv/Al2O3.csv similarity index 100% rename from tmmax/nk_data/Al2O3.csv rename to tmmax/nk_data/csv/Al2O3.csv diff --git a/tmmax/nk_data/CaF2.csv b/tmmax/nk_data/csv/CaF2.csv similarity index 100% rename from tmmax/nk_data/CaF2.csv rename to tmmax/nk_data/csv/CaF2.csv diff --git a/tmmax/nk_data/CdS.csv b/tmmax/nk_data/csv/CdS.csv similarity index 100% rename from tmmax/nk_data/CdS.csv rename to tmmax/nk_data/csv/CdS.csv diff --git a/tmmax/nk_data/CdTe.csv b/tmmax/nk_data/csv/CdTe.csv similarity index 100% rename from tmmax/nk_data/CdTe.csv rename to tmmax/nk_data/csv/CdTe.csv diff --git a/tmmax/nk_data/CeF3.csv b/tmmax/nk_data/csv/CeF3.csv similarity index 100% rename from tmmax/nk_data/CeF3.csv rename to tmmax/nk_data/csv/CeF3.csv diff --git a/tmmax/nk_data/Ge.csv b/tmmax/nk_data/csv/Ge.csv similarity index 100% rename from tmmax/nk_data/Ge.csv rename to tmmax/nk_data/csv/Ge.csv diff --git a/tmmax/nk_data/HfO2.csv b/tmmax/nk_data/csv/HfO2.csv similarity index 100% rename from tmmax/nk_data/HfO2.csv rename to tmmax/nk_data/csv/HfO2.csv diff --git a/tmmax/nk_data/LaF3.csv b/tmmax/nk_data/csv/LaF3.csv similarity index 100% rename from tmmax/nk_data/LaF3.csv rename to tmmax/nk_data/csv/LaF3.csv diff --git a/tmmax/nk_data/LiF.csv b/tmmax/nk_data/csv/LiF.csv similarity index 100% rename from tmmax/nk_data/LiF.csv rename to tmmax/nk_data/csv/LiF.csv diff --git a/tmmax/nk_data/MgF2.csv b/tmmax/nk_data/csv/MgF2.csv similarity index 100% rename from tmmax/nk_data/MgF2.csv rename to tmmax/nk_data/csv/MgF2.csv diff --git a/tmmax/nk_data/MgO.csv b/tmmax/nk_data/csv/MgO.csv similarity index 100% rename from tmmax/nk_data/MgO.csv rename to tmmax/nk_data/csv/MgO.csv diff --git a/tmmax/nk_data/NaF.csv b/tmmax/nk_data/csv/NaF.csv similarity index 100% rename from tmmax/nk_data/NaF.csv rename to tmmax/nk_data/csv/NaF.csv diff --git a/tmmax/nk_data/Nb2O5.csv b/tmmax/nk_data/csv/Nb2O5.csv similarity index 100% rename from tmmax/nk_data/Nb2O5.csv rename to tmmax/nk_data/csv/Nb2O5.csv diff --git a/tmmax/nk_data/PbF2.csv b/tmmax/nk_data/csv/PbF2.csv similarity index 100% rename from tmmax/nk_data/PbF2.csv rename to tmmax/nk_data/csv/PbF2.csv diff --git a/tmmax/nk_data/Sc2O3.csv b/tmmax/nk_data/csv/Sc2O3.csv similarity index 100% rename from tmmax/nk_data/Sc2O3.csv rename to tmmax/nk_data/csv/Sc2O3.csv diff --git a/tmmax/nk_data/Si.csv b/tmmax/nk_data/csv/Si.csv similarity index 100% rename from tmmax/nk_data/Si.csv rename to tmmax/nk_data/csv/Si.csv diff --git a/tmmax/nk_data/Si3N4.csv b/tmmax/nk_data/csv/Si3N4.csv similarity index 100% rename from tmmax/nk_data/Si3N4.csv rename to tmmax/nk_data/csv/Si3N4.csv diff --git a/tmmax/nk_data/SiO.csv b/tmmax/nk_data/csv/SiO.csv similarity index 100% rename from tmmax/nk_data/SiO.csv rename to tmmax/nk_data/csv/SiO.csv diff --git a/tmmax/nk_data/SiO2.csv b/tmmax/nk_data/csv/SiO2.csv similarity index 100% rename from tmmax/nk_data/SiO2.csv rename to tmmax/nk_data/csv/SiO2.csv diff --git a/tmmax/nk_data/SrF2.csv b/tmmax/nk_data/csv/SrF2.csv similarity index 100% rename from tmmax/nk_data/SrF2.csv rename to tmmax/nk_data/csv/SrF2.csv diff --git a/tmmax/nk_data/Ta2O5.csv b/tmmax/nk_data/csv/Ta2O5.csv similarity index 100% rename from tmmax/nk_data/Ta2O5.csv rename to tmmax/nk_data/csv/Ta2O5.csv diff --git a/tmmax/nk_data/Te.csv b/tmmax/nk_data/csv/Te.csv similarity index 100% rename from tmmax/nk_data/Te.csv rename to tmmax/nk_data/csv/Te.csv diff --git a/tmmax/nk_data/TiO2.csv b/tmmax/nk_data/csv/TiO2.csv similarity index 100% rename from tmmax/nk_data/TiO2.csv rename to tmmax/nk_data/csv/TiO2.csv diff --git a/tmmax/nk_data/Y2O3.csv b/tmmax/nk_data/csv/Y2O3.csv similarity index 100% rename from tmmax/nk_data/Y2O3.csv rename to tmmax/nk_data/csv/Y2O3.csv diff --git a/tmmax/nk_data/ZnS.csv b/tmmax/nk_data/csv/ZnS.csv similarity index 100% rename from tmmax/nk_data/ZnS.csv rename to tmmax/nk_data/csv/ZnS.csv diff --git a/tmmax/nk_data/ZnSe.csv b/tmmax/nk_data/csv/ZnSe.csv similarity index 100% rename from tmmax/nk_data/ZnSe.csv rename to tmmax/nk_data/csv/ZnSe.csv diff --git a/tmmax/nk_data/ZrO2.csv b/tmmax/nk_data/csv/ZrO2.csv similarity index 100% rename from tmmax/nk_data/ZrO2.csv rename to tmmax/nk_data/csv/ZrO2.csv diff --git a/tmmax/nk_data/numpy/ALON.npy b/tmmax/nk_data/numpy/ALON.npy new file mode 100644 index 0000000..3ec725c Binary files /dev/null and b/tmmax/nk_data/numpy/ALON.npy differ diff --git a/tmmax/nk_data/numpy/Air.npy b/tmmax/nk_data/numpy/Air.npy new file mode 100644 index 0000000..360aae2 Binary files /dev/null and b/tmmax/nk_data/numpy/Air.npy differ diff --git a/tmmax/nk_data/numpy/Al2O3.npy b/tmmax/nk_data/numpy/Al2O3.npy new file mode 100644 index 0000000..d41e9f5 Binary files /dev/null and b/tmmax/nk_data/numpy/Al2O3.npy differ diff --git a/tmmax/nk_data/numpy/CaF2.npy b/tmmax/nk_data/numpy/CaF2.npy new file mode 100644 index 0000000..23eec20 Binary files /dev/null and b/tmmax/nk_data/numpy/CaF2.npy differ diff --git a/tmmax/nk_data/numpy/CdS.npy b/tmmax/nk_data/numpy/CdS.npy new file mode 100644 index 0000000..3621d0e Binary files /dev/null and b/tmmax/nk_data/numpy/CdS.npy differ diff --git a/tmmax/nk_data/numpy/CdTe.npy b/tmmax/nk_data/numpy/CdTe.npy new file mode 100644 index 0000000..3b0aa2f Binary files /dev/null and b/tmmax/nk_data/numpy/CdTe.npy differ diff --git a/tmmax/nk_data/numpy/CeF3.npy b/tmmax/nk_data/numpy/CeF3.npy new file mode 100644 index 0000000..7b7f46c Binary files /dev/null and b/tmmax/nk_data/numpy/CeF3.npy differ diff --git a/tmmax/nk_data/numpy/Ge.npy b/tmmax/nk_data/numpy/Ge.npy new file mode 100644 index 0000000..3c9c3ed Binary files /dev/null and b/tmmax/nk_data/numpy/Ge.npy differ diff --git a/tmmax/nk_data/numpy/HfO2.npy b/tmmax/nk_data/numpy/HfO2.npy new file mode 100644 index 0000000..4986baf Binary files /dev/null and b/tmmax/nk_data/numpy/HfO2.npy differ diff --git a/tmmax/nk_data/numpy/LaF3.npy b/tmmax/nk_data/numpy/LaF3.npy new file mode 100644 index 0000000..82b5cec Binary files /dev/null and b/tmmax/nk_data/numpy/LaF3.npy differ diff --git a/tmmax/nk_data/numpy/LiF.npy b/tmmax/nk_data/numpy/LiF.npy new file mode 100644 index 0000000..74b2508 Binary files /dev/null and b/tmmax/nk_data/numpy/LiF.npy differ diff --git a/tmmax/nk_data/numpy/MgF2.npy b/tmmax/nk_data/numpy/MgF2.npy new file mode 100644 index 0000000..cba5467 Binary files /dev/null and b/tmmax/nk_data/numpy/MgF2.npy differ diff --git a/tmmax/nk_data/numpy/MgO.npy b/tmmax/nk_data/numpy/MgO.npy new file mode 100644 index 0000000..ce64aac Binary files /dev/null and b/tmmax/nk_data/numpy/MgO.npy differ diff --git a/tmmax/nk_data/numpy/NaF.npy b/tmmax/nk_data/numpy/NaF.npy new file mode 100644 index 0000000..0d6425f Binary files /dev/null and b/tmmax/nk_data/numpy/NaF.npy differ diff --git a/tmmax/nk_data/numpy/Nb2O5.npy b/tmmax/nk_data/numpy/Nb2O5.npy new file mode 100644 index 0000000..c0b0016 Binary files /dev/null and b/tmmax/nk_data/numpy/Nb2O5.npy differ diff --git a/tmmax/nk_data/numpy/PbF2.npy b/tmmax/nk_data/numpy/PbF2.npy new file mode 100644 index 0000000..8a49a9d Binary files /dev/null and b/tmmax/nk_data/numpy/PbF2.npy differ diff --git a/tmmax/nk_data/numpy/Sc2O3.npy b/tmmax/nk_data/numpy/Sc2O3.npy new file mode 100644 index 0000000..53d3acf Binary files /dev/null and b/tmmax/nk_data/numpy/Sc2O3.npy differ diff --git a/tmmax/nk_data/numpy/Si.npy b/tmmax/nk_data/numpy/Si.npy new file mode 100644 index 0000000..583faa3 Binary files /dev/null and b/tmmax/nk_data/numpy/Si.npy differ diff --git a/tmmax/nk_data/numpy/Si3N4.npy b/tmmax/nk_data/numpy/Si3N4.npy new file mode 100644 index 0000000..60b4bb4 Binary files /dev/null and b/tmmax/nk_data/numpy/Si3N4.npy differ diff --git a/tmmax/nk_data/numpy/SiO.npy b/tmmax/nk_data/numpy/SiO.npy new file mode 100644 index 0000000..788f317 Binary files /dev/null and b/tmmax/nk_data/numpy/SiO.npy differ diff --git a/tmmax/nk_data/numpy/SiO2.npy b/tmmax/nk_data/numpy/SiO2.npy new file mode 100644 index 0000000..40430da Binary files /dev/null and b/tmmax/nk_data/numpy/SiO2.npy differ diff --git a/tmmax/nk_data/numpy/SrF2.npy b/tmmax/nk_data/numpy/SrF2.npy new file mode 100644 index 0000000..9840000 Binary files /dev/null and b/tmmax/nk_data/numpy/SrF2.npy differ diff --git a/tmmax/nk_data/numpy/Ta2O5.npy b/tmmax/nk_data/numpy/Ta2O5.npy new file mode 100644 index 0000000..0f12b49 Binary files /dev/null and b/tmmax/nk_data/numpy/Ta2O5.npy differ diff --git a/tmmax/nk_data/numpy/Te.npy b/tmmax/nk_data/numpy/Te.npy new file mode 100644 index 0000000..da92ad9 Binary files /dev/null and b/tmmax/nk_data/numpy/Te.npy differ diff --git a/tmmax/nk_data/numpy/TiO2.npy b/tmmax/nk_data/numpy/TiO2.npy new file mode 100644 index 0000000..70ceb90 Binary files /dev/null and b/tmmax/nk_data/numpy/TiO2.npy differ diff --git a/tmmax/nk_data/numpy/Y2O3.npy b/tmmax/nk_data/numpy/Y2O3.npy new file mode 100644 index 0000000..e9cdb63 Binary files /dev/null and b/tmmax/nk_data/numpy/Y2O3.npy differ diff --git a/tmmax/nk_data/numpy/ZnS.npy b/tmmax/nk_data/numpy/ZnS.npy new file mode 100644 index 0000000..cc6ef56 Binary files /dev/null and b/tmmax/nk_data/numpy/ZnS.npy differ diff --git a/tmmax/nk_data/numpy/ZnSe.npy b/tmmax/nk_data/numpy/ZnSe.npy new file mode 100644 index 0000000..3b07e96 Binary files /dev/null and b/tmmax/nk_data/numpy/ZnSe.npy differ diff --git a/tmmax/nk_data/numpy/ZrO2.npy b/tmmax/nk_data/numpy/ZrO2.npy new file mode 100644 index 0000000..b73ce8c Binary files /dev/null and b/tmmax/nk_data/numpy/ZrO2.npy differ diff --git a/tmmax/reflect_transmit.py b/tmmax/reflect_transmit.py index 25bdcc8..33ebd28 100644 --- a/tmmax/reflect_transmit.py +++ b/tmmax/reflect_transmit.py @@ -1,241 +1,109 @@ +import jax.numpy as jnp +from jax import vmap import jax -jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX -import jax.numpy as jnp # Import JAX's version of NumPy for differentiable computations -from typing import Union, List - -from .fresnel import _fresnel_s, _fresnel_p - -def _compute_rt_at_interface_s(carry, concatenated_nk_list_theta): - """ - This function calculates the reflection (r) and transmission (t) coefficients - for s-polarization at the interface between two layers in a multilayer thin-film system. - It uses the Fresnel equations for s-polarized light. The function is designed to be used - in a JAX `lax.scan` loop, where it processes each interface iteratively. - - Args: - carry: A tuple containing the index (carry_idx) and a matrix (carry_values) - where the reflection and transmission coefficients will be stored. - - carry_idx (int): The current index, indicating which layer interface is being processed. - - carry_values (array): An array to store the r,t coefficients for each interface. - - concatenated_nk_list_theta: A tuple containing two arrays: - - stacked_nk_list (array): The refractive indices (n) of two consecutive layers at the interface. - - stacked_layer_angles (array): The angles of incidence for the two consecutive layers. - - Returns: - A tuple of: - - Updated carry: The new index and updated matrix with the calculated r,t coefficients. - - None: Required to match the JAX `lax.scan` interface, where a second argument is expected. - """ - - # Unpack the concatenated list into refractive index list and angle list - stacked_nk_list, stacked_layer_angles = concatenated_nk_list_theta - # `stacked_nk_list`: contains the refractive indices of two consecutive layers at the interface - # `stacked_layer_angles`: contains the angles of incidence for these two layers - - # Unpack the carry tuple - carry_idx, carry_values = carry - # `carry_idx`: current index in the process, starts from 0 and iterates over layer interfaces - # `carry_values`: the array that stores the reflection and transmission coefficients - - # Compute the reflection and transmission coefficients using the Fresnel equations for s-polarization - r_t_matrix = _fresnel_s(_first_layer_theta=stacked_layer_angles[0], # Incident angle of the first layer - _second_layer_theta=stacked_layer_angles[1], # Incident angle of the second layer - _first_layer_n=stacked_nk_list[0], # Refractive index of the first layer - _second_layer_n=stacked_nk_list[1]) # Refractive index of the second layer - # This line computes r and t coefficients between two consecutive layers - # based on their refractive indices and angles of incidence. - - # Store the computed r,t matrix in the `carry_values` array at the current index - carry_values = carry_values.at[carry_idx, :].set(r_t_matrix) # Set r,t coefficients at the current index - # The `carry_values.at[carry_idx, :].set(r_t_matrix)` updates the array at position `carry_idx` - # with the computed r,t coefficients. - - carry_idx = carry_idx + 1 # Move to the next index for the next iteration - # The carry index is incremented to process the next layer interface in subsequent iterations. - - # Return the updated carry (with new index and r,t coefficients) and None for lax.scan compatibility - return (carry_idx, carry_values), None - - -def _compute_rt_at_interface_p(carry, concatenated_nk_list_theta): - """ - This function computes the reflection and transmission (r, t) coefficients at the interface between two layers - for P-polarized light (parallel polarization). It uses the Fresnel equations to calculate these coefficients - based on the refractive indices and angles of incidence and refraction for the two layers. - - Args: - carry: A tuple (carry_idx, carry_values) where: - - carry_idx: The current index that keeps track of the layer. - - carry_values: A matrix to store the computed reflection and transmission coefficients. - - concatenated_nk_list_theta: A tuple (stacked_nk_list, stacked_layer_angles) where: - - stacked_nk_list: A list of refractive indices of the two consecutive layers. - - stacked_layer_angles: A list of angles of incidence and refraction at the interface between the layers. - - Returns: - A tuple: - - Updated carry containing: - - carry_idx incremented by 1. - - carry_values with the newly computed r, t coefficients at the current interface. - - None (This is used to maintain the structure of a functional-style loop but has no further use). - """ - - # Unpack the concatenated data into two variables: refractive indices (nk) and angles (theta) - stacked_nk_list, stacked_layer_angles = concatenated_nk_list_theta # Extract the refractive indices and angles from the input tuple - carry_idx, carry_values = carry # Unpack carry: carry_idx is the current index, carry_values stores r and t coefficients - - # Compute reflection (r) and transmission (t) coefficients at the interface using Fresnel equations for P-polarized light - r_t_matrix = _fresnel_p(_first_layer_theta = stacked_layer_angles[0], # Incident angle at the first layer - _second_layer_theta = stacked_layer_angles[1], # Refraction angle at the second layer - _first_layer_n = stacked_nk_list[0], # Refractive index of the first layer - _second_layer_n = stacked_nk_list[1]) # Refractive index of the second layer - - # Update carry_values by setting the r,t matrix at the current index (carry_idx) - carry_values = carry_values.at[carry_idx, :].set(r_t_matrix) # Store the computed r,t matrix into the carry_values at the index 'carry_idx' - - carry_idx = carry_idx + 1 # Move to the next index for further iterations - return (carry_idx, carry_values), None # Return the updated carry with incremented index and updated r,t values, and None as a placeholder - -def _compute_rt_one_wl(nk_list: jnp.ndarray, layer_angles: jnp.ndarray, - wavelength: Union[float, jnp.ndarray], polarization: bool) -> jnp.ndarray: - """ - Computes the reflectance and transmittance for a single wavelength - across multiple layers in a stack of materials. The computation - takes into account the refractive index of each layer, the angle of - incidence in each layer, the wavelength of the light, and the - polarization of the light. - - Args: - nk_list (jnp.ndarray): Array of complex refractive indices for each layer. - The shape should be (num_layers,). - layer_angles (jnp.ndarray): Array of angles of incidence for each layer. - The shape should be (num_layers,). - wavelength (float or jnp.ndarray): The wavelength of light, given as either - a scalar or a JAX array. - polarization (bool): Boolean flag that determines the polarization state of the light. - If False, s-polarization is used; if True, p-polarization is used. - - Returns: - jnp.ndarray: A 1D JAX array representing the reflectance and transmittance - coefficients at the specified wavelength and polarization. - """ - - # Initialize the state for `jax.lax.scan`. The first element (0) is a placeholder - # and won't be used. The second element is a 2D array of zeros to hold intermediate - # results, representing the reflectance and transmittance across layers. - init_state = (0, jnp.zeros((len(nk_list) - 2, 2), dtype=jnp.float32)) # Initial state with an array of zeros - # The shape of `jnp.zeros` is (num_layers - 2, 2) because we exclude the first - # and last layers, assuming they are boundary layers. - - # Stack the refractive indices (`nk_list`) for each adjacent pair of layers. - # This creates a new array where each element contains a pair of adjacent refractive indices - # from `nk_list`, which will be used to compute the reflection and transmission at the interface - # between these two layers. - stacked_nk_list = jnp.stack([nk_list[:-2], nk_list[1:-1]], axis=1) # Stack the original and shifted inputs for processing in pairs - # For example, if `nk_list` is [n1, n2, n3, n4], this will create pairs [(n1, n2), (n2, n3), (n3, n4)]. - - # Similarly, stack the angles for adjacent layers. - # The same logic applies to `layer_angles` as for `nk_list`. Each pair of adjacent layers - # will have an associated pair of angles. - stacked_layer_angles = jnp.stack([layer_angles[:-2], layer_angles[1:-1]], axis=1) - # This operation aligns the angles with the corresponding refractive indices. - - # Now we need to compute reflectance and transmittance for each interface. - # This can be done using `jax.lax.scan`, which efficiently loops over the stacked pairs - # of refractive indices and angles. - - # If the light is s-polarized (polarization = False), we call the function `_compute_rt_at_interface_s`. - # This function calculates the reflection and transmission coefficients specifically for s-polarized light. - if polarization == False: - rt_one_wl, _ = jax.lax.scan(_compute_rt_at_interface_s, init_state, (stacked_nk_list, stacked_layer_angles)) # s-polarization case - # `jax.lax.scan` applies the function `_compute_rt_at_interface_s` to each pair of adjacent layers - # along with the corresponding angles. It processes this in a loop, accumulating the results. - - # If the light is p-polarized (polarization = True), we use `_compute_rt_at_interface_p` instead. - # This function handles p-polarized light. - elif polarization == True: - rt_one_wl, _ = jax.lax.scan(_compute_rt_at_interface_p, init_state, (stacked_nk_list, stacked_layer_angles)) # p-polarization case - # The same process as above but with a function specific to p-polarized light. - - # Finally, return the computed reflectance and transmittance coefficients. - # The result is stored in `rt_one_wl[1]` (the second element of `rt_one_wl`), which corresponds - # to the reflectance and transmittance after all layers have been processed. - return rt_one_wl[1] # Return a 1D theta array for each layer - # This output is the desired result: the reflectance and transmittance for the given wavelength. - - -def _calculate_transmittace_from_coeff(t: Union[float, jnp.ndarray], - n_list_first: Union[complex, jnp.ndarray], - n_list_last: Union[complex, jnp.ndarray], - angle_of_incidence: Union[float, jnp.ndarray], - last_layer_angle: Union[complex, jnp.ndarray], - polarization: bool) -> jnp.ndarray: - - """ - Computes the transmittance for light passing through layers with potential polarization effects. - - Args: - t (float or jnp.ndarray): The transmission coefficient or array of coefficients. - n_list_first (complex or jnp.ndarray): The refractive index of the first layer or an array of indices. - n_list_last (complex or jnp.ndarray): The refractive index of the last layer or an array of indices. - angle_of_incidence (float or jnp.ndarray): The angle of incidence in radians or an array of angles. - last_layer_angle (complex or jnp.ndarray): The angle in the last layer, can be complex, in radians or array. - polarization (bool): Indicates if polarization effects should be considered (True) or not (False). - - Returns: - jnp.ndarray: The calculated transmittance, which takes into account the polarization if specified. - """ - - # Check if polarization effect should be considered - if not polarization: # If polarization is False - # Compute transmittance without polarization - return jnp.abs(t)**2 * ( # Square of the magnitude of t - jnp.real(n_list_last * jnp.cos(last_layer_angle)) / # Real part of (n_last * cos(last_layer_angle)) - jnp.real(n_list_first * jnp.cos(angle_of_incidence)) # Real part of (n_first * cos(angle_of_incidence)) - ) - else: # If polarization is True - # Compute transmittance considering polarization - return jnp.abs(t)**2 * ( # Square of the magnitude of t - jnp.real(n_list_last * jnp.conj(jnp.cos(last_layer_angle))) / # Real part of (n_last * conjugate(cos(last_layer_angle))) - jnp.real(n_list_first * jnp.conj(jnp.cos(angle_of_incidence))) # Real part of (n_first * conjugate(cos(angle_of_incidence))) - ) - - -def _create_phases_ts_rs(_trs: jnp.ndarray, _phases: jnp.ndarray) -> jnp.ndarray: +import numpy as np +import pickle +import sys +from jax import Array +from jax.typing import ArrayLike + +from .fresnel import fresnel_s, fresnel_p + +#@jit +def calculate_reflectance_from_coeff(r: ArrayLike) -> Array: + return jnp.square(jnp.abs(r)) + + +def calculate_transmittace_from_coeff_s_pol(t: ArrayLike, + nk_first_layer_of_slab: ArrayLike, + angle_first_layer_of_slab: ArrayLike, + nk_last_layer_of_slab: ArrayLike, + angle_last_layer_of_slab: ArrayLike) -> Array: + + T = jnp.multiply(jnp.square(jnp.abs(t)), jnp.true_divide(jnp.real(jnp.multiply(nk_last_layer_of_slab, jnp.cos(angle_last_layer_of_slab))), + jnp.real(jnp.multiply(nk_first_layer_of_slab, jnp.cos(angle_first_layer_of_slab))))) + return T + +def calculate_transmittace_from_coeff_p_pol(t: ArrayLike, + nk_first_layer_of_slab: ArrayLike, + angle_first_layer_of_slab: ArrayLike, + nk_last_layer_of_slab: ArrayLike, + angle_last_layer_of_slab: ArrayLike) -> Array: + + T = jnp.multiply(jnp.square(jnp.abs(t)), jnp.true_divide(jnp.real(jnp.multiply(nk_last_layer_of_slab, jnp.conj(jnp.cos(angle_last_layer_of_slab)))), + jnp.real(jnp.multiply(nk_first_layer_of_slab, jnp.conj(jnp.cos(angle_first_layer_of_slab)))))) + return T + + +#@jit +def calculate_transmittace_from_coeff(t: ArrayLike, + nk_first_layer_of_slab: ArrayLike, + angle_first_layer_of_slab: ArrayLike, + nk_last_layer_of_slab: ArrayLike, + angle_last_layer_of_slab: ArrayLike, + polarization: ArrayLike) -> Array: + + return jnp.select(condlist=[jnp.array_equal(polarization, jnp.array([0], dtype = jnp.int16)), + jnp.array_equal(polarization, jnp.array([1], dtype = jnp.int16))], + choicelist=[calculate_transmittace_from_coeff_s_pol(t, + nk_first_layer_of_slab, + angle_first_layer_of_slab, + nk_last_layer_of_slab, + angle_last_layer_of_slab), + calculate_transmittace_from_coeff_p_pol(t, + nk_first_layer_of_slab, + angle_first_layer_of_slab, + nk_last_layer_of_slab, + angle_last_layer_of_slab)]) + + +def compute_rt_at_interface_s(layer_idx: ArrayLike, + nk_angles_stack: ArrayLike) -> Array: + rt = fresnel_s(first_layer_n = nk_angles_stack.at[layer_idx,0].get(), + second_layer_n = nk_angles_stack.at[jnp.add(layer_idx, jnp.array([1], dtype = jnp.int32)), 0].get(), + first_layer_theta = nk_angles_stack.at[layer_idx, 1].get(), + second_layer_theta = nk_angles_stack.at[jnp.add(layer_idx, jnp.array([1], dtype = jnp.int32)), 1].get()) + #print("rt shape: ", jnp.shape(nk_angles_stack)) + return rt + +def compute_rt_at_interface_p(layer_idx: ArrayLike, + nk_angles_stack: ArrayLike) -> Array: + rt = fresnel_p(first_layer_n = nk_angles_stack.at[layer_idx,0].get(), + second_layer_n = nk_angles_stack.at[jnp.add(layer_idx, jnp.array([1], dtype = jnp.int32)), 0].get(), + first_layer_theta = nk_angles_stack.at[layer_idx, 1].get(), + second_layer_theta = nk_angles_stack.at[jnp.add(layer_idx, jnp.array([1], dtype = jnp.int32)), 1].get()) + return rt + + +def vectorized_rt_s_pol(): + return vmap(compute_rt_at_interface_s, (0, None)) + +def vectorized_rt_p_pol(): + return vmap(compute_rt_at_interface_p, (0, None)) + +def polarization_based_rt_selection(layer_indices: ArrayLike, nk_angles_stack: ArrayLike, polarization: ArrayLike) -> Array: + + return jnp.select(condlist=[jnp.array_equal(polarization, jnp.array([0], dtype = jnp.int16)), + jnp.array_equal(polarization, jnp.array([1], dtype = jnp.int16))], + choicelist=[vectorized_rt_s_pol()(layer_indices, nk_angles_stack), + vectorized_rt_p_pol()(layer_indices, nk_angles_stack)]) + +#@jit +def compute_rt(nk_list: ArrayLike, angles: ArrayLike, polarization: ArrayLike) -> Array: """ - Create a new array combining phase and ts values. - - Args: - _trs (jnp.ndarray): A 2D array of shape (N, 2) where N is the number of elements. - Each element is a pair of values [t, s]. - _phases (jnp.ndarray): A 1D array of shape (N,) containing phase values for each element. + Calculates the angles of incidence across layers for a set of refractive indices (nk_list_2d) + and an initial angle of incidence (initial_theta) using vectorization. Returns: - jnp.ndarray: A 2D array of shape (N, 3) where each row is [phase, t, s]. - The phase is from _phases, and t, s are from _trs. + jnp.ndarray: A 3D JAX array where the [i, j, :] entry represents the angles of incidence + for the j-th initial angle at the i-th wavelength. The size of the third dimension + corresponds to the number of layers. """ - - N = _phases.shape[0] # Get the number of elements (N) in the _phases array - - def process_element(i: int) -> List[float]: - """ - Process an individual element to create a list of phase and ts values. - - Args: - i (int): Index of the element to process. - - Returns: - List[float]: A list containing [phase, t, s] where: - - phase: The phase value from _phases at index i - - t: The first value of the pair in _trs at index i - - s: The second value of the pair in _trs at index i - """ - return [_phases[i], _trs[i][0], _trs[i][1]] # Return the phase and ts values as a list - - # Apply process_element function across all indices from 0 to N-1 - result = jax.vmap(process_element)(jnp.arange(N)) # jax.vmap vectorizes the process_element function - # to apply it across all indices efficiently - - return result # Return the result as a 2D array of shape (N, 3) - + #print(jnp.shape(nk_list)) + #print(jnp.shape(angles)) + nk_angles_stack = jnp.concat([jnp.expand_dims(nk_list, 1), jnp.expand_dims(angles, 1)], axis=1) + #print(jnp.shape(nk_angles_stack)) + stop_value = int(jnp.size(nk_list)) - 1 # Concrete integer + layer_indices = jnp.arange(stop_value, dtype=jnp.int32) + + return polarization_based_rt_selection(layer_indices, nk_angles_stack, polarization) diff --git a/tmmax/tmm.py b/tmmax/tmm.py index 5b3d624..09daf3b 100644 --- a/tmmax/tmm.py +++ b/tmmax/tmm.py @@ -1,48 +1,76 @@ import jax -jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX import jax.numpy as jnp # Import JAX's version of NumPy for differentiable computations -from jax import vmap -from typing import Union, List, Tuple, Text, Dict, Callable +from jax import jit, vmap -from .angle import _compute_layer_angles_single_wl_angle_point +from .angle import compute_layer_angles +from .wavevector import compute_kz from .cascaded_matmul import _cascaded_matrix_multiplication from .data import interpolate_nk from .reflect_transmit import _compute_rt_one_wl, _create_phases_ts_rs, _calculate_transmittace_from_coeff -def _compute_kz_single_wl_angle_point( - nk_list: jnp.ndarray, # Array of complex refractive indices for different wavelengths - layer_angles: Union[int, jnp.ndarray], # Angle of incidence for each layer, can be a single angle or an array - wavelength: Union[int, jnp.ndarray] # Wavelengths corresponding to the refractive indices, can be a single wavelength or an array -) -> jnp.ndarray: # Returns an array of computed kz values for each wavelength and angle - """ - Computes the z-component of the wave vector (kz) for a given set of refractive indices, layer angles, and wavelengths. - - Args: - nk_list (jnp.ndarray): A 1D array containing the refractive indices (n) or complex indices (n + ik) for different wavelengths. - layer_angles (Union[int, jnp.ndarray]): A scalar or 1D array specifying the angle of incidence for each layer. It should be in radians. - wavelength (Union[int, jnp.ndarray]): A scalar or 1D array of wavelengths corresponding to the refractive indices. +@jit +def tmm_single_wl_angle_point_jit(nk_list: ArrayLike,thickness_list: ArrayLike, + wavelength: ArrayLike, + angle_of_incidence: ArrayLike, + polarization: ArrayLike) -> Array: - Returns: - jnp.ndarray: A 1D array of computed kz values, which represents the z-component of the wave vector for each wavelength and angle. - """ - # Calculate the z-component of the wave vector for each wavelength and angle - return 2 * jnp.pi * nk_list * jnp.cos(layer_angles) / wavelength - # 2 * jnp.pi * nk_list: Scales the refractive index to account for wavelength in radians - # jnp.cos(layer_angles): Computes the cosine of the incident angle for each layer - # / wavelength: Divides by wavelength to get the wave vector component in the z-direction + layer_angles = compute_layer_angles(angle_of_incidence, nk_list, polarization) + # Compute the angles within each layer based on the refractive indices, incidence angle, and wavelength + #print("layer_angles ", layer_angles) + #print("layer_angles shape", jnp.shape(layer_angles)) + kz = compute_kz(nk_list, layer_angles, wavelength) + # Calculate the z-component of the wave vector for each layer + #print("kz", kz) + layer_phases = jnp.multiply(kz.at[1:-1].get(), thickness_list) + # Compute the phase shifts in each layer by multiplying kz by the layer thicknesses + # `jnp.pad(thickness_list, (1), constant_values=0)` adds a leading zero to the thickness_list + #print("layer_phases", layer_phases) + rt = jnp.squeeze(compute_rt(nk_list = nk_list, angles = layer_angles, polarization = polarization)) + # Compute the reflection and transmission matrices for the wavelength + #print("rt", rt) + #print("rt shape", jnp.shape(rt)) + tr_matrix = cascaded_matrix_multiplication(phases = layer_phases, rts = rt.at[1:,:].get()) + # Perform matrix multiplication to obtain the cascaded transfer matrix for the entire stack + #print("tr_matrix", tr_matrix) + tr_matrix = jnp.multiply(jnp.true_divide(1, rt.at[0,1].get()), jnp.dot(jnp.array([[1, rt.at[0,0].get()], [rt.at[0,0].get(), 1]]), tr_matrix)) + #print("tr_matrix", tr_matrix) + # Normalize the transfer matrix and include the boundary conditions + # `jnp.dot` multiplies the transfer matrix by the boundary conditions matrix + r = jnp.true_divide(tr_matrix.at[1,0].get(), tr_matrix.at[0,0].get()) + t = jnp.true_divide(1, tr_matrix.at[0,0].get()) + #print("r", r) + #print("t", t) + # Calculate the reflectance (r) and transmittance (t) from the transfer matrix + # Reflectance is obtained by dividing the (1, 0) element by the (0, 0) element + # Transmittance is obtained by taking the reciprocal of the (0, 0) element + + R = calculate_reflectance_from_coeff(r) + T = calculate_transmittace_from_coeff(t, nk_list.at[0].get(), angle_of_incidence, nk_list.at[-1].get(), layer_angles.at[-1].get(), polarization) -def _tmm_single_wl_angle_point(nk_functions: Dict[int, Callable], material_list: list[int], - thickness_list: jnp.ndarray, wavelength: Union[float, jnp.ndarray], - angle_of_incidence: Union[float, jnp.ndarray], polarization: bool) -> Tuple[jnp.ndarray, jnp.ndarray]: + #print("R", R) + #print("T", T) + # Compute the reflectance (R) and transmittance (T) using their respective formulas + # Reflectance R is the squared magnitude of r + # Transmittance T is calculated using a function `_calculate_transmittace_from_coeff` + return R, T + # Return the reflectance and transmittance values + + + + + +def tmm_single_wl_angle_point(nk_functions: Dict[int, Callable], material_list: ArrayLike, + thickness_list: ArrayLike, wavelength: ArrayLike, + angle_of_incidence: ArrayLike, polarization: ArrayLike) -> Array: """ Computes the reflectance (R) and transmittance (T) of a multi-layer optical film for a given wavelength and angle of incidence using the Transfer Matrix Method (TMM). Args: - nk_functions (Dict[int, Callable]): Dictionary mapping material indices to functions that return + nk_functions (Dict[int, Callable]): Dictionary mapping material indices to functions that return the complex refractive index (n + ik) for a given wavelength. material_list (list[int]): List of indices representing the order of materials in the stack. thickness_list (jnp.ndarray): Array of thicknesses for each layer in the stack. @@ -67,48 +95,11 @@ def get_nk_values(wl): return jnp.array([nk_functions[mat_idx](wl) for mat_idx in material_list]) # Get nk values for each material nk_list = get_nk_values(wavelength) # Call get_nk_values to get refractive index values for all materials - - layer_angles = _compute_layer_angles_single_wl_angle_point(nk_list, angle_of_incidence, wavelength, polarization) - # Compute the angles within each layer based on the refractive indices, incidence angle, and wavelength - - kz = _compute_kz_single_wl_angle_point(nk_list, layer_angles, wavelength) - # Calculate the z-component of the wave vector for each layer - - layer_phases = kz * jnp.pad(thickness_list, (1), constant_values=0) - # Compute the phase shifts in each layer by multiplying kz by the layer thicknesses - # `jnp.pad(thickness_list, (1), constant_values=0)` adds a leading zero to the thickness_list - - rt = _compute_rt_one_wl(nk_list=nk_list, layer_angles=layer_angles, wavelength=wavelength, polarization=polarization) - # Compute the reflection and transmission matrices for the wavelength - - _phases_ts_rs = _create_phases_ts_rs(rt[1:,:], layer_phases[1:-1]) - # Create a list of phase shift matrices from the transmission and reflection matrices and layer phases - # Exclude the first and last reflection matrix, and the first and last phase shift values - - tr_matrix = _cascaded_matrix_multiplication(_phases_ts_rs) - # Perform matrix multiplication to obtain the cascaded transfer matrix for the entire stack - - tr_matrix = (1 / rt[0, 1]) * jnp.dot(jnp.array([[1, rt[0, 0]], [rt[0, 0], 1]]), tr_matrix) - # Normalize the transfer matrix and include the boundary conditions - # `jnp.dot` multiplies the transfer matrix by the boundary conditions matrix - - r = tr_matrix[1, 0] / tr_matrix[0, 0] - t = 1 / tr_matrix[0, 0] - # Calculate the reflectance (r) and transmittance (t) from the transfer matrix - # Reflectance is obtained by dividing the (1, 0) element by the (0, 0) element - # Transmittance is obtained by taking the reciprocal of the (0, 0) element - - R = jnp.abs(r) ** 2 - T = _calculate_transmittace_from_coeff(t, nk_list[0], nk_list[-1], angle_of_incidence, layer_angles[-1], polarization) - # Compute the reflectance (R) and transmittance (T) using their respective formulas - # Reflectance R is the squared magnitude of r - # Transmittance T is calculated using a function `_calculate_transmittace_from_coeff` - + #print("nk_list", nk_list) + R, T = tmm_single_wl_angle_point_jit(nk_list, thickness_list, wavelength, angle_of_incidence, polarization) return R, T # Return the reflectance and transmittance values - - def tmm(material_list: List[str], thickness_list: jnp.ndarray, wavelength_arr: jnp.ndarray, @@ -116,7 +107,7 @@ def tmm(material_list: List[str], polarization: Text) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Perform the Transfer Matrix Method (TMM) for multilayer thin films. - + Args: material_list (List[str]): A list of material names. Each material is identified by a string. thickness_list (jnp.ndarray): An array of thicknesses corresponding to each layer. @@ -137,19 +128,21 @@ def tmm(material_list: List[str], # Create a dictionary of interpolation functions for each material nk_funkcs = {i: interpolate_nk(material) for i, material in enumerate(material_set)} # Interpolate n and k for each material - # Convert polarization type to a boolean flag if polarization == 's': - polarization = False # s-polarized light + # Unpolarized case: Return tuple (s-polarization, p-polarization) + polarization = jnp.array([0], dtype = jnp.int16) elif polarization == 'p': - polarization = True # p-polarized light + # s-polarization case + polarization = jnp.array([1], dtype = jnp.int16) else: raise TypeError("The polarization can be 's' or 'p', not the other parts. Correct it") # Raise an error for invalid polarization input + # Vectorize the _tmm_single_wl_angle_point function across wavelength and angle of incidence - tmm_vectorized = vmap(vmap(_tmm_single_wl_angle_point, (None, None, None, 0, None, None)), (None, None, None, None, 0, None)) # Vectorize _tmm_single_wl_angle_point over wavelengths and angles + tmm_vectorized = vmap(vmap(tmm_single_wl_angle_point, (None, None, None, 0, None, None)), (None, None, None, None, 0, None)) # Vectorize _tmm_single_wl_angle_point over wavelengths and angles # Apply the vectorized TMM function to the arrays result = tmm_vectorized(nk_funkcs, material_list, thickness_list, wavelength_arr, angle_of_incidences, polarization) # Compute the TMM results # Return the computed result - return result # Tuple of transmission and reflection coefficients + return result # Tuple of transmission and reflection coefficients \ No newline at end of file diff --git a/tmmax/wavevector.py b/tmmax/wavevector.py new file mode 100644 index 0000000..f762344 --- /dev/null +++ b/tmmax/wavevector.py @@ -0,0 +1,17 @@ +import jax.numpy as jnp +from jax import vmap +import jax +import numpy as np +import pickle +import sys +from jax import Array +from jax.typing import ArrayLike + +#@jit +def compute_kz(nk_list: ArrayLike, + angles: ArrayLike, + wavelength: ArrayLike) -> Array: + + kz = jnp.true_divide(jnp.multiply(jnp.multiply(jnp.array([2.0]), jnp.pi), jnp.multiply(nk_list, jnp.cos(angles))), wavelength) + + return kz \ No newline at end of file