From e5ecee98e316da0ec9176286f5d1bfb5623bcbc7 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Thu, 31 Oct 2024 08:16:24 +0800 Subject: [PATCH] Improve the data type checking for 2-D arrays passed to the GMT C API (#3563) --- pygmt/clib/session.py | 65 ++++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 09eec6004bf..bdd8a13a1d3 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -82,7 +82,8 @@ REGISTRATIONS = ["GMT_GRID_NODE_REG", "GMT_GRID_PIXEL_REG"] -DTYPES = { +# Dictionary for mapping numpy dtypes to GMT data types. +DTYPES_NUMERIC = { np.int8: "GMT_CHAR", np.int16: "GMT_SHORT", np.int32: "GMT_INT", @@ -93,10 +94,14 @@ np.uint64: "GMT_ULONG", np.float32: "GMT_FLOAT", np.float64: "GMT_DOUBLE", + np.timedelta64: "GMT_LONG", +} +DTYPES_TEXT = { np.str_: "GMT_TEXT", np.datetime64: "GMT_DATETIME", - np.timedelta64: "GMT_LONG", } +DTYPES = DTYPES_NUMERIC | DTYPES_TEXT + # Dictionary for storing the values of GMT constants. GMT_CONSTANTS = {} @@ -879,63 +884,59 @@ def _parse_constant( integer_value = sum(self[part] for part in parts) return integer_value - def _check_dtype_and_dim(self, array, ndim): + def _check_dtype_and_dim(self, array: np.ndarray, ndim: int) -> int: """ Check that a numpy array has the given number of dimensions and is a valid data type. Parameters ---------- - array : numpy.ndarray + array The array to be tested. - ndim : int + ndim The desired number of array dimensions. Returns ------- - gmt_type : int + gmt_type The GMT constant value representing this data type. Raises ------ GMTInvalidInput - If the array has the wrong number of dimensions or - is an unsupported data type. + If the array has the wrong number of dimensions or is an unsupported data + type. Examples -------- - >>> import numpy as np >>> data = np.array([1, 2, 3], dtype="float64") - >>> with Session() as ses: - ... gmttype = ses._check_dtype_and_dim(data, ndim=1) - ... gmttype == ses["GMT_DOUBLE"] + >>> with Session() as lib: + ... gmttype = lib._check_dtype_and_dim(data, ndim=1) + ... gmttype == lib["GMT_DOUBLE"] True >>> data = np.ones((5, 2), dtype="float32") - >>> with Session() as ses: - ... gmttype = ses._check_dtype_and_dim(data, ndim=2) - ... gmttype == ses["GMT_FLOAT"] + >>> with Session() as lib: + ... gmttype = lib._check_dtype_and_dim(data, ndim=2) + ... gmttype == lib["GMT_FLOAT"] True """ - # Check that the array has the given number of dimensions + # Check that the array has the given number of dimensions. if array.ndim != ndim: - raise GMTInvalidInput( - f"Expected a numpy {ndim}-D array, got {array.ndim}-D." - ) + msg = f"Expected a numpy {ndim}-D array, got {array.ndim}-D." + raise GMTInvalidInput(msg) - # Check that the array has a valid/known data type - if array.dtype.type not in DTYPES: - try: - if array.dtype.type is np.object_: - # Try to convert unknown object type to np.datetime64 - array = array_to_datetime(array) - else: - raise ValueError - except ValueError as e: - raise GMTInvalidInput( - f"Unsupported numpy data type '{array.dtype.type}'." - ) from e - return self[DTYPES[array.dtype.type]] + # For 1-D arrays, try to convert unknown object type to np.datetime64. + if ndim == 1 and array.dtype.type is np.object_: + with contextlib.suppress(ValueError): + array = array_to_datetime(array) + + # 1-D arrays can be numeric or text, 2-D arrays can only be numeric. + valid_dtypes = DTYPES if ndim == 1 else DTYPES_NUMERIC + if (dtype := array.dtype.type) not in valid_dtypes: + msg = f"Unsupported numpy data type '{dtype}'." + raise GMTInvalidInput(msg) + return self[DTYPES[dtype]] def put_vector(self, dataset: ctp.c_void_p, column: int, vector: np.ndarray): r"""