diff --git a/src/auto_map.c b/src/auto_map.c index 039251e3..64b65fb5 100644 --- a/src/auto_map.c +++ b/src/auto_map.c @@ -74,20 +74,6 @@ typedef enum KeysArrayType{ KAT_DTas, } KeysArrayType; -NPY_DATETIMEUNIT -dt_unit_from_array(PyArrayObject* a) { - // This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dytpe is of the appropriate type. - PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref - PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta); - return dma->base; -} - -NPY_DATETIMEUNIT -dt_unit_from_scalar(PyDatetimeScalarObject* dts) { - // Based on convert_pyobject_to_datetime and related usage in datetime.c - PyArray_DatetimeMetaData* dma = &(dts->obmeta); - return dma->base; -} KeysArrayType at_to_kat(int array_t, PyArrayObject* a) { @@ -123,7 +109,7 @@ at_to_kat(int array_t, PyArrayObject* a) { return KAT_STRING; case NPY_DATETIME: { - NPY_DATETIMEUNIT dtu = dt_unit_from_array(a); + NPY_DATETIMEUNIT dtu = AK_dt_unit_from_array(a); switch (dtu) { case NPY_FR_Y: return KAT_DTY; @@ -685,9 +671,6 @@ lookup_hash_obj(FAMObject *self, PyObject *key, Py_hash_t hash) int result = -1; Py_hash_t h = 0; - // AK_DEBUG_MSG_OBJ("lookup_hash_obj", key); - // TODO: if key is a dt64, we need to get the units and compare to units before doing PyObject_RichCompareBool - while (1) { for (Py_ssize_t i = 0; i < SCAN; i++) { h = table[table_pos].hash; @@ -702,6 +685,16 @@ lookup_hash_obj(FAMObject *self, PyObject *key, Py_hash_t hash) if (guess == key) { // Hit. Object ID comparison return table_pos; } + + // if key is a dt64, only do PyObject_RichCompareBool if units match + if (PyArray_IsScalar(key, Datetime) && PyArray_IsScalar(guess, Datetime)) { + if (AK_dt_unit_from_scalar((PyDatetimeScalarObject *)key) + != AK_dt_unit_from_scalar((PyDatetimeScalarObject *)guess)) { + table_pos++; + continue; + } + } + result = PyObject_RichCompareBool(guess, key, Py_EQ); if (result < 0) { // Error. return -1; @@ -1030,10 +1023,9 @@ lookup_datetime(FAMObject *self, PyObject* key) { if (PyArray_IsScalar(key, Datetime)) { v = (npy_int64)PyArrayScalar_VAL(key, Datetime); // if we observe a NAT, we skip unit checks - // AK_DEBUG_MSG_OBJ("dt64 value", PyLong_FromLongLong(v)); if (v != NPY_DATETIME_NAT) { - NPY_DATETIMEUNIT key_unit = dt_unit_from_scalar( + NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_scalar( (PyDatetimeScalarObject *)key); if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) { return -1; @@ -1872,7 +1864,7 @@ fam_get_all(FAMObject *self, PyObject *key) { GET_ALL_FLEXIBLE(char, char_get_end_p, lookup_hash_string, string_to_hash, PyBytes_FromStringAndSize); break; case NPY_DATETIME: { - NPY_DATETIMEUNIT key_unit = dt_unit_from_array(key_array); + NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_array(key_array); if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) { PyErr_SetString(PyExc_KeyError, "datetime64 units do not match"); Py_DECREF(array); @@ -2070,7 +2062,7 @@ fam_get_any(FAMObject *self, PyObject *key) { GET_ANY_FLEXIBLE(char, char_get_end_p, lookup_hash_string, string_to_hash); break; case NPY_DATETIME: { - NPY_DATETIMEUNIT key_unit = dt_unit_from_array(key_array); + NPY_DATETIMEUNIT key_unit = AK_dt_unit_from_array(key_array); if (!kat_is_datetime_unit(self->keys_array_type, key_unit)) { return values; } diff --git a/src/tri_map.c b/src/tri_map.c index 48c41c47..6eb0c7d1 100644 --- a/src/tri_map.c +++ b/src/tri_map.c @@ -11,15 +11,6 @@ # include "tri_map.h" # include "utilities.h" -static inline NPY_DATETIMEUNIT -AK_dt_unit_from_array(PyArrayObject* a) { - // This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dtype is of the appropriate type. - PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref - PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta); - // PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyArray_DESCR(a)->c_metadata)->meta); - return dma->base; -} - typedef struct TriMapOne { Py_ssize_t from; // signed Py_ssize_t to; diff --git a/src/utilities.h b/src/utilities.h index 9b85a198..660681c0 100644 --- a/src/utilities.h +++ b/src/utilities.h @@ -9,6 +9,7 @@ # define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION # include "numpy/arrayobject.h" +# include "numpy/arrayscalars.h" static const size_t UCS4_SIZE = sizeof(Py_UCS4); @@ -318,4 +319,20 @@ AK_nonzero_1d(PyArrayObject* array) { return final; } +static inline NPY_DATETIMEUNIT +AK_dt_unit_from_array(PyArrayObject* a) { + // This is based on get_datetime_metadata_from_dtype in the NumPy source, but that function is private. This does not check that the dtype is of the appropriate type. + PyArray_Descr* dt = PyArray_DESCR(a); // borrowed ref + PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyDataType_C_METADATA(dt))->meta); + // PyArray_DatetimeMetaData* dma = &(((PyArray_DatetimeDTypeMetaData *)PyArray_DESCR(a)->c_metadata)->meta); + return dma->base; +} + +static inline NPY_DATETIMEUNIT +AK_dt_unit_from_scalar(PyDatetimeScalarObject* dts) { + // Based on convert_pyobject_to_datetime and related usage in datetime.c + PyArray_DatetimeMetaData* dma = &(dts->obmeta); + return dma->base; +} + #endif /* ARRAYKIT_SRC_UTILITIES_H_ */