Skip to content

Commit

Permalink
Update hyperit.py
Browse files Browse the repository at this point in the history
  • Loading branch information
EdoardoChidichimo authored Jun 24, 2024
1 parent f1f382d commit 7b8cfac
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions hyperit/hyperit.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,6 @@ def __which_estimator(self, measure: str) -> None:
self._Calc.setProperty(key, value)

if initialise_parameter:
if self._measure == MeasureType.TE:
dataDim = self._data1.ndim
self._initialise_parameter = (dataDim, dataDim, *initialise_parameter)
self._initialise_parameter = initialise_parameter

if self._measure == MeasureType.TE and self._estimator_name == 'kernel':
Expand Down Expand Up @@ -495,19 +492,24 @@ def __setup_matrix(self) -> None:

self._it_matrix = np.zeros((1, self._loop_range, self._loop_range, 16)) if self._epoch_average else np.zeros((self._n_epo, self._loop_range, self._loop_range, 16))


def __estimate_it(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray:
""" Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators. """

# Initialise parameter describes the dimensions of the data
def __initialise_estimator(self, s1: np.ndarray, s2: np.ndarray) -> None:

if not self._initialise_parameter:
self._Calc.initialise()
else:
if self._measure == MeasureType.TE and self._estimator == 'symbolic': # symbolic estimator takes only one argument so cannot be unrolled.
return

if self._measure == MeasureType.TE:
if self._estimator == 'symbolic': # symbolic estimator takes only one argument so cannot be unrolled.
self._Calc.initialise(self._initialise_parameter)
else:
self._Calc.initialise(*self._initialise_parameter)
return

self._Calc.initialise(s1.ndim, s2.ndim, *self._initialise_parameter)
return

self._Calc.initialise(*self._initialise_parameter)

def __estimate_it(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray:
""" Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators. """

self._Calc.setObservations(setup_JArray(s1), setup_JArray(s2))
result = self._Calc.computeAverageLocalOfObservations() * np.log(2)
Expand All @@ -522,14 +524,6 @@ def __estimate_it(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray:
def __estimate_it_epoch_average(self, s1: np.ndarray, s2: np.ndarray) -> float:
""" Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators.
s1 and s2 should have shape (epochs, samples) referring to a pairwise comparison of two channels."""

if not self._initialise_parameter:
self._Calc.initialise()
else:
if self._measure == MeasureType.TE and self._estimator == 'symbolic': # symbolic estimator takes only one argument so cannot be unrolled.
self._Calc.initialise(self._initialise_parameter)
else:
self._Calc.initialise(*self._initialise_parameter)

self._Calc.startAddObservations()
for epoch in range(self._n_epo):
Expand Down Expand Up @@ -574,6 +568,8 @@ def __filter_estimation(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndar
elif self._measure == MeasureType.PhyID:
return self.__estimate_atoms(s1, s2)

__initialise_estimator(s1, s2)

return self.__estimate_it_epoch_average(s1, s2) if self._epoch_average else self.__estimate_it(s1, s2)


Expand Down

0 comments on commit 7b8cfac

Please sign in to comment.