Skip to content

Commit

Permalink
Reimplement factorial() without GMP
Browse files Browse the repository at this point in the history
  • Loading branch information
skirpichev committed Dec 30, 2024
1 parent 2641625 commit 911cb7c
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 103 deletions.
71 changes: 71 additions & 0 deletions _gmp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import ctypes


def multiplicity(n, p):
"""Return the power of the prime number p in the factorization of n!"""
import gmp

mpz = gmp.mpz
if p > n:
return mpz(0)
if p > n//2:
return mpz(1)
q, m = n, mpz(0)
while q >= p:
q //= p
m += q
return m


def primes(n):
"""Generate a list of the prime numbers [2, 3, ... m], m <= n."""
import gmp

mpz = gmp.mpz
isqrt = gmp.isqrt
n = n + mpz(1)
sieve = [mpz(_) for _ in range(n)]
sieve[:2] = [mpz(0), mpz(0)]
for i in range(2, isqrt(n) + 1):
if sieve[i]:
for j in range(i**2, n, i):
sieve[j] = mpz(0)
# Filter out the composites, which have been replaced by 0's
return [p for p in sieve if p]


def powproduct(ns):
import gmp

mpz = gmp.mpz
if not ns:
return mpz(1)
units = mpz(1)
multi = []
for base, exp in ns:
if exp == 0:
continue
elif exp == 1:
units *= base
else:
if exp % 2:
units *= base
multi.append((base, exp//2))
return units * powproduct(multi)**2


def factorial(n, /):
"""
Find n!.
Raise a ValueError if n is negative or non-integral.
"""
import gmp

mpz = gmp.mpz
n = mpz(n)
if n < 0:
raise ValueError("factorial() not defined for negative values")
if ctypes.c_long(n).value != n:
raise OverflowError("factorial() argument should not exceed LONG_MAX")
return powproduct((p, multiplicity(n, p)) for p in primes(n))
148 changes: 46 additions & 102 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static struct {
} gmp_tracker;

static void *
gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size)
gmp_allocate_function(size_t size)
{
if (gmp_tracker.size >= gmp_tracker.alloc) {
void **tmp = gmp_tracker.ptrs;
Expand All @@ -39,31 +39,13 @@ gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size)
goto err;
}
}
if (!ptr) {
void *ret = malloc(new_size);

if (!ret) {
goto err;
}
gmp_tracker.ptrs[gmp_tracker.size] = ret;
gmp_tracker.size++;
return ret;
}

size_t i = gmp_tracker.size - 1;

for (;; i--) {
if (gmp_tracker.ptrs[i] == ptr) {
break;
}
}

void *ret = realloc(ptr, new_size);
void *ret = malloc(size);

if (!ret) {
goto err;
}
gmp_tracker.ptrs[i] = ret;
gmp_tracker.ptrs[gmp_tracker.size] = ret;
gmp_tracker.size++;
return ret;
err:
for (size_t i = 0; i < gmp_tracker.size; i++) {
Expand All @@ -79,12 +61,6 @@ gmp_reallocate_function(void *ptr, size_t old_size, size_t new_size)
longjmp(gmp_env, 1);
}

static void *
gmp_allocate_function(size_t size)
{
return gmp_reallocate_function(NULL, 0, size);
}

static void
gmp_free_function(void *ptr, size_t size)
{
Expand Down Expand Up @@ -3204,80 +3180,13 @@ gmp_isqrt(PyObject *Py_UNUSED(module), PyObject *arg)
return (PyObject *)res;
}

static PyObject *
gmp_factorial(PyObject *Py_UNUSED(module), PyObject *arg)
{
MPZ_Object *x, *res = NULL;

if (MPZ_Check(arg)) {
x = (MPZ_Object *)arg;
Py_INCREF(x);
}
else if (PyLong_Check(arg)) {
x = MPZ_from_int(arg);
if (!x) {
/* LCOV_EXCL_START */
goto end;
/* LCOV_EXCL_STOP */
}
}
else {
PyErr_SetString(PyExc_TypeError,
"factorial() argument must be an integer");
return NULL;
}

__mpz_struct tmp;

tmp._mp_d = x->digits;
tmp._mp_size = (x->negative ? -1 : 1) * x->size;
if (x->negative) {
PyErr_SetString(PyExc_ValueError,
"factorial() not defined for negative values");
goto end;
}
if (!mpz_fits_ulong_p(&tmp)) {
PyErr_Format(PyExc_OverflowError,
"factorial() argument should not exceed %ld", LONG_MAX);
goto end;
}

unsigned long n = mpz_get_ui(&tmp);

if (CHECK_NO_MEM_LEAK) {
mpz_init(&tmp);
mpz_fac_ui(&tmp, n);
}
else {
/* LCOV_EXCL_START */
Py_DECREF(x);
return PyErr_NoMemory();
/* LCOV_EXCL_STOP */
}
res = MPZ_new(tmp._mp_size, 0);
if (!res) {
/* LCOV_EXCL_START */
mpz_clear(&tmp);
goto end;
/* LCOV_EXCL_STOP */
}
mpn_copyi(res->digits, tmp._mp_d, res->size);
mpz_clear(&tmp);
end:
Py_XDECREF(x);
return (PyObject *)res;
}

static PyMethodDef functions[] = {
{"gcd", (PyCFunction)gmp_gcd, METH_FASTCALL,
("gcd($module, /, *integers)\n--\n\n"
"Greatest Common Divisor.")},
{"isqrt", gmp_isqrt, METH_O,
("isqrt($module, n, /)\n--\n\n"
"Return the integer part of the square root of the input.")},
{"factorial", gmp_factorial, METH_O,
("factorial($module, n, /)\n--\n\n"
"Find n!.\n\nRaise a ValueError if x is negative or non-integral.")},
{"_from_bytes", _from_bytes, METH_O, NULL},
{NULL} /* sentinel */
};
Expand Down Expand Up @@ -3310,8 +3219,7 @@ static PyStructSequence_Desc gmp_info_desc = {
PyMODINIT_FUNC
PyInit_gmp(void)
{
mp_set_memory_functions(gmp_allocate_function, gmp_reallocate_function,
gmp_free_function);
mp_set_memory_functions(gmp_allocate_function, NULL, gmp_free_function);
#if !defined(PYPY_VERSION)
/* Query parameters of Python’s internal representation of integers. */
const PyLongLayout *layout = PyLong_GetNativeLayout();
Expand Down Expand Up @@ -3396,38 +3304,74 @@ PyInit_gmp(void)
return NULL;
/* LCOV_EXCL_STOP */
}
Py_DECREF(gmp_fractions);

PyObject *mname = PyUnicode_FromString("gmp");

if (!mname) {
/* LCOV_EXCL_START */
Py_DECREF(ns);
Py_DECREF(gmp_fractions);
Py_DECREF(mpq);
return NULL;
/* LCOV_EXCL_STOP */
}
if (PyObject_SetAttrString(mpq, "__module__", mname) < 0) {
/* LCOV_EXCL_START */
Py_DECREF(ns);
Py_DECREF(gmp_fractions);
Py_DECREF(mpq);
Py_DECREF(mname);
return NULL;
/* LCOV_EXCL_STOP */
}
Py_DECREF(mname);
if (PyModule_AddType(m, (PyTypeObject *)mpq) < 0) {
/* LCOV_EXCL_START */
Py_DECREF(ns);
Py_DECREF(gmp_fractions);
Py_DECREF(mpq);
Py_DECREF(mname);
return NULL;
/* LCOV_EXCL_STOP */
}
Py_DECREF(gmp_fractions);
Py_DECREF(mpq);

PyObject *gmp_utils = PyImport_ImportModule("_gmp_utils");

if (!gmp_utils) {
/* LCOV_EXCL_START */
Py_DECREF(ns);
Py_DECREF(mname);
return NULL;
/* LCOV_EXCL_STOP */
}

PyObject *factorial = PyObject_GetAttrString(gmp_utils, "factorial");

if (!factorial) {
/* LCOV_EXCL_START */
Py_DECREF(ns);
Py_DECREF(gmp_utils);
Py_DECREF(mname);
return NULL;
/* LCOV_EXCL_STOP */
}
Py_DECREF(gmp_utils);
if (PyObject_SetAttrString(factorial, "__module__", mname) < 0) {
/* LCOV_EXCL_START */
Py_DECREF(ns);
Py_DECREF(factorial);
Py_DECREF(mname);
return NULL;
/* LCOV_EXCL_STOP */
}
Py_DECREF(mname);
if (PyModule_AddObject(m, "factorial", factorial) < 0) {
/* LCOV_EXCL_START */
Py_DECREF(ns);
Py_DECREF(factorial);
return NULL;
/* LCOV_EXCL_STOP */
}
Py_DECREF(factorial);

PyObject *numbers = PyImport_ImportModule("numbers");

if (!numbers) {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Documentation = "https://python-gmp.readthedocs.io/en/latest/"

[tool.setuptools]
ext-modules = [{name = "gmp", sources = ["main.c"], libraries = ["gmp"], include-dirs = [".local/include"], library-dirs = [".local/lib"]}]
py-modules = ["_gmp_fractions"]
py-modules = ["_gmp_fractions", "_gmp_utils"]

[tool.setuptools.dynamic]
version = {attr = "setuptools_scm.get_version"}
Expand Down
23 changes: 23 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import math
import platform
import resource

import pytest
from gmp import factorial, gcd, isqrt, mpz
Expand All @@ -20,6 +22,27 @@ def test_factorial(x):
assert factorial(mx) == factorial(x) == r


@pytest.mark.skipif(platform.system() != "Linux",
reason="FIXME: setrlimit fails with ValueError on MacOS")
@pytest.mark.skipif(platform.python_implementation() == "PyPy",
reason="XXX: bug in PyNumber_ToBase()?")
def test_factorial_outofmemory():
import random

for _ in range(100):
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (1024*64*1024, hard))
a = random.randint(12811, 24984)
a = mpz(a)
while True:
try:
factorial(a)
a *= 2
except MemoryError:
break
resource.setrlimit(resource.RLIMIT_AS, (soft, hard))


@given(integers(), integers())
def test_gcd(x, y):
mx = mpz(x)
Expand Down

0 comments on commit 911cb7c

Please sign in to comment.