diff --git a/docs/developers_notes/01-basis_module.md b/docs/developers_notes/01-basis_module.md index c0bdf6f8..b8d8a925 100644 --- a/docs/developers_notes/01-basis_module.md +++ b/docs/developers_notes/01-basis_module.md @@ -23,7 +23,9 @@ Abstract Class Basis │ │ │ └─ Concrete Subclass RaisedCosineBasisLog │ -└─ Concrete Subclass OrthExponentialBasis +├─ Concrete Subclass OrthExponentialBasis +│ +└─ Concrete Subclass FourierBasis ``` The super-class `Basis` provides two public methods, [`evaluate`](#the-public-method-evaluate) and [`evaluate_on_grid`](#the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the private abstract method `_evaluate` that is specific for each concrete class. See below for more details. diff --git a/docs/examples/plot_1D_basis_function.py b/docs/examples/plot_1D_basis_function.py index 37ce1e1a..97fceb9f 100644 --- a/docs/examples/plot_1D_basis_function.py +++ b/docs/examples/plot_1D_basis_function.py @@ -65,8 +65,10 @@ # ----------------- # Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description, # please refer to the [Code References](../../../reference/nemos/basis). After instantiation, all classes -# share the same syntax for basis evaluation. The following is an example of how to instantiate and -# evaluate a log-spaced cosine raised function basis. +# share the same syntax for basis evaluation. +# +# ### The Log-Spaced Raised Cosine Basis +# The following is an example of how to instantiate and evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50) @@ -81,3 +83,89 @@ plt.plot(samples, eval_basis) plt.show() +# %% +# ### The Fourier Basis +# Another type of basis available is the Fourier Basis. Fourier basis are ideal to capture periodic and +# quasi-periodic patterns. Such oscillatory, rhythmic behavior is a common signature of many neural signals. +# Additionally, the Fourier basis has the advantage of being orthogonal, which simplifies the estimation and +# interpretation of the model parameters, each of which will represent the relative contribution of a specific +# oscillation frequency to the overall signal. +# +# A Fourier basis can be instantiated with the following syntax: +# the user can provide the maximum frequency of the cosine and negative +# sine pairs by setting the `max_freq` parameter. +# The sinusoidal basis elements will have frequencies from 0 to `max_freq`. + + +fourier_basis = nmo.basis.FourierBasis(max_freq=3) + +# evaluate on equi-spaced samples +samples, eval_basis = fourier_basis.evaluate_on_grid(1000) + +# plot the `sin` and `cos` separately +plt.figure(figsize=(6, 3)) +plt.subplot(121) +plt.title("Cos") +plt.plot(samples, eval_basis[:, :4]) +plt.subplot(122) +plt.title("Sin") +plt.plot(samples, eval_basis[:, 4:]) +plt.tight_layout() + +# %% +# #### Fourier Basis Convolution and Fourier Transform +# The Fourier transform of a signal $ s(t) $ restricted to a temporal window $ [t_0,\;t_1] $ is +# $$ \\hat{x}(\\omega) = \\int_{t_0}^{t_1} s(\\tau) e^{-j\\omega \\tau} d\\tau. $$ +# where $ e^{-j\\omega \\tau} = \\cos(\\omega \\tau) - j \\sin (\\omega \\tau) $. +# +# When computing the cross-correlation of a signal with the Fourier basis functions, +# we are measuring how well the signal correlates with sinusoids of different frequencies, +# within a specified temporal window. This process mirrors the operation performed by the Fourier transform. +# Therefore, computing the cross-correlation of a signal with the Fourier basis defined here +# is equivalent to computing the discrete Fourier transform on a sliding window of the same size +# as that of the basis. + + +n_samples = 1000 +max_freq = 20 + +# define a signal +signal = np.random.normal(size=n_samples) + +# evaluate the basis +_, eval_basis = nmo.basis.FourierBasis(max_freq=max_freq).evaluate_on_grid(n_samples) + +# compute the cross-corr with the signal and the basis +# Note that we are inverting the time axis of the basis because we are aiming +# for a cross-correlation, while np.convolve compute a convolution which would flip the time axis. +xcorr = np.array( + [ + np.convolve(eval_basis[::-1, k], signal, mode="valid")[0] + for k in range(2 * max_freq + 1) + ] +) + +# compute the power (add back sin(0 * t) = 0) +fft_complex = np.fft.fft(signal) +fft_amplitude = np.abs(fft_complex[:max_freq + 1]) +fft_phase = np.angle(fft_complex[:max_freq + 1]) +# compute the phase and amplitude from the convolution +xcorr_phase = np.arctan2(np.hstack([[0], xcorr[max_freq+1:]]), xcorr[:max_freq+1]) +xcorr_aplitude = np.sqrt(xcorr[:max_freq+1] ** 2 + np.hstack([[0], xcorr[max_freq+1:]]) ** 2) + +fig, ax = plt.subplots(1, 2) +ax[0].set_aspect("equal") +ax[0].set_title("Signal amplitude") +ax[0].scatter(fft_amplitude, xcorr_aplitude) +ax[0].set_xlabel("FFT") +ax[0].set_ylabel("cross-correlation") + +ax[1].set_aspect("equal") +ax[1].set_title("Signal phase") +ax[1].scatter(fft_phase, xcorr_phase) +ax[1].set_xlabel("FFT") +ax[1].set_ylabel("cross-correlation") +plt.tight_layout() + +print(f"Max Error Amplitude: {np.abs(fft_amplitude - xcorr_aplitude).max()}") +print(f"Max Error Phase: {np.abs(fft_phase - xcorr_phase).max()}") diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 44150e68..8e85f4b1 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -22,6 +22,7 @@ "OrthExponentialBasis", "AdditiveBasis", "MultiplicativeBasis", + "FourierBasis", ] @@ -103,7 +104,7 @@ def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]: # make sure array is at least 1d (so that we succeed when only # passed a scalar) xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) - except TypeError: + except (TypeError, ValueError): raise TypeError("Input samples must be array-like of floats!") # check for non-empty samples @@ -1086,7 +1087,8 @@ def _check_rates(self): "linearly dependent set of function for the basis." ) - def _check_sample_range(self, sample_pts: NDArray): + @staticmethod + def _check_sample_range(sample_pts: NDArray): """ Check if the sample points are all positive. @@ -1177,6 +1179,96 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) +class FourierBasis(Basis): + """Set of 1D Fourier basis. + + This class defines a cosine and negative sine basis (quadrature pair) + with frequencies ranging 0 to max_freq. + + Parameters + ---------- + max_freq + Highest frequency of the cosine, negative sine pairs. + The number of basis function will be 2*max_freq + 1. + """ + + def __init__(self, max_freq: int): + super().__init__(n_basis_funcs=2 * max_freq + 1) + + self._frequencies = np.arange(max_freq + 1, dtype=float) + self._n_input_dimensionality = 1 + + def _check_n_basis_min(self) -> None: + """Check that the user required enough basis elements. + + Checks that the number of basis is at least 1. + + Raises + ------ + ValueError + If an insufficient number of basis element is requested for the basis type. + """ + if self.n_basis_funcs < 0: + raise ValueError( + f"Object class {self.__class__.__name__} requires >= 1 basis elements. " + f"{self.n_basis_funcs} basis elements specified instead" + ) + + def evaluate(self, sample_pts: ArrayLike) -> NDArray: + """Generate basis functions with given spacing. + + Parameters + ---------- + sample_pts + Spacing for basis functions. + + Returns + ------- + basis_funcs + Evaluated Fourier basis, shape (n_samples, n_basis_funcs). + + Notes + ----- + The frequencies are set to np.arange(max_freq+1), convolving a signal + of length n_samples with this basis is equivalent, but slower, + then computing the FFT truncated to the first max_freq components. + + Therefore, convolving a signal with this basis is equivalent + to compute the FFT over a sliding window. + + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> n_samples, max_freq = 1000, 10 + >>> basis = nmo.basis.FourierBasis(max_freq) + >>> eval_basis = basis.evaluate(np.linspace(0, 1, n_samples)) + >>> sinusoid = np.cos(3 * np.arange(0, 1000) * np.pi * 2 / 1000.) + >>> conv = [np.convolve(eval_basis[::-1, k], sinusoid, mode='valid')[0] for k in range(2*max_freq+1)] + >>> fft = np.fft.fft(sinusoid) + >>> print('FFT power: ', np.round(np.real(fft[:max_freq]), 4)) + >>> print('Convolution: ', np.round(conv[:max_freq], 4)) + """ + (sample_pts,) = self._check_evaluate_input(sample_pts) + # assumes equi-spaced samples. + if sample_pts.shape[0] / np.max(self._frequencies) < 2: + raise ValueError("Not enough samples, aliasing likely to occur!") + + # rescale to [0, 2pi) + mn, mx = np.nanmin(sample_pts), np.nanmax(sample_pts) + # first sample in 0, last sample in 2 pi - 2 pi / n_samples. + sample_pts = ( + 2 + * np.pi + * (sample_pts - mn) + / (mx - mn) + * (1.0 - 1.0 / sample_pts.shape[0]) + ) + # create the basis + angles = np.einsum("i,j->ij", sample_pts, self._frequencies) + return np.concatenate([np.cos(angles), -np.sin(angles[:, 1:])], axis=1) + + def mspline(x: NDArray, k: int, i: int, T: NDArray): """Compute M-spline basis function. diff --git a/tests/test_basis.py b/tests/test_basis.py index 96760840..ee059b53 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -846,6 +846,178 @@ def test_decay_rate_size_match_n_basis_func(self, decay_rates, n_basis_func): self.cls(n_basis_funcs=n_basis_func, decay_rates=decay_rates) +class TestFourierBasis(BasisFuncsTesting): + cls = basis.FourierBasis + + @pytest.mark.parametrize("max_freq", [2, 4, 8]) + @pytest.mark.parametrize("sample_size", [20, 1000]) + def test_evaluate_returns_expected_number_of_basis( + self, max_freq, sample_size + ): + """Tests whether the evaluate method returns the expected number of basis functions.""" + basis_obj = self.cls(max_freq=max_freq) + eval_basis = basis_obj.evaluate(np.linspace(0, 1, sample_size)) + assert(eval_basis.shape[1] == 2*max_freq+1) + + @pytest.mark.parametrize("samples, expectation", + [ + ([], pytest.raises(ValueError, match="All sample provided must be non empty")), + (np.zeros(11), does_not_raise()) + ] + ) + def test_non_empty_samples(self, samples, expectation): + with expectation: + self.cls(5).evaluate(samples) + + @pytest.mark.parametrize( + "arraylike, expectation", [ + (["x"], pytest.raises(TypeError, match="Input samples must be array-like of floats")), + ([0]*11, does_not_raise()), + ((0,)*11, does_not_raise()), + (np.array([0]*11), does_not_raise()), + (jax.numpy.array([0]*11), does_not_raise()) + ] + ) + def test_input_to_evaluate_is_arraylike(self, arraylike, expectation): + """ + Checks that the sample size of the output from the evaluate() method matches the input sample size. + """ + basis_obj = self.cls(max_freq=5) + with expectation: + basis_obj.evaluate(arraylike) + + @pytest.mark.parametrize("sample_size", [100, 1000]) + @pytest.mark.parametrize("max_freq", [4, 10]) + def test_sample_size_of_evaluate_matches_that_of_input( + self, max_freq, sample_size + ): + """ + Checks that the sample size of the output from the evaluate() method matches the input sample size. + """ + basis_obj = self.cls(max_freq) + eval_basis = basis_obj.evaluate(np.linspace(0, 1, sample_size)) + if eval_basis.shape[0] != sample_size: + raise ValueError( + f"Dimensions do not agree: The window size should match the second dimension of the evaluated basis." + f"The window size is {sample_size}", + f"The second dimension of the evaluated basis is {eval_basis.shape[0]}", + ) + + @pytest.mark.parametrize("max_freq, expectation", + [ + (-1, pytest.raises(ValueError,match=r"Object class FourierBasis requires >= 1 basis elements")), + (0, does_not_raise()), + (1, does_not_raise()), + (3, does_not_raise()) + ] + ) + def test_minimum_number_of_basis_required_is_matched(self, max_freq, expectation): + """ + Verifies that the minimum number of basis functions and order required (i.e., at least 1) and + order < #basis are enforced. + """ + n_samples = 10 + with expectation: + basis_obj = self.cls(max_freq=max_freq) + basis_obj.evaluate(np.linspace(0, 1, n_samples)) + + @pytest.mark.parametrize("max_freq, expectation", + [ + (3, does_not_raise()), + (6, pytest.raises(ValueError,match=rf"Not enough samples, aliasing likely to occur")), + (7, pytest.raises(ValueError,match=rf"Not enough samples, aliasing likely to occur")), + (10, pytest.raises(ValueError,match=rf"Not enough samples, aliasing likely to occur")) + ] + ) + def test_minimum_aliasing_detection(self, max_freq, expectation): + """ + Verifies that the minimum number of basis functions and order required (i.e., at least 1) and + order < #basis are enforced. + """ + n_samples = 10 + basis_obj = self.cls(max_freq=max_freq) + with expectation: + basis_obj.evaluate(np.linspace(0, 1, n_samples)) + + @pytest.mark.parametrize( + "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] + ) + def test_samples_range_matches_evaluate_requirements(self, sample_range: tuple): + """ + Verifies that the evaluate() method can handle input range. + """ + basis_obj = self.cls(max_freq=5) + basis_obj.evaluate(np.linspace(*sample_range, 100)) + + @pytest.mark.parametrize("n_input, expectation", [ + (0, pytest.raises(TypeError, match=r"FourierBasis\.evaluate\(\) missing 1 required positional argument")), + (1, does_not_raise()), + (2, pytest.raises(TypeError, match=r"FourierBasis\.evaluate\(\) takes 2 positional arguments but")), + (3, pytest.raises(TypeError, match=r"FourierBasis\.evaluate\(\) takes 2 positional arguments but")) + ] + ) + def test_number_of_required_inputs_evaluate(self, n_input, expectation): + """ + Confirms that the evaluate() method correctly handles the number of input samples that are provided. + """ + basis_obj = self.cls(max_freq=5) + inputs = [np.linspace(0, 1, 20)] * n_input + with expectation: + basis_obj.evaluate(*inputs) + + @pytest.mark.parametrize( + "sample_size, expectation", + [ + (-1, pytest.raises(ValueError, match=r"Invalid input data|All sample counts provided must be greater")), + (0, pytest.raises(ValueError, match=r"Invalid input data|All sample counts provided must be greater")), + (10, does_not_raise()), + (11, does_not_raise()), + (100, does_not_raise()) + ] + ) + def test_evaluate_on_grid_meshgrid_valid_size(self, sample_size, expectation): + """ + Checks that the evaluate_on_grid() method returns a grid of the expected size. + """ + basis_obj = self.cls(max_freq=5) + with expectation: + basis_obj.evaluate_on_grid(sample_size) + + @pytest.mark.parametrize( + "sample_size, expectation", + [ + (10, does_not_raise()), + (11, does_not_raise()), + (100, does_not_raise()) + ] + ) + def test_evaluate_on_grid_meshgrid_match_size(self, sample_size, expectation): + """ + Checks that the evaluate_on_grid() method returns a grid of the expected size. + """ + basis_obj = self.cls(max_freq=5) + with expectation: + grid, _ = basis_obj.evaluate_on_grid(sample_size) + assert grid.shape[0] == sample_size + + @pytest.mark.parametrize( + "n_input, expectation", + [ + (0, pytest.raises(TypeError, match="Input dimensionality mismatch. This basis evaluation")), + (1, does_not_raise()), + (2, pytest.raises(TypeError, match="Input dimensionality mismatch. This basis evaluation")) + ] + ) + def test_evaluate_on_grid_input_number(self, n_input, expectation): + """ + Validates that the evaluate_on_grid() method correctly handles the number of input samples that are provided. + """ + basis_obj = self.cls(max_freq=5) + inputs = [10] * n_input + with expectation: + basis_obj.evaluate_on_grid(*inputs) + + class TestBSplineBasis(BasisFuncsTesting): cls = basis.BSplineBasis @@ -1253,6 +1425,8 @@ def instantiate_basis(n_basis, basis_class): basis_obj = basis_class( n_basis_funcs=n_basis, decay_rates=np.arange(1, 1 + n_basis) ) + elif basis_class == basis.FourierBasis: + basis_obj = basis_class(max_freq=n_basis) elif basis_class == basis.BSplineBasis: basis_obj = basis_class(n_basis_funcs=n_basis, order=3) elif basis_class == basis.CyclicBSplineBasis: @@ -1307,7 +1481,7 @@ def test_evaluate_input(self, eval_input): @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) - @pytest.mark.parametrize("sample_size", [10, 1000]) + @pytest.mark.parametrize("sample_size", [15, 1000]) @pytest.mark.parametrize( "basis_a", [class_obj for _, class_obj in utils_testing.get_non_abstract_classes(basis)], @@ -1523,7 +1697,7 @@ def test_evaluate_input(self, eval_input): @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) - @pytest.mark.parametrize("sample_size", [10, 1000]) + @pytest.mark.parametrize("sample_size", [15, 1000]) @pytest.mark.parametrize( "basis_a", [class_obj for _, class_obj in utils_testing.get_non_abstract_classes(basis)],