-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstructure.py
814 lines (617 loc) · 25.3 KB
/
structure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
#!/usr/bin/env python3
from collections import OrderedDict
from functools import wraps, total_ordering
from inspect import isfunction
from typing import TYPE_CHECKING, Callable
from numpy import linspace, asarray, floor, ones, array_equal, where, ones_like
from pandas import DataFrame, Series
from .utils.utils import *
def constructor(func):
@wraps(func)
def wrapper(klass, *args, **kwargs):
res = func(klass, *args, **kwargs)
if isinstance(res, DataFrame):
try:
if list(res.index) == list(res.columns):
return Signal(res, klass.fs, res.columns, res.index)
return Signal(res, klass.fs, klass.channels, klass.time)
except ValueError:
if array_equal(res.columns, klass.columns):
cols = klass.columns
else:
cols = res.columns
return Signal(res, klass.fs, cols, res.index)
return res
return wrapper
class ConstructorMeta(type):
@classmethod
def __prepare__(mcs, name, bases):
return OrderedDict()
def __new__(mcs, name, bases, namespace, **kwargs):
for key, value in namespace.items():
if isfunction(value) and key != '__init__':
namespace[key] = constructor(value)
return type.__new__(mcs, name, bases, dict(namespace))
@total_ordering
class Signal(metaclass=ConstructorMeta):
"""
"""
data: DataFrame = DataFrame()
def __init__(self, signal: ArrayOrDataFrame, fs: int, channels: Iterable = None, time: Iterable = None):
"""
:param signal:
:type signal:
:param fs:
:type fs:
:param channels:
:type channels:
:param time:
:type time:
"""
self.data: DataFrame = signal if isinstance(signal, DataFrame) else DataFrame(signal)
self.fs = fs
self._hilbert = None
if channels is not None:
self.data.columns = channels
ind_len = max(signal.shape)
self.data.index = time if time is not None else linspace(0, ind_len / fs, ind_len)
@classmethod
def from_edf(cls, path: str, filter_func: Callable=None,
min_t: RealNumber=0, max_t: RealNumberOrNone=None, all_channels=False):
"""
:param path:
:type path:
:param filter_func:
:type filter_func:
:param min_t:
:type min_t:
:param max_t:
:type max_t:
:param all_channels:
:type all_channels:
:return:
:rtype:
"""
from re import compile
from mne.io.edf.edf import RawEDF
edf = RawEDF(path, preload=False, montage=None, stim_channel=None, verbose=False)
edf_channels = edf.ch_names
ch_names, label_filter = list(), ones_like(edf_channels, dtype=bool)
pattern = compile(r'(\w+)?.(\w+\d+)?-?(\w+)?')
_filter_func = lambda x: True
filter_func = filter_func or _filter_func
for index, item in enumerate(edf_channels):
found = pattern.findall(item)[0]
full_name = None
if all_channels:
full_name = f'{found[1].strip()} {found[2].strip() or found[3].strip()}'
elif found[1] and found[0].lower() in ('eeg', 'ecog', 'meg'):
full_name = found[1].strip()
if not full_name or not filter_func(full_name):
label_filter[index] = False
else:
ch_names.append(full_name)
fs = edf.info['sfreq']
data, times = edf.get_data(
picks=where(label_filter)[0],
start=min_t,
stop=max_t,
return_times=True
)
return cls(data.T, time=times, fs=fs, channels=ch_names)
@classmethod
def from_txt(cls, path: str, fs: int, max_t=None, filter_func: Callable=None):
"""
:param path:
:type path:
:param fs:
:type fs:
:param headings:
:type headings:
:return:
:rtype:
"""
from pandas import read_csv
_filter_func = lambda x: True
filter_func = filter_func or _filter_func
dt = read_csv(path, sep='\t')
dt.columns = [c.replace(' ', '_') for c in dt.columns]
cols = [c for c in dt.columns if filter_func(c)]
dt = dt[cols][:max_t or dt.index.max()]
return cls(dt, fs, channels=cols)
@classmethod
def from_hdf(cls, path: str, key: str, fs: int):
"""
:param path:
:type path:
:param key:
:type key:
:param fs:
:type fs:
:return:
:rtype:
"""
from .analyser import _from_hdf5
dt = _from_hdf5(path, key)[key]
return cls(dt.compute(optimize_graph=True), fs)
def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
return getattr(self.data, item)
def __len__(self):
return self.data.size
def __call__(self):
return self.data
def __repr__(self):
return self.data.__repr__()
def __str__(self):
return self.data.to_string()
def __getitem__(self, item):
return self.data.__getitem__(item)
def __setitem__(self, key, value):
self.data[key] = value
def __delitem__(self, key):
return self.data.__delitem__(key)
def __getstate__(self):
return self.data.__getstate__()
def __setstate__(self, state):
return self.data.__setstate__(state)
def __invert__(self):
return self.data.__invert__()
def __iter__(self):
return self.data.__iter__()
def __reduce__(self):
return self.data.__reduce__()
def __contains__(self, item):
return self.data.__contains__(item)
def __add__(self, other):
return self.data + other
def __iadd__(self, other):
return self.data + other
def __sub__(self, other):
return self.data - other
def __isub__(self, other):
return self.data - other
def __mul__(self, other):
return self.data * other
def __imul__(self, other):
return self.data * other
def __divmod__(self, other):
return self.data // other, self.data % other
def __mod__(self, other):
return self.data % other
def __imod__(self, other):
return self.data % other
def __floordiv__(self, other):
return self.data // other
def __ifloordiv__(self, other):
return self.data // other
def __truediv__(self, other):
return self.data / other
def __itruediv__(self, other):
return self.data / other
def __pow__(self, power, modulo=None):
return self.data ** power if modulo is None else (self.data ** power) % modulo
def __ipow__(self, other):
return self.data ** other
def __abs__(self):
return self.data.abs()
def __lt__(self, other):
return self.data < other
def __eq__(self, other):
return self.data == other
def __ne__(self, other):
return not (self.data == other)
def __neg__(self):
return -self.data
@property
def len_channels(self):
"""
:return:
:rtype:
"""
return self.data.columns.size
@property
def len_time(self):
"""
:return:
:rtype:
"""
return self.data.index.size
@property
def channels(self):
"""
:return:
:rtype:
"""
return asarray(self.data.columns)
@property
def time(self):
"""
:return:
:rtype:
"""
return asarray(self.data.index)
@property
def shape(self):
"""
:return:
:rtype:
"""
return self.data.shape
@property
def T(self):
return self.transpose()
def transpose(self):
return self.data.transpose
def to_list(self):
"""
:return:
:rtype:
"""
return list(self.data)
def to_matrix(self):
return self.data.as_matrix().T
def to_dict(self):
return self.data.to_dict()
def as_dataframe(self):
"""
:return:
:rtype:
"""
return self.data
def corr(self, method='pearson', min_periods=1):
return self.data.corr(method, min_periods)
def cov(self, min_periods=None):
return self.data.cov(min_periods=min_periods)
def apply(self, func, axis=0, broadcast=False, raw=False, reduce=None, *args, **kwargs):
return self.data.apply(func=func, axis=axis, broadcast=broadcast, raw=raw, reduce=reduce, *args, **kwargs)
def min(self):
return self.data.min()
def max(self):
return self.data.max()
def plot(self, *args, collection=True, ax=None, **kwargs):
"""
Plotting the data as a line collection if ``collection`` is True, else
Pandas DataFrame plot method will be called.
:param collection:
:type collection:
DataFrame Plot: :func:`~DataFrame.plot`
"""
if collection:
from .visuals import plot
ax = plot(self.data, fs=self.fs, ax=ax)
return ax
return self.data.plot(ax=ax, *args, **kwargs)
def spectrogram(self, channel, NFFT=None, detrend=None, window=None,
noverlap=None, pad_to=None, sides=None, scale_by_freq=None,
mode=None):
"""
Compute a spectrogram.
Compute and plot a spectrogram of data in x. Data are split into
NFFT length segments and the spectrum of each section is
computed. The windowing function window is applied to each
segment, and the amount of overlap of each segment is
specified with noverlap.
Parameters
----------
@param channel: Channel names
@type channel: 1-D array or sequence.
@param window: A function or a vector of length *NFFT*. To create window
vectors see :func:`window_hanning`, :func:`window_none`,
:func:`numpy.blackman`, :func:`numpy.hamming`,
:func:`numpy.bartlett`, :func:`scipy.signal`,
:func:`scipy.signal.get_window`, etc. The default is
:func:`window_hanning`. If a function is passed as the
argument, it must take a data segment as an argument and
return the windowed version of the segment.
@type window: callable or ndarray
@param sides:
Specifies which sides of the spectrum to return. Default gives the
default behavior, which returns one-sided for real data and both
for complex data. 'onesided' forces the return of a one-sided
spectrum, while 'twosided' forces two-sided.
@type sides: [ 'default' | 'onesided' | 'twosided' ]
@param pad_to:
The number of points to which the data segment is padded when
performing the FFT. This can be different from *NFFT*, which
specifies the number of data points used. While not increasing
the actual resolution of the spectrum (the minimum distance between
resolvable peaks), this can give more points in the plot,
allowing for more detail. This corresponds to the *n* parameter
in the call to ``fft()``. The default is ``None``, which sets ``pad_to``
equal to ``NFFT``.
@type pad_to: int
@param NFFT:
The number of data points used in each block for the FFT.
A power 2 is most efficient. The default value is 256.
This should *NOT* be used to get zero padding, or the scaling of the
result will be incorrect. Use ``pad_to`` for this instead.
@type NFFT: int
@param detrend:
The function applied to each segment before fft-ing,
designed to remove the mean or linear trend. Unlike in
MATLAB, where the *detrend* parameter is a vector, in
matplotlib is it a function. The :mod:`~matplotlib.pylab`
module defines :func:`~matplotlib.pylab.detrend_none`,
:func:`~matplotlib.pylab.detrend_mean`, and
:func:`~matplotlib.pylab.detrend_linear`, but you can use
a custom function as well. You can also use a string to choose
one of the functions. 'default', 'constant', and 'mean' call
:func:`~matplotlib.pylab.detrend_mean`. 'linear' calls
:func:`~matplotlib.pylab.detrend_linear`. 'none' calls
:func:`~matplotlib.pylab.detrend_none`.
@type detrend: {'default', 'constant', 'mean', 'linear', 'none'} or callable
@param scale_by_freq: optional
Specifies whether the resulting density values should be scaled
by the scaling frequency, which gives density in units of Hz^-1.
This allows for integration over the returned frequency values.
The default is True for MATLAB compatibility.
@type scale_by_freq: bool
@param noverlap: optional
The number of points of overlap between blocks. The default
value is 128.
@type noverlap: int
@param mode: optional
What sort of spectrum to use, default is 'psd'.
'psd'
Returns the power spectral density.
'complex'
Returns the complex-valued frequency spectrum.
'magnitude'
Returns the magnitude spectrum.
'angle'
Returns the phase spectrum without unwrapping.
'phase'
Returns the phase spectrum with unwrapping.
@type mode: str
Returns
-------
spectrum : array_like
2-D array, columns are the periodograms of successive segments.
freqs : array_like
1-D array, frequencies corresponding to the rows in *spectrum*.
t : array_like
1-D array, the times corresponding to midpoints of segments
(i.e the columns in *spectrum*).
Notes
-----
detrend and scale_by_freq only apply when *mode* is set to 'psd'.
References
----------
MatPlotLib
"""
from matplotlib.mlab import specgram
return specgram(self[channel], Fs=self.fs, NFFT=NFFT, detrend=detrend,
window=window, noverlap=noverlap, pad_to=pad_to,
sides=sides, scale_by_freq=scale_by_freq, mode=mode)
def filter_by_freq(self, band: Band):
"""
The signal is transformed using a 4th order Butterworth filter such that it
will only contain the frequencies between the lower and the upper THRESHOLDS.
.. Note::
Butterworth is a analogue method, and is applied through a
forward-backward filter cascaded second-order sections; and
the results are returned as a digital signal. Digitization
might introduce (a) some minor inaccuracies in both ends of
the signal, and (b) minor OVERLAP of frequencies above and
below the threshold (Butterworth does not cut off
instantaneously).
:param band: Threshold - Use ``NeuroEnsemble.bands`` for predefined THRESHOLDS, or define
custom ones as follows: ``{"name": 'alpha', "thresh": (8, 15)}``
:type band: bands, dict["name": str, "thresh": (float, float)]
:return: Filtered signal.
:rtype: Electrogram
Example
-------
.. plot::
:context: close-figs
>>> from NeuroEnsemble.structure import Signal
>>> from ECoG import bands
>>> from numpy.random import random
>>> signal_raw = random([2048, 5])
>>> SAMPLING_FREQ = 512
>>> signal = Signal(signal=signal_raw, channels=list('abcde'), SAMPLING_FREQ=SAMPLING_FREQ)
>>> signal.plot()
>>> theta_band = signal.filter_by_freq(bands.THETA)
>>> theta_band.plot()
"""
from .tools import butterworth_filter
filtered = self.data.apply(
butterworth_filter,
thresh=band['thresh'],
rate=self.fs
)
return filtered
columns = channels
index = time
cov.__doc__ = data.cov.__doc__
corr.__doc__ = data.corr.__doc__
min.__doc__ = DataFrame.min.__doc__
max.__doc__ = DataFrame.max.__doc__
to_dict.__doc__ = data.to_dict.__doc__
apply.__doc__ = DataFrame.apply.__doc__
T.__doc__ = DataFrame.transpose.__doc__
to_matrix.__doc__ = data.as_matrix.__doc__
transpose.__doc__ = DataFrame.transpose.__doc__
__len__.__doc__ = DataFrame.size
__str__.__doc__ = DataFrame.__str__.__doc__
__repr__.__doc__ = DataFrame.__repr__.__doc__
__setstate__.__doc__ = DataFrame.__setstate__.__doc__
__getstate__.__doc__ = DataFrame.__getstate__.__doc__
if TYPE_CHECKING:
from .tools import Hilbert
from .phase import Phase
def shuffled_phase_surrogate(self):
"""
See :func:`~ECoG.tools.shuffled_phase_surrogate` for additional information.
:return: Signal with shuffled phases.
:rtype: BaseStructure
"""
from .surrogates import shuffled_phase_surrogate
return Signal(shuffled_phase_surrogate(self.data), fs=self.fs, channels=self.channels, time=self.time)
def correlated_phase_surrogate(self):
"""
See :func:`~ECoG.tools.correlated_noise_surrogate` for additional information.
:return: Signal with correlated surrogate phases.
:rtype: BaseStructure
"""
from .surrogates import correlated_noise_surrogate
return Signal(correlated_noise_surrogate(self), fs=self.fs, channels=self.channels, time=self.time)
@property
def hilbert(self) -> 'Hilbert':
"""
See :func:`~ECoG.tools.Hilbert` for additional information.
:return: Hilbert transformed signal.
:rtype: Hilbert
"""
from .tools import Hilbert
if self._hilbert is None:
self._hilbert = Hilbert(self.data, fs=self.fs)
return self._hilbert
def bandpower(self) -> Series:
"""
Produces the band power for each channel using the trapezoidal integral of
the absolute values of the Fourier transformed signals with the resultant
spectrum limited between :math:`[0, Fs/4]`.
See :func:`~ECoG.tools.spectral_power` for additional information.
:return: Powers of spectrum for each channel.
:rtype: Series
Example
-------
.. plot::
:context: close-figs
>>> from NeuroEnsemble.structure import Signal
>>> from NeuroEnsemble import bands
>>> from numpy.random import random
>>> signal_raw = random([1024, 5])
>>> SAMPLING_FREQ = 512
>>> signal = Signal(signal=signal_raw, channels=list('abcde'), SAMPLING_FREQ=SAMPLING_FREQ)
>>> theta_band = signal.filter_by_freq(bands.THETA)
>>> powers = theta_band.spectral_power()
>>> powers.plot('bar')
"""
from .tools import spectral_power
return spectral_power(self.data, self.fs)
def phases(self, thresholds: Tuple[Band]) -> 'Phase':
"""
:param thresholds:
:type thresholds:
:return:
:rtype:
"""
from .phase import Phase
return Phase(self.data, self.fs, thresholds)
def crosscorr_with(self, other: 'Signal', normed=True,
individually: bool=True, mode: str='full',
method: str='auto'):
"""
Cross-correlate two N-dimensional arrays.
Cross-correlate `in1` and `in2`, with the output size determined by the
`mode` argument.
The correlation z of two d-dimensional arrays x and y is defined as::
z[...,k,...] = sum[..., i_l, ...] x[..., i_l,...] * conj(y[..., i_l - k,...])
This way, if x and y are 1-D arrays and ``z = correlate(x, y, 'full')`` then
.. math::
z[k] &= (x * y)(k - N + 1) \\
&= \sum_{l=0}^{||x||-1}x_l y_{l-k+N-1}^{*}
for :math:`k = 0, 1, ..., ||x|| + ||y|| - 2`
where :math:`||x||` is the length of ``x``, :math:`N = \max(||x||,||y||)`,
and :math:`y_m` is 0 when m is outside the range of y.
``method='fft'`` only works for numerical arrays as it relies on
`fftconvolve`. In certain cases (i.e., arrays of objects or when
rounding integers can lose PRECISION), ``method='direct'`` is always used.
:param other: DataFrame to calculate the cross correlation with. Regarded
as `int2`.
:type other: Signal
:param mode: str {'full', 'valid', 'same'}, optional.
A string indicating the size of the output:
``full``
The output is the full discrete linear cross-correlation
of the inputs.
``valid``
The output consists only of those elements that do not
rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
must be at least as large as the other in every dimension.
``same``
The output is the same size as `in1`, centered
with respect to the 'full' output. (Default)
:type mode: str
:param method: str {'auto', 'direct', 'fft'}, optional
A string indicating which method to use to calculate the correlation.
``direct``
The correlation is determined directly from sums, the definition of
correlation.
``fft``
The Fast Fourier Transform is used to perform the correlation more
quickly (only available for numerical arrays.)
``auto``
Automatically chooses direct or Fourier method based on an estimate
of which is faster (default). See `convolve` Notes for more detail.
:type method: str
:return: Cross kendall
:rtype: Signal
Attributes
----------
Function uses the default [SciPy]_ implementation for cross correlation;
see `SciPy documentations`_ for additional information.
.. _Scipy documentations: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html
.. [SciPy] Jones E, Oliphant E, Peterson P, et al. SciPy: Open Source
Scientific Tools for Python, 2001-, http://www.scipy.org/ [Online; accessed 18/06/2017].
"""
from numpy import arange
if individually:
from .tools import xcorr
res = DataFrame()
lags = None
for channel in self.channels:
res[channel], lags = xcorr(
x=self[channel].as_matrix().ravel(),
y=other[channel].as_matrix().ravel(),
normalize='biased'
)
res.index = lags
else:
from scipy.signal import correlate
res = correlate(in1=self, in2=other, mode=mode, method=method)
dt = DataFrame(res)
# if normed:
# from numpy.linalg import norm
#
# dt = dt/norm(res)
# if mode == 'full':
# dt.index = arange(-max(other.shape), self.time.size + max(other.shape))
#
# # ToDo: Index of mode='same' to be defined.
return dt
def xcos_with(self, other):
from .tools import xcosine_similarity
from numpy import vectorize, float64
from numpy.linalg import norm
win = vectorize(xcosine_similarity, signature='(n),(m)->(n)')
xcos = win(self.as_matrix().T.astype(float64), other[list(self.channels)].as_matrix().T.astype(float64)).T
xcos = xcos / norm(xcos)
res = DataFrame(
xcos,
columns=self.channels,
index=self.time
)
return res
def factorize(self, n_components='auto'):
from sklearn.decomposition import TruncatedSVD
if n_components == 'auto':
n_components = self.index // 16
tsvd = TruncatedSVD(n_components=n_components, tol=self.as_matrix().std())
results = tsvd.fit_transform(self.as_matrix())
return Signal(results, fs=self.fs // 16, channels=self.channels)
Electrogram = Signal
def rec_plot(sig, eps=0.1, steps=10):
from scipy.spatial.distance import pdist, squareform
d = pdist(sig[:, None], metric='cosine')
d = floor(d / eps)
d[d > steps] = steps
z = squareform(d)
return z
def moving_average(sig, r=5):
from scipy.signal import convolve
return convolve(sig, ones((r,)) / r, mode='valid')