Skip to content

Commit

Permalink
changed correlatio to pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Jan 13, 2024
1 parent e693128 commit a9e6c4f
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/nemos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def check_dimensionality(
def convolve_1d_trials(
basis_matrix: ArrayLike,
time_series: Union[Iterable[NDArray], NDArray, Iterable[jnp.ndarray], jnp.ndarray],
) -> List[jnp.ndarray]:
) -> Union[Iterable[NDArray], NDArray, Iterable[jnp.ndarray], jnp.ndarray]:
"""Convolve trial time series with a basis matrix.
This function checks if all trials have the same duration. If they do, it uses a fast method
Expand Down Expand Up @@ -110,7 +110,7 @@ def convolve_1d_trials(

except ValueError:
# convert each trial to array
time_series = [jnp.asarray(trial) for trial in time_series]
time_series = jax.tree_map(jnp.asarray, time_series)
if not check_dimensionality(time_series, 2):
raise ValueError(
"time_series must be an iterable of 2 dimensional array-like objects."
Expand All @@ -123,21 +123,21 @@ def convolve_1d_trials(

# Check window size
ws = len(basis_matrix)
if any(trial.shape[0] < ws for trial in time_series):
if pytree_map_and_reduce(lambda x: x.shape[0] < ws, any, list(time_series)):
raise ValueError(
"Insufficient trial duration. The number of time points in each trial must "
"be greater or equal to the window size."
)

if isinstance(time_series, jnp.ndarray):
# if the conversion to array went through, time_series have trials with equal size
conv_trials = list(_CORR_SAME_TRIAL_DUR(time_series, basis_matrix))
conv_trials = _CORR_SAME_TRIAL_DUR(time_series, basis_matrix)
else:
# trials have different length
conv_trials = [
_CORR_VARIABLE_TRIAL_DUR(jnp.atleast_2d(trial), basis_matrix)
for trial in time_series
]
conv_trials = jax.tree_map(
lambda x: _CORR_VARIABLE_TRIAL_DUR(jnp.atleast_2d(x), basis_matrix),
time_series, basis_matrix
)

return conv_trials

Expand Down

0 comments on commit a9e6c4f

Please sign in to comment.