Skip to content

Commit

Permalink
Python formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ejaasaari committed Feb 28, 2025
1 parent 0857e1b commit 1108973
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 180 deletions.
4 changes: 2 additions & 2 deletions .clangd
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
CompileFlags:
# Treat code as C++, use C++17 standard, enable more warnings.
Add: [-xc++, -std=c++17, -Wall, -Wno-missing-prototypes]
# Treat code as C++, use C++14 standard, enable more warnings.
Add: [-xc++, -std=c++14, -Wall, -Wno-missing-prototypes]
43 changes: 20 additions & 23 deletions cpp/mrptmodule.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
/*
* This file wraps the C++11 Mrpt code to an extension module compatible with
* Python 3.
*/
#define PY_SSIZE_T_CLEAN

#include <sys/stat.h>
#include <sys/types.h>
Expand Down Expand Up @@ -514,25 +511,25 @@ static PyMethodDef MrptMethods[] = {
};

static PyTypeObject MrptIndexType = {
PyVarObject_HEAD_INIT(NULL, 0) "mrpt.MrptIndex", /*tp_name*/
sizeof(mrptIndex), /*tp_basicsize*/
0, /*tp_itemsize*/
(destructor)mrpt_dealloc, /*tp_dealloc*/
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
0, /*tp_compare*/
0, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
0, /*tp_as_mapping*/
0, /*tp_hash */
0, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT, /*tp_flags*/
PyVarObject_HEAD_INIT(NULL, 0) "mrpt.MrptIndex", /* tp_name*/
sizeof(mrptIndex), /* tp_basicsize*/
0, /* tp_itemsize*/
(destructor)mrpt_dealloc, /* tp_dealloc*/
0, /* tp_print*/
0, /* tp_getattr*/
0, /* tp_setattr*/
0, /* tp_compare*/
0, /* tp_repr*/
0, /* tp_as_number*/
0, /* tp_as_sequence*/
0, /* tp_as_mapping*/
0, /* tp_hash */
0, /* tp_call*/
0, /* tp_str*/
0, /* tp_getattro*/
0, /* tp_setattro*/
0, /* tp_as_buffer*/
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
"Mrpt index object", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
Expand Down
96 changes: 73 additions & 23 deletions mrpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class MRPTIndex(object):
"""
An MRPT index object
"""

def __init__(self, data, shape=None, mmap=False):
"""
Initializes an MRPT index object.
Expand All @@ -21,18 +22,20 @@ def __init__(self, data, shape=None, mmap=False):
raise ValueError("The data matrix should be non-empty and two-dimensional")
if data.dtype != np.float32:
raise ValueError("The data matrix should have type float32")
if not data.flags['C_CONTIGUOUS'] or not data.flags['ALIGNED']:
if not data.flags["C_CONTIGUOUS"] or not data.flags["ALIGNED"]:
raise ValueError("The data matrix has to be C_CONTIGUOUS and ALIGNED")
n_samples, dim = data.shape
elif isinstance(data, str):
if not isinstance(shape, tuple) or len(shape) != 2:
raise ValueError("You must specify the shape of the data as a tuple (N, dim) "
"when loading data from a binary file")
raise ValueError(
"You must specify the shape of the data as a tuple (N, dim) "
"when loading data from a binary file"
)
n_samples, dim = shape
elif data is not None:
raise ValueError("Data must be either an ndarray or a filepath")

if mmap and os_name == 'nt':
if mmap and os_name == "nt":
raise ValueError("Memory mapping is not available on Windows")

if data is not None:
Expand All @@ -43,14 +46,15 @@ def __init__(self, data, shape=None, mmap=False):
self.autotuned = False

def _compute_sparsity(self, projection_sparsity):
if projection_sparsity == 'auto':
return 1. / np.sqrt(self.dim)
elif projection_sparsity is None:
if projection_sparsity == "auto":
return 1.0 / np.sqrt(self.dim)
if projection_sparsity is None:
return 1
elif not 0 < projection_sparsity <= 1:
if not (0 < projection_sparsity <= 1):
raise ValueError("Sparsity should be in (0, 1]")
return projection_sparsity

def build(self, depth, n_trees, projection_sparsity='auto'):
def build(self, depth, n_trees, projection_sparsity="auto"):
"""
Builds a normal MRPT index.
:param depth: The depth of the trees; should be in the set {1, 2, ..., floor(log2(n))}.
Expand All @@ -65,8 +69,18 @@ def build(self, depth, n_trees, projection_sparsity='auto'):
self.index.build(n_trees, depth, projection_sparsity)
self.built = True

def build_autotune(self, target_recall, Q, k, trees_max=-1, depth_min=-1, depth_max=-1,
votes_max=-1, projection_sparsity='auto', shape=None):
def build_autotune(
self,
target_recall,
Q,
k,
trees_max=-1,
depth_min=-1,
depth_max=-1,
votes_max=-1,
projection_sparsity="auto",
shape=None,
):
"""
Builds an autotuned MRPT index.
:param target_recall: The target recall level (float) or None if the target recall level
Expand Down Expand Up @@ -94,19 +108,23 @@ def build_autotune(self, target_recall, Q, k, trees_max=-1, depth_min=-1, depth_
raise ValueError("The test query matrix should be non-empty and two-dimensional")
if Q.dtype != np.float32:
raise ValueError("The test query matrix should have type float32")
if not Q.flags['C_CONTIGUOUS'] or not Q.flags['ALIGNED']:
if not Q.flags["C_CONTIGUOUS"] or not Q.flags["ALIGNED"]:
raise ValueError("The test query matrix has to be C_CONTIGUOUS and ALIGNED")
n_test, dim = Q.shape
elif isinstance(Q, str):
if not isinstance(shape, tuple) or len(shape) != 2:
raise ValueError("You must specify the shape of the data as a tuple (n_test, dim) "
"when loading the test query matrix from a binary file")
raise ValueError(
"You must specify the shape of the data as a tuple (n_test, dim) "
"when loading the test query matrix from a binary file"
)
n_test, dim = shape
else:
raise ValueError("The test query matrix must be either an ndarray or a filepath")

if dim != self.dim:
raise ValueError("The test query matrix should have the same number of columns as the data matrix")
raise ValueError(
"The test query matrix should have the same number of columns as the data matrix"
)

self.built = target_recall is not None
self.autotuned = True
Expand All @@ -116,10 +134,28 @@ def build_autotune(self, target_recall, Q, k, trees_max=-1, depth_min=-1, depth_

projection_sparsity = self._compute_sparsity(projection_sparsity)
self.index.build_autotune(
target_recall, Q, n_test, k, trees_max, depth_min, depth_max, votes_max, projection_sparsity)

def build_autotune_sample(self, target_recall, k, n_test=100, trees_max=-1,
depth_min=-1, depth_max=-1, votes_max=-1, projection_sparsity='auto'):
target_recall,
Q,
n_test,
k,
trees_max,
depth_min,
depth_max,
votes_max,
projection_sparsity,
)

def build_autotune_sample(
self,
target_recall,
k,
n_test=100,
trees_max=-1,
depth_min=-1,
depth_max=-1,
votes_max=-1,
projection_sparsity="auto",
):
"""
Builds an autotuned MRPT index.
:param target_recall: The target recall level (float) or None if the target recall level
Expand Down Expand Up @@ -148,7 +184,15 @@ def build_autotune_sample(self, target_recall, k, n_test=100, trees_max=-1,

projection_sparsity = self._compute_sparsity(projection_sparsity)
self.index.build_autotune_sample(
target_recall, n_test, k, trees_max, depth_min, depth_max, votes_max, projection_sparsity)
target_recall,
n_test,
k,
trees_max,
depth_min,
depth_max,
votes_max,
projection_sparsity,
)

def subset(self, target_recall):
"""
Expand All @@ -175,9 +219,15 @@ def parameters(self):
n_trees, depth, votes, k, qtime, recall = self.index.parameters()

if self.index.is_autotuned():
return {'n_trees': n_trees, 'depth': depth, 'k': k, 'votes': votes,
'estimated_qtime': qtime, 'estimated_recall': recall}
return {'n_trees': n_trees, 'depth': depth}
return {
"n_trees": n_trees,
"depth": depth,
"k": k,
"votes": votes,
"estimated_qtime": qtime,
"estimated_recall": recall,
}
return {"n_trees": n_trees, "depth": depth}

def save(self, path):
"""
Expand Down
29 changes: 29 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[build-system]
requires = ["setuptools>=42", "wheel", "numpy"]
build-backend = "setuptools.build_meta"

[tool.isort]
profile = "black"
src_paths = ["python"]

[tool.black]
line-length = 100
target-version = ['py312']
include = '(\.pyi?$)'
exclude = '''
(
/(
\.github
| \.vscode
| \.venv
| docs\/
| licenses\/
| src\/
)/
)
'''

[tool.ruff]
line-length = 100
indent-width = 4
1 change: 0 additions & 1 deletion requirements.txt

This file was deleted.

131 changes: 0 additions & 131 deletions utils/binary_converter.py

This file was deleted.

0 comments on commit 1108973

Please sign in to comment.