diff --git a/csvy/readers.py b/csvy/readers.py index 9edc1fd..1b52f83 100644 --- a/csvy/readers.py +++ b/csvy/readers.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from itertools import zip_longest from pathlib import Path from typing import Any, Literal @@ -299,3 +300,51 @@ def read_to_list( data.append(row) return data, header + + +def read_to_dict( + filename: Path | str, + marker: str = "---", + encoding: str = "utf-8", + csv_options: dict[str, Any] | None = None, + yaml_options: dict[str, Any] | None = None, + *, + column_names: list[Any] | int | None = None, + fillvalue: Any = None, +) -> tuple[dict[str, list[Any]], dict[str, Any]]: + """Read a CSVY file into a dictionary with the header and the data as dictionaries. + + Internally, it calls `read_to_list` and then transforms the data into a dictionary. + + Args: + filename: Name of the file to read. + marker: The marker characters that indicate the yaml header. + encoding: The character encoding in the file to read. + csv_options: Options to pass to csv.reader. + yaml_options: Options to pass to yaml.safe_load. + column_names: Either a list with the column names, the row number containing the + column names or None. If None (the default) an automatic column name + ('col_0', 'col_1', ...) will be used. + fillvalue: Value to use for missing data in the columns. + + Returns: + Tuple containing: The data and the header both as a dictionaries. + + """ + data, header = read_to_list(filename, marker, encoding, csv_options, yaml_options) + + longest_row = len(max(data, key=len)) + if column_names is None: + column_names = [f"col_{i}" for i in range(longest_row)] + else: + if isinstance(column_names, int): + column_names = data.pop(column_names) + + if len(column_names) != longest_row: + raise ValueError( + "The number of column names must be exactly the length of the longest " + f"row ({len({column_names})} != {longest_row})." + ) + + columns = list(map(list, zip_longest(*data, fillvalue=fillvalue))) + return dict(zip(column_names, columns)), header diff --git a/csvy/writers.py b/csvy/writers.py index 9e9f12f..77ff1c7 100644 --- a/csvy/writers.py +++ b/csvy/writers.py @@ -4,8 +4,9 @@ import csv import logging -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Mapping from io import TextIOBase +from itertools import zip_longest from pathlib import Path from typing import Any @@ -310,6 +311,43 @@ def write_polars( return False +@register_writer +def write_dict( + filename: Path | str, + data: Any, + comment: str = "", + encoding: str = "utf-8", + *, + fillvalue: Any = None, + **kwargs: Any, +) -> bool: + """Write the dictionary to the chosen file, adding it after the header. + + It transforms the dictionary into a tabular format before saving it using the + generic `write_csv` function. + + Args: + filename: Name of the file to save the data into. The data will be added to the + end of the file. + data: The data as a dictionary. + comment: String to use to mark the header lines as comments. + encoding: The character encoding to use in the file to write. + fillvalue: Value to use to fill the missing values in the dictionary. + **kwargs: Arguments to be passed to the underlaying saving method. + + Returns: + True if the writer worked, False otherwise. + + """ + if not isinstance(data, Mapping): + return False + + data_ = [list(data.keys())] + data_.extend(list(map(list, zip_longest(*data.values(), fillvalue=fillvalue)))) + + return write_csv(filename, data_, comment, encoding, **kwargs) + + def write_csv( filename: Path | str, data: Any, diff --git a/tests/test_read.py b/tests/test_read.py index d748b17..406e75d 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -132,3 +132,47 @@ def test_read_to_list(array_data_path): assert len(data[0]) == 4 assert isinstance(header, dict) assert len(header) > 0 + + +def test_read_to_dict_with_default_column_names(array_data_path): + """Test the read_to_list function.""" + from csvy.readers import read_to_dict + + data, header = read_to_dict(array_data_path, csv_options={"delimiter": ","}) + + assert isinstance(data, dict) + assert len(data) == 4 + assert list(data.keys()) == ["col_0", "col_1", "col_2", "col_3"] + assert len(data["col_0"]) == 15 + assert len(header) > 0 + + +def test_read_to_dict_with_custom_column_names(array_data_path): + """Test the read_to_list function.""" + from csvy.readers import read_to_dict + + column_names = ["A", "B", "C", "D"] + data, header = read_to_dict( + array_data_path, column_names=column_names, csv_options={"delimiter": ","} + ) + + assert isinstance(data, dict) + assert len(data) == 4 + assert list(data.keys()) == column_names + assert len(data["A"]) == 15 + assert len(header) > 0 + + +def test_read_to_dict_with_row_based_column_names(data_path): + """Test the read_to_list function.""" + from csvy.readers import read_to_dict + + data, header = read_to_dict( + data_path, column_names=0, csv_options={"delimiter": ","} + ) + + assert isinstance(data, dict) + assert len(data) == 2 + assert list(data.keys()) == ["Date", "WTI"] + assert len(data["Date"]) == 15 + assert len(header) > 0 diff --git a/tests/test_write.py b/tests/test_write.py index 34a843d..3878ac4 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -126,6 +126,25 @@ class Writer: assert Writer.writerow.call_count == len(data) +@patch("csv.writer") +def test_write_dict(mock_save, tmpdir): + """Test the write_csv function.""" + from csvy.writers import write_dict + + class Writer: + writerow = MagicMock() + + mock_save.return_value = Writer + filename = tmpdir / "some_file.csv" + + data = {"a": [1, 2, 3, 4], "b": [1, 2, 3], "c": [1, 2, 3, 4, 5]} + expected_rows = max(map(len, data.values())) + 1 # +1 for the column names + assert write_dict(filename, data) + + mock_save.assert_called_once() + assert Writer.writerow.call_count == expected_rows + + @patch("csv.writer") @patch("csvy.writers.write_header") @pytest.mark.parametrize(