From de3e6a356008b95b3630f329049ff775691b495e Mon Sep 17 00:00:00 2001 From: rtosholdings-bot Date: Tue, 23 Apr 2024 14:29:25 -0400 Subject: [PATCH] v1.17.0-rc0 --- README.md | 2 +- conda_recipe/conda_build_config.yaml | 2 +- conda_recipe/meta.yaml | 2 +- dev_tools/_docstring_config.py | 71 ++- dev_tools/docstring_xfails.txt | 8 - dev_tools/validate_docstrings.py | 67 +-- pyproject.toml | 3 +- riptable/rt_categorical.py | 55 ++- riptable/rt_datetime.py | 418 ++++++++++++------ riptable/rt_groupbyops.py | 25 +- riptable/rt_numpy.py | 41 +- riptable/rt_struct.py | 392 ++++++++++------ riptable/rt_utils.py | 52 ++- riptable/tests/test_base_function.py | 16 + .../tests/test_categorical_filter_invalid.py | 17 +- riptable/tests/test_rtutils.py | 18 +- 16 files changed, 808 insertions(+), 381 deletions(-) diff --git a/README.md b/README.md index 5da8526..d385a6c 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ![](https://riptable.readthedocs.io/en/stable/_static/riptable_logo.PNG) An open-source, 64-bit Python analytics engine for high-performance data analysis with -multithreading support. Riptable supports Python 3.10 through 3.11 on 64-bit Linux and +multithreading support. Riptable supports Python 3.10 through 3.12 on 64-bit Linux and Windows. Similar to Pandas and based on NumPy, Riptable optimizes analyzing large volumes of data diff --git a/conda_recipe/conda_build_config.yaml b/conda_recipe/conda_build_config.yaml index 83501de..15919a7 100644 --- a/conda_recipe/conda_build_config.yaml +++ b/conda_recipe/conda_build_config.yaml @@ -1,2 +1,2 @@ python: - - 3.11 + - 3.12 diff --git a/conda_recipe/meta.yaml b/conda_recipe/meta.yaml index 8422121..f993e3d 100644 --- a/conda_recipe/meta.yaml +++ b/conda_recipe/meta.yaml @@ -21,7 +21,7 @@ requirements: - pandas >=1.0,<3.0 - python - python-dateutil - - riptide_cpp >=1.17.0,<2 # run with any (compatible) version in this range + - riptide_cpp >=1.19.0,<2 # run with any (compatible) version in this range - typing-extensions >=4.9.0 about: diff --git a/dev_tools/_docstring_config.py b/dev_tools/_docstring_config.py index 762183c..3abadab 100644 --- a/dev_tools/_docstring_config.py +++ b/dev_tools/_docstring_config.py @@ -1,12 +1,65 @@ import riptable +import contextlib -# Standardized riptable configuration settings applied when doing docstring validation. -# Standardize on these display settings when executing examples -riptable.Display.FORCE_REPR = True # Don't auto-detect console dimensions, just use CONSOLE_X/Y -riptable.Display.options.COL_MAX = 1_000_000 # display all Dataset columns (COL_ALL is incomplete) -riptable.Display.options.E_MAX = 100_000_000 # render up to 100MM before using scientific notation -riptable.Display.options.P_THRESHOLD = 0 # truncate small decimals, rather than scientific notation -riptable.Display.options.NUMBER_SEPARATOR = True # put commas in numbers -riptable.Display.options.HEAD_ROWS = 3 -riptable.Display.options.TAIL_ROWS = 3 +def _setup_display_config(): + """Initialize display config settings. + Any options that can be modified should be set here, even if set to default values. + """ + riptable.Display.FORCE_REPR = True # Don't auto-detect console dimensions, just use CONSOLE_X/Y + riptable.Display.options.CONSOLE_X = 150 + riptable.Display.options.COL_MAX = 1_000_000 # display all Dataset columns (COL_ALL is incomplete) + riptable.Display.options.E_MAX = 100_000_000 # render up to 100MM before using scientific notation + riptable.Display.options.P_THRESHOLD = 0 # truncate small decimals, rather than scientific notation + riptable.Display.options.NUMBER_SEPARATOR = True # put commas in numbers + riptable.Display.options.HEAD_ROWS = 3 + riptable.Display.options.TAIL_ROWS = 3 + riptable.Display.options.ROW_ALL = False + riptable.Display.options.COL_ALL = False + riptable.Display.options.MAX_STRING_WIDTH = 15 + + +def setup_init_config(): + """Initialize all config settings. Typically only done once.""" + _setup_display_config() + + +class ScopedExampleSetup(contextlib.AbstractContextManager): + """Context manager to clean up after any changes made during example setup.""" + + _CLEANUP_CALLBACKS = [] + + @staticmethod + def add_cleanup_callback(fn): + ScopedExampleSetup._CLEANUP_CALLBACKS.append(fn) + + def __enter__(self) -> None: + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback) -> bool | None: + callbacks = ScopedExampleSetup._CLEANUP_CALLBACKS + ScopedExampleSetup._CLEANUP_CALLBACKS = [] + for callback in callbacks: + callback() + return super().__exit__(exc_type, exc_value, traceback) + + +def setup_for_examples(*configs: tuple[str]): + """Applies specified config setups for an example. + Configs are applied in order. + Any modifications done here need to be undone by registering a cleanup task with ScopedExampleSetup. + """ + + for config in configs: + if config == "struct-display": + riptable.Display.options.CONSOLE_X = 120 + riptable.Display.options.HEAD_ROWS = 15 + riptable.Display.options.TAIL_ROWS = 15 + ScopedExampleSetup.add_cleanup_callback(_setup_display_config) # reset all display configs. + + else: + raise NotImplementedError(f"Unknown config, {config}") + + +# Initialize all config globally. +setup_init_config() diff --git a/dev_tools/docstring_xfails.txt b/dev_tools/docstring_xfails.txt index a17c2ca..e3652cc 100644 --- a/dev_tools/docstring_xfails.txt +++ b/dev_tools/docstring_xfails.txt @@ -273,18 +273,14 @@ riptable.rt_datetime.DateTimeCommon.tz_offset riptable.rt_datetime.DateTimeCommon.year riptable.rt_datetime.DateTimeCommon.yyyymmdd riptable.rt_datetime.DateTimeNano.cut_time -riptable.rt_datetime.DateTimeNano.diff riptable.rt_datetime.DateTimeNano.display_convert_func riptable.rt_datetime.DateTimeNano.fill_invalid riptable.rt_datetime.DateTimeNano.get_scalar riptable.rt_datetime.DateTimeNano.hstack riptable.rt_datetime.DateTimeNano.info -riptable.rt_datetime.DateTimeNano.isfinite -riptable.rt_datetime.DateTimeNano.isnotfinite riptable.rt_datetime.DateTimeNano.newclassfrominstance riptable.rt_datetime.DateTimeNano.random riptable.rt_datetime.DateTimeNano.random_invalid -riptable.rt_datetime.DateTimeNano.resample riptable.rt_datetime.DateTimeNano.shift riptable.rt_datetime.DateTimeNano.to_arrow riptable.rt_datetime.DateTimeNanoScalar @@ -771,10 +767,6 @@ riptable.rt_str.FAString.substr_char_stop riptable.rt_str.FAString.upper riptable.rt_str.FAString.upper_inplace riptable.rt_struct.Struct -riptable.rt_struct.Struct._A -riptable.rt_struct.Struct._G -riptable.rt_struct.Struct._H -riptable.rt_struct.Struct._V riptable.rt_struct.Struct.all riptable.rt_struct.Struct.any riptable.rt_struct.Struct.apply_schema diff --git a/dev_tools/validate_docstrings.py b/dev_tools/validate_docstrings.py index 85b7db7..8631d59 100644 --- a/dev_tools/validate_docstrings.py +++ b/dev_tools/validate_docstrings.py @@ -58,8 +58,8 @@ # With template backend, matplotlib plots nothing matplotlib.use("template") -# Apply riptable docstring configuration for examples. -from _docstring_config import * +# Riptable docstring configuration setup for examples. +import _docstring_config ERROR_MSGS = { @@ -101,6 +101,10 @@ "rt": riptable, } +ATTRIBS_CONTEXT = { + "setup_for_examples": _docstring_config.setup_for_examples, +} + def riptable_error(code, **kwargs): """ @@ -210,10 +214,11 @@ def examples_errors(self): error_msgs = "" current_dir = set(os.listdir()) tempdir = pathlib.Path("tempdir") # special reserved directory for temporary files; will be deleted per test - for test in finder.find(self.raw_doc, self.name, globs=IMPORT_CONTEXT): + for test in finder.find(self.raw_doc, self.name, globs=dict(**IMPORT_CONTEXT, **ATTRIBS_CONTEXT)): tempdir.mkdir() f = io.StringIO() - failed_examples, total_examples = runner.run(test, out=f.write) + with _docstring_config.ScopedExampleSetup(): + failed_examples, total_examples = runner.run(test, out=f.write) if failed_examples: error_msgs += f.getvalue() shutil.rmtree(tempdir) @@ -319,11 +324,11 @@ def non_hyphenated_array_like(self): def riptable_validate( func_name: str, - errors: typing.Optional(list[str]) = None, - not_errors: typing.Optional(list[str]) = None, - flake8_errors: typing.Optional(list[str]) = None, - flake8_not_errors: typing.Optional(list[str]) = None, - xfails: typing.Optional(list[str]) = None, + errors: typing.Optional[list[str]] = None, + not_errors: typing.Optional[list[str]] = None, + flake8_errors: typing.Optional[list[str]] = None, + flake8_not_errors: typing.Optional[list[str]] = None, + xfails: typing.Optional[list[str]] = None, verbose: bool = False, ): """ @@ -513,8 +518,8 @@ def is_default_excluded(fullname: str) -> bool: def is_included( fullname: str, - includes: typing.Optional(list[str]) = None, - excludes: typing.Optional(list[str]) = None, + includes: typing.Optional[list[str]] = None, + excludes: typing.Optional[list[str]] = None, ) -> bool: """Indicates whether the name should be included in validation.""" @@ -537,14 +542,14 @@ def validate_all( match: str, not_match: str = None, names_from: str = NAMES_FROM_OPTS[0], - errors: typing.Optional(list[str]) = None, - not_errors: typing.Optional(list[str]) = None, - flake8_errors: typing.Optional(list[str]) = None, - flake8_not_errors: typing.Optional(list[str]) = None, + errors: typing.Optional[list[str]] = None, + not_errors: typing.Optional[list[str]] = None, + flake8_errors: typing.Optional[list[str]] = None, + flake8_not_errors: typing.Optional[list[str]] = None, ignore_deprecated: bool = False, - includes: typing.Optional(list[str]) = None, - excludes: typing.Optional(list[str]) = None, - xfails: typing.Optional(list[str]) = None, + includes: typing.Optional[list[str]] = None, + excludes: typing.Optional[list[str]] = None, + xfails: typing.Optional[list[str]] = None, verbose: int = 0, ) -> dict: """ @@ -650,17 +655,17 @@ def print_validate_all_results( match: str, not_match: str = None, names_from: str = NAMES_FROM_OPTS[0], - errors: typing.Optional(list[str]) = None, - not_errors: typing.Optional(list[str]) = None, - flake8_errors: typing.Optional(list[str]) = None, - flake8_not_errors: typing.Optional(list[str]) = None, + errors: typing.Optional[list[str]] = None, + not_errors: typing.Optional[list[str]] = None, + flake8_errors: typing.Optional[list[str]] = None, + flake8_not_errors: typing.Optional[list[str]] = None, out_format: str = OUT_FORMAT_OPTS[0], ignore_deprecated: bool = False, - includes: typing.Optional(list[str]) = None, - excludes: typing.Optional(list[str]) = None, - xfails: typing.Optional(list[str]) = None, + includes: typing.Optional[list[str]] = None, + excludes: typing.Optional[list[str]] = None, + xfails: typing.Optional[list[str]] = None, outfile: typing.IO = sys.stdout, - outfailsfile: typing.Optional(typing.IO) = None, + outfailsfile: typing.Optional[typing.IO] = None, verbose: int = 0, ): if out_format not in OUT_FORMAT_OPTS: @@ -713,11 +718,11 @@ def print_validate_all_results( def print_validate_one_results( func_name: str, - errors: typing.Optional(list[str]) = None, - not_errors: typing.Optional(list[str]) = None, - flake8_errors: typing.Optional(list[str]) = None, - flake8_not_errors: typing.Optional(list[str]) = None, - xfails: typing.Optional(list[str]) = None, + errors: typing.Optional[list[str]] = None, + not_errors: typing.Optional[list[str]] = None, + flake8_errors: typing.Optional[list[str]] = None, + flake8_not_errors: typing.Optional[list[str]] = None, + xfails: typing.Optional[list[str]] = None, outfile: typing.IO = sys.stdout, verbose: int = 0, ): diff --git a/pyproject.toml b/pyproject.toml index ce467fe..97d1667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "numpy >=1.23", "pandas >=1.0,<3.0", "python-dateutil", - "riptide_cpp >=1.17.0,<2", + "riptide_cpp >=1.19.0,<2", "typing-extensions >=4.9.0", ] classifiers = [ @@ -20,6 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", ] diff --git a/riptable/rt_categorical.py b/riptable/rt_categorical.py index 9eb0642..1a4f1d9 100644 --- a/riptable/rt_categorical.py +++ b/riptable/rt_categorical.py @@ -2462,9 +2462,16 @@ def fill_backward(self, *args, limit: int = 0, fill_val=None, inplace: bool = Fa # ------------------------------------------------------------ def isfiltered(self) -> FastArray: """ - True where bin == 0. - Only applies to categoricals with base index 1, otherwise returns all False. - Different than invalid category. + Returns a boolean array of whether each category value is filtered. + + For base-0 categoricals, return all False. + For base-1 categoricals, returns True where bin == 0. + For dict-based categoricals, returns True for values that don't exist in the provided mapping. + + Returns + ------- + out: FastArray + FastArray of bools. See Also -------- @@ -2473,8 +2480,11 @@ def isfiltered(self) -> FastArray: """ if self.base_index == 1: return self._fa == 0 - else: - return zeros(len(self), dtype=bool) + + if self.base_index is None: + return self._fa.isin(self._grouping._enum.code_array, invert=True) + + return zeros(len(self), dtype=bool) # ------------------------------------------------------------ def set_name(self, name) -> Categorical: @@ -4876,29 +4886,36 @@ def ilastkey(self): @property def unique_count(self): """ - Number of unique values in the categorical. - It is necessary for every groupby operation. + Number of unique values in the :py:class:`~.rt_categorical.Categorical`. - Notes - ----- - For categoricals in dict / enum mode that have generated their grouping object, this - will reflect the number of unique values that `occur` in the non-unique values. Empty - bins will not be included in the count. + This property is used for every groupby operation. + + For :py:class:`~.rt_categorical.Categorical` objects constructed from dictionaries or + :py:class:`~enum.IntEnum` objects, the returned count includes unique invalid values from the + underlying array. Otherwise, invalid values are not counted. + + See Also + -------- + :py:meth:`.rt_categorical.Categorical.nunique` : Number of unique values in the :py:class:`~.rt_categorical.Categorical`. + :py:meth:`.rt_groupbyops.GroupByOps.count_uniques` : Count the unique values for each group. """ return self.grouping.unique_count # ------------------------------------------------------------ def nunique(self): """ - Number of unique values that occur in the Categorical. - Does not include invalids. Not the same as the length of possible uniques. + Number of unique values in the :py:class:`~.rt_categorical.Categorical`. + + Not the same as the length of possible uniques. - Categoricals based on dictionary mapping / enum will return unique count including all possibly - invalid values from underlying array. + For :py:class:`~.rt_categorical.Categorical` objects constructed from dictionaries or + :py:class:`~enum.IntEnum` objects, the returned count includes unique invalid values from the + underlying array. Otherwise, invalid values are not counted. See Also -------- - Categorical.unique_count + :py:attr:`.rt_categorical.Categorical.unique_count` : Number of unique values in the :py:class:`~.rt_categorical.Categorical`. + :py:meth:`.rt_groupbyops.GroupByOps.count_uniques` : Count the unique values for each group. """ un = unique(self._fa, sorted=False) count = len(un) @@ -6330,7 +6347,7 @@ def __del__(self): # python has trouble deleting objects with circular references if hasattr(self, "_categories_wrap"): del self._categories_wrap - self._grouping = None + del self._grouping # ------------------------------------------------------------ @classmethod @@ -6504,7 +6521,7 @@ def column_name(arg): return value - if np.isscalar(example_res) & ~transform: # userfunc is a scalar function + if np.isscalar(example_res) and not transform: # userfunc is a scalar function res = self._scalar_compiled_numba_apply(iGroup, iFirstGroup, nCountGroup, userfunc, args) res_ds = TypeRegister.Dataset(self.gb_keychain.gbkeys) diff --git a/riptable/rt_datetime.py b/riptable/rt_datetime.py index 792e04b..e549427 100644 --- a/riptable/rt_datetime.py +++ b/riptable/rt_datetime.py @@ -937,13 +937,13 @@ def _strftime(self, format, dtype="O"): if isinstance(self, np.ndarray): return np.asarray( [ - dt.utcfromtimestamp(timestamp).strftime(format) + dt.fromtimestamp(timestamp, timezone.utc).strftime(format) for timestamp in self._fa.astype(np.int64) * SECONDS_PER_DAY ], dtype=dtype, ) else: - return dt.strftime(dt.utcfromtimestamp(self * SECONDS_PER_DAY), format) + return dt.strftime(dt.fromtimestamp(self * SECONDS_PER_DAY, timezone.utc), format) # ------------------------------------------------------------ @property @@ -1606,29 +1606,29 @@ def isfinite(self): See Also -------- - :py:meth:`.rt_datetime.Date.isnan` - :py:meth:`.rt_datetime.DateTimeNano.isnan` - :py:meth:`.rt_datetime.DateTimeNano.isnotnan` - :py:func:`.rt_numpy.isnan` - :py:func:`.rt_numpy.isnotnan` - :py:func:`.rt_numpy.isnanorzero` - :py:meth:`.rt_fastarray.FastArray.isnan` - :py:meth:`.rt_fastarray.FastArray.isnotnan` - :py:meth:`.rt_fastarray.FastArray.notna` - :py:meth:`.rt_fastarray.FastArray.isnanorzero` - :py:meth:`.rt_categorical.Categorical.isnan` - :py:meth:`.rt_categorical.Categorical.isnotnan` - :py:meth:`.rt_categorical.Categorical.notna` - :py:meth:`.rt_dataset.Dataset.mask_or_isnan` : - Return a boolean array that's `True` for each - :py:class:`~.rt_dataset.Dataset` row that contains at least one ``NaN``. - :py:meth:`.rt_dataset.Dataset.mask_and_isnan` : + :py:meth:`.rt_datetime.Date.isnotfinite` + :py:meth:`.rt_datetime.DateTimeNano.isfinite` + :py:meth:`.rt_datetime.DateTimeNano.isnotfinite` + :py:meth:`.rt_fastarray.FastArray.isfinite` + :py:meth:`.rt_fastarray.FastArray.isnotfinite` + :py:meth:`.rt_fastarray.FastArray.isinf` + :py:meth:`.rt_fastarray.FastArray.isnotinf` + :py:meth:`.rt_dataset.Dataset.mask_or_isfinite` : Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` - row that contains only ``NaN`` values. + row that has at least one finite value. + :py:meth:`.rt_dataset.Dataset.mask_and_isfinite` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all finite values. + :py:meth:`.rt_dataset.Dataset.mask_or_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that has at least one value that's positive or negative infinity. + :py:meth:`.rt_dataset.Dataset.mask_and_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all infinite values. Notes ----- - Riptable currently uses ``0`` to represent ``NaN`` values for the classes in the + Riptable uses ``0`` to represent ``NaN`` values for the classes in the :py:mod:`.rt_datetime` module. This constant is held in the :py:class:`~.rt_datetime.DateTimeBase` class. @@ -1661,29 +1661,29 @@ def isnotfinite(self): See Also -------- - :py:meth:`.rt_datetime.Date.isnotnan` - :py:meth:`.rt_datetime.DateTimeNano.isnan` - :py:meth:`.rt_datetime.DateTimeNano.isnotnan` - :py:func:`.rt_numpy.isnan` - :py:func:`.rt_numpy.isnotnan` - :py:func:`.rt_numpy.isnanorzero` - :py:meth:`.rt_fastarray.FastArray.isnan` - :py:meth:`.rt_fastarray.FastArray.isnotnan` - :py:meth:`.rt_fastarray.FastArray.notna` - :py:meth:`.rt_fastarray.FastArray.isnanorzero` - :py:meth:`.rt_categorical.Categorical.isnan` - :py:meth:`.rt_categorical.Categorical.isnotnan` - :py:meth:`.rt_categorical.Categorical.notna` - :py:meth:`.rt_dataset.Dataset.mask_or_isnan` : - Return a boolean array that's `True` for each - :py:class:`~.rt_dataset.Dataset` row that contains at least one ``NaN``. - :py:meth:`.rt_dataset.Dataset.mask_and_isnan` : + :py:meth:`.rt_datetime.Date.isfinite` + :py:meth:`.rt_datetime.DateTimeNano.isnotfinite` + :py:meth:`.rt_datetime.DateTimeNano.isfinite` + :py:meth:`.rt_fastarray.FastArray.isfinite` + :py:meth:`.rt_fastarray.FastArray.isnotfinite` + :py:meth:`.rt_fastarray.FastArray.isinf` + :py:meth:`.rt_fastarray.FastArray.isnotinf` + :py:meth:`.rt_dataset.Dataset.mask_or_isfinite` : Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` - row that contains only ``NaN`` values. + row that has at least one finite value. + :py:meth:`.rt_dataset.Dataset.mask_and_isfinite` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all finite values. + :py:meth:`.rt_dataset.Dataset.mask_or_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that has at least one value that's positive or negative infinity. + :py:meth:`.rt_dataset.Dataset.mask_and_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all infinite values. Notes ----- - Riptable currently uses ``0`` to represent ``NaN`` values for the classes in the + Riptable uses ``0`` to represent ``NaN`` values for the classes in the :py:mod:`.rt_datetime` module. This constant is held in the :py:class:`~.rt_datetime.DateTimeBase` class. @@ -4161,10 +4161,11 @@ def _strftime(self, format, dtype="O"): if to_tz in ["GMT", "UTC"]: if isinstance(in_seconds, np.ndarray): return np.asarray( - [dt.utcfromtimestamp(timestamp).strftime(format) for timestamp in in_seconds], dtype=dtype + [dt.fromtimestamp(timestamp, timezone.utc).strftime(format) for timestamp in in_seconds], + dtype=dtype, ) else: - return dt.strftime(dt.utcfromtimestamp(in_seconds), format) + return dt.strftime(dt.fromtimestamp(in_seconds, timezone.utc), format) else: # Choose timezone from to_tz @@ -5078,16 +5079,6 @@ def info(self): """ print(self.__repr__(verbose=True)) - # ------------------------------------------------------- - def diff(self, periods=1): - """ - Returns - ------- - TimeSpan - """ - result = self._fa.diff(periods=periods) - return TimeSpan(result) - # ------------------------------------------------------------ def __repr__(self, verbose=False): repr_strings = [] @@ -5308,8 +5299,8 @@ def isnan(self): Return a boolean array that's `True` for each invalid :py:class:`~.rt_datetime.DateTimeNano` element, `False` otherwise. - ``0`` and ``NaN`` values are treated as invalid - :py:class:`~.rt_datetime.DateTimeNano` elements. + ``0``, ``NaN`` values, and positive and negative infinity are + treated as invalid :py:class:`~.rt_datetime.DateTimeNano` elements. Returns ------- @@ -5371,8 +5362,8 @@ def isnotnan(self): Return a boolean array that's `True` for each valid :py:class:`~.rt_datetime.DateTimeNano` element, `False` otherwise. - ``0`` and ``NaN`` values are treated as invalid - :py:class:`~.rt_datetime.DateTimeNano` elements. + ``0``, ``NaN`` values, and positive and negative infinity are + treated as invalid :py:class:`~.rt_datetime.DateTimeNano` elements. Returns ------- @@ -5431,90 +5422,133 @@ def isnotnan(self): # ------------------------------------------------------------- def isfinite(self): """ - Return a boolean array that's True for each `DateTimeNano` element - that's not a NaN (Not a Number), False otherwise. + Return a boolean array that's `True` for each finite + :py:class:`~.rt_datetime.DateTimeNano` element, `False` otherwise. - Both the DateTime NaN (0) and Riptable's int64 sentinel value are - considered to be NaN. + These invalid :py:class:`~.rt_datetime.DateTimeNano` values are considered + non-finite: + + - Positive and negative infinity + - ``0`` + - ``NaN`` values (``rt.nan`` and other sentinel values) + - Strings representing dates before the UNIX epoch Returns ------- - `FastArray` - A `FastArray` of booleans that's True for each non-NaN element, - False otherwise. + :py:class:`~.rt_fastarray.FastArray` + A :py:class:`~.rt_fastarray.FastArray` of booleans that's `True` for each + finite :py:class:`~.rt_datetime.DateTimeNano` element, `False` otherwise. See Also -------- - DateTimeNano.isnan, Date.isnan, Date.isnotnan, riptable.isnan, - riptable.isnotnan, riptable.isnanorzero, FastArray.isnan, - FastArray.isnotnan, FastArray.notna, FastArray.isnanorzero, - Categorical.isnan, Categorical.isnotnan, Categorical.notna - Dataset.mask_or_isnan : - Return a boolean array that's True for each `Dataset` row that - contains at least one NaN. - Dataset.mask_and_isnan : - Return a boolean array that's True for each all-NaN `Dataset` row. + :py:meth:`.rt_datetime.Date.isnotfinite` + :py:meth:`.rt_datetime.Date.isfinite` + :py:meth:`.rt_datetime.DateTimeNano.isnotfinite` + :py:meth:`.rt_fastarray.FastArray.isfinite` + :py:meth:`.rt_fastarray.FastArray.isnotfinite` + :py:meth:`.rt_fastarray.FastArray.isinf` + :py:meth:`.rt_fastarray.FastArray.isnotinf` + :py:meth:`.rt_dataset.Dataset.mask_or_isfinite` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that has at least one finite value. + :py:meth:`.rt_dataset.Dataset.mask_and_isfinite` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all finite values. + :py:meth:`.rt_dataset.Dataset.mask_or_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that has at least one value that's positive or negative infinity. + :py:meth:`.rt_dataset.Dataset.mask_and_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all infinite values. Notes ----- - Riptable currently uses 0 for the DateTime NaN value. This constant is - held in the `DateTimeBase` class. + Riptable treats ``0`` as an invalid date for the classes in the + :py:mod:`.rt_datetime` module. This constant is held in + :py:class:`~.rt_datetime.DateTimeBase`. Examples -------- - >>> dtn = rt.DateTimeNano(['20210101 09:31:15', '20210519 05:21:17', - ... '20210713 02:44:19'], from_tz = 'NYC') - >>> dtn[0] = 0 - >>> dtn[1] = dtn.inv - >>> dtn - DateTimeNano(['Inv', 'Inv', '20210712 22:44:19.000000000'], to_tz='NYC') + >>> dtn = rt.DateTimeNano([rt.inf, -rt.inf]) >>> dtn.isfinite() - FastArray([False, False, True]) + FastArray([False, False]) + + >>> dtn2 = rt.DateTimeNano([0.0, rt.nan]) + >>> dtn2.isfinite() + FastArray([False, False]) + + >>> dtn3 = rt.DateTimeNano(["2024-01-01 12:00:00.000000000", + ... "2010-09-24 12:00:00.000000000", + ... "1962-03-17 12:00:00.000000000"], from_tz="NYC") + >>> dtn3.isfinite() + FastArray([ True, True, False]) """ return ~self.isnan() # ------------------------------------------------------------- def isnotfinite(self): """ - Return a boolean array that's True for each `DateTimeNano` element - that's a NaN (Not a Number), False otherwise. + Return a boolean array that's `True` for each non-finite + :py:class:`~.rt_datetime.DateTimeNano` element, `False` otherwise. - Both the DateTime NaN (0) and Riptable's int64 sentinel value are - considered to be NaN. + These invalid :py:class:`~.rt_datetime.DateTimeNano` values are considered + non-finite: + + - Positive and negative infinity + - ``0`` + - ``NaN`` values (``rt.nan`` and other sentinel values) + - Strings representing dates before the UNIX epoch Returns ------- - `FastArray` - A `FastArray` of booleans that's True for each NaN element, False + :py:class:`~.rt_fastarray.FastArray` + A :py:class:`~.rt_fastarray.FastArray` of booleans that's `True` for each + non-finite :py:class:`~.rt_datetime.DateTimeNano` element, `False` otherwise. See Also -------- - DateTimeNano.isnotnan, Date.isnan, Date.isnotnan, riptable.isnan, - riptable.isnotnan, riptable.isnanorzero, FastArray.isnan, - FastArray.isnotnan, FastArray.notna, FastArray.isnanorzero, - Categorical.isnan, Categorical.isnotnan, Categorical.notna - Dataset.mask_or_isnan : - Return a boolean array that's True for each `Dataset` row that contains - at least one NaN. - Dataset.mask_and_isnan : - Return a boolean array that's True for each all-NaN `Dataset` row. + :py:meth:`.rt_datetime.Date.isnotfinite` + :py:meth:`.rt_datetime.Date.isfinite` + :py:meth:`.rt_datetime.DateTimeNano.isfinite` + :py:meth:`.rt_fastarray.FastArray.isfinite` + :py:meth:`.rt_fastarray.FastArray.isnotfinite` + :py:meth:`.rt_fastarray.FastArray.isinf` + :py:meth:`.rt_fastarray.FastArray.isnotinf` + :py:meth:`.rt_dataset.Dataset.mask_or_isfinite` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that has at least one finite value. + :py:meth:`.rt_dataset.Dataset.mask_and_isfinite` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all finite values. + :py:meth:`.rt_dataset.Dataset.mask_or_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that has at least one value that's positive or negative infinity. + :py:meth:`.rt_dataset.Dataset.mask_and_isinf` : + Return a boolean array that's `True` for each :py:class:`~.rt_dataset.Dataset` + row that contains all infinite values. Notes ----- - Riptable currently uses 0 for the DateTime NaN value. This constant is - held in the `DateTimeBase` class. + Riptable treats ``0`` as an invalid date for the classes in the + :py:mod:`.rt_datetime` module. This constant is held in + :py:class:`~.rt_datetime.DateTimeBase`. Examples -------- - >>> dtn = rt.DateTimeNano(['20210101 09:31:15', '20210519 05:21:17', - ... '20210713 02:44:19'], from_tz = 'NYC') - >>> dtn[0] = 0 - >>> dtn[1] = dtn.inv - >>> dtn - DateTimeNano(['Inv', 'Inv', '20210712 22:44:19.000000000'], to_tz='NYC') + >>> dtn = rt.DateTimeNano([rt.inf, -rt.inf]) >>> dtn.isnotfinite() - FastArray([ True, True, False]) + FastArray([ True, True]) + + >>> dtn2 = rt.DateTimeNano([0.0, rt.nan]) + >>> dtn2.isnotfinite() + FastArray([ True, True]) + + >>> dtn3 = rt.DateTimeNano(["2024-01-01 12:00:00.000000000", + ... "2010-09-24 12:00:00.000000000", + ... "1962-03-17 12:00:00.000000000"], from_tz="NYC") + >>> dtn3.isnotfinite() + FastArray([False, False, True]) """ return self._fa.isnanorzero() @@ -5640,17 +5674,61 @@ def max(self, **kwargs): # ------------------------------------------------------------- def diff(self, periods=1): """ - Calculate the n-th discrete difference. + Calculate the differences between elements of a :py:class:`~.rt_datetime.DateTimeNano`. + + Subtracts the :py:class:`~.rt_datetime.DateTimeNano` array shifted to the right + ``periods`` times from the original :py:class:`~.rt_datetime.DateTimeNano` array. + + Spaces at either end of the returned array are filled with invalid values. Parameters ---------- periods : int, optional - The number of times values are differenced. If zero, the input - is returned as-is. + The number of element positions to shift the :py:class:`~.rt_datetime.DateTimeNano` + right (if positive) or left (if negative) before subtracting it from the + original. Returns ------- - obj:`TimeSpan` + :py:class:`~.rt_datetime.TimeSpan` + An equivalent-length :py:class:`~.rt_datetime.TimeSpan` array containing the + differences between :py:class:`~.rt_datetime.DateTimeNano` elements. + + Examples + -------- + Create a :py:class:`~.rt_datetime.DateTimeNano` for the following examples: + + >>> dtn = rt.DateTimeNano(["2021-08-15 00:00", + ... "2021-08-15 00:30", + ... "2021-08-15 01:30", + ... "2021-08-16 01:30", + ... "2021-09-16 01:30", + ... "2022-09-16 01:30"], + ... from_tz="NYC") + >>> dtn.diff() + TimeSpan(['Inv', '00:30:00.000000000', '01:00:00.000000000', '1d 00:00:00.000000000', '31d 00:00:00.000000000', '365d 00:00:00.000000000']) + + Calculate the difference between the original + :py:class:`~.rt_datetime.DateTimeNano` array and the array shifted two positions + to the right: + + >>> dtn.diff(periods=2) + TimeSpan(['Inv', 'Inv', '01:30:00.000000000', '1d 01:00:00.000000000', '32d 00:00:00.000000000', '396d 00:00:00.000000000']) + + Calculate the difference between the original + :py:class:`~.rt_datetime.DateTimeNano` array and the array shifted two positions + to the left: + + >>> dtn.diff(periods=-2) + TimeSpan(['-01:30:00.000000000', '-1d 01:00:00.000000000', '-32d 00:00:00.000000000', '-396d 00:00:00.000000000', 'Inv', 'Inv']) + + Calculate the difference between the original + :py:class:`~.rt_datetime.DateTimeNano` array and the array shifted six positions + to the right. Note that with only six elements in the array, the shifted array + is empty, resulting in a :py:class:`~.rt_datetime.TimeSpan` array of invalid values. + + >>> dtn.diff(periods=6) + TimeSpan(['Inv', 'Inv', 'Inv', 'Inv', 'Inv', 'Inv']) """ return TimeSpan(self._fa.diff(periods=periods).astype(np.float64)) @@ -6066,48 +6144,91 @@ def random_invalid(cls, sz, to_tz="NYC", from_tz="NYC", start=None, end=None): # ------------------------------------------------------------- def resample(self, rule: str, dropna: bool = False): - """Convenience method for frequency conversion and resampling of - DateTimeNano arrays. + """ + Create a new :py:class:`~.rt_datetime.DateTimeNano` range with a specified + frequency, or round each element of a :py:class:`~.rt_datetime.DateTimeNano` + down to the nearest specified time unit. + + The action performed by this method is controlled by the ``dropna`` parameter. Parameters ---------- - rule : string - The offset string or object representing target conversion. - Can also begin the string with a number e.g. '3H' - Currently supported: - H hour - T, min minute - S second - L, ms millisecond - U, us microsecond - N, ns nanosecond - - dropna : bool, default False - If True, returns a DateTimeNano the same length as caller, with all values rounded to specified frequency. - If False, returns a DateTimeNano range from caller's min to max with values at every specified frequency. + rule : str + The time unit specifying the frequency of the returned + :py:class:`~.rt_datetime.DateTimeNano` range if ``dropna`` is `False`. Or the + time unit to which the original :py:class:`~.rt_datetime.DateTimeNano` + elements are rounded if ``dropna`` is `True`. + + Supported time units: + + - ``"H"``: Hours + - ``"T"`` or ``"min"``: Minutes + - ``"S"``: Seconds + - ``"L"`` or ``"ms"``: Milliseconds + - ``"U"`` or ``"us"``: Microseconds + - ``"N"`` or ``"ns"``: Nanoseconds + + A number can be added before the time unit to adjust the frequency. + For example: + + - ``"3H"`` means every three hours + - ``"20S"`` means every twenty seconds + - ``"1.5us"`` means every 1.5 microseconds + + dropna : bool, default `False` + Controls the action of the method: + + - If `False`, returns a new :py:class:`~.rt_datetime.DateTimeNano` + range between the original min and max values, with a frequency specified + by ``rule``. + - If `True`, returns a new :py:class:`~.rt_datetime.DateTimeNano` with + the original elements rounded down to the nearest time unit specified by + ``rule``. + + Returns + ------- + :py:class:`~.rt_datetime.DateTimeNano` + A :py:class:`~.rt_datetime.DateTimeNano` with resampled or rounded elements. Examples -------- - >>> dtn = DateTimeNano(['2015-04-15 14:26:54.735321368', - '2015-04-20 07:30:00.858219615', - '2015-04-23 13:15:24.526871083', - '2015-04-21 02:25:11.768548100', - '2015-04-24 07:47:54.737776979', - '2015-04-10 23:59:59.376589955'], - from_tz='UTC', to_tz='UTC') - >>> dtn.resample('L', dropna=True) + Create a new range in a :py:class:`~.rt_datetime.DateTimeNano` array with a + frequency of one microsecond: + + >>> dtn = rt.DateTimeNano(["20190417 17:47:00.000001", + ... "20190417 17:47:00.000003", + ... "20190417 17:47:00.000005"], + ... from_tz="NYC") + >>> dtn.resample("us") + DateTimeNano(['20190417 17:47:00.000001000', '20190417 17:47:00.000002000', '20190417 17:47:00.000003000', '20190417 17:47:00.000004000', '20190417 17:47:00.000005000'], to_tz='America/New_York') + + Create a new range in a :py:class:`~.rt_datetime.DateTimeNano` with a frequency + of three microseconds: + + >>> dtn.resample("3us") + DateTimeNano(['20190417 17:47:00.000000000', '20190417 17:47:00.000003000'], to_tz='America/New_York') + + Note that the first element in the new range is rounded down below the minimum + value of the original :py:class:`~.rt_datetime.DateTimeNano`. To get the first + element of the new range, this method rounds the minimum of the original + :py:class:`~.rt_datetime.DateTimeNano` down to the closest multiple of ``rule``. + The closest multiple is less than or equal to the original + minimum, including the zero value of the ``rule`` unit digit. In the previous + example, the method attempts to round one microsecond down to the nearest three + microseconds, which is 0 microseconds. + + Set ``dropna`` to `True` to round the elements of a + :py:class:`~.rt_datetime.DateTimeNano` down to the nearest millisecond: + + >>> dtn = rt.DateTimeNano(["2015-04-15 14:26:54.735321368", + ... "2015-04-20 07:30:00.858219615", + ... "2015-04-23 13:15:24.526871083", + ... "2015-04-21 02:25:11.768548100", + ... "2015-04-24 07:47:54.737776979", + ... "2015-04-10 23:59:59.376589955"], + ... from_tz="UTC", to_tz="UTC") + >>> dtn.resample("L", dropna=True) DateTimeNano(['20150415 14:26:54.735000000', '20150420 07:30:00.858000000', '20150423 13:15:24.526000000', '20150421 02:25:11.768000000', '20150424 07:47:54.737000000', '20150410 23:59:59.376000000'], to_tz='UTC') - - >>> dtn = DateTimeNano(['20190417 17:47:00.000001', - '20190417 17:47:00.000003', - '20190417 17:47:00.000005'], - from_tz='NYC') - >>> dtn.resample('1us') - DateTimeNano(['20190417 17:47:00.000001000', '20190417 17:47:00.000002000', '20190417 17:47:00.000003000', '20190417 17:47:00.000004000', '20190417 17:47:00.000005000'], to_tz='NYC') - - Returns - ------- - dtn : `DateTimeNano` """ # ------------------------------------------------------- @@ -6410,7 +6531,10 @@ def _strftime(self, format, dtype="U"): if isinstance(self, np.ndarray): result = np.asarray( - [dt.utcfromtimestamp(timestamp).strftime(format) for timestamp in self._fa.abs() / 1_000_000_000.0], + [ + dt.fromtimestamp(timestamp, timezone.utc).strftime(format) + for timestamp in self._fa.abs() / 1_000_000_000.0 + ], dtype=dtype, ) if isnegative.sum() > 0: @@ -6422,7 +6546,7 @@ def _strftime(self, format, dtype="U"): negcol[isnegative] = "-" result = negcol + result else: - result = dt.strftime(dt.utcfromtimestamp(abs(self) / 1_000_000_000.0), format) + result = dt.strftime(dt.fromtimestamp(abs(self) / 1_000_000_000.0, timezone.utc), format) if isnegative: # check dtype 'S' if dtype == "S": @@ -7398,7 +7522,7 @@ def strftime(self, format): >>> d[0].strftime('%D') '01/01/21' """ - return dt.strftime(dt.utcfromtimestamp(self.astype(np.int64) * SECONDS_PER_DAY), format) + return dt.strftime(dt.fromtimestamp(self.astype(np.int64) * SECONDS_PER_DAY, timezone.utc), format) # ------------------------------------------------------------ @property diff --git a/riptable/rt_groupbyops.py b/riptable/rt_groupbyops.py index 7efd384..9dec598 100644 --- a/riptable/rt_groupbyops.py +++ b/riptable/rt_groupbyops.py @@ -1036,23 +1036,30 @@ def null(self, showfilter=False): # --------------------------------------------------------------- def count_uniques(self, *args, **kwargs): """ - Compute unique count of group + Count the unique values for each group. Returns ------- - Dataset with grouped key plus the unique count for each column by group. + :py:class:`~.rt_dataset.Dataset` : + :py:class:`~.rt_dataset.Dataset` with grouped keys and the unique count for each column by group. + + See Also + -------- + :py:attr:`.rt_categorical.Categorical.unique_count` : Number of unique values in the :py:class:`~.rt_categorical.Categorical`. + :py:meth:`.rt_categorical.Categorical.nunique` : Number of unique values in the :py:class:`~.rt_categorical.Categorical`. Examples -------- - >>> N = 17; np.random.seed(1) - >>> ds =Dataset( + >>> N = 17 + >>> np.random.seed(1) + >>> ds = rt.Dataset( dict( - Symbol = Cat(np.random.choice(['SPY','IBM'], N)), - Exchange = Cat(np.random.choice(['AMEX','NYSE'], N)), - TradeSize = np.random.choice([1,5,10], N), - TradePrice = np.random.choice([1.1,2.2,3.3], N), + Symbol=rt.Cat(np.random.choice(["SPY", "IBM"], N)), + Exchange=rt.Cat(np.random.choice(["AMEX", "NYSE"], N)), + TradeSize=np.random.choice([1, 5, 10], N), + TradePrice=np.random.choice([1.1, 2.2, 3.3], N), )) - >>> ds.cat(['Symbol','Exchange']).count_uniques() + >>> ds.cat(["Symbol", "Exchange"]).count_uniques() *Symbol *Exchange TradeSize TradePrice ------- --------- --------- ---------- IBM NYSE 2 2 diff --git a/riptable/rt_numpy.py b/riptable/rt_numpy.py index 2fdc86c..85406da 100644 --- a/riptable/rt_numpy.py +++ b/riptable/rt_numpy.py @@ -651,23 +651,50 @@ def empty_like( # ------------------------------------------------------- def _searchsorted(array, v, side="left", sorter=None): + from .rt_utils import possibly_convert + + def _punt_to_numpy(array, v, side, sorter): + # numpy does not like fastarrays for this routine + if isinstance(array, TypeRegister.FastArray): + array = array._np + if isinstance(v, TypeRegister.FastArray): + v = v._np + return LedgerFunction(np.searchsorted, array, v, side=side, sorter=sorter) + + is_scalar = np.isscalar(v) + dtype = get_common_dtype(array, v) + + if not dtype.kind in 'biuf': + return _punt_to_numpy(array, v, side, sorter) + + if not isinstance(array, np.ndarray): + array = np.array(array, dtype=dtype) + + if not isinstance(v, np.ndarray): + v = np.array(v, dtype=dtype) + + array = possibly_convert(array, dtype) + v = possibly_convert(v, dtype) # we cannot handle a sorter if sorter is None: try: + res = None if side == "leftplus": - return rc.BinsToCutsBSearch(v, array, 0) + res = rc.BinsToCutsBSearch(v, array, 0) elif side == "left": - return rc.BinsToCutsBSearch(v, array, 1) + res = rc.BinsToCutsBSearch(v, array, 1) else: - return rc.BinsToCutsBSearch(v, array, 2) + res = rc.BinsToCutsBSearch(v, array, 2) + + if is_scalar and not np.isscalar(res): + res = res[0] + + return res except: # fall into numpy pass - # numpy does not like fastarrays for this routine - if isinstance(array, TypeRegister.FastArray): - array = array._np - return LedgerFunction(np.searchsorted, array, v, side=side, sorter=sorter) + return _punt_to_numpy(array, v, side, sorter) # ------------------------------------------------------- diff --git a/riptable/rt_struct.py b/riptable/rt_struct.py index f0f650f..ab0537f 100644 --- a/riptable/rt_struct.py +++ b/riptable/rt_struct.py @@ -2712,9 +2712,9 @@ def col_exists(self, name): # ------------------------------------------------------- def _aggregate_column_matches( self, - items: Optional[Union[str, int, Iterable[Union[str, int]]]] = None, - like: Optional[str] = None, - regex: Optional["re.Pattern"] = None, + items: str | int | Iterable[str | int] | None = None, + like: str | None = None, + regex: re.Pattern | str | None = None, on_missing: Literal["raise", "warn", "ignore"] = "raise", func: Optional[Callable] = None, ) -> list[str]: @@ -2810,9 +2810,9 @@ def index_to_col(col_num: Union[str, int]) -> str: # ------------------------------------------------------- def col_filter( self, - items: Optional[Union[str, int, Iterable[Union[str, int]]]] = None, - like: Optional[str] = None, - regex: Optional["re.Pattern"] = None, + items: str | int | Iterable[str | int] | None = None, + like: str | None = None, + regex: re.Pattern | str | None = None, on_missing: Literal["raise", "warn", "ignore"] = "raise", ): """ @@ -3511,9 +3511,9 @@ def _ensure_atomic(self, colnames, func): # -------------------------------------------------------- def col_remove( self, - items: Optional[Union[str, int, Iterable[Union[str, int]]]] = None, - like: Optional[str] = None, - regex: Optional["re.Pattern"] = None, + items: str | int | Iterable[str | int] | None = None, + like: str | None = None, + regex: re.Pattern | str | None = None, on_missing: Literal["raise", "warn", "ignore"] = "warn", ): """ @@ -4231,27 +4231,46 @@ def _T(self): @property def _V(self): """ - Display all rows (up to 10,000) of a `.Dataset` or `Struct`. + Display all rows (up to 10,000) returned by a :py:class:`~.rt_dataset.Dataset` + or :py:class:`~.rt_struct.Struct`. - Without this property, rows are elided when there are more than 30 to display. + Without this property, rows are elided when there are more than the sum of + :py:attr:`~riptable.Utils.display_options.DisplayOptions.HEAD_ROWS` (default + 15 rows) and :py:attr:`~riptable.Utils.display_options.DisplayOptions.TAIL_ROWS` + (default 15 rows) and when + :py:attr:`~riptable.Utils.display_options.DisplayOptions.ROW_ALL` is `False`. Returns ------- DisplayString - A wrapper for display operations that don't return a `.Dataset` or `Struct`. + A wrapper for display operations that don't return a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. See Also -------- - Struct._H : Display all columns and long strings of a `.Dataset` or `Struct`. - Struct._A : Display all columns, rows, and long strings of a `.Dataset` or `Struct`. - Struct._G : Display all columns of a `.Dataset` or `Struct`, wrapping the table as needed. - Struct._T : Display a transposed view of a `.Dataset` or `Struct`. + :py:attr:`.rt_struct.Struct._H` : + Display all columns and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._A` : + Display all columns, rows, and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._G` : + Display all columns returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`, wrapping the table as needed. + :py:attr:`.rt_struct.Struct._T` : + Display a transposed view of a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`. Examples -------- + The following examples use default display options. The following internal + function resets display options to their defaults: + + >>> setup_for_examples("struct-display") + By default, rows are elided when there are more than 30 to display. - >>> ds = rt.Dataset({'a' : rt.arange(31)}) + >>> ds = rt.Dataset({"a": rt.arange(31)}) >>> ds # a --- --- @@ -4286,43 +4305,45 @@ def _V(self): 28 28 29 29 30 30 + + [31 rows x 1 columns] total bytes: 248.0 B Display all rows: >>> ds._V - # a - --- --- - 0 0 - 1 1 - 2 2 - 3 3 - 4 4 - 5 5 - 6 6 - 7 7 - 8 8 - 9 9 - 10 10 - 11 11 - 12 12 - 13 13 - 14 14 - 15 15 - 16 16 - 17 17 - 18 18 - 19 19 - 20 20 - 21 21 - 22 22 - 23 23 - 24 24 - 25 25 - 26 26 - 27 27 - 28 28 - 29 29 - 30 30 + # a + -- -- + 0 0 + 1 1 + 2 2 + 3 3 + 4 4 + 5 5 + 6 6 + 7 7 + 8 8 + 9 9 + 10 10 + 11 11 + 12 12 + 13 13 + 14 14 + 15 15 + 16 16 + 17 17 + 18 18 + 19 19 + 20 20 + 21 21 + 22 22 + 23 23 + 24 24 + 25 25 + 26 26 + 27 27 + 28 28 + 29 29 + 30 30 """ maxrows = 10_000 numrows = self._nrows if hasattr(self, "_nrows") else len(self) @@ -4337,29 +4358,53 @@ def _V(self): @property def _H(self): """ - Display all columns and long strings of a `.Dataset` or `Struct`. + Display all columns and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`, wrapping the + table as needed. + + Without this property, :py:class:`~.rt_dataset.Dataset` and + :py:class:`~.rt_struct.Struct` objects are displayed according to Riptable's + display options: - Without this property, columns are elided when the maximum display - width is reached, and strings are truncated after 15 characters. + - Columns are elided when the maximum display console width is reached and + :py:attr:`~riptable.Utils.display_options.DisplayOptions.COL_ALL` is `False`. + - Strings are truncated after exceeding + :py:attr:`~riptable.Utils.display_options.DisplayOptions.MAX_STRING_WIDTH` + (default 15 characters). Returns ------- DisplayString - A wrapper for display operations that don't return a `.Dataset` or `Struct`. + A wrapper for display operations that don't return a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. See Also -------- - Struct._V : Display all rows of a `.Dataset` or `Struct`. - Struct._A : Display all columns, rows, and long strings of a `.Dataset` or `Struct`. - Struct._G : Display all columns of a `.Dataset` or `Struct`, wrapping the table as needed. - Struct._T : Display a transposed view of a `.Dataset` or `Struct`. + :py:attr:`.rt_struct.Struct._V` : + Display all rows returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._A` : + Display all columns, rows, and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._G` : + Display all columns returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`, wrapping the table as needed. + :py:attr:`.rt_struct.Struct._T` : + Display a transposed view returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`. Examples -------- - By default, columns are elided when the maximum display width is reached, and strings are truncated after 15 characters. + The following examples use default display options. The following internal + function resets display options to their defaults: - >>> ds = rt.Dataset({key : rt.FA([i, 2*i, 3*i, 4*i])%3 == 0 for i, key in enumerate('abcdefghijklm')}) - >>> ds[0] = rt.FA('long_string_long_string') + >>> setup_for_examples("struct-display") + + By default, columns are elided when the maximum console display width is + reached, and strings are truncated after 15 characters. + + >>> ds = rt.Dataset({key: rt.FA([i, 2 * i, 3 * i, 4 * i]) % 3 == 0 for i, key in enumerate("abcdefghijklm")}) + >>> ds[0] = rt.FA("long_string_long_string") >>> ds # a b c d e f ... h i j k l m - --------------- ----- ----- ---- ----- ----- --- ----- ----- ---- ----- ----- ---- @@ -4367,16 +4412,25 @@ def _H(self): 1 long_string_lon False False True False False ... False False True False False True 2 long_string_lon True True True True True ... True True True True True True 3 long_string_lon False False True False False ... False False True False False True + + [4 rows x 13 columns] total bytes: 140.0 B - Display all columns and long strings: + Display all columns and long strings, wrapping the table as needed: >>> ds._H - # a b c d e f g h i j k l m - - ----------------------- ----- ----- ---- ----- ----- ---- ----- ----- ---- ----- ----- ---- - 0 long_string_long_string False False True False False True False False True False False True - 1 long_string_long_string False False True False False True False False True False False True - 2 long_string_long_string True True True True True True True True True True True True - 3 long_string_long_string False False True False False True False False True False False True + # a b c d e f g h i j k + - ----------------------- ----- ----- ---- ----- ----- ---- ----- ----- ---- ----- + 0 long_string_long_string False False True False False True False False True False + 1 long_string_long_string False False True False False True False False True False + 2 long_string_long_string True True True True True True True True True True + 3 long_string_long_string False False True False False True False False True False + + # l m + - ----- ---- + 0 False True + 1 False True + 2 True True + 3 False True """ return self._temp_display(["COL_ALL", "MAX_STRING_WIDTH"], [True, 1000]) @@ -4384,27 +4438,46 @@ def _H(self): @property def _G(self): """ - Display all columns of a `.Dataset` or `Struct`, wrapping the table - after the maximum display width is reached. + Display all columns returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`, wrapping the table after the maximum display + width is reached. + + Without this property, columns are elided when the maximum display console width + is reached and :py:attr:`~riptable.Utils.display_options.DisplayOptions.COL_ALL` + is `False`. Note: The table is displayed as text, not HTML. Returns ------- - None + `None` + Returns nothing. See Also -------- - Struct._V : Display all rows of a `.Dataset` or `Struct`. - Struct._H : Display all columns and long strings of a `.Dataset` or `Struct`. - Struct._A : Display all columns, rows, and long strings of a `.Dataset` or `Struct`. - Struct._T : Display a transposed view of a `.Dataset` or `Struct`. + :py:attr:`.rt_struct.Struct._V` : + Display all rows returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._H` : + Display all columns and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._A` : + Display all columns, rows, and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._T` : + Display a transposed view of a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`. Examples -------- - >>> ds = rt.Dataset( - ... {key: rt.FA([i, 2 * i, 3 * i, 4 * i]) % 3 == 0 for i, key in enumerate('abcdefghijklmno')} - ... ) + The following examples use default display options. The following internal + function resets display options to their defaults: + + >>> setup_for_examples("struct-display") + + Create a :py:class:`~.rt_dataset.Dataset` for the following examples: + + >>> ds = rt.Dataset({key: rt.FA([i, 2 * i, 3 * i, 4 * i]) % 3 == 0 for i, key in enumerate("abcdefghijklmno")}) Default behavior: @@ -4415,8 +4488,10 @@ def _G(self): 1 True False False True False False True ... True False False True False False 2 True True True True True True True ... True True True True True True 3 True False False True False False True ... True False False True False False + + [4 rows x 15 columns] total bytes: 60.0 B - Show all rows, wrapping the table as needed: + Show all columns, wrapping the table as needed: >>> ds._G # a b c d e f g h i j k l m @@ -4445,29 +4520,57 @@ def _G(self): @property def _A(self): """ - Display all columns, all rows (up to 10,000), and long strings of a - `.Dataset` or `Struct`. - - Without this property, columns are elided when the maximum display width - is reached, rows are elided when there are more then 30 to display, - and strings are truncated after 15 characters. + Display all columns, rows (up to 10,000), and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`, wrapping the + table as needed. + + Without this property, :py:class:`~.rt_dataset.Dataset` and + :py:class:`~.rt_struct.Struct` objects are displayed according to Riptable's + display options: + + - Columns are elided when the maximum display console width is reached and + :py:attr:`~riptable.Utils.display_options.DisplayOptions.COL_ALL` is `False`. + - Rows are elided when there are more than the sum of + :py:attr:`~riptable.Utils.display_options.DisplayOptions.HEAD_ROWS` (default + 15 rows) and :py:attr:`~riptable.Utils.display_options.DisplayOptions.TAIL_ROWS` + (default 15 rows) and when + :py:attr:`~riptable.Utils.display_options.DisplayOptions.ROW_ALL` is `False`. + - Strings are truncated after exceeding + :py:attr:`~riptable.Utils.display_options.DisplayOptions.MAX_STRING_WIDTH` + (default 15 characters). Returns ------- DisplayString - A wrapper for display operations that don't return a `.Dataset` or `Struct`. + A wrapper for display operations that don't return a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. See Also -------- - Struct._V : Display all rows of a `.Dataset` or `Struct`. - Struct._H : Display all columns and long strings of a `.Dataset` or `Struct`. - Struct._G : Display all columns of a `.Dataset` or `Struct`, wrapping the table as needed. - Struct._T : Display a transposed view of a `.Dataset` or `Struct`. + :py:attr:`.rt_struct.Struct._V` : + Display all rows returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._H` : + Display all columns and long strings returned by a + :py:class:`~.rt_dataset.Dataset` or :py:class:`~.rt_struct.Struct`. + :py:attr:`.rt_struct.Struct._G` : + Display all columns returned by a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`, wrapping the table as needed. + :py:attr:`.rt_struct.Struct._T` : + Display a transposed view of a :py:class:`~.rt_dataset.Dataset` or + :py:class:`~.rt_struct.Struct`. Examples -------- - >>> ds = rt.Dataset({'col_'+str(i):rt.arange(31) for i in range(12)}) - >>> ds[0] = 'long_string_long_string' + The following examples use default display options. The following internal + function resets display options to their defaults: + + >>> setup_for_examples("struct-display") + + Create a :py:class:`~.rt_dataset.Dataset` for the following examples: + + >>> ds = rt.Dataset({"col_" + str(i): rt.arange(31) for i in range(12)}) + >>> ds[0] = "long_string_long_string" By default, columns are elided when the maximum display width is reached, rows are elided when there are more then 30 to display, and @@ -4507,43 +4610,80 @@ def _A(self): 28 long_string_lon 28 28 28 28 28 ... 28 28 28 28 28 29 long_string_lon 29 29 29 29 29 ... 29 29 29 29 29 30 long_string_lon 30 30 30 30 30 ... 30 30 30 30 30 + + [31 rows x 12 columns] total bytes: 5.4 KB - Display all columns, rows, and long strings: + Display all columns, rows, and long strings. Note that the columns that exceed + the console dimensions wrap below into another section. >>> ds._A - # col_0 col_1 col_2 col_3 col_4 col_5 col_6 col_7 col_8 col_9 col_10 col_11 - --- ----------------------- ----- ----- ----- ----- ----- ----- ----- ----- ----- ------ ------ - 0 long_string_long_string 0 0 0 0 0 0 0 0 0 0 0 - 1 long_string_long_string 1 1 1 1 1 1 1 1 1 1 1 - 2 long_string_long_string 2 2 2 2 2 2 2 2 2 2 2 - 3 long_string_long_string 3 3 3 3 3 3 3 3 3 3 3 - 4 long_string_long_string 4 4 4 4 4 4 4 4 4 4 4 - 5 long_string_long_string 5 5 5 5 5 5 5 5 5 5 5 - 6 long_string_long_string 6 6 6 6 6 6 6 6 6 6 6 - 7 long_string_long_string 7 7 7 7 7 7 7 7 7 7 7 - 8 long_string_long_string 8 8 8 8 8 8 8 8 8 8 8 - 9 long_string_long_string 9 9 9 9 9 9 9 9 9 9 9 - 10 long_string_long_string 10 10 10 10 10 10 10 10 10 10 10 - 11 long_string_long_string 11 11 11 11 11 11 11 11 11 11 11 - 12 long_string_long_string 12 12 12 12 12 12 12 12 12 12 12 - 13 long_string_long_string 13 13 13 13 13 13 13 13 13 13 13 - 14 long_string_long_string 14 14 14 14 14 14 14 14 14 14 14 - 15 long_string_long_string 15 15 15 15 15 15 15 15 15 15 15 - 16 long_string_long_string 16 16 16 16 16 16 16 16 16 16 16 - 17 long_string_long_string 17 17 17 17 17 17 17 17 17 17 17 - 18 long_string_long_string 18 18 18 18 18 18 18 18 18 18 18 - 19 long_string_long_string 19 19 19 19 19 19 19 19 19 19 19 - 20 long_string_long_string 20 20 20 20 20 20 20 20 20 20 20 - 21 long_string_long_string 21 21 21 21 21 21 21 21 21 21 21 - 22 long_string_long_string 22 22 22 22 22 22 22 22 22 22 22 - 23 long_string_long_string 23 23 23 23 23 23 23 23 23 23 23 - 24 long_string_long_string 24 24 24 24 24 24 24 24 24 24 24 - 25 long_string_long_string 25 25 25 25 25 25 25 25 25 25 25 - 26 long_string_long_string 26 26 26 26 26 26 26 26 26 26 26 - 27 long_string_long_string 27 27 27 27 27 27 27 27 27 27 27 - 28 long_string_long_string 28 28 28 28 28 28 28 28 28 28 28 - 29 long_string_long_string 29 29 29 29 29 29 29 29 29 29 29 - 30 long_string_long_string 30 30 30 30 30 30 30 30 30 30 30 + # col_0 col_1 col_2 col_3 col_4 col_5 col_6 col_7 col_8 col_9 col_10 + -- ----------------------- ----- ----- ----- ----- ----- ----- ----- ----- ----- ------ + 0 long_string_long_string 0 0 0 0 0 0 0 0 0 0 + 1 long_string_long_string 1 1 1 1 1 1 1 1 1 1 + 2 long_string_long_string 2 2 2 2 2 2 2 2 2 2 + 3 long_string_long_string 3 3 3 3 3 3 3 3 3 3 + 4 long_string_long_string 4 4 4 4 4 4 4 4 4 4 + 5 long_string_long_string 5 5 5 5 5 5 5 5 5 5 + 6 long_string_long_string 6 6 6 6 6 6 6 6 6 6 + 7 long_string_long_string 7 7 7 7 7 7 7 7 7 7 + 8 long_string_long_string 8 8 8 8 8 8 8 8 8 8 + 9 long_string_long_string 9 9 9 9 9 9 9 9 9 9 + 10 long_string_long_string 10 10 10 10 10 10 10 10 10 10 + 11 long_string_long_string 11 11 11 11 11 11 11 11 11 11 + 12 long_string_long_string 12 12 12 12 12 12 12 12 12 12 + 13 long_string_long_string 13 13 13 13 13 13 13 13 13 13 + 14 long_string_long_string 14 14 14 14 14 14 14 14 14 14 + 15 long_string_long_string 15 15 15 15 15 15 15 15 15 15 + 16 long_string_long_string 16 16 16 16 16 16 16 16 16 16 + 17 long_string_long_string 17 17 17 17 17 17 17 17 17 17 + 18 long_string_long_string 18 18 18 18 18 18 18 18 18 18 + 19 long_string_long_string 19 19 19 19 19 19 19 19 19 19 + 20 long_string_long_string 20 20 20 20 20 20 20 20 20 20 + 21 long_string_long_string 21 21 21 21 21 21 21 21 21 21 + 22 long_string_long_string 22 22 22 22 22 22 22 22 22 22 + 23 long_string_long_string 23 23 23 23 23 23 23 23 23 23 + 24 long_string_long_string 24 24 24 24 24 24 24 24 24 24 + 25 long_string_long_string 25 25 25 25 25 25 25 25 25 25 + 26 long_string_long_string 26 26 26 26 26 26 26 26 26 26 + 27 long_string_long_string 27 27 27 27 27 27 27 27 27 27 + 28 long_string_long_string 28 28 28 28 28 28 28 28 28 28 + 29 long_string_long_string 29 29 29 29 29 29 29 29 29 29 + 30 long_string_long_string 30 30 30 30 30 30 30 30 30 30 + + # col_11 + -- ------ + 0 0 + 1 1 + 2 2 + 3 3 + 4 4 + 5 5 + 6 6 + 7 7 + 8 8 + 9 9 + 10 10 + 11 11 + 12 12 + 13 13 + 14 14 + 15 15 + 16 16 + 17 17 + 18 18 + 19 19 + 20 20 + 21 21 + 22 22 + 23 23 + 24 24 + 25 25 + 26 26 + 27 27 + 28 28 + 29 29 + 30 30 """ maxrows = 10_000 numrows = self._nrows if hasattr(self, "_nrows") else len(self) diff --git a/riptable/rt_utils.py b/riptable/rt_utils.py index 2772969..ff89ee1 100644 --- a/riptable/rt_utils.py +++ b/riptable/rt_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = [ # misc riptable utility funcs "get_default_value", @@ -15,6 +17,7 @@ "crc_match", # h5 -> riptable "load_h5", + "possibly_convert", ] import keyword @@ -36,6 +39,7 @@ from .rt_dataset import Dataset from .rt_struct import Struct + from .rt_fastarray import FastArray _T = TypeVar("_T") @@ -389,6 +393,39 @@ def merge_prebinned(key1: np.ndarray, key2: np.ndarray, val1, val2, totalUniqueS return rc.MergeBinnedAndSorted(key1, key2, val1, val2, totalUniqueSize) +def possibly_convert(arr: Union[np.ndarray, FastArray], common_dtype: np.dtype | str): + """ + Helper function for converting FastArray/ndarray to specified dtype. + This performs a safe conversion that understands riptable sentinels. + + Parameters + ---------- + arr: np.ndarray or FastArray + The array to cast. + dtype: str or dtype + The dtype to which the array is cast. + + Returns + ------- + arr_t: np.ndarray or FastArray + arr_t is a new array with specified dtype. + """ + common_dtype = np.dtype(common_dtype) + # upcast if need to + if arr.dtype.num != common_dtype.num: + try: + # perform a safe conversion understanding sentinels + arr = TypeRegister.MathLedger._AS_FA_TYPE(arr, common_dtype.num) + except Exception: + # try numpy conversion + arr = arr.astype(common_dtype) + + elif arr.itemsize != common_dtype.itemsize: + # make strings sizes the same + arr = arr.astype(common_dtype) + return arr + + # ------------------------------------------------------------------------------------------------------ def normalize_keys(key1, key2, verbose=False): """ @@ -436,21 +473,6 @@ def check_key(key): # extract the value in the dictlike object return list(key.values()) - def possibly_convert(arr, common_dtype): - # upcast if need to - if arr.dtype.num != common_dtype.num: - try: - # perform a safe conversion understanding sentinels - arr = TypeRegister.MathLedger._AS_FA_TYPE(arr, common_dtype.num) - except Exception: - # try numpy conversion - arr = arr.astype(common_dtype) - - elif arr.itemsize != common_dtype.itemsize: - # make strings sizes the same - arr = arr.astype(common_dtype) - return arr - key1 = check_key(key1) key2 = check_key(key2) diff --git a/riptable/tests/test_base_function.py b/riptable/tests/test_base_function.py index 22b5a0d..303a35e 100644 --- a/riptable/tests/test_base_function.py +++ b/riptable/tests/test_base_function.py @@ -165,6 +165,22 @@ def test_searchsorted(self): b = b.astype(np.int32) + @pytest.mark.parametrize( + "haystack, needle", [(rt.FA([1, 2, 3]), rt.FA([1.5, 2.5, 3.5])), (rt.FA([1.5, 2.5, 3.5]), rt.FA([1, 2, 3]))] + ) + @pytest.mark.parametrize("side", ["left", "right"]) + def test_searchsorted_mismatch(self, haystack, needle, side): + assert_array_equal( + rt.searchsorted(haystack, needle, side=side), np.searchsorted(haystack._np, needle._np, side=side) + ) + + @pytest.mark.parametrize("haystack, needle", [(rt.FA(["AAPL", "AAPL", "AAPL", "AMZN"]), rt.FA(["AAPL"]))]) + @pytest.mark.parametrize("side", ["left", "right"]) + def test_searchsorted_string(self, haystack, needle, side): + assert_array_equal( + rt.searchsorted(haystack, needle, side=side), np.searchsorted(haystack._np, needle._np, side=side) + ) + class TestStd: @pytest.mark.skip(reason="this test depends on implementation specific behavior") diff --git a/riptable/tests/test_categorical_filter_invalid.py b/riptable/tests/test_categorical_filter_invalid.py index 22294e2..8692fb4 100644 --- a/riptable/tests/test_categorical_filter_invalid.py +++ b/riptable/tests/test_categorical_filter_invalid.py @@ -94,11 +94,18 @@ def test_saveload_invalid(self): assert c.invalid_category == c2.invalid_category assert c.filtered_name == c2.filtered_name - def test_isfiltered(self): - c = Categorical(np.random.choice(4, 100), ["a", "b", "c"]) - flt = c.isfiltered() - eq_z = c._fa == 0 - assert arr_eq(flt, eq_z) + @pytest.mark.parametrize( + "cat, exp", + [ + (Cat([0, 1, 0, 2, 3], ["a", "b", "c"]), [True, False, True, False, False]), + (Cat([0, 1, 0, 2, 3], {1: "a", 2: "b"}), [True, False, True, False, True]), + (Cat([0, 1, 0, 2, 3], {"a": 1, "b": 2}), [True, False, True, False, True]), + (Cat([0, 1, 0, 2, 3], {"a": 1, "b": 2}, base_index=0), [True, False, True, False, True]), + (Cat([0, 1, 0, 2, 3], base_index=0), [False, False, False, False, False]), + ], + ) + def test_isfiltered(self, cat, exp): + assert arr_eq(cat.isfiltered(), exp) def test_isnan(self): c = Categorical(np.random.choice(4, 100), ["a", "b", "c"], invalid="a") diff --git a/riptable/tests/test_rtutils.py b/riptable/tests/test_rtutils.py index 4781d75..da90797 100644 --- a/riptable/tests/test_rtutils.py +++ b/riptable/tests/test_rtutils.py @@ -3,7 +3,7 @@ from numpy.testing import assert_array_equal import riptable as rt -from riptable.rt_utils import crc_match +from riptable.rt_utils import crc_match, possibly_convert @pytest.mark.parametrize( @@ -155,3 +155,19 @@ def test_mbget_with_non_FA_subclass(): # Check that the valid indices fetched the correct values. assert_array_equal(rt.FA([3, 28, 13, 20, 38]), result[valid_indices]) assert type(data) == np.ndarray + + +@pytest.mark.parametrize( + "arr, dtype, exp", + [ + pytest.param(rt.FA([1, 2, 3], dtype=rt.int32), rt.int64, rt.FA([1, 2, 3], dtype=rt.int64)), + pytest.param(rt.FA([1.0, 2.0, 3.0], dtype=rt.float32), rt.float64, rt.FA([1.0, 2.0, 3.0], dtype=rt.float64)), + pytest.param(rt.FA([rt.int8.inv, 2, 3], dtype=rt.int8), rt.int32, rt.FA([rt.int32.inv, 2, 3], dtype=rt.int32)), + pytest.param( + rt.FA([rt.int32.inv, 2, 3], dtype=rt.int32), rt.float32, rt.FA([rt.float32.inv, 2.0, 3], dtype=rt.float32) + ), + pytest.param(rt.FA(["a", "b", "c"], dtype="|S1"), "|S5", rt.FA(["a", "b", "c"], dtype="|S5")), + ], +) +def test_possibly_convert(arr, dtype, exp): + assert_array_equal(possibly_convert(arr, dtype), exp, "Arrays were not equal")