diff --git a/hyperit/hyperit.py b/hyperit/hyperit.py index ca46731..a7d9ed5 100644 --- a/hyperit/hyperit.py +++ b/hyperit/hyperit.py @@ -447,9 +447,6 @@ def __which_estimator(self, measure: str) -> None: self._Calc.setProperty(key, value) if initialise_parameter: - if self._measure == MeasureType.TE: - dataDim = self._data1.ndim - self._initialise_parameter = (dataDim, dataDim, *initialise_parameter) self._initialise_parameter = initialise_parameter if self._measure == MeasureType.TE and self._estimator_name == 'kernel': @@ -495,19 +492,24 @@ def __setup_matrix(self) -> None: self._it_matrix = np.zeros((1, self._loop_range, self._loop_range, 16)) if self._epoch_average else np.zeros((self._n_epo, self._loop_range, self._loop_range, 16)) - - def __estimate_it(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray: - """ Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators. """ - - # Initialise parameter describes the dimensions of the data + def __initialise_estimator(self, s1: np.ndarray, s2: np.ndarray) -> None: if not self._initialise_parameter: self._Calc.initialise() - else: - if self._measure == MeasureType.TE and self._estimator == 'symbolic': # symbolic estimator takes only one argument so cannot be unrolled. + return + + if self._measure == MeasureType.TE: + if self._estimator == 'symbolic': # symbolic estimator takes only one argument so cannot be unrolled. self._Calc.initialise(self._initialise_parameter) - else: - self._Calc.initialise(*self._initialise_parameter) + return + + self._Calc.initialise(s1.ndim, s2.ndim, *self._initialise_parameter) + return + + self._Calc.initialise(*self._initialise_parameter) + + def __estimate_it(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray: + """ Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators. """ self._Calc.setObservations(setup_JArray(s1), setup_JArray(s2)) result = self._Calc.computeAverageLocalOfObservations() * np.log(2) @@ -522,14 +524,6 @@ def __estimate_it(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray: def __estimate_it_epoch_average(self, s1: np.ndarray, s2: np.ndarray) -> float: """ Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators. s1 and s2 should have shape (epochs, samples) referring to a pairwise comparison of two channels.""" - - if not self._initialise_parameter: - self._Calc.initialise() - else: - if self._measure == MeasureType.TE and self._estimator == 'symbolic': # symbolic estimator takes only one argument so cannot be unrolled. - self._Calc.initialise(self._initialise_parameter) - else: - self._Calc.initialise(*self._initialise_parameter) self._Calc.startAddObservations() for epoch in range(self._n_epo): @@ -574,6 +568,8 @@ def __filter_estimation(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndar elif self._measure == MeasureType.PhyID: return self.__estimate_atoms(s1, s2) + __initialise_estimator(s1, s2) + return self.__estimate_it_epoch_average(s1, s2) if self._epoch_average else self.__estimate_it(s1, s2)