diff --git a/_gmp_utils.py b/_gmp_utils.py new file mode 100644 index 0000000..eb2ba94 --- /dev/null +++ b/_gmp_utils.py @@ -0,0 +1,71 @@ +import ctypes + + +def multiplicity(n, p): + """Return the power of the prime number p in the factorization of n!""" + import gmp + + mpz = gmp.mpz + if p > n: + return mpz(0) + if p > n//2: + return mpz(1) + q, m = n, mpz(0) + while q >= p: + q //= p + m += q + return m + + +def primes(n): + """Generate a list of the prime numbers [2, 3, ... m], m <= n.""" + import gmp + + mpz = gmp.mpz + isqrt = gmp.isqrt + n = n + mpz(1) + sieve = [mpz(_) for _ in range(n)] + sieve[:2] = [mpz(0), mpz(0)] + for i in range(2, isqrt(n) + 1): + if sieve[i]: + for j in range(i**2, n, i): + sieve[j] = mpz(0) + # Filter out the composites, which have been replaced by 0's + return [p for p in sieve if p] + + +def powproduct(ns): + import gmp + + mpz = gmp.mpz + if not ns: + return mpz(1) + units = mpz(1) + multi = [] + for base, exp in ns: + if exp == 0: + continue + elif exp == 1: + units *= base + else: + if exp % 2: + units *= base + multi.append((base, exp//2)) + return units * powproduct(multi)**2 + + +def factorial(n, /): + """ + Find n!. + + Raise a ValueError if n is negative or non-integral. + """ + import gmp + + mpz = gmp.mpz + n = mpz(n) + if n < 0: + raise ValueError("factorial() not defined for negative values") + if ctypes.c_long(n).value != n: + raise OverflowError("factorial() argument should not exceed LONG_MAX") + return powproduct((p, multiplicity(n, p)) for p in primes(n)) diff --git a/main.c b/main.c index ade3cfd..dd43df7 100644 --- a/main.c +++ b/main.c @@ -17,7 +17,7 @@ static struct { } gmp_tracker; static void * -gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size) +gmp_allocate_function(size_t size) { if (gmp_tracker.size >= gmp_tracker.alloc) { void **tmp = gmp_tracker.ptrs; @@ -39,31 +39,13 @@ gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size) goto err; } } - if (!ptr) { - void *ret = malloc(new_size); - - if (!ret) { - goto err; - } - gmp_tracker.ptrs[gmp_tracker.size] = ret; - gmp_tracker.size++; - return ret; - } - - size_t i = gmp_tracker.size - 1; - - for (;; i--) { - if (gmp_tracker.ptrs[i] == ptr) { - break; - } - } - - void *ret = realloc(ptr, new_size); + void *ret = malloc(size); if (!ret) { goto err; } - gmp_tracker.ptrs[i] = ret; + gmp_tracker.ptrs[gmp_tracker.size] = ret; + gmp_tracker.size++; return ret; err: for (size_t i = 0; i < gmp_tracker.size; i++) { @@ -79,12 +61,6 @@ gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size) longjmp(gmp_env, 1); } -static void * -gmp_allocate_function(size_t size) -{ - return gmp_reallocate_function(NULL, 0, size); -} - static void gmp_free_function(void *ptr, size_t size) { @@ -3204,70 +3180,6 @@ gmp_isqrt(PyObject *Py_UNUSED(module), PyObject *arg) return (PyObject *)res; } -static PyObject * -gmp_factorial(PyObject *Py_UNUSED(module), PyObject *arg) -{ - MPZ_Object *x, *res = NULL; - - if (MPZ_Check(arg)) { - x = (MPZ_Object *)arg; - Py_INCREF(x); - } - else if (PyLong_Check(arg)) { - x = MPZ_from_int(arg); - if (!x) { - /* LCOV_EXCL_START */ - goto end; - /* LCOV_EXCL_STOP */ - } - } - else { - PyErr_SetString(PyExc_TypeError, - "factorial() argument must be an integer"); - return NULL; - } - - __mpz_struct tmp; - - tmp._mp_d = x->digits; - tmp._mp_size = (x->negative ? -1 : 1) * x->size; - if (x->negative) { - PyErr_SetString(PyExc_ValueError, - "factorial() not defined for negative values"); - goto end; - } - if (!mpz_fits_ulong_p(&tmp)) { - PyErr_Format(PyExc_OverflowError, - "factorial() argument should not exceed %ld", LONG_MAX); - goto end; - } - - unsigned long n = mpz_get_ui(&tmp); - - if (CHECK_NO_MEM_LEAK) { - mpz_init(&tmp); - mpz_fac_ui(&tmp, n); - } - else { - /* LCOV_EXCL_START */ - Py_DECREF(x); - return PyErr_NoMemory(); - /* LCOV_EXCL_STOP */ - } - res = MPZ_new(tmp._mp_size, 0); - if (!res) { - /* LCOV_EXCL_START */ - mpz_clear(&tmp); - goto end; - /* LCOV_EXCL_STOP */ - } - mpn_copyi(res->digits, tmp._mp_d, res->size); - mpz_clear(&tmp); -end: - Py_XDECREF(x); - return (PyObject *)res; -} - static PyMethodDef functions[] = { {"gcd", (PyCFunction)gmp_gcd, METH_FASTCALL, ("gcd($module, /, *integers)\n--\n\n" @@ -3275,9 +3187,6 @@ static PyMethodDef functions[] = { {"isqrt", gmp_isqrt, METH_O, ("isqrt($module, n, /)\n--\n\n" "Return the integer part of the square root of the input.")}, - {"factorial", gmp_factorial, METH_O, - ("factorial($module, n, /)\n--\n\n" - "Find n!.\n\nRaise a ValueError if x is negative or non-integral.")}, {"_from_bytes", _from_bytes, METH_O, NULL}, {NULL} /* sentinel */ }; @@ -3310,8 +3219,7 @@ static PyStructSequence_Desc gmp_info_desc = { PyMODINIT_FUNC PyInit_gmp(void) { - mp_set_memory_functions(gmp_allocate_function, gmp_reallocate_function, - gmp_free_function); + mp_set_memory_functions(gmp_allocate_function, NULL, gmp_free_function); #if !defined(PYPY_VERSION) /* Query parameters of Python’s internal representation of integers. */ const PyLongLayout *layout = PyLong_GetNativeLayout(); @@ -3396,13 +3304,13 @@ PyInit_gmp(void) return NULL; /* LCOV_EXCL_STOP */ } + Py_DECREF(gmp_fractions); PyObject *mname = PyUnicode_FromString("gmp"); if (!mname) { /* LCOV_EXCL_START */ Py_DECREF(ns); - Py_DECREF(gmp_fractions); Py_DECREF(mpq); return NULL; /* LCOV_EXCL_STOP */ @@ -3410,24 +3318,60 @@ PyInit_gmp(void) if (PyObject_SetAttrString(mpq, "__module__", mname) < 0) { /* LCOV_EXCL_START */ Py_DECREF(ns); - Py_DECREF(gmp_fractions); Py_DECREF(mpq); Py_DECREF(mname); return NULL; /* LCOV_EXCL_STOP */ } - Py_DECREF(mname); if (PyModule_AddType(m, (PyTypeObject *)mpq) < 0) { /* LCOV_EXCL_START */ Py_DECREF(ns); - Py_DECREF(gmp_fractions); Py_DECREF(mpq); + Py_DECREF(mname); return NULL; /* LCOV_EXCL_STOP */ } - Py_DECREF(gmp_fractions); Py_DECREF(mpq); + PyObject *gmp_utils = PyImport_ImportModule("_gmp_utils"); + + if (!gmp_utils) { + /* LCOV_EXCL_START */ + Py_DECREF(ns); + Py_DECREF(mname); + return NULL; + /* LCOV_EXCL_STOP */ + } + + PyObject *factorial = PyObject_GetAttrString(gmp_utils, "factorial"); + + if (!factorial) { + /* LCOV_EXCL_START */ + Py_DECREF(ns); + Py_DECREF(gmp_utils); + Py_DECREF(mname); + return NULL; + /* LCOV_EXCL_STOP */ + } + Py_DECREF(gmp_utils); + if (PyObject_SetAttrString(factorial, "__module__", mname) < 0) { + /* LCOV_EXCL_START */ + Py_DECREF(ns); + Py_DECREF(factorial); + Py_DECREF(mname); + return NULL; + /* LCOV_EXCL_STOP */ + } + Py_DECREF(mname); + if (PyModule_AddObject(m, "factorial", factorial) < 0) { + /* LCOV_EXCL_START */ + Py_DECREF(ns); + Py_DECREF(factorial); + return NULL; + /* LCOV_EXCL_STOP */ + } + Py_DECREF(factorial); + PyObject *numbers = PyImport_ImportModule("numbers"); if (!numbers) { diff --git a/pyproject.toml b/pyproject.toml index 9becb93..281da07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ Documentation = "https://python-gmp.readthedocs.io/en/latest/" [tool.setuptools] ext-modules = [{name = "gmp", sources = ["main.c"], libraries = ["gmp"], include-dirs = [".local/include"], library-dirs = [".local/lib"]}] -py-modules = ["_gmp_fractions"] +py-modules = ["_gmp_fractions", "_gmp_utils"] [tool.setuptools.dynamic] version = {attr = "setuptools_scm.get_version"} diff --git a/tests/test_functions.py b/tests/test_functions.py index 8335bc2..b2f53ba 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1,4 +1,6 @@ import math +import platform +import resource import pytest from gmp import factorial, gcd, isqrt, mpz @@ -20,6 +22,27 @@ def test_factorial(x): assert factorial(mx) == factorial(x) == r +@pytest.mark.skipif(platform.system() != "Linux", + reason="FIXME: setrlimit fails with ValueError on MacOS") +@pytest.mark.skipif(platform.python_implementation() == "PyPy", + reason="XXX: bug in PyNumber_ToBase()?") +def test_factorial_outofmemory(): + import random + + for _ in range(100): + soft, hard = resource.getrlimit(resource.RLIMIT_AS) + resource.setrlimit(resource.RLIMIT_AS, (1024*64*1024, hard)) + a = random.randint(12811, 24984) + a = mpz(a) + while True: + try: + factorial(a) + a *= 2 + except MemoryError: + break + resource.setrlimit(resource.RLIMIT_AS, (soft, hard)) + + @given(integers(), integers()) def test_gcd(x, y): mx = mpz(x)