Skip to content

Commit

Permalink
Merge pull request #1630 from NeuralEnsemble/black-formatting
Browse files Browse the repository at this point in the history
Black formatting
  • Loading branch information
zm711 authored Jan 19, 2025
2 parents a75cd42 + e485357 commit e9a710d
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 86 deletions.
2 changes: 1 addition & 1 deletion neo/core/spiketrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def normalize_times_array(times, units=None, dtype=None, copy=None):
"In order to facilitate the deprecation copy can be set to None but will raise an "
"error if set to True/False since this will silently do nothing. This argument will be completely "
"removed in Neo 0.15.0. Please update your code base as necessary."
)
)

if dtype is None:
if not hasattr(times, "dtype"):
Expand Down
19 changes: 10 additions & 9 deletions neo/rawio/blackrockrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,7 @@ def _get_timestamp_slice(self, timestamp, seg_index, t_start, t_stop):
if t_start is None:
t_start = self._seg_t_starts[seg_index]
if t_stop is None:
t_stop = self._seg_t_stops[seg_index] + 1 / float(
self.__nev_basic_header['timestamp_resolution'])
t_stop = self._seg_t_stops[seg_index] + 1 / float(self.__nev_basic_header["timestamp_resolution"])

if t_start is None:
ind_start = None
Expand Down Expand Up @@ -715,15 +714,16 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start,
)
unit_spikes = all_spikes[mask]

wf_dtype = self.__nev_params('waveform_dtypes')[channel_id]
wf_size = self.__nev_params('waveform_size')[channel_id]
wf_dtype = self.__nev_params("waveform_dtypes")[channel_id]
wf_size = self.__nev_params("waveform_size")[channel_id]
wf_byte_size = np.dtype(wf_dtype).itemsize * wf_size

dt1 = [
('extra', 'S{}'.format(unit_spikes['waveform'].dtype.itemsize - wf_byte_size)),
('ch_waveform', 'S{}'.format(wf_byte_size))]
("extra", "S{}".format(unit_spikes["waveform"].dtype.itemsize - wf_byte_size)),
("ch_waveform", "S{}".format(wf_byte_size)),
]

waveforms = unit_spikes['waveform'].view(dt1)['ch_waveform'].flatten().view(wf_dtype)
waveforms = unit_spikes["waveform"].view(dt1)["ch_waveform"].flatten().view(wf_dtype)

waveforms = waveforms.reshape(int(unit_spikes.size), 1, int(wf_size))

Expand Down Expand Up @@ -1365,7 +1365,9 @@ def __match_nsx_and_nev_segment_ids(self, nsx_nb):

# Show warning if spikes do not fit any segment (+- 1 sampling 'tick')
# Spike should belong to segment before
mask_outside = (ev_ids == i) & (data["timestamp"] < int(seg["timestamp"]) - int(nsx_offset) - int(nsx_period))
mask_outside = (ev_ids == i) & (
data["timestamp"] < int(seg["timestamp"]) - int(nsx_offset) - int(nsx_period)
)

if len(data[mask_outside]) > 0:
warnings.warn(f"Spikes outside any segment. Detected on segment #{i}")
Expand Down Expand Up @@ -1995,7 +1997,6 @@ def __get_nsx_param_variant_a(self, nsx_nb):
else:
units = "uV"


nsx_parameters = {
"nb_data_points": int(
(self.__get_file_size(filename) - bytes_in_headers)
Expand Down
2 changes: 1 addition & 1 deletion neo/rawio/intanrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class IntanRawIO(BaseRawIO):
one long vector, which must be post-processed to extract individual digital channel information.
See the intantech website for more information on performing this post-processing.
Examples
--------
>>> import neo.rawio
Expand Down
38 changes: 16 additions & 22 deletions neo/rawio/micromedrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self, filename=""):

def _parse_header(self):


with open(self.filename, "rb") as fid:
f = StructFile(fid)

Expand Down Expand Up @@ -99,7 +98,6 @@ def _parse_header(self):
if zname != zname2.decode("ascii").strip(" "):
raise NeoReadWriteError("expected the zone name to match")


# "TRONCA" zone define segments
zname2, pos, length = zones["TRONCA"]
f.seek(pos)
Expand All @@ -114,7 +112,7 @@ def _parse_header(self):
break
else:
self.info_segments.append((seg_start, trace_offset))

if len(self.info_segments) == 0:
# one unique segment = general case
self.info_segments.append((0, 0))
Expand Down Expand Up @@ -152,8 +150,9 @@ def _parse_header(self):
(sampling_rate,) = f.read_f("H")
sampling_rate *= Rate_Min
chan_id = str(c)
signal_channels.append((chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id))

signal_channels.append(
(chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)
)

signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)

Expand All @@ -166,31 +165,31 @@ def _parse_header(self):
self._sampling_rate = float(np.unique(signal_channels["sampling_rate"])[0])

# memmap traces buffer
full_signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset)
full_signal_shape = get_memmap_shape(
self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset
)
seg_limits = [trace_offset for seg_start, trace_offset in self.info_segments] + [full_signal_shape[0]]
self._t_starts = []
self._buffer_descriptions = {0 :{}}
self._buffer_descriptions = {0: {}}
for seg_index in range(nb_segment):
seg_start, trace_offset = self.info_segments[seg_index]
self._t_starts.append(seg_start / self._sampling_rate)

start = seg_limits[seg_index]
stop = seg_limits[seg_index + 1]

shape = (stop - start, Num_Chan)
file_offset = Data_Start_Offset + ( start * np.dtype(sig_dtype).itemsize * Num_Chan)
file_offset = Data_Start_Offset + (start * np.dtype(sig_dtype).itemsize * Num_Chan)
self._buffer_descriptions[0][seg_index] = {}
self._buffer_descriptions[0][seg_index][buffer_id] = {
"type" : "raw",
"file_path" : str(self.filename),
"dtype" : sig_dtype,
"type": "raw",
"file_path": str(self.filename),
"dtype": sig_dtype,
"order": "C",
"file_offset" : file_offset,
"shape" : shape,
"file_offset": file_offset,
"shape": shape,
}



# Event channels
event_channels = []
event_channels.append(("Trigger", "", "event"))
Expand All @@ -217,14 +216,9 @@ def _parse_header(self):
for seg_index in range(nb_segment):
left_lim = seg_limits[seg_index]
right_lim = seg_limits[seg_index + 1]
keep = (
(rawevent["start"] >= left_lim)
& (rawevent["start"] < right_lim)
& (rawevent["start"] != 0)
)
keep = (rawevent["start"] >= left_lim) & (rawevent["start"] < right_lim) & (rawevent["start"] != 0)
self._raw_events[-1].append(rawevent[keep])


# No spikes
spike_channels = []
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
Expand Down
2 changes: 1 addition & 1 deletion neo/rawio/neuronexusrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, filename: str | Path = ""):
* The *.xdat.json metadata file
* The *_data.xdat binary file of all raw data
* The *_timestamps.xdat binary file of the timestamp data
From the metadata the other two files are located within the same directory
and loaded.
Expand Down
20 changes: 8 additions & 12 deletions neo/rawio/spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class SpikeGLXRawIO(BaseRawWithBufferApiIO):
* This IO reads the entire folder and subfolders locating the `.bin` and `.meta` files
* Handles gates and triggers as segments (based on the `_gt0`, `_gt1`, `_t0` , `_t1` in filenames)
* Handles all signals coming from different acquisition cards ("imec0", "imec1", etc) in a typical
PXIe chassis setup and also external signal like "nidq".
PXIe chassis setup and also external signal like "nidq".
* For imec devices both "ap" and "lf" are extracted so even a one device setup will have several "streams"
Examples
Expand Down Expand Up @@ -227,22 +227,19 @@ def _parse_header(self):

self._t_starts = {stream_name: {} for stream_name in stream_names}
self._t_stops = {seg_index: 0.0 for seg_index in range(nb_segment)}

for stream_name in stream_names:
for seg_index in range(nb_segment):
info = self.signals_info_dict[seg_index, stream_name]

frame_start = float(info["meta"]["firstSample"])
sampling_frequency = info["sampling_rate"]
t_start = frame_start / sampling_frequency
self._t_starts[stream_name][seg_index] = t_start

self._t_starts[stream_name][seg_index] = t_start
t_stop = info["sample_length"] / info["sampling_rate"]
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)




# fille into header dict
self.header = {}
self.header["nb_block"] = 1
Expand Down Expand Up @@ -361,24 +358,23 @@ def scan_files(dirname):
raise FileNotFoundError(f"No appropriate combination of .meta and .bin files were detected in {dirname}")

# This sets non-integers values before integers
normalize = lambda x: x if isinstance(x, int) else -1
normalize = lambda x: x if isinstance(x, int) else -1

# Segment index is determined by the gate_num and trigger_num in that order
def get_segment_tuple(info):
# Create a key from the normalized gate_num and trigger_num
gate_num = normalize(info.get("gate_num"))
trigger_num = normalize(info.get("trigger_num"))
return (gate_num, trigger_num)

unique_segment_tuples = {get_segment_tuple(info) for info in info_list}
sorted_keys = sorted(unique_segment_tuples)

# Map each unique key to a corresponding index
segment_tuple_to_segment_index = {key: idx for idx, key in enumerate(sorted_keys)}

for info in info_list:
info["seg_index"] = segment_tuple_to_segment_index[get_segment_tuple(info)]

info["seg_index"] = segment_tuple_to_segment_index[get_segment_tuple(info)]

# Probe index calculation
# The calculation is ordered by slot, port, dock in that order, this is the number that appears in the filename
Expand Down Expand Up @@ -409,7 +405,7 @@ def get_probe_tuple(info):
stream_name = f"{device_kind}{device_index}{stream_kind}"

info["stream_name"] = stream_name

return info_list


Expand Down
8 changes: 4 additions & 4 deletions neo/test/rawiotest/test_micromedrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np


class TestMicromedRawIO(
BaseTestRawIO,
unittest.TestCase,
Expand All @@ -25,15 +26,15 @@ class TestMicromedRawIO(
def test_micromed_multi_segments(self):
file_full = self.get_local_path("micromed/File_mircomed2.TRC")
file_splitted = self.get_local_path("micromed/File_mircomed2_2segments.TRC")

# the second file contains 2 pieces of the first file
# so it is 2 segments with the same traces but reduced
# note that traces in the splited can differ at the very end of the cut

reader1 = MicromedRawIO(file_full)
reader1.parse_header()
assert reader1.segment_count(block_index=0) == 1
assert reader1.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) == 0.
assert reader1.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) == 0.0
traces1 = reader1.get_analogsignal_chunk(stream_index=0)

reader2 = MicromedRawIO(file_splitted)
Expand All @@ -48,11 +49,10 @@ def test_micromed_multi_segments(self):
sr = reader2.get_signal_sampling_rate(stream_index=0)
ind_start = int(t_start * sr)
traces2 = reader2.get_analogsignal_chunk(block_index=0, seg_index=seg_index, stream_index=0)
traces1_chunk = traces1[ind_start: ind_start+traces2.shape[0]]
traces1_chunk = traces1[ind_start : ind_start + traces2.shape[0]]
# we remove the last 100 sample because tools that cut traces is truncating the last buffer
assert np.array_equal(traces2[:-100], traces1_chunk[:-100])



if __name__ == "__main__":
unittest.main()
49 changes: 13 additions & 36 deletions neo/test/rawiotest/test_spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,63 +114,40 @@ def test_nidq_digital_channel(self):

def test_t_start_reading(self):
"""Test that t_start values are correctly read for all streams and segments."""

# Expected t_start values for each stream and segment
expected_t_starts = {
'imec0.ap': {
0: 15.319535472007237,
1: 15.339535431281986,
2: 21.284723325294053,
3: 21.3047232845688
},
'imec1.ap': {
0: 15.319554693264516,
1: 15.339521518106308,
2: 21.284735282142822,
3: 21.304702106984614
},
'imec0.lf': {
0: 15.3191688060872,
1: 15.339168765361949,
2: 21.284356659374016,
3: 21.304356618648765
},
'imec1.lf': {
0: 15.319321358082725,
1: 15.339321516521915,
2: 21.284568614155827,
3: 21.30456877259502
}
"imec0.ap": {0: 15.319535472007237, 1: 15.339535431281986, 2: 21.284723325294053, 3: 21.3047232845688},
"imec1.ap": {0: 15.319554693264516, 1: 15.339521518106308, 2: 21.284735282142822, 3: 21.304702106984614},
"imec0.lf": {0: 15.3191688060872, 1: 15.339168765361949, 2: 21.284356659374016, 3: 21.304356618648765},
"imec1.lf": {0: 15.319321358082725, 1: 15.339321516521915, 2: 21.284568614155827, 3: 21.30456877259502},
}

# Initialize the RawIO
rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4"))
rawio.parse_header()

# Get list of stream names
stream_names = rawio.header["signal_streams"]["name"]

# Test t_start for each stream and segment
for stream_name, expected_values in expected_t_starts.items():
# Get stream index
stream_index = list(stream_names).index(stream_name)

# Check each segment
for seg_index, expected_t_start in expected_values.items():
actual_t_start = rawio.get_signal_t_start(
block_index=0,
seg_index=seg_index,
stream_index=stream_index
)

actual_t_start = rawio.get_signal_t_start(block_index=0, seg_index=seg_index, stream_index=stream_index)

# Use numpy.testing for proper float comparison
np.testing.assert_allclose(
actual_t_start,
expected_t_start,
rtol=1e-9,
atol=1e-9,
err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}"
err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit e9a710d

Please sign in to comment.