diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 2786fde0..4a4a57c6 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -101,7 +101,7 @@ template Safe_PyObjectPtr PyCustomFloat_FromT(T x) { PyTypeObject* type = reinterpret_cast(TypeDescriptor::type_ptr); - Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0)); + Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type)); PyCustomFloat* p = reinterpret_cast*>(ref.get()); if (p) { p->value = x; @@ -213,7 +213,9 @@ PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x + y).release(); } - return PyArray_Type.tp_as_number->nb_add(a, b); + auto array_nb_add = + reinterpret_cast(PyType_GetSlot(&PyArray_Type, Py_nb_add)); + return array_nb_add(a, b); } template @@ -222,7 +224,9 @@ PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x - y).release(); } - return PyArray_Type.tp_as_number->nb_subtract(a, b); + auto array_nb_subtract = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_subtract)); + return array_nb_subtract(a, b); } template @@ -231,7 +235,9 @@ PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x * y).release(); } - return PyArray_Type.tp_as_number->nb_multiply(a, b); + auto array_nb_multiply = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_multiply)); + return array_nb_multiply(a, b); } template @@ -240,7 +246,9 @@ PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x / y).release(); } - return PyArray_Type.tp_as_number->nb_true_divide(a, b); + auto array_nb_true_divide = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_true_divide)); + return array_nb_true_divide(a, b); } // Constructs a new PyCustomFloat. @@ -281,8 +289,7 @@ PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args, return PyCustomFloat_FromT(value).release(); } } - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(arg)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg)); return nullptr; } @@ -291,7 +298,9 @@ template PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) { T x, y; if (!SafeCastToCustomFloat(a, &x) || !SafeCastToCustomFloat(b, &y)) { - return PyGenericArrType_Type.tp_richcompare(a, b, op); + auto generic_tp_richcompare = reinterpret_cast( + PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare)); + return generic_tp_richcompare(a, b, op); } bool result; switch (op) { @@ -340,25 +349,18 @@ PyObject* PyCustomFloat_Str(PyObject* self) { return PyUnicode_FromString(s.str().c_str()); } -// _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to -// handle the two possibilities. -// NOLINTNEXTLINE(clang-diagnostic-unused-function) -inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double), - PyObject* self, double value) { - return hash_double(self, value); -} - -// NOLINTNEXTLINE(clang-diagnostic-unused-function) -inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self, - double value) { - return hash_double(value); -} - // Hash function for PyCustomFloat. template Py_hash_t PyCustomFloat_Hash(PyObject* self) { T x = reinterpret_cast*>(self)->value; - return HashImpl(&_Py_HashDouble, self, static_cast(x)); + if (std::isnan(x)) { + // NaNs hash as the pointer hash of the object. + auto f = reinterpret_cast( + PyType_GetSlot(&PyBaseObject_Type, Py_tp_hash)); + return f(self); + } + Safe_PyObjectPtr f(PyFloat_FromDouble(static_cast(x))); + return PyObject_Hash(f.get()); } template @@ -428,8 +430,7 @@ template int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) { T x; if (!CastToCustomFloat(item, &x)) { - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(item)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item)); return -1; } memcpy(data, &x, sizeof(T)); diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index ccb4ed63..55a84d3a 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -100,7 +100,7 @@ template Safe_PyObjectPtr PyIntN_FromValue(T x) { PyTypeObject* type = reinterpret_cast(TypeDescriptor::type_ptr); - Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0)); + Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type)); PyIntN* p = reinterpret_cast*>(ref.get()); if (p) { p->value = x; @@ -214,16 +214,21 @@ PyObject* PyIntN_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { } } else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) { // Parse float from string, then cast to T. - PyObject* f = PyLong_FromUnicodeObject(arg, /*base=*/0); - if (PyErr_Occurred()) { + Safe_PyObjectPtr bytes(PyUnicode_AsUTF8String(arg)); + if (!bytes) { + return nullptr; + } + PyObject* f = + PyLong_FromString(PyBytes_AsString(bytes.get()), /*end=*/nullptr, + /*base=*/0); + if (!f) { return nullptr; } if (CastToIntN(f, &value)) { return PyIntN_FromValue(value).release(); } } - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(arg)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg)); return nullptr; } @@ -257,7 +262,9 @@ PyObject* PyIntN_nb_add(PyObject* a, PyObject* b) { if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x + y).release(); } - return PyArray_Type.tp_as_number->nb_add(a, b); + auto array_nb_add = + reinterpret_cast(PyType_GetSlot(&PyArray_Type, Py_nb_add)); + return array_nb_add(a, b); } template @@ -266,7 +273,9 @@ PyObject* PyIntN_nb_subtract(PyObject* a, PyObject* b) { if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x - y).release(); } - return PyArray_Type.tp_as_number->nb_subtract(a, b); + auto array_nb_subtract = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_subtract)); + return array_nb_subtract(a, b); } template @@ -275,7 +284,9 @@ PyObject* PyIntN_nb_multiply(PyObject* a, PyObject* b) { if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x * y).release(); } - return PyArray_Type.tp_as_number->nb_multiply(a, b); + auto array_nb_multiply = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_multiply)); + return array_nb_multiply(a, b); } template @@ -292,7 +303,9 @@ PyObject* PyIntN_nb_remainder(PyObject* a, PyObject* b) { } return PyIntN_FromValue(v).release(); } - return PyArray_Type.tp_as_number->nb_remainder(a, b); + auto array_nb_remainder = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_remainder)); + return array_nb_remainder(a, b); } template @@ -309,7 +322,9 @@ PyObject* PyIntN_nb_floor_divide(PyObject* a, PyObject* b) { } return PyIntN_FromValue(v).release(); } - return PyArray_Type.tp_as_number->nb_floor_divide(a, b); + auto array_nb_floor_divide = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_floor_divide)); + return array_nb_floor_divide(a, b); } // Implementation of repr() for PyIntN. @@ -342,7 +357,9 @@ template PyObject* PyIntN_RichCompare(PyObject* a, PyObject* b, int op) { T x, y; if (!PyIntN_Value(a, &x) || !PyIntN_Value(b, &y)) { - return PyGenericArrType_Type.tp_richcompare(a, b, op); + auto generic_tp_richcompare = reinterpret_cast( + PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare)); + return generic_tp_richcompare(a, b, op); } bool result; switch (op) { @@ -440,8 +457,7 @@ template int NPyIntN_SetItem(PyObject* item, void* data, void* arr) { T x; if (!CastToIntN(item, &x)) { - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(item)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item)); return -1; } memcpy(data, &x, sizeof(T)); diff --git a/setup.py b/setup.py index 8bbd112f..824399fe 100644 --- a/setup.py +++ b/setup.py @@ -16,45 +16,50 @@ import fnmatch import platform +import sysconfig + import numpy as np from setuptools import Extension from setuptools import setup from setuptools.command.build_py import build_py as build_py_orig +free_threading = sysconfig.get_config_var("Py_GIL_DISABLED") + if platform.system() == "Windows": - COMPILE_ARGS = [ - "/std:c++17", - "/DEIGEN_MPL2_ONLY", - "/EHsc", - "/bigobj", - ] + COMPILE_ARGS = [ + "/std:c++17", + "/DEIGEN_MPL2_ONLY", + "/EHsc", + "/bigobj", + ] else: - COMPILE_ARGS = [ - "-std=c++17", - "-DEIGEN_MPL2_ONLY", - "-fvisibility=hidden", - # -ftrapping-math is necessary because NumPy looks at floating point - # exception state to determine whether to emit, e.g., invalid value - # warnings. Without this setting, on Mac ARM we see spurious "invalid - # value" warnings when running the tests. - "-ftrapping-math", - ] + COMPILE_ARGS = [ + "-std=c++17", + "-DEIGEN_MPL2_ONLY", + "-fvisibility=hidden", + # -ftrapping-math is necessary because NumPy looks at floating point + # exception state to determine whether to emit, e.g., invalid value + # warnings. Without this setting, on Mac ARM we see spurious "invalid + # value" warnings when running the tests. + "-ftrapping-math", + ] + if not free_threading: + COMPILE_ARGS.append("-DPy_LIMITED_API=0x03090000") exclude = ["third_party*"] class build_py(build_py_orig): # pylint: disable=invalid-name - def find_package_modules(self, package, package_dir): - modules = super().find_package_modules(package, package_dir) - return [ # pylint: disable=g-complex-comprehension - (pkg, mod, file) - for (pkg, mod, file) in modules - if not any( - fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern) - for pattern in exclude - ) - ] + def find_package_modules(self, package, package_dir): + modules = super().find_package_modules(package, package_dir) + return [ # pylint: disable=g-complex-comprehension + (pkg, mod, file) + for (pkg, mod, file) in modules + if not any( + fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern) for pattern in exclude + ) + ] setup( @@ -71,7 +76,9 @@ def find_package_modules(self, package, package_dir): np.get_include(), ], extra_compile_args=COMPILE_ARGS, + py_limited_api=not free_threading, ) ], cmdclass={"build_py": build_py}, + options={} if free_threading else {"bdist_wheel": {"py_limited_api": "cp39"}}, )