Skip to content

Commit

Permalink
jitted tmm function and add npy
Browse files Browse the repository at this point in the history
  • Loading branch information
bahremsd committed Dec 18, 2024
1 parent 0ece266 commit 9731ae0
Show file tree
Hide file tree
Showing 66 changed files with 628 additions and 577 deletions.
2 changes: 1 addition & 1 deletion tmmax/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.2
1.0.0
286 changes: 132 additions & 154 deletions tmmax/angle.py

Large diffs are not rendered by default.

66 changes: 28 additions & 38 deletions tmmax/cascaded_matmul.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,37 @@
import jax
jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX
#jax.config.update('jax_enable_x64', True) # Ensure high precision (64-bit) is enabled in JAX
import jax.numpy as jnp # Import JAX's version of NumPy for differentiable computations

def matmul(carry, phase_r_t):

def _matmul(carry, phase_t_r):
"""
Multiplies two complex matrices in a sequence.

Args:
carry (jax.numpy.ndarray): The accumulated product of the matrices so far.
This is expected to be a 2x2 complex matrix.
phase_t_r (jax.numpy.ndarray): A 3-element array where:
- phase_t_r[0] represents the phase shift delta (a scalar).
- phase_t_r[1] represents the transmission coefficient t or T (a scalar).
- phase_t_r[2] represents the reflection coefficient r or R (a scalar).
transfer_matrix_00 = jnp.exp(jnp.multiply(jnp.array([-1j], dtype = jnp.complex64), phase_r_t.at[0].get()))
transfer_matrix_11 = jnp.exp(jnp.multiply(jnp.array([1j], dtype = jnp.complex64), phase_r_t.at[0].get()))
transfer_matrix_01 = jnp.multiply(phase_r_t.at[1].get(), transfer_matrix_00)
transfer_matrix_10 = jnp.multiply(phase_r_t.at[1].get(), transfer_matrix_11)

Returns:
jax.numpy.ndarray: The updated product after multiplying the carry matrix with the current matrix.
This is also a 2x2 complex matrix.
None: A placeholder required by jax.lax.scan for compatibility.
"""
# Create the diagonal phase matrix based on phase_t_r[0]
# This matrix introduces a phase shift based on the delta value
phase_matrix = jnp.array([[jnp.exp(-1j * phase_t_r[0]), 0], # Matrix with phase shift for the first entry
[0, jnp.exp(1j * phase_t_r[0])]]) # Matrix with phase shift for the second entry
transfer_matrix = jnp.multiply(jnp.true_divide(1, phase_r_t.at[2].get()), jnp.array([[transfer_matrix_00, transfer_matrix_01],
[transfer_matrix_10, transfer_matrix_11]]))

# Create the matrix based on phase_t_r[1] and phase_t_r[2]
# This matrix incorporates the transmission and reflection coefficients
transmission_reflection_matrix = jnp.array([[1, phase_t_r[1]], # Top row with transmission coefficient
[phase_t_r[1], 1]]) # Bottom row with transmission coefficient
result = jnp.matmul(carry, transfer_matrix)

# Compute the current matrix by multiplying the phase_matrix with the transmission_reflection_matrix
# The multiplication is scaled by 1/phase_t_r[2] to account for the reflection coefficient
mat = jnp.array(1 / phase_t_r[2]) * jnp.dot(phase_matrix, transmission_reflection_matrix)


# Multiply the accumulated carry matrix with the current matrix
# This updates the product with the new matrix
result = jnp.dot(carry, mat)
return jnp.squeeze(result), jnp.squeeze(result) # Return the updated matrix and None as a placeholder for jax.lax.scan

return result, None # Return the updated matrix and None as a placeholder for jax.lax.scan

#@jax.jit
def cascaded_matrix_multiplication(phases: ArrayLike, rts: ArrayLike) -> Array:
"""
Calculates the angles of incidence across layers for a set of refractive indices (nk_list_2d)
and an initial angle of incidence (initial_theta) using vectorization.
def _cascaded_matrix_multiplication(phases_ts_rs: jnp.ndarray) -> jnp.ndarray:
Returns:
jnp.ndarray: A 3D JAX array where the [i, j, :] entry represents the angles of incidence
for the j-th initial angle at the i-th wavelength. The size of the third dimension
corresponds to the number of layers.
"""
phase_rt_stack = jnp.concat([jnp.expand_dims(phases, 1), rts], axis=1)
"""
Performs cascaded matrix multiplication on a sequence of complex matrices using scan.
Expand All @@ -53,13 +43,13 @@ def _cascaded_matrix_multiplication(phases_ts_rs: jnp.ndarray) -> jnp.ndarray:
jax.numpy.ndarray: The final result of multiplying all the matrices together in sequence.
This result is a single 2x2 complex matrix representing the accumulated product of all input matrices.
"""
initial_value = jnp.eye(2, dtype=jnp.complex128)
# Initialize with the identity matrix of size 2x2. # The identity matrix acts as the multiplicative identity,
initial_value = jnp.eye(2, dtype=jnp.complex64)
# Initialize with the identity matrix of size 2x2. # The identity matrix acts as the multiplicative identity,
# ensuring that the multiplication starts correctly.

# jax.lax.scan applies a function across the sequence of matrices.
# jax.lax.scan applies a function across the sequence of matrices.
#Here, _matmul is the function applied, starting with the identity matrix.
# `result` will hold the final matrix after processing all input matrices.
result, _ = jax.lax.scan(_matmul, initial_value, phases_ts_rs) # Scan function accumulates results of _matmul over the matrices.

return result # Return the final accumulated matrix product. # The result is the product of all input matrices in the given sequence.
result, _ = jax.lax.scan(matmul, initial_value, phase_rt_stack) # Scan function accumulates results of _matmul over the matrices.
return result
Loading

0 comments on commit 9731ae0

Please sign in to comment.