diff --git a/src/neuroconv/tools/nwb_helpers/__init__.py b/src/neuroconv/tools/nwb_helpers/__init__.py index 6221aead6..87a1e2081 100644 --- a/src/neuroconv/tools/nwb_helpers/__init__.py +++ b/src/neuroconv/tools/nwb_helpers/__init__.py @@ -5,7 +5,9 @@ from ._backend_configuration import ( BACKEND_CONFIGURATIONS, + BACKEND_NWB_IO, get_default_backend_configuration, + get_existing_backend_configuration, ) from ._configuration_models import DATASET_IO_CONFIGURATIONS from ._configuration_models._base_backend import BackendConfiguration @@ -21,15 +23,15 @@ ZarrDatasetIOConfiguration, ) from ._configure_backend import configure_backend -from ._dataset_configuration import get_default_dataset_io_configurations +from ._dataset_configuration import get_default_dataset_io_configurations, get_existing_dataset_io_configurations from ._metadata_and_file_helpers import ( - BACKEND_NWB_IO, add_device_from_metadata, configure_and_write_nwbfile, get_default_nwbfile_metadata, get_module, make_nwbfile_from_metadata, make_or_load_nwbfile, + repack_nwbfile, ) __all__ = [ @@ -46,6 +48,8 @@ "ZarrDatasetIOConfiguration", "get_default_backend_configuration", "get_default_dataset_io_configurations", + "get_existing_backend_configuration", + "get_existing_dataset_io_configurations", "configure_backend", "get_default_dataset_io_configurations", "get_default_backend_configuration", @@ -55,4 +59,5 @@ "get_module", "make_nwbfile_from_metadata", "make_or_load_nwbfile", + "repack_nwbfile", ] diff --git a/src/neuroconv/tools/nwb_helpers/_backend_configuration.py b/src/neuroconv/tools/nwb_helpers/_backend_configuration.py index 8cb465c76..860d1a2cf 100644 --- a/src/neuroconv/tools/nwb_helpers/_backend_configuration.py +++ b/src/neuroconv/tools/nwb_helpers/_backend_configuration.py @@ -2,12 +2,14 @@ from typing import Literal, Union -from pynwb import NWBFile +from hdmf_zarr import NWBZarrIO +from pynwb import NWBHDF5IO, NWBFile from ._configuration_models._hdf5_backend import HDF5BackendConfiguration from ._configuration_models._zarr_backend import ZarrBackendConfiguration BACKEND_CONFIGURATIONS = dict(hdf5=HDF5BackendConfiguration, zarr=ZarrBackendConfiguration) +BACKEND_NWB_IO = dict(hdf5=NWBHDF5IO, zarr=NWBZarrIO) def get_default_backend_configuration( @@ -17,3 +19,25 @@ def get_default_backend_configuration( BackendConfigurationClass = BACKEND_CONFIGURATIONS[backend] return BackendConfigurationClass.from_nwbfile(nwbfile=nwbfile) + + +def get_existing_backend_configuration(nwbfile: NWBFile) -> Union[HDF5BackendConfiguration, ZarrBackendConfiguration]: + """Fill an existing backend configuration to serve as a starting point for further customization. + + Parameters + ---------- + nwbfile : NWBFile + The NWBFile object to extract the backend configuration from. The nwbfile must have been read from an io object + to work properly. + + Returns + ------- + Union[HDF5BackendConfiguration, ZarrBackendConfiguration] + The backend configuration extracted from the nwbfile. + """ + read_io = nwbfile.read_io + for backend, io in BACKEND_NWB_IO.items(): + if isinstance(read_io, io): + break + BackendConfigurationClass = BACKEND_CONFIGURATIONS[backend] + return BackendConfigurationClass.from_nwbfile(nwbfile=nwbfile, use_default_dataset_io_configurations=False) diff --git a/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_backend.py b/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_backend.py index 2c07a1bb0..241da1d66 100644 --- a/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_backend.py +++ b/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_backend.py @@ -9,7 +9,10 @@ from ._base_dataset_io import DatasetIOConfiguration from ._pydantic_pure_json_schema_generator import PureJSONSchemaGenerator -from .._dataset_configuration import get_default_dataset_io_configurations +from .._dataset_configuration import ( + get_default_dataset_io_configurations, + get_existing_dataset_io_configurations, +) class BackendConfiguration(BaseModel): @@ -56,11 +59,31 @@ def model_json_schema(cls, **kwargs) -> Dict[str, Any]: return super().model_json_schema(mode="validation", schema_generator=PureJSONSchemaGenerator, **kwargs) @classmethod - def from_nwbfile(cls, nwbfile: NWBFile) -> Self: - default_dataset_configurations = get_default_dataset_io_configurations(nwbfile=nwbfile, backend=cls.backend) + def from_nwbfile(cls, nwbfile: NWBFile, use_default_dataset_io_configurations: bool = True) -> Self: + """ + Create a backend configuration from an NWBFile. + + Parameters + ---------- + nwbfile : pynwb.NWBFile + The NWBFile object to extract the backend configuration from. + use_default_dataset_io_configurations : bool, optional + Whether to use default dataset configurations, by default True. If False, the existing dataset + configurations in the NWBFile will be used, which requires that the NWBFile was read from an io object. + + Returns + ------- + Self + The backend configuration extracted from the NWBFile. + """ + + if use_default_dataset_io_configurations: + dataset_io_configurations = get_default_dataset_io_configurations(nwbfile=nwbfile, backend=cls.backend) + else: + dataset_io_configurations = get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=cls.backend) dataset_configurations = { default_dataset_configuration.location_in_file: default_dataset_configuration - for default_dataset_configuration in default_dataset_configurations + for default_dataset_configuration in dataset_io_configurations } return cls(dataset_configurations=dataset_configurations) diff --git a/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_dataset_io.py b/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_dataset_io.py index 01e291034..20fcf9d0c 100644 --- a/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_dataset_io.py +++ b/src/neuroconv/tools/nwb_helpers/_configuration_models/_base_dataset_io.py @@ -147,7 +147,6 @@ def __str__(self) -> str: """ size_in_bytes = math.prod(self.full_shape) * self.dtype.itemsize maximum_ram_usage_per_iteration_in_bytes = math.prod(self.buffer_shape) * self.dtype.itemsize - disk_space_usage_per_chunk_in_bytes = math.prod(self.chunk_shape) * self.dtype.itemsize string = ( f"\n{self.location_in_file}" @@ -159,10 +158,14 @@ def __str__(self) -> str: f"\n buffer shape : {self.buffer_shape}" f"\n expected RAM usage : {human_readable_size(maximum_ram_usage_per_iteration_in_bytes)}" "\n" - f"\n chunk shape : {self.chunk_shape}" - f"\n disk space usage per chunk : {human_readable_size(disk_space_usage_per_chunk_in_bytes)}" - "\n" ) + if self.chunk_shape is not None: + disk_space_usage_per_chunk_in_bytes = math.prod(self.chunk_shape) * self.dtype.itemsize + string += ( + f"\n chunk shape : {self.chunk_shape}" + f"\n disk space usage per chunk : {human_readable_size(disk_space_usage_per_chunk_in_bytes)}" + "\n" + ) if self.compression_method is not None: string += f"\n compression method : {self.compression_method}" if self.compression_options is not None: @@ -182,9 +185,9 @@ def validate_all_shapes(cls, values: Dict[str, Any]) -> Dict[str, Any]: dataset_name == location_in_file.split("/")[-1] ), f"The `dataset_name` ({dataset_name}) does not match the end of the `location_in_file` ({location_in_file})!" - chunk_shape = values["chunk_shape"] - buffer_shape = values["buffer_shape"] full_shape = values["full_shape"] + chunk_shape = values["chunk_shape"] if values["chunk_shape"] is not None else full_shape + buffer_shape = values["buffer_shape"] if values["buffer_shape"] is not None else full_shape if len(chunk_shape) != len(buffer_shape): raise ValueError( diff --git a/src/neuroconv/tools/nwb_helpers/_configuration_models/_hdf5_dataset_io.py b/src/neuroconv/tools/nwb_helpers/_configuration_models/_hdf5_dataset_io.py index 828a37998..4a3180b60 100644 --- a/src/neuroconv/tools/nwb_helpers/_configuration_models/_hdf5_dataset_io.py +++ b/src/neuroconv/tools/nwb_helpers/_configuration_models/_hdf5_dataset_io.py @@ -3,9 +3,13 @@ from typing import Any, Dict, Literal, Union import h5py +import numpy as np +from hdmf import Container from pydantic import Field, InstanceOf +from typing_extensions import Self -from ._base_dataset_io import DatasetIOConfiguration +from ._base_dataset_io import DatasetIOConfiguration, _find_location_in_memory_nwbfile +from ...hdmf import SliceableDataChunkIterator from ...importing import is_package_installed _base_hdf5_filters = set(h5py.filters.decode) @@ -78,3 +82,37 @@ def get_data_io_kwargs(self) -> Dict[str, Any]: compression_bundle = dict(compression=self.compression_method, compression_opts=compression_opts) return dict(chunks=self.chunk_shape, **compression_bundle) + + @classmethod + def from_neurodata_object( + cls, + neurodata_object: Container, + dataset_name: Literal["data", "timestamps"], + use_default_dataset_io_configuration: bool = True, + ) -> Self: + if use_default_dataset_io_configuration: + return super().from_neurodata_object(neurodata_object=neurodata_object, dataset_name=dataset_name) + + location_in_file = _find_location_in_memory_nwbfile(neurodata_object=neurodata_object, field_name=dataset_name) + full_shape = getattr(neurodata_object, dataset_name).shape + dtype = getattr(neurodata_object, dataset_name).dtype + chunk_shape = getattr(neurodata_object, dataset_name).chunks + buffer_chunk_shape = chunk_shape or full_shape + buffer_shape = SliceableDataChunkIterator.estimate_default_buffer_shape( + buffer_gb=0.5, chunk_shape=buffer_chunk_shape, maxshape=full_shape, dtype=np.dtype(dtype) + ) + compression_method = getattr(neurodata_object, dataset_name).compression + compression_opts = getattr(neurodata_object, dataset_name).compression_opts + compression_options = dict(compression_opts=compression_opts) + return cls( + object_id=neurodata_object.object_id, + object_name=neurodata_object.name, + location_in_file=location_in_file, + dataset_name=dataset_name, + full_shape=full_shape, + dtype=dtype, + chunk_shape=chunk_shape, + buffer_shape=buffer_shape, + compression_method=compression_method, + compression_options=compression_options, + ) diff --git a/src/neuroconv/tools/nwb_helpers/_configuration_models/_zarr_dataset_io.py b/src/neuroconv/tools/nwb_helpers/_configuration_models/_zarr_dataset_io.py index c070a20e9..3112f5480 100644 --- a/src/neuroconv/tools/nwb_helpers/_configuration_models/_zarr_dataset_io.py +++ b/src/neuroconv/tools/nwb_helpers/_configuration_models/_zarr_dataset_io.py @@ -1,12 +1,15 @@ """Base Pydantic models for the ZarrDatasetConfiguration.""" -from typing import Any, Dict, List, Literal, Union +from typing import Any, Dict, List, Literal, Self, Union import numcodecs +import numpy as np import zarr +from hdmf import Container from pydantic import Field, InstanceOf, model_validator -from ._base_dataset_io import DatasetIOConfiguration +from ._base_dataset_io import DatasetIOConfiguration, _find_location_in_memory_nwbfile +from ...hdmf import SliceableDataChunkIterator _base_zarr_codecs = set(zarr.codec_registry.keys()) _lossy_zarr_codecs = set(("astype", "bitround", "quantize")) @@ -130,3 +133,36 @@ def get_data_io_kwargs(self) -> Dict[str, Any]: compressor = False return dict(chunks=self.chunk_shape, filters=filters, compressor=compressor) + + @classmethod + def from_neurodata_object( + cls, + neurodata_object: Container, + dataset_name: Literal["data", "timestamps"], + use_default_dataset_io_configuration: bool = True, + ) -> Self: + if use_default_dataset_io_configuration: + return super().from_neurodata_object(neurodata_object=neurodata_object, dataset_name=dataset_name) + + location_in_file = _find_location_in_memory_nwbfile(neurodata_object=neurodata_object, field_name=dataset_name) + full_shape = getattr(neurodata_object, dataset_name).shape + dtype = getattr(neurodata_object, dataset_name).dtype + chunk_shape = getattr(neurodata_object, dataset_name).chunks + buffer_chunk_shape = chunk_shape or full_shape + buffer_shape = SliceableDataChunkIterator.estimate_default_buffer_shape( + buffer_gb=0.5, chunk_shape=buffer_chunk_shape, maxshape=full_shape, dtype=np.dtype(dtype) + ) + compression_method = getattr(neurodata_object, dataset_name).compressor + filter_methods = getattr(neurodata_object, dataset_name).filters + return cls( + object_id=neurodata_object.object_id, + object_name=neurodata_object.name, + location_in_file=location_in_file, + dataset_name=dataset_name, + full_shape=full_shape, + dtype=dtype, + chunk_shape=chunk_shape, + buffer_shape=buffer_shape, + compression_method=compression_method, + filter_methods=filter_methods, + ) diff --git a/src/neuroconv/tools/nwb_helpers/_configure_backend.py b/src/neuroconv/tools/nwb_helpers/_configure_backend.py index a67308d43..0fcbe5756 100644 --- a/src/neuroconv/tools/nwb_helpers/_configure_backend.py +++ b/src/neuroconv/tools/nwb_helpers/_configure_backend.py @@ -4,6 +4,7 @@ from typing import Union from hdmf.common import Data +from hdmf.data_utils import DataChunkIterator from pynwb import NWBFile, TimeSeries from ._configuration_models._hdf5_backend import HDF5BackendConfiguration @@ -46,16 +47,24 @@ def configure_backend( # Table columns if isinstance(neurodata_object, Data): - neurodata_object.set_data_io(data_io_class=data_io_class, data_io_kwargs=data_io_kwargs) + neurodata_object.set_data_io( + data_io_class=data_io_class, data_io_kwargs=data_io_kwargs, data_chunk_iterator_class=DataChunkIterator + ) # TimeSeries data or timestamps elif isinstance(neurodata_object, TimeSeries) and not is_dataset_linked: neurodata_object.set_data_io( - dataset_name=dataset_name, data_io_class=data_io_class, data_io_kwargs=data_io_kwargs + dataset_name=dataset_name, + data_io_class=data_io_class, + data_io_kwargs=data_io_kwargs, + data_chunk_iterator_class=DataChunkIterator, ) # Special ndx-events v0.2.0 types elif is_ndx_events_installed and isinstance(neurodata_object, ndx_events.Events): neurodata_object.set_data_io( - dataset_name=dataset_name, data_io_class=data_io_class, data_io_kwargs=data_io_kwargs + dataset_name=dataset_name, + data_io_class=data_io_class, + data_io_kwargs=data_io_kwargs, + data_chunk_iterator_class=DataChunkIterator, ) # But temporarily skipping LabeledEvents elif is_ndx_events_installed and isinstance(neurodata_object, ndx_events.LabeledEvents): diff --git a/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py b/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py index f3d8e7560..10f788cf7 100644 --- a/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py +++ b/src/neuroconv/tools/nwb_helpers/_dataset_configuration.py @@ -172,3 +172,82 @@ def get_default_dataset_io_configurations( ) yield dataset_io_configuration + + +def get_existing_dataset_io_configurations( + nwbfile: NWBFile, + backend: Literal["hdf5", "zarr"], +) -> Generator[DatasetIOConfiguration, None, None]: + """ + Generate DatasetIOConfiguration objects for each neurodata object in an nwbfile. + + Parameters + ---------- + nwbfile : pynwb.NWBFile + An NWBFile object that has been read from an existing file with an existing backend configuration. + backend : "hdf5" or "zarr" + Which backend format type you would like to use in configuring each dataset's compression methods and options. + + Yields + ------ + DatasetIOConfiguration + A configuration object for each dataset in the NWB file. + """ + + DatasetIOConfigurationClass = DATASET_IO_CONFIGURATIONS[backend] + + known_dataset_fields = ("data", "timestamps") + for neurodata_object in nwbfile.objects.values(): + if isinstance(neurodata_object, DynamicTable): + dynamic_table = neurodata_object # For readability + + for column in dynamic_table.columns: + candidate_dataset = column.data # VectorData object + + # Skip over columns whose values are links, such as the 'group' of an ElectrodesTable + if any(isinstance(value, Container) for value in candidate_dataset): + continue # Skip + + # Skip when columns whose values are a reference type + if isinstance(column, TimeSeriesReferenceVectorData): + continue + + # Skip datasets with any zero-length axes + dataset_name = "data" + candidate_dataset = getattr(column, dataset_name) + full_shape = get_data_shape(data=candidate_dataset) + if any(axis_length == 0 for axis_length in full_shape): + continue + + dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object( + neurodata_object=column, + dataset_name=dataset_name, + use_default_dataset_io_configuration=False, + ) + + yield dataset_io_configuration + elif isinstance(neurodata_object, NWBContainer): + for known_dataset_field in known_dataset_fields: + # Skip optional fields that aren't present + if known_dataset_field not in neurodata_object.fields: + continue + + candidate_dataset = getattr(neurodata_object, known_dataset_field) + + # Skip edge case of in-memory ImageSeries with external mode; data is in fields and is empty array + if isinstance(candidate_dataset, np.ndarray) and candidate_dataset.size == 0: + continue + + # Skip datasets with any zero-length axes + candidate_dataset = getattr(neurodata_object, known_dataset_field) + full_shape = get_data_shape(data=candidate_dataset) + if any(axis_length == 0 for axis_length in full_shape): + continue + + dataset_io_configuration = DatasetIOConfigurationClass.from_neurodata_object( + neurodata_object=neurodata_object, + dataset_name=known_dataset_field, + use_default_dataset_io_configuration=False, + ) + + yield dataset_io_configuration diff --git a/src/neuroconv/tools/nwb_helpers/_metadata_and_file_helpers.py b/src/neuroconv/tools/nwb_helpers/_metadata_and_file_helpers.py index c3aaea48d..b92bf10ad 100644 --- a/src/neuroconv/tools/nwb_helpers/_metadata_and_file_helpers.py +++ b/src/neuroconv/tools/nwb_helpers/_metadata_and_file_helpers.py @@ -10,17 +10,20 @@ from typing import Literal, Optional from warnings import warn -from hdmf_zarr import NWBZarrIO from pydantic import FilePath -from pynwb import NWBHDF5IO, NWBFile +from pynwb import NWBFile from pynwb.file import Subject -from . import BackendConfiguration, configure_backend, get_default_backend_configuration +from . import ( + BACKEND_NWB_IO, + BackendConfiguration, + configure_backend, + get_default_backend_configuration, + get_existing_backend_configuration, +) from ...utils.dict import DeepDict, load_dict_from_file from ...utils.json_schema import validate_metadata -BACKEND_NWB_IO = dict(hdf5=NWBHDF5IO, zarr=NWBZarrIO) - def get_module(nwbfile: NWBFile, name: str, description: str = None): """Check if processing module exists. If not, create it. Then return module.""" @@ -337,6 +340,7 @@ def configure_and_write_nwbfile( output_filepath: str, backend: Optional[Literal["hdf5"]] = None, backend_configuration: Optional[BackendConfiguration] = None, + export: bool = False, ) -> None: """ Write an NWB file using a specific backend or backend configuration. @@ -355,6 +359,8 @@ def configure_and_write_nwbfile( backend_configuration: BackendConfiguration, optional Specifies the backend type and the chunking and compression parameters of each dataset. If no ``backend_configuration`` is specified, the default configuration for the specified ``backend`` is used. + export: bool, default: False + Whether to export the NWB file instead of writing. """ @@ -369,4 +375,69 @@ def configure_and_write_nwbfile( IO = BACKEND_NWB_IO[backend_configuration.backend] with IO(output_filepath, mode="w") as io: - io.write(nwbfile) + if export: + nwbfile.set_modified() + io.export(nwbfile=nwbfile, src_io=nwbfile.read_io, write_args=dict(link_data=False)) + else: + io.write(nwbfile) + + +def repack_nwbfile( + *, + nwbfile_path: Path, + export_nwbfile_path: Path, + backend: Literal["hdf5", "zarr"] = "hdf5", + export_backend: Literal["hdf5", "zarr", None] = None, + use_default_backend_configuration: bool = True, + backend_configuration_changes: dict[str, dict] = None, +): + """ + Repack an NWBFile with a new backend configuration. + + Parameters + ---------- + nwbfile_path : Path + Path to the NWB file to be repacked. + export_nwbfile_path : Path + Path to export the repacked NWB file. + backend : {"hdf5", "zarr"}, default: "hdf5" + The type of backend used to read the file. + export_backend : {"hdf5", "zarr", None}, default: None + The type of backend used to write the repacked file. If None, the same backend as the input file is used. + use_default_backend_configuration : bool, default: True + Whether to use the default backend configuration for the specified backend and nwbfile. If False, the nwbfile + must be written to disk and its existing backend configuration is used. + backend_configuration_changes : dict, default: None + Changes to the backend configuration. The keys are the locations of the datasets in the NWB file, and the values + are dictionaries of the changes to be made to the dataset configuration. + + Notes + ----- + The keys for the `backend_configuration_changes` must be as they appear in the BackendConfiguration NOT how they + appear in the H5DataIO. For example, if you want to change the chunking of the 'acquisition/RawTimeSeries/data' + dataset to (10,), you would pass {'acquisition/RawTimeSeries/data': {'chunk_shape': (10,)}}. + """ + backend_configuration_changes = backend_configuration_changes or dict() + export_backend = export_backend or backend + + IO = BACKEND_NWB_IO[backend] + with IO(nwbfile_path, mode="r") as io: + nwbfile = io.read() + if use_default_backend_configuration: + backend_configuration = get_default_backend_configuration(nwbfile=nwbfile, backend=backend) + else: + backend_configuration = get_existing_backend_configuration(nwbfile=nwbfile) + dataset_configurations = backend_configuration.dataset_configurations + + for neurodata_object_location, dataset_config_changes in backend_configuration_changes.items(): + dataset_configuration = dataset_configurations[neurodata_object_location] + for dataset_config_key, dataset_config_value in dataset_config_changes.items(): + setattr(dataset_configuration, dataset_config_key, dataset_config_value) + + configure_and_write_nwbfile( + nwbfile=nwbfile, + backend_configuration=backend_configuration, + output_filepath=export_nwbfile_path, + backend=export_backend, + export=True, + ) diff --git a/temp_test.py b/temp_test.py new file mode 100644 index 000000000..74bd0294c --- /dev/null +++ b/temp_test.py @@ -0,0 +1,63 @@ +import os +import shutil +from pathlib import Path + +import numpy as np +from hdmf_zarr import ZarrDataIO +from hdmf_zarr.nwb import NWBZarrIO +from pynwb import NWBHDF5IO, H5DataIO, TimeSeries +from pynwb.testing.mock.file import mock_NWBFile + +from neuroconv.tools.nwb_helpers import repack_nwbfile + + +def write_nwbfile(nwbfile_path: Path, backend: str = "hdf5"): + if nwbfile_path.exists(): + os.remove(nwbfile_path) + nwbfile = mock_NWBFile() + timestamps = np.arange(10.0) + data = np.arange(100, 200, 10) + if backend == "hdf5": + data = H5DataIO(data=data, compression="gzip", chunks=(1,), compression_opts=2) + elif backend == "zarr": + data = ZarrDataIO(data=data, chunks=(3,), compressor=True) + time_series_with_timestamps = TimeSeries( + name="test_timeseries", + description="an example time series", + data=data, + unit="m", + timestamps=timestamps, + ) + nwbfile.add_acquisition(time_series_with_timestamps) + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(nwbfile_path), mode="w") as io: + io.write(nwbfile) + + +def main(): + nwbfile_path = Path("temp.nwb.zarr") + repacked_nwbfile_path = Path("repacked_temp.nwb.zarr") + if repacked_nwbfile_path.exists(): + if repacked_nwbfile_path.is_dir(): + shutil.rmtree(repacked_nwbfile_path) + else: + os.remove(repacked_nwbfile_path) + if not nwbfile_path.exists(): + write_nwbfile(nwbfile_path, backend="zarr") + + backend_configuration_changes = {"acquisition/test_timeseries/data": dict(chunk_shape=(2,))} + repack_nwbfile( + nwbfile_path=str(nwbfile_path), + export_nwbfile_path=str(repacked_nwbfile_path), + backend="zarr", + backend_configuration_changes=backend_configuration_changes, + use_default_backend_configuration=False, + ) + + with NWBZarrIO(str(repacked_nwbfile_path), mode="r") as io: + nwbfile = io.read() + print(f'{nwbfile.acquisition["test_timeseries"].data.chunks = }') + + +if __name__ == "__main__": + main() diff --git a/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_get_existing_backend_configuration.py b/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_get_existing_backend_configuration.py new file mode 100644 index 000000000..1497523cb --- /dev/null +++ b/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_get_existing_backend_configuration.py @@ -0,0 +1,317 @@ +"""Integration tests for `get_existing_backend_configuration`.""" + +from io import StringIO +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest +from hdmf_zarr import ZarrDataIO +from hdmf_zarr.nwb import NWBZarrIO +from numcodecs import Blosc +from pynwb import NWBHDF5IO, H5DataIO, NWBFile +from pynwb.testing.mock.base import mock_TimeSeries +from pynwb.testing.mock.file import mock_NWBFile + +from neuroconv.tools.nwb_helpers import ( + HDF5BackendConfiguration, + ZarrBackendConfiguration, + get_existing_backend_configuration, + get_module, +) + + +def generate_complex_nwbfile() -> NWBFile: + nwbfile = mock_NWBFile() + + raw_array = np.array([[1, 2, 3], [4, 5, 6]]) + raw_time_series = mock_TimeSeries(name="RawTimeSeries", data=raw_array) + nwbfile.add_acquisition(raw_time_series) + + number_of_trials = 10 + for start_time, stop_time in zip( + np.linspace(start=0.0, stop=10.0, num=number_of_trials), np.linspace(start=1.0, stop=11.0, num=number_of_trials) + ): + nwbfile.add_trial(start_time=start_time, stop_time=stop_time) + + ecephys_module = get_module(nwbfile=nwbfile, name="ecephys") + processed_array = np.array([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0], [13.0, 14.0]]) + processed_time_series = mock_TimeSeries(name="ProcessedTimeSeries", data=processed_array) + ecephys_module.add(processed_time_series) + + return nwbfile + + +@pytest.fixture(scope="session") +def hdf5_nwbfile_path(tmpdir_factory): + nwbfile_path = tmpdir_factory.mktemp("data").join("test_existing_backend_configuration_hdf5_nwbfile.nwb.h5") + if not Path(nwbfile_path).exists(): + nwbfile = generate_complex_nwbfile() + + # Add a H5DataIO-compressed time series + raw_array = np.array([[11, 21, 31], [41, 51, 61]], dtype="int32") + data = H5DataIO(data=raw_array, compression="gzip", compression_opts=2) + raw_time_series = mock_TimeSeries(name="CompressedRawTimeSeries", data=data) + nwbfile.add_acquisition(raw_time_series) + + # Add H5DataIO-compressed trials column + number_of_trials = 10 + start_time = np.linspace(start=0.0, stop=10.0, num=number_of_trials) + nwbfile.add_trial_column( + name="compressed_start_time", + description="start time of epoch", + data=H5DataIO(data=start_time, compression="gzip", compression_opts=2), + ) + + with NWBHDF5IO(path=str(nwbfile_path), mode="w") as io: + io.write(nwbfile) + return str(nwbfile_path) + + +@pytest.fixture(scope="session") +def zarr_nwbfile_path(tmpdir_factory): + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + + nwbfile_path = tmpdir_factory.mktemp("data").join("test_default_backend_configuration_hdf5_nwbfile.nwb.zarr") + if not Path(nwbfile_path).exists(): + nwbfile = generate_complex_nwbfile() + + # Add a ZarrDataIO-compressed time series + raw_array = np.array([[11, 21, 31], [41, 51, 61]], dtype="int32") + data = ZarrDataIO(data=raw_array, chunks=(1, 3), compressor=compressor, filters=filters) + raw_time_series = mock_TimeSeries(name="CompressedRawTimeSeries", data=data) + nwbfile.add_acquisition(raw_time_series) + + # Add ZarrDataIO-compressed trials column + number_of_trials = 10 + start_time = np.linspace(start=0.0, stop=10.0, num=number_of_trials) + data = ZarrDataIO(data=start_time, chunks=(5,), compressor=compressor, filters=filters) + nwbfile.add_trial_column( + name="compressed_start_time", + description="start time of epoch", + data=data, + ) + + with NWBZarrIO(path=str(nwbfile_path), mode="w") as io: + io.write(nwbfile) + return str(nwbfile_path) + + +def test_complex_hdf5(hdf5_nwbfile_path): + with NWBHDF5IO(path=hdf5_nwbfile_path, mode="a") as io: + nwbfile = io.read() + backend_configuration = get_existing_backend_configuration(nwbfile=nwbfile) + + assert isinstance(backend_configuration, HDF5BackendConfiguration) + + dataset_configurations = backend_configuration.dataset_configurations + assert len(dataset_configurations) == 6 + + # Best summary test of expected output is the printout + with patch("sys.stdout", new=StringIO()) as stdout: + print(backend_configuration) + + expected_print = """ +HDF5 dataset configurations +--------------------------- + +intervals/trials/start_time/data +-------------------------------- + dtype : float64 + full shape of source array : (10,) + full size of source array : 80 B + + buffer shape : (10,) + expected RAM usage : 80 B + + compression options : {'compression_opts': None} + + +intervals/trials/stop_time/data +------------------------------- + dtype : float64 + full shape of source array : (10,) + full size of source array : 80 B + + buffer shape : (10,) + expected RAM usage : 80 B + + compression options : {'compression_opts': None} + + +intervals/trials/compressed_start_time/data +------------------------------------------- + dtype : float64 + full shape of source array : (10,) + full size of source array : 80 B + + buffer shape : (10,) + expected RAM usage : 80 B + + chunk shape : (10,) + disk space usage per chunk : 80 B + + compression method : gzip + compression options : {'compression_opts': 2} + + +processing/ecephys/ProcessedTimeSeries/data +------------------------------------------- + dtype : float64 + full shape of source array : (4, 2) + full size of source array : 64 B + + buffer shape : (4, 2) + expected RAM usage : 64 B + + compression options : {'compression_opts': None} + + +acquisition/RawTimeSeries/data +------------------------------ + dtype : int64 + full shape of source array : (2, 3) + full size of source array : 48 B + + buffer shape : (2, 3) + expected RAM usage : 48 B + + compression options : {'compression_opts': None} + + +acquisition/CompressedRawTimeSeries/data +---------------------------------------- + dtype : int32 + full shape of source array : (2, 3) + full size of source array : 24 B + + buffer shape : (2, 3) + expected RAM usage : 24 B + + chunk shape : (2, 3) + disk space usage per chunk : 24 B + + compression method : gzip + compression options : {'compression_opts': 2} + +""" + assert stdout.getvalue() == expected_print + + +def test_complex_zarr(zarr_nwbfile_path): + with NWBZarrIO(path=zarr_nwbfile_path, mode="a") as io: + nwbfile = io.read() + backend_configuration = get_existing_backend_configuration(nwbfile=nwbfile) + + assert isinstance(backend_configuration, ZarrBackendConfiguration) + + dataset_configurations = backend_configuration.dataset_configurations + assert len(dataset_configurations) == 6 + + # Best summary test of expected output is the printout + print(backend_configuration) + with patch("sys.stdout", new=StringIO()) as stdout: + print(backend_configuration) + + expected_print = """ +Zarr dataset configurations +--------------------------- + +intervals/trials/start_time/data +-------------------------------- + dtype : float64 + full shape of source array : (10,) + full size of source array : 80 B + + buffer shape : (10,) + expected RAM usage : 80 B + + chunk shape : (10,) + disk space usage per chunk : 80 B + + compression method : Blosc(cname='lz4', clevel=5, shuffle=SHUFFLE, blocksize=0) + + +intervals/trials/stop_time/data +------------------------------- + dtype : float64 + full shape of source array : (10,) + full size of source array : 80 B + + buffer shape : (10,) + expected RAM usage : 80 B + + chunk shape : (10,) + disk space usage per chunk : 80 B + + compression method : Blosc(cname='lz4', clevel=5, shuffle=SHUFFLE, blocksize=0) + + +intervals/trials/compressed_start_time/data +------------------------------------------- + dtype : float64 + full shape of source array : (10,) + full size of source array : 80 B + + buffer shape : (10,) + expected RAM usage : 80 B + + chunk shape : (5,) + disk space usage per chunk : 40 B + + compression method : Blosc(cname='lz4', clevel=5, shuffle=SHUFFLE, blocksize=0) + + filter methods : [Blosc(cname='zstd', clevel=1, shuffle=SHUFFLE, blocksize=0), Blosc(cname='zstd', clevel=2, shuffle=SHUFFLE, blocksize=0)] + + +processing/ecephys/ProcessedTimeSeries/data +------------------------------------------- + dtype : float64 + full shape of source array : (4, 2) + full size of source array : 64 B + + buffer shape : (4, 2) + expected RAM usage : 64 B + + chunk shape : (4, 2) + disk space usage per chunk : 64 B + + compression method : Blosc(cname='lz4', clevel=5, shuffle=SHUFFLE, blocksize=0) + + +acquisition/RawTimeSeries/data +------------------------------ + dtype : int64 + full shape of source array : (2, 3) + full size of source array : 48 B + + buffer shape : (2, 3) + expected RAM usage : 48 B + + chunk shape : (2, 3) + disk space usage per chunk : 48 B + + compression method : Blosc(cname='lz4', clevel=5, shuffle=SHUFFLE, blocksize=0) + + +acquisition/CompressedRawTimeSeries/data +---------------------------------------- + dtype : int32 + full shape of source array : (2, 3) + full size of source array : 24 B + + buffer shape : (2, 3) + expected RAM usage : 24 B + + chunk shape : (1, 3) + disk space usage per chunk : 12 B + + compression method : Blosc(cname='lz4', clevel=5, shuffle=SHUFFLE, blocksize=0) + + filter methods : [Blosc(cname='zstd', clevel=1, shuffle=SHUFFLE, blocksize=0), Blosc(cname='zstd', clevel=2, shuffle=SHUFFLE, blocksize=0)] + +""" + assert stdout.getvalue() == expected_print diff --git a/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_get_existing_dataset_io_configurations.py b/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_get_existing_dataset_io_configurations.py new file mode 100644 index 000000000..2ea2ed189 --- /dev/null +++ b/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_get_existing_dataset_io_configurations.py @@ -0,0 +1,630 @@ +"""Unit tests for `get_default_dataset_io_configurations`.""" + +from typing import Literal + +import numpy as np +import pytest +from hdmf.common import VectorData +from hdmf_zarr import ZarrDataIO +from hdmf_zarr.nwb import NWBZarrIO +from numcodecs import Blosc +from pynwb import NWBHDF5IO, H5DataIO +from pynwb.base import DynamicTable +from pynwb.behavior import CompassDirection +from pynwb.image import ImageSeries +from pynwb.testing.mock.base import mock_TimeSeries +from pynwb.testing.mock.behavior import mock_SpatialSeries +from pynwb.testing.mock.file import mock_NWBFile + +from neuroconv.tools.importing import is_package_installed +from neuroconv.tools.nwb_helpers import ( + DATASET_IO_CONFIGURATIONS, + get_existing_dataset_io_configurations, + get_module, +) + + +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +def test_configuration_on_time_series(tmp_path, backend: Literal["hdf5", "zarr"]): + data = np.array([[1, 2, 3], [4, 5, 6]]) + + nwbfile = mock_NWBFile() + if backend == "zarr": # ZarrDataIO compresses by default, so we disable it to test no-compression + data = ZarrDataIO(data=data, compressor=False) + time_series = mock_TimeSeries(name="TestTimeSeries", data=data) + nwbfile.add_acquisition(time_series) + + data = np.array([[1, 2, 3], [4, 5, 6]]) + if backend == "hdf5": + data = H5DataIO(data=data, compression="gzip", compression_opts=2, chunks=(1, 3)) + elif backend == "zarr": + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + data = ZarrDataIO(data=data, chunks=(1, 3), compressor=compressor, filters=filters) + compressed_time_series = mock_TimeSeries( + name="CompressedTimeSeries", + data=data, + ) + nwbfile.add_acquisition(compressed_time_series) + + nwbfile_path = tmp_path / "test_existing_dataset_io_configurations_timeseries.nwb" + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(nwbfile_path), "w") as io: + io.write(nwbfile) + with IO(str(nwbfile_path), "r") as io: + nwbfile = io.read() + + dataset_configurations = list(get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=backend)) + + assert len(dataset_configurations) == 2 + + dataset_configuration = dataset_configurations[0] + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.object_id == time_series.object_id + assert dataset_configuration.location_in_file == "acquisition/TestTimeSeries/data" + assert dataset_configuration.full_shape == data.shape + assert dataset_configuration.dtype == data.dtype + assert dataset_configuration.buffer_shape == data.shape + assert dataset_configuration.compression_method is None + + if backend == "hdf5": + assert dataset_configuration.chunk_shape is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + + elif backend == "zarr": + assert dataset_configuration.chunk_shape == (2, 3) + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = dataset_configurations[1] + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.object_id == compressed_time_series.object_id + assert dataset_configuration.location_in_file == "acquisition/CompressedTimeSeries/data" + assert dataset_configuration.full_shape == data.shape + assert dataset_configuration.dtype == data.dtype + assert dataset_configuration.chunk_shape == (1, 3) + assert dataset_configuration.buffer_shape == data.shape + + if backend == "hdf5": + assert dataset_configuration.compression_method == "gzip" + assert dataset_configuration.compression_options["compression_opts"] == 2 + + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods == filters + assert dataset_configuration.filter_options is None + + +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +def test_configuration_on_external_image_series(tmp_path, backend: Literal["hdf5", "zarr"]): + nwbfile = mock_NWBFile() + image_series = ImageSeries(name="TestImageSeries", external_file=[""], rate=1.0) + nwbfile.add_acquisition(image_series) + + nwbfile_path = tmp_path / "test_existing_dataset_io_configurations_external_image_series.nwb" + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(nwbfile_path), "w") as io: + io.write(nwbfile) + with IO(str(nwbfile_path), "r") as io: + nwbfile = io.read() + dataset_configurations = list(get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=backend)) + assert len(dataset_configurations) == 0 + + +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +def test_configuration_on_dynamic_table(tmp_path, backend: Literal["hdf5", "zarr"]): + data = np.array([0.1, 0.2, 0.3]) + + nwbfile = mock_NWBFile() + if backend == "zarr": # ZarrDataIO compresses by default, so we disable it to test no-compression + data = ZarrDataIO(data=data, compressor=False) + column = VectorData(name="TestColumn", description="", data=data) + + data = np.array([0.1, 0.2, 0.3]) + if backend == "hdf5": + data = H5DataIO(data=data, compression="gzip", compression_opts=2, chunks=(1,)) + elif backend == "zarr": + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + data = ZarrDataIO(data=data, chunks=(1,), compressor=compressor, filters=filters) + compressed_column = VectorData( + name="CompressedColumn", + description="", + data=data, + ) + dynamic_table = DynamicTable( + name="TestDynamicTable", description="", columns=[column, compressed_column], id=list(range(len(data))) + ) + nwbfile.add_acquisition(dynamic_table) + + nwbfile_path = tmp_path / "test_existing_dataset_io_configurations_dynamic_table.nwb" + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(nwbfile_path), "w") as io: + io.write(nwbfile) + with IO(str(nwbfile_path), "r") as io: + nwbfile = io.read() + + dataset_configurations = list(get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=backend)) + + assert len(dataset_configurations) == 2 + + dataset_configuration = dataset_configurations[0] + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.object_id == column.object_id + assert dataset_configuration.location_in_file == "acquisition/TestDynamicTable/TestColumn/data" + assert dataset_configuration.full_shape == data.shape + assert dataset_configuration.dtype == data.dtype + assert dataset_configuration.buffer_shape == data.shape + assert dataset_configuration.compression_method is None + + if backend == "hdf5": + assert dataset_configuration.chunk_shape is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + elif backend == "zarr": + assert dataset_configuration.chunk_shape == (3,) + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = dataset_configurations[1] + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.object_id == compressed_column.object_id + assert dataset_configuration.location_in_file == "acquisition/TestDynamicTable/CompressedColumn/data" + assert dataset_configuration.full_shape == data.shape + assert dataset_configuration.dtype == data.dtype + assert dataset_configuration.chunk_shape == (1,) + assert dataset_configuration.buffer_shape == data.shape + + if backend == "hdf5": + assert dataset_configuration.compression_method == "gzip" + assert dataset_configuration.compression_options == dict(compression_opts=2) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods == filters + assert dataset_configuration.filter_options is None + + +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +def test_configuration_on_ragged_units_table(tmp_path, backend: Literal["hdf5", "zarr"]): + nwbfile = mock_NWBFile() + + spike_times1 = np.array([0.0, 1.0, 2.0]) + waveforms1 = np.array( + [[[1, 2, 3], [1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3], [1, 2, 3]]], + dtype="int32", + ) + nwbfile.add_unit(spike_times=spike_times1, waveforms=waveforms1) + + spike_times2 = np.array([3.0, 4.0]) + waveforms2 = np.array([[[4, 5, 6], [4, 5, 6], [4, 5, 6]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]], dtype="int32") + nwbfile.add_unit(spike_times=spike_times2, waveforms=waveforms2) + + spike_times = np.concatenate([spike_times1, spike_times2]) + waveforms = np.concatenate([waveforms1, waveforms2], axis=0) + index = [len(spike_times1), len(spike_times1) + len(spike_times2)] + if backend == "hdf5": + spike_times = H5DataIO(data=spike_times, compression="gzip", compression_opts=2, chunks=(2,)) + waveforms = H5DataIO(data=waveforms, compression="gzip", compression_opts=2, chunks=(1, 3, 3)) + elif backend == "zarr": + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + spike_times = ZarrDataIO(data=spike_times, chunks=(2,), compressor=compressor, filters=filters) + waveforms = ZarrDataIO(data=waveforms, chunks=(1, 3, 3), compressor=compressor, filters=filters) + nwbfile.add_unit_column(name="compressed_spike_times", description="", data=spike_times, index=index) + nwbfile.add_unit_column(name="compressed_waveforms", description="", data=waveforms, index=index) + + nwbfile_path = tmp_path / "test_existing_dataset_io_configurations_ragged_units_table.nwb" + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(nwbfile_path), "w") as io: + io.write(nwbfile) + with IO(str(nwbfile_path), "r") as io: + nwbfile = io.read() + dataset_configurations = list(get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=backend)) + + assert len(dataset_configurations) == 9 + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/spike_times/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (5,) + assert dataset_configuration.dtype == np.dtype("float64") + assert dataset_configuration.buffer_shape == (5,) + if backend == "hdf5": + assert dataset_configuration.compression_method is None + assert dataset_configuration.chunk_shape is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.chunk_shape == (5,) + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/spike_times_index/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (2,) + assert dataset_configuration.dtype == np.dtype("uint8") + assert dataset_configuration.buffer_shape == (2,) + if backend == "hdf5": + assert dataset_configuration.compression_method is None + assert dataset_configuration.chunk_shape is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.chunk_shape == (2,) + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/waveforms/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (15, 3) + assert dataset_configuration.dtype == np.dtype("int32") + assert dataset_configuration.buffer_shape == (15, 3) + if backend == "hdf5": + assert dataset_configuration.compression_method is None + assert dataset_configuration.chunk_shape is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.chunk_shape == (15, 3) + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/waveforms_index/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (5,) + assert dataset_configuration.dtype == np.dtype("uint8") + assert dataset_configuration.buffer_shape == (5,) + if backend == "hdf5": + assert dataset_configuration.compression_method is None + assert dataset_configuration.chunk_shape is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.chunk_shape == (5,) + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/waveforms_index_index/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (2,) + assert dataset_configuration.dtype == np.dtype("uint8") + assert dataset_configuration.buffer_shape == (2,) + if backend == "hdf5": + assert dataset_configuration.compression_method is None + assert dataset_configuration.chunk_shape is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.chunk_shape == (2,) + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/compressed_spike_times/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (5,) + assert dataset_configuration.dtype == np.dtype("float64") + assert dataset_configuration.buffer_shape == (5,) + assert dataset_configuration.chunk_shape == (2,) + if backend == "hdf5": + assert dataset_configuration.compression_method == "gzip" + assert dataset_configuration.compression_options == dict(compression_opts=2) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods == filters + assert dataset_configuration.filter_options is None + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/compressed_spike_times_index/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (2,) + assert dataset_configuration.dtype == np.dtype("uint8") + assert dataset_configuration.buffer_shape == (2,) + if backend == "hdf5": + assert dataset_configuration.compression_method is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + assert dataset_configuration.chunk_shape is None + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + assert dataset_configuration.chunk_shape == (2,) + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/compressed_waveforms/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (5, 3, 3) + assert dataset_configuration.dtype == np.dtype("int32") + assert dataset_configuration.chunk_shape == (1, 3, 3) + assert dataset_configuration.buffer_shape == (5, 3, 3) + if backend == "hdf5": + assert dataset_configuration.compression_method == "gzip" + assert dataset_configuration.compression_options == dict(compression_opts=2) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods == filters + assert dataset_configuration.filter_options is None + + dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "units/compressed_waveforms_index/data" + ) + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.full_shape == (2,) + assert dataset_configuration.dtype == np.dtype("uint8") + assert dataset_configuration.buffer_shape == (2,) + if backend == "hdf5": + assert dataset_configuration.compression_method is None + assert dataset_configuration.compression_options == dict(compression_opts=None) + assert dataset_configuration.chunk_shape is None + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + assert dataset_configuration.chunk_shape == (2,) + + +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +def test_configuration_on_compass_direction(tmp_path, backend: Literal["hdf5", "zarr"]): + data = np.array([[1, 2, 3], [4, 5, 6]]) + + nwbfile = mock_NWBFile() + if backend == "zarr": # ZarrDataIO compresses by default, so we disable it to test no-compression + data = ZarrDataIO(data=data, compressor=False) + spatial_series = mock_SpatialSeries(name="TestSpatialSeries", data=data) + compass_direction = CompassDirection(name="TestCompassDirection", spatial_series=spatial_series) + behavior_module = get_module(nwbfile=nwbfile, name="behavior") + behavior_module.add(compass_direction) + data = np.array([[1, 2, 3], [4, 5, 6]]) + if backend == "hdf5": + data = H5DataIO(data=data, compression="gzip", compression_opts=2, chunks=(1, 3)) + elif backend == "zarr": + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + data = ZarrDataIO(data=data, chunks=(1, 3), compressor=compressor, filters=filters) + compressed_spatial_series = mock_SpatialSeries( + name="CompressedSpatialSeries", + data=data, + ) + compressed_compass_direction = CompassDirection( + name="CompressedCompassDirection", spatial_series=compressed_spatial_series + ) + behavior_module.add(compressed_compass_direction) + nwbfile_path = tmp_path / "test_existing_dataset_io_configurations_compass_direction.nwb" + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(nwbfile_path), "w") as io: + io.write(nwbfile) + + with IO(str(nwbfile_path), "r") as io: + nwbfile = io.read() + dataset_configurations = list(get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=backend)) + + assert len(dataset_configurations) == 2 + + dataset_configuration = dataset_configurations[0] + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.object_id == spatial_series.object_id + assert ( + dataset_configuration.location_in_file == "processing/behavior/TestCompassDirection/TestSpatialSeries/data" + ) + assert dataset_configuration.full_shape == data.shape + assert dataset_configuration.dtype == data.dtype + assert dataset_configuration.buffer_shape == data.shape + assert dataset_configuration.compression_method is None + if backend == "hdf5": + assert dataset_configuration.compression_options == dict(compression_opts=None) + assert dataset_configuration.chunk_shape is None + elif backend == "zarr": + assert dataset_configuration.compression_options is None + assert dataset_configuration.chunk_shape == data.shape + assert dataset_configuration.filter_methods is None + assert dataset_configuration.filter_options is None + + dataset_configuration = dataset_configurations[1] + assert isinstance(dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert dataset_configuration.object_id == compressed_spatial_series.object_id + assert ( + dataset_configuration.location_in_file + == "processing/behavior/CompressedCompassDirection/CompressedSpatialSeries/data" + ) + assert dataset_configuration.full_shape == data.shape + assert dataset_configuration.dtype == data.dtype + assert dataset_configuration.chunk_shape == (1, 3) + assert dataset_configuration.buffer_shape == data.shape + if backend == "hdf5": + assert dataset_configuration.compression_method == "gzip" + assert dataset_configuration.compression_options == dict(compression_opts=2) + elif backend == "zarr": + assert dataset_configuration.compression_method == compressor + assert dataset_configuration.compression_options is None + assert dataset_configuration.filter_methods == filters + assert dataset_configuration.filter_options is None + + +@pytest.mark.skipif( + not is_package_installed(package_name="ndx_events"), + reason="The extra testing package 'ndx-events' is not installed!", +) +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +def test_configuration_on_ndx_events(tmp_path, backend: Literal["hdf5", "zarr"]): + from ndx_events import LabeledEvents + + # ndx_events data fields do not support wrapping in DataChunkIterators - data is nearly always small enough + # to fit entirely in memory + data = np.array([1, 2, 3], dtype="uint32") + timestamps = np.array([4.5, 6.7, 8.9]) + + nwbfile = mock_NWBFile() + if backend == "zarr": # ZarrDataIO compresses by default, so we disable it to test no-compression + data = ZarrDataIO(data=data, compressor=False) + timestamps = ZarrDataIO(data=timestamps, compressor=False) + labeled_events = LabeledEvents( + name="TestLabeledEvents", + description="", + timestamps=timestamps, + data=data, + labels=["response_left", "cue_onset", "cue_offset"], + ) + behavior_module = get_module(nwbfile=nwbfile, name="behavior") + behavior_module.add(labeled_events) + data = np.array([1, 2, 3], dtype="uint32") + timestamps = np.array([4.5, 6.7, 8.9]) + if backend == "hdf5": + data = H5DataIO(data=data, compression="gzip", compression_opts=2, chunks=(3,)) + timestamps = H5DataIO(data=timestamps, compression="gzip", compression_opts=2, chunks=(3,)) + elif backend == "zarr": + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + data = ZarrDataIO(data=data, chunks=(3,), compressor=compressor, filters=filters) + timestamps = ZarrDataIO(data=timestamps, chunks=(3,), compressor=compressor, filters=filters) + compressed_labeled_events = LabeledEvents( + name="CompressedLabeledEvents", + description="", + timestamps=timestamps, + data=data, + labels=["response_left", "cue_onset", "cue_offset"], + ) + behavior_module.add(compressed_labeled_events) + nwbfile_path = tmp_path / "test_existing_dataset_io_configurations_ndx_events.nwb" + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(nwbfile_path), "w") as io: + io.write(nwbfile) + + with IO(str(nwbfile_path), "r") as io: + nwbfile = io.read() + + dataset_configurations = list(get_existing_dataset_io_configurations(nwbfile=nwbfile, backend=backend)) + + # Note that the labels dataset is not caught since we search only for 'data' and 'timestamps' fields + assert len(dataset_configurations) == 4 + + data_dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "processing/behavior/TestLabeledEvents/data" + ) + assert isinstance(data_dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert data_dataset_configuration.object_id == labeled_events.object_id + assert data_dataset_configuration.full_shape == data.shape + assert data_dataset_configuration.dtype == data.dtype + assert data_dataset_configuration.buffer_shape == data.shape + assert data_dataset_configuration.compression_method is None + if backend == "hdf5": + assert data_dataset_configuration.compression_options == dict(compression_opts=None) + assert data_dataset_configuration.chunk_shape is None + elif backend == "zarr": + assert data_dataset_configuration.compression_options is None + assert data_dataset_configuration.chunk_shape == data.shape + assert data_dataset_configuration.filter_methods is None + assert data_dataset_configuration.filter_options is None + + timestamps_dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "processing/behavior/TestLabeledEvents/timestamps" + ) + assert isinstance(timestamps_dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert timestamps_dataset_configuration.object_id == labeled_events.object_id + assert timestamps_dataset_configuration.full_shape == timestamps.shape + assert timestamps_dataset_configuration.dtype == timestamps.dtype + assert timestamps_dataset_configuration.buffer_shape == timestamps.shape + assert timestamps_dataset_configuration.compression_method is None + if backend == "hdf5": + assert timestamps_dataset_configuration.compression_options == dict(compression_opts=None) + assert timestamps_dataset_configuration.chunk_shape is None + elif backend == "zarr": + assert timestamps_dataset_configuration.compression_options is None + assert timestamps_dataset_configuration.chunk_shape == timestamps.shape + assert timestamps_dataset_configuration.filter_methods is None + assert timestamps_dataset_configuration.filter_options is None + + data_dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "processing/behavior/CompressedLabeledEvents/data" + ) + assert isinstance(data_dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert data_dataset_configuration.object_id == compressed_labeled_events.object_id + assert data_dataset_configuration.full_shape == data.shape + assert data_dataset_configuration.dtype == data.dtype + assert data_dataset_configuration.chunk_shape == (3,) + assert data_dataset_configuration.buffer_shape == data.shape + if backend == "hdf5": + assert data_dataset_configuration.compression_method == "gzip" + assert data_dataset_configuration.compression_options == dict(compression_opts=2) + elif backend == "zarr": + assert data_dataset_configuration.compression_method == compressor + assert data_dataset_configuration.compression_options is None + assert data_dataset_configuration.filter_methods == filters + assert data_dataset_configuration.filter_options is None + + timestamps_dataset_configuration = next( + dataset_configuration + for dataset_configuration in dataset_configurations + if dataset_configuration.location_in_file == "processing/behavior/CompressedLabeledEvents/timestamps" + ) + assert isinstance(timestamps_dataset_configuration, DATASET_IO_CONFIGURATIONS[backend]) + assert timestamps_dataset_configuration.object_id == compressed_labeled_events.object_id + assert timestamps_dataset_configuration.full_shape == timestamps.shape + assert timestamps_dataset_configuration.dtype == timestamps.dtype + assert timestamps_dataset_configuration.chunk_shape == (3,) + assert timestamps_dataset_configuration.buffer_shape == timestamps.shape + if backend == "hdf5": + assert timestamps_dataset_configuration.compression_method == "gzip" + assert timestamps_dataset_configuration.compression_options == dict(compression_opts=2) + elif backend == "zarr": + assert timestamps_dataset_configuration.compression_method == compressor + assert timestamps_dataset_configuration.compression_options is None + assert timestamps_dataset_configuration.filter_methods == filters + assert timestamps_dataset_configuration.filter_options is None diff --git a/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_repack_nwbfile.py b/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_repack_nwbfile.py new file mode 100644 index 000000000..b0bf138b3 --- /dev/null +++ b/tests/test_minimal/test_tools/test_backend_and_dataset_configuration/test_helpers/test_repack_nwbfile.py @@ -0,0 +1,225 @@ +from pathlib import Path + +import numpy as np +import pytest +from hdmf_zarr import NWBZarrIO, ZarrDataIO +from hdmf_zarr.nwb import NWBZarrIO +from numcodecs import Blosc, GZip +from pynwb import NWBHDF5IO, H5DataIO, NWBFile +from pynwb.testing.mock.base import mock_TimeSeries +from pynwb.testing.mock.file import mock_NWBFile + +from neuroconv.tools.nwb_helpers import ( + get_module, + repack_nwbfile, +) + + +def generate_complex_nwbfile() -> NWBFile: + nwbfile = mock_NWBFile() + + raw_array = np.array([[1, 2, 3], [4, 5, 6]]) + raw_time_series = mock_TimeSeries(name="RawTimeSeries", data=raw_array) + nwbfile.add_acquisition(raw_time_series) + + number_of_trials = 10 + for start_time, stop_time in zip( + np.linspace(start=0.0, stop=10.0, num=number_of_trials), np.linspace(start=1.0, stop=11.0, num=number_of_trials) + ): + nwbfile.add_trial(start_time=start_time, stop_time=stop_time) + + ecephys_module = get_module(nwbfile=nwbfile, name="ecephys") + processed_array = np.array([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0], [13.0, 14.0]]) + processed_time_series = mock_TimeSeries(name="ProcessedTimeSeries", data=processed_array) + ecephys_module.add(processed_time_series) + + return nwbfile + + +@pytest.fixture(scope="session") +def hdf5_nwbfile_path(tmpdir_factory): + nwbfile_path = tmpdir_factory.mktemp("data").join("test_repack_nwbfile.nwb.h5") + if not Path(nwbfile_path).exists(): + nwbfile = generate_complex_nwbfile() + + # Add a H5DataIO-compressed time series + raw_array = np.array([[11, 21, 31], [41, 51, 61]], dtype="int32") + data = H5DataIO(data=raw_array, compression="gzip", compression_opts=2) + raw_time_series = mock_TimeSeries(name="CompressedRawTimeSeries", data=data) + nwbfile.add_acquisition(raw_time_series) + + # Add H5DataIO-compressed trials column + number_of_trials = 10 + start_time = np.linspace(start=0.0, stop=10.0, num=number_of_trials) + nwbfile.add_trial_column( + name="compressed_start_time", + description="start time of epoch", + data=H5DataIO(data=start_time, compression="gzip", compression_opts=2), + ) + + with NWBHDF5IO(path=str(nwbfile_path), mode="w") as io: + io.write(nwbfile) + return str(nwbfile_path) + + +@pytest.fixture(scope="session") +def zarr_nwbfile_path(tmpdir_factory): + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + + nwbfile_path = tmpdir_factory.mktemp("data").join("test_default_backend_configuration_hdf5_nwbfile.nwb.zarr") + if not Path(nwbfile_path).exists(): + nwbfile = generate_complex_nwbfile() + + # Add a ZarrDataIO-compressed time series + raw_array = np.array([[11, 21, 31], [41, 51, 61]], dtype="int32") + data = ZarrDataIO(data=raw_array, chunks=(1, 3), compressor=compressor, filters=filters) + raw_time_series = mock_TimeSeries(name="CompressedRawTimeSeries", data=data) + nwbfile.add_acquisition(raw_time_series) + + # Add ZarrDataIO-compressed trials column + number_of_trials = 10 + start_time = np.linspace(start=0.0, stop=10.0, num=number_of_trials) + data = ZarrDataIO(data=start_time, chunks=(5,), compressor=compressor, filters=filters) + nwbfile.add_trial_column( + name="compressed_start_time", + description="start time of epoch", + data=data, + ) + + with NWBZarrIO(path=str(nwbfile_path), mode="w") as io: + io.write(nwbfile) + return str(nwbfile_path) + + +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +@pytest.mark.parametrize("use_default_backend_configuration", [True, False]) +def test_repack_nwbfile(hdf5_nwbfile_path, zarr_nwbfile_path, backend, use_default_backend_configuration): + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + default_compressor = GZip(level=1) + + if backend == "hdf5": + nwbfile_path = hdf5_nwbfile_path + export_path = Path(hdf5_nwbfile_path).parent / "repacked_test_repack_nwbfile.nwb.h5" + elif backend == "zarr": + nwbfile_path = zarr_nwbfile_path + export_path = Path(hdf5_nwbfile_path).parent / "repacked_test_repack_nwbfile.nwb.zarr" + repack_nwbfile( + nwbfile_path=str(nwbfile_path), + export_nwbfile_path=str(export_path), + backend=backend, + use_default_backend_configuration=use_default_backend_configuration, + ) + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(export_path), mode="r") as io: + nwbfile = io.read() + + if backend == "hdf5": + if use_default_backend_configuration: + assert nwbfile.acquisition["RawTimeSeries"].data.compression_opts == 4 + assert nwbfile.intervals["trials"].start_time.data.compression_opts == 4 + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compression_opts == 4 + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compression_opts == 4 + assert nwbfile.intervals["trials"].compressed_start_time.data.compression_opts == 4 + else: + assert nwbfile.acquisition["RawTimeSeries"].data.compression_opts is None + assert nwbfile.intervals["trials"].start_time.data.compression_opts is None + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compression_opts is None + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compression_opts == 2 + assert nwbfile.intervals["trials"].compressed_start_time.data.compression_opts == 2 + elif backend == "zarr": + if use_default_backend_configuration: + assert nwbfile.acquisition["RawTimeSeries"].data.compressor == default_compressor + assert nwbfile.acquisition["RawTimeSeries"].data.filters is None + assert nwbfile.intervals["trials"].start_time.data.compressor == default_compressor + assert nwbfile.intervals["trials"].start_time.data.filters is None + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compressor == default_compressor + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.filters is None + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compressor == default_compressor + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.filters is None + else: + assert nwbfile.acquisition["RawTimeSeries"].data.compressor == compressor + assert nwbfile.acquisition["RawTimeSeries"].data.filters is None + assert nwbfile.intervals["trials"].start_time.data.compressor == compressor + assert nwbfile.intervals["trials"].start_time.data.filters is None + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compressor == compressor + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.filters is None + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compressor == compressor + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.filters == filters + + +@pytest.mark.parametrize("backend", ["hdf5", "zarr"]) +@pytest.mark.parametrize("use_default_backend_configuration", [True, False]) +def test_repack_nwbfile_with_changes(hdf5_nwbfile_path, zarr_nwbfile_path, backend, use_default_backend_configuration): + compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) + filter1 = Blosc(cname="zstd", clevel=1, shuffle=Blosc.SHUFFLE) + filter2 = Blosc(cname="zstd", clevel=2, shuffle=Blosc.SHUFFLE) + filters = [filter1, filter2] + default_compressor = GZip(level=1) + + if backend == "hdf5": + nwbfile_path = hdf5_nwbfile_path + export_path = Path(hdf5_nwbfile_path).parent / "repacked_test_repack_nwbfile.nwb.h5" + backend_configuration_changes = { + "acquisition/RawTimeSeries/data": dict( + compression_method="gzip", compression_options=dict(compression_opts=1) + ) + } + elif backend == "zarr": + nwbfile_path = zarr_nwbfile_path + export_path = Path(hdf5_nwbfile_path).parent / "repacked_test_repack_nwbfile.nwb.zarr" + changed_compressor = Blosc(cname="lz4", clevel=3, shuffle=Blosc.SHUFFLE, blocksize=0) + changed_filters = [Blosc(cname="zstd", clevel=3, shuffle=Blosc.SHUFFLE)] + backend_configuration_changes = { + "acquisition/RawTimeSeries/data": dict( + compression_method=changed_compressor, filter_methods=changed_filters + ) + } + repack_nwbfile( + nwbfile_path=str(nwbfile_path), + export_nwbfile_path=str(export_path), + backend=backend, + use_default_backend_configuration=use_default_backend_configuration, + backend_configuration_changes=backend_configuration_changes, + ) + + IO = NWBHDF5IO if backend == "hdf5" else NWBZarrIO + with IO(str(export_path), mode="r") as io: + nwbfile = io.read() + if backend == "hdf5": + if use_default_backend_configuration: + assert nwbfile.acquisition["RawTimeSeries"].data.compression_opts == 1 + assert nwbfile.intervals["trials"].start_time.data.compression_opts == 4 + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compression_opts == 4 + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compression_opts == 4 + assert nwbfile.intervals["trials"].compressed_start_time.data.compression_opts == 4 + else: + assert nwbfile.acquisition["RawTimeSeries"].data.compression_opts == 1 + assert nwbfile.intervals["trials"].start_time.data.compression_opts is None + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compression_opts is None + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compression_opts == 2 + assert nwbfile.intervals["trials"].compressed_start_time.data.compression_opts == 2 + elif backend == "zarr": + if use_default_backend_configuration: + assert nwbfile.acquisition["RawTimeSeries"].data.compressor == changed_compressor + assert nwbfile.acquisition["RawTimeSeries"].data.filters == changed_filters + assert nwbfile.intervals["trials"].start_time.data.compressor == default_compressor + assert nwbfile.intervals["trials"].start_time.data.filters is None + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compressor == default_compressor + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.filters is None + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compressor == default_compressor + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.filters is None + else: + assert nwbfile.acquisition["RawTimeSeries"].data.compressor == changed_compressor + assert nwbfile.acquisition["RawTimeSeries"].data.filters == changed_filters + assert nwbfile.intervals["trials"].start_time.data.compressor == compressor + assert nwbfile.intervals["trials"].start_time.data.filters is None + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.compressor == compressor + assert nwbfile.processing["ecephys"]["ProcessedTimeSeries"].data.filters is None + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.compressor == compressor + assert nwbfile.acquisition["CompressedRawTimeSeries"].data.filters == filters