From daf3c1a55d06c5de612b3a05e33dfada4b74319d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 13 Sep 2024 20:23:14 +0000 Subject: [PATCH] Switch ml_dtypes to use the Python Stable ABI. This allows us to build a single wheel using Python 3.9 and deploy it on all platforms, with the exception of 3.13t. The main downside of this change is it made the hash function for scalars slightly more expensive. --- ml_dtypes/_src/custom_float.h | 51 +++++++++++++++--------------- ml_dtypes/_src/intn_numpy.h | 42 +++++++++++++++++-------- setup.py | 59 ++++++++++++++++++++--------------- 3 files changed, 88 insertions(+), 64 deletions(-) diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 2786fde0..4a4a57c6 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -101,7 +101,7 @@ template Safe_PyObjectPtr PyCustomFloat_FromT(T x) { PyTypeObject* type = reinterpret_cast(TypeDescriptor::type_ptr); - Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0)); + Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type)); PyCustomFloat* p = reinterpret_cast*>(ref.get()); if (p) { p->value = x; @@ -213,7 +213,9 @@ PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x + y).release(); } - return PyArray_Type.tp_as_number->nb_add(a, b); + auto array_nb_add = + reinterpret_cast(PyType_GetSlot(&PyArray_Type, Py_nb_add)); + return array_nb_add(a, b); } template @@ -222,7 +224,9 @@ PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x - y).release(); } - return PyArray_Type.tp_as_number->nb_subtract(a, b); + auto array_nb_subtract = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_subtract)); + return array_nb_subtract(a, b); } template @@ -231,7 +235,9 @@ PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x * y).release(); } - return PyArray_Type.tp_as_number->nb_multiply(a, b); + auto array_nb_multiply = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_multiply)); + return array_nb_multiply(a, b); } template @@ -240,7 +246,9 @@ PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) { if (SafeCastToCustomFloat(a, &x) && SafeCastToCustomFloat(b, &y)) { return PyCustomFloat_FromT(x / y).release(); } - return PyArray_Type.tp_as_number->nb_true_divide(a, b); + auto array_nb_true_divide = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_true_divide)); + return array_nb_true_divide(a, b); } // Constructs a new PyCustomFloat. @@ -281,8 +289,7 @@ PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args, return PyCustomFloat_FromT(value).release(); } } - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(arg)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg)); return nullptr; } @@ -291,7 +298,9 @@ template PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) { T x, y; if (!SafeCastToCustomFloat(a, &x) || !SafeCastToCustomFloat(b, &y)) { - return PyGenericArrType_Type.tp_richcompare(a, b, op); + auto generic_tp_richcompare = reinterpret_cast( + PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare)); + return generic_tp_richcompare(a, b, op); } bool result; switch (op) { @@ -340,25 +349,18 @@ PyObject* PyCustomFloat_Str(PyObject* self) { return PyUnicode_FromString(s.str().c_str()); } -// _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to -// handle the two possibilities. -// NOLINTNEXTLINE(clang-diagnostic-unused-function) -inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double), - PyObject* self, double value) { - return hash_double(self, value); -} - -// NOLINTNEXTLINE(clang-diagnostic-unused-function) -inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self, - double value) { - return hash_double(value); -} - // Hash function for PyCustomFloat. template Py_hash_t PyCustomFloat_Hash(PyObject* self) { T x = reinterpret_cast*>(self)->value; - return HashImpl(&_Py_HashDouble, self, static_cast(x)); + if (std::isnan(x)) { + // NaNs hash as the pointer hash of the object. + auto f = reinterpret_cast( + PyType_GetSlot(&PyBaseObject_Type, Py_tp_hash)); + return f(self); + } + Safe_PyObjectPtr f(PyFloat_FromDouble(static_cast(x))); + return PyObject_Hash(f.get()); } template @@ -428,8 +430,7 @@ template int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) { T x; if (!CastToCustomFloat(item, &x)) { - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(item)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item)); return -1; } memcpy(data, &x, sizeof(T)); diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index ccb4ed63..55a84d3a 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -100,7 +100,7 @@ template Safe_PyObjectPtr PyIntN_FromValue(T x) { PyTypeObject* type = reinterpret_cast(TypeDescriptor::type_ptr); - Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0)); + Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type)); PyIntN* p = reinterpret_cast*>(ref.get()); if (p) { p->value = x; @@ -214,16 +214,21 @@ PyObject* PyIntN_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { } } else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) { // Parse float from string, then cast to T. - PyObject* f = PyLong_FromUnicodeObject(arg, /*base=*/0); - if (PyErr_Occurred()) { + Safe_PyObjectPtr bytes(PyUnicode_AsUTF8String(arg)); + if (!bytes) { + return nullptr; + } + PyObject* f = + PyLong_FromString(PyBytes_AsString(bytes.get()), /*end=*/nullptr, + /*base=*/0); + if (!f) { return nullptr; } if (CastToIntN(f, &value)) { return PyIntN_FromValue(value).release(); } } - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(arg)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg)); return nullptr; } @@ -257,7 +262,9 @@ PyObject* PyIntN_nb_add(PyObject* a, PyObject* b) { if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x + y).release(); } - return PyArray_Type.tp_as_number->nb_add(a, b); + auto array_nb_add = + reinterpret_cast(PyType_GetSlot(&PyArray_Type, Py_nb_add)); + return array_nb_add(a, b); } template @@ -266,7 +273,9 @@ PyObject* PyIntN_nb_subtract(PyObject* a, PyObject* b) { if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x - y).release(); } - return PyArray_Type.tp_as_number->nb_subtract(a, b); + auto array_nb_subtract = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_subtract)); + return array_nb_subtract(a, b); } template @@ -275,7 +284,9 @@ PyObject* PyIntN_nb_multiply(PyObject* a, PyObject* b) { if (PyIntN_Value(a, &x) && PyIntN_Value(b, &y)) { return PyIntN_FromValue(x * y).release(); } - return PyArray_Type.tp_as_number->nb_multiply(a, b); + auto array_nb_multiply = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_multiply)); + return array_nb_multiply(a, b); } template @@ -292,7 +303,9 @@ PyObject* PyIntN_nb_remainder(PyObject* a, PyObject* b) { } return PyIntN_FromValue(v).release(); } - return PyArray_Type.tp_as_number->nb_remainder(a, b); + auto array_nb_remainder = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_remainder)); + return array_nb_remainder(a, b); } template @@ -309,7 +322,9 @@ PyObject* PyIntN_nb_floor_divide(PyObject* a, PyObject* b) { } return PyIntN_FromValue(v).release(); } - return PyArray_Type.tp_as_number->nb_floor_divide(a, b); + auto array_nb_floor_divide = reinterpret_cast( + PyType_GetSlot(&PyArray_Type, Py_nb_floor_divide)); + return array_nb_floor_divide(a, b); } // Implementation of repr() for PyIntN. @@ -342,7 +357,9 @@ template PyObject* PyIntN_RichCompare(PyObject* a, PyObject* b, int op) { T x, y; if (!PyIntN_Value(a, &x) || !PyIntN_Value(b, &y)) { - return PyGenericArrType_Type.tp_richcompare(a, b, op); + auto generic_tp_richcompare = reinterpret_cast( + PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare)); + return generic_tp_richcompare(a, b, op); } bool result; switch (op) { @@ -440,8 +457,7 @@ template int NPyIntN_SetItem(PyObject* item, void* data, void* arr) { T x; if (!CastToIntN(item, &x)) { - PyErr_Format(PyExc_TypeError, "expected number, got %s", - Py_TYPE(item)->tp_name); + PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item)); return -1; } memcpy(data, &x, sizeof(T)); diff --git a/setup.py b/setup.py index 8bbd112f..824399fe 100644 --- a/setup.py +++ b/setup.py @@ -16,45 +16,50 @@ import fnmatch import platform +import sysconfig + import numpy as np from setuptools import Extension from setuptools import setup from setuptools.command.build_py import build_py as build_py_orig +free_threading = sysconfig.get_config_var("Py_GIL_DISABLED") + if platform.system() == "Windows": - COMPILE_ARGS = [ - "/std:c++17", - "/DEIGEN_MPL2_ONLY", - "/EHsc", - "/bigobj", - ] + COMPILE_ARGS = [ + "/std:c++17", + "/DEIGEN_MPL2_ONLY", + "/EHsc", + "/bigobj", + ] else: - COMPILE_ARGS = [ - "-std=c++17", - "-DEIGEN_MPL2_ONLY", - "-fvisibility=hidden", - # -ftrapping-math is necessary because NumPy looks at floating point - # exception state to determine whether to emit, e.g., invalid value - # warnings. Without this setting, on Mac ARM we see spurious "invalid - # value" warnings when running the tests. - "-ftrapping-math", - ] + COMPILE_ARGS = [ + "-std=c++17", + "-DEIGEN_MPL2_ONLY", + "-fvisibility=hidden", + # -ftrapping-math is necessary because NumPy looks at floating point + # exception state to determine whether to emit, e.g., invalid value + # warnings. Without this setting, on Mac ARM we see spurious "invalid + # value" warnings when running the tests. + "-ftrapping-math", + ] + if not free_threading: + COMPILE_ARGS.append("-DPy_LIMITED_API=0x03090000") exclude = ["third_party*"] class build_py(build_py_orig): # pylint: disable=invalid-name - def find_package_modules(self, package, package_dir): - modules = super().find_package_modules(package, package_dir) - return [ # pylint: disable=g-complex-comprehension - (pkg, mod, file) - for (pkg, mod, file) in modules - if not any( - fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern) - for pattern in exclude - ) - ] + def find_package_modules(self, package, package_dir): + modules = super().find_package_modules(package, package_dir) + return [ # pylint: disable=g-complex-comprehension + (pkg, mod, file) + for (pkg, mod, file) in modules + if not any( + fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern) for pattern in exclude + ) + ] setup( @@ -71,7 +76,9 @@ def find_package_modules(self, package, package_dir): np.get_include(), ], extra_compile_args=COMPILE_ARGS, + py_limited_api=not free_threading, ) ], cmdclass={"build_py": build_py}, + options={} if free_threading else {"bdist_wheel": {"py_limited_api": "cp39"}}, )