From 96802553d9611d44d12fe299cce46a2423733078 Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Mon, 9 Oct 2023 15:34:26 -0700 Subject: [PATCH] Add iterator over Reader object Closes #23. --- .gitignore | 1 + HISTORY.rst | 6 +- README.rst | 7 ++ extension/maxminddb.c | 243 +++++++++++++++++++++++++++++++++++++++++- maxminddb/reader.py | 61 ++++++++--- tests/data | 2 +- tests/reader_test.py | 48 +++++++++ 7 files changed, 350 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 40c7883..bfa1c1b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ .pyre .tox build +compile_flags.txt core dist docs/_build diff --git a/HISTORY.rst b/HISTORY.rst index 8a32588..878491e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,9 +3,13 @@ History ------- -2.4.1 +2.5.0 ++++++++++++++++++ +* The ``Reader`` class now implements the ``__iter__`` method. This will + return an iterator that iterates over all records in the database, + excluding repeated aliased of the IPv4 network. Requested by + Jean-Baptiste Braun and others. GitHub #23. * The multiprocessing test now explicitly uses ``fork``. This allows it to run successfully on macOS. Pull request by Theodore Ni. GitHub #116. diff --git a/README.rst b/README.rst index f93eb95..f27d37f 100644 --- a/README.rst +++ b/README.rst @@ -69,6 +69,10 @@ If you wish to also retrieve the prefix length for the record, use the ``get_with_prefix_len`` method. This returns a tuple containing the record followed by the network prefix length associated with the record. +You may also iterate over the whole database. The ``Reader`` class implements +the ``__iter__`` method that returns an iterator. This iterator yields a +tuple containing the network and the record. + Example ------- @@ -83,6 +87,9 @@ Example >>> >>> reader.get_with_prefix_len('152.216.7.110') ({'country': ... }, 24) + >>> + >>> for network, record in reader: + >>> ... Exceptions ---------- diff --git a/extension/maxminddb.c b/extension/maxminddb.c index 6584344..d58b305 100644 --- a/extension/maxminddb.c +++ b/extension/maxminddb.c @@ -1,3 +1,4 @@ +#define PY_SSIZE_T_CLEAN #include #include #include @@ -9,8 +10,10 @@ #include static PyTypeObject Reader_Type; +static PyTypeObject ReaderIter_Type; static PyTypeObject Metadata_Type; static PyObject *MaxMindDB_error; +static PyObject *ipaddress_ip_network; // clang-format off typedef struct { @@ -19,6 +22,22 @@ typedef struct { PyObject *closed; } Reader_obj; +typedef struct record record; +struct record{ + char ip_packed[16]; + int depth; + uint64_t record; + uint8_t type; + MMDB_entry_s entry; + struct record *next; +}; + +typedef struct { + PyObject_HEAD /* no semicolon */ + Reader_obj *reader; + struct record *next; +} ReaderIter_obj; + typedef struct { PyObject_HEAD /* no semicolon */ PyObject *binary_format_major_version; @@ -389,6 +408,199 @@ static void Reader_dealloc(PyObject *self) { PyObject_Del(self); } +static PyObject *Reader_iter(PyObject *reader) { + ReaderIter_obj *ri = PyObject_New(ReaderIter_obj, &ReaderIter_Type); + if (ri == NULL) { + return NULL; + } + + ri->reader = (Reader_obj *)reader; + if (ri->reader->closed == Py_True) { + PyErr_SetString(PyExc_ValueError, + "Attempt to iterate over a closed MaxMind DB."); + return NULL; + } + Py_INCREF(reader); + + // Currently, we are always starting from the 0 node with the 0 IP + ri->next = calloc(1, sizeof(record)); + if (ri->next == NULL) { + Py_DECREF(reader); + PyErr_NoMemory(); + return NULL; + } + + return (PyObject *)ri; +} + +static bool is_ipv6(char ip[16]) { + char z = 0; + for (int i = 0; i < 12; i++) { + z |= ip[i]; + } + return z; +} + +static PyObject *ReaderIter_next(PyObject *self) { + ReaderIter_obj *ri = (ReaderIter_obj *)self; + if (ri->reader->closed == Py_True) { + PyErr_SetString(PyExc_ValueError, + "Attempt to iterate over a closed MaxMind DB."); + return NULL; + } + + while (ri->next != NULL) { + record *cur = ri->next; + ri->next = cur->next; + + switch (cur->type) { + case MMDB_RECORD_TYPE_INVALID: + PyErr_SetString(MaxMindDB_error, + "Invalid record when reading node"); + free(cur); + return NULL; + case MMDB_RECORD_TYPE_SEARCH_NODE: { + if (cur->record == + ri->reader->mmdb->ipv4_start_node.node_value && + is_ipv6(cur->ip_packed)) { + // These are aliased networks. Skip them. + break; + } + MMDB_search_node_s node; + int status = MMDB_read_node( + ri->reader->mmdb, (uint32_t)cur->record, &node); + if (status != MMDB_SUCCESS) { + const char *error = MMDB_strerror(status); + PyErr_Format( + MaxMindDB_error, "Error reading node: %s", error); + free(cur); + return NULL; + } + struct record *left = calloc(1, sizeof(record)); + if (left == NULL) { + free(cur); + PyErr_NoMemory(); + return NULL; + } + memcpy( + left->ip_packed, cur->ip_packed, sizeof(left->ip_packed)); + left->depth = cur->depth + 1; + left->record = node.left_record; + left->type = node.left_record_type; + left->entry = node.left_record_entry; + + struct record *right = left->next = calloc(1, sizeof(record)); + if (right == NULL) { + free(cur); + PyErr_NoMemory(); + return NULL; + } + memcpy( + right->ip_packed, cur->ip_packed, sizeof(right->ip_packed)); + right->ip_packed[cur->depth / 8] |= 1 << (7 - cur->depth % 8); + right->depth = cur->depth + 1; + right->record = node.right_record; + right->type = node.right_record_type; + right->entry = node.right_record_entry; + right->next = ri->next; + + ri->next = left; + break; + } + case MMDB_RECORD_TYPE_EMPTY: + break; + case MMDB_RECORD_TYPE_DATA: { + MMDB_entry_data_list_s *entry_data_list = NULL; + int status = + MMDB_get_entry_data_list(&cur->entry, &entry_data_list); + if (MMDB_SUCCESS != status) { + PyErr_Format( + MaxMindDB_error, + "Error looking up data while iterating over tree: %s", + MMDB_strerror(status)); + MMDB_free_entry_data_list(entry_data_list); + free(cur); + return NULL; + } + + MMDB_entry_data_list_s *original_entry_data_list = + entry_data_list; + PyObject *record = from_entry_data_list(&entry_data_list); + MMDB_free_entry_data_list(original_entry_data_list); + if (record == NULL) { + free(cur); + return NULL; + } + + int ip_start = 0; + int ip_length = 4; + if (ri->reader->mmdb->depth == 128) { + if (is_ipv6(cur->ip_packed)) { + // IPv6 address + ip_length = 16; + } else { + // IPv4 address in IPv6 tree + ip_start = 12; + } + } + PyObject *network_tuple = + Py_BuildValue("(y#i)", + &(cur->ip_packed[ip_start]), + ip_length, + cur->depth - ip_start * 8); + if (network_tuple == NULL) { + Py_DECREF(record); + free(cur); + return NULL; + } + PyObject *args = PyTuple_Pack(1, network_tuple); + Py_DECREF(network_tuple); + if (args == NULL) { + Py_DECREF(record); + free(cur); + return NULL; + } + PyObject *network = + PyObject_CallObject(ipaddress_ip_network, args); + Py_DECREF(args); + if (network == NULL) { + Py_DECREF(record); + free(cur); + return NULL; + } + + PyObject *rv = PyTuple_Pack(2, network, record); + Py_DECREF(network); + Py_DECREF(record); + + free(cur); + return rv; + } + default: + PyErr_Format( + MaxMindDB_error, "Unknown record type: %u", cur->type); + free(cur); + return NULL; + } + free(cur); + } + return NULL; +} + +static void ReaderIter_dealloc(PyObject *self) { + ReaderIter_obj *ri = (ReaderIter_obj *)self; + + Py_DECREF(ri->reader); + + struct record *next = ri->next; + while (next != NULL) { + struct record *cur = next; + next = cur->next; + free(cur); + } + PyObject_Del(self); +} + static int Metadata_init(PyObject *self, PyObject *args, PyObject *kwds) { PyObject *binary_format_major_version, *binary_format_minor_version, @@ -644,6 +856,7 @@ static PyTypeObject Reader_Type = { .tp_dealloc = Reader_dealloc, .tp_doc = "Reader object", .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_iter = Reader_iter, .tp_methods = Reader_methods, .tp_members = Reader_members, .tp_name = "Reader", @@ -651,6 +864,22 @@ static PyTypeObject Reader_Type = { }; // clang-format on +static PyMethodDef ReaderIter_methods[] = {{NULL, NULL, 0, NULL}}; + +// clang-format off +static PyTypeObject ReaderIter_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + .tp_basicsize = sizeof(ReaderIter_obj), + .tp_dealloc = ReaderIter_dealloc, + .tp_doc = "Iterator for Reader object", + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_iter = PyObject_SelfIter, + .tp_iternext = ReaderIter_next, + .tp_methods = ReaderIter_methods, + .tp_name = "odict_iterator", +}; +// clang-format on + static PyMethodDef Metadata_methods[] = {{NULL, NULL, 0, NULL}}; static PyMemberDef Metadata_members[] = { @@ -753,9 +982,21 @@ PyMODINIT_FUNC PyInit_extension(void) { if (MaxMindDB_error == NULL) { return NULL; } - Py_INCREF(MaxMindDB_error); + PyObject *ipaddress_mod = PyImport_ImportModule("ipaddress"); + if (error_mod == NULL) { + return NULL; + } + + ipaddress_ip_network = PyObject_GetAttrString(ipaddress_mod, "ip_network"); + Py_DECREF(ipaddress_mod); + + if (ipaddress_ip_network == NULL) { + return NULL; + } + Py_INCREF(ipaddress_ip_network); + /* We primarily add it to the module for backwards compatibility */ PyModule_AddObject(m, "InvalidDatabaseError", MaxMindDB_error); diff --git a/maxminddb/reader.py b/maxminddb/reader.py index 734cd97..f4af118 100644 --- a/maxminddb/reader.py +++ b/maxminddb/reader.py @@ -23,6 +23,8 @@ from maxminddb.file import FileBuffer from maxminddb.types import Record +_IPV4_MAX_NUM = 2**32 + class Reader: """ @@ -34,7 +36,7 @@ class Reader: _METADATA_START_MARKER = b"\xAB\xCD\xEFMaxMind.com" _buffer: Union[bytes, FileBuffer, "mmap.mmap"] - _ipv4_start: Optional[int] = None + _ipv4_start: int def __init__( self, database: Union[AnyStr, int, PathLike, IO], mode: int = MODE_AUTO @@ -107,6 +109,18 @@ def __init__( ) self.closed = False + ipv4_start = 0 + if self._metadata.ip_version == 6: + # We are looking up an IPv4 address in an IPv6 tree. Skip over the + # first 96 nodes. + node = 0 + for _ in range(96): + if node >= self._metadata.node_count: + break + node = self._read_node(node, 0) + ipv4_start = node + self._ipv4_start = ipv4_start + def metadata(self) -> "Metadata": """Return the metadata associated with the MaxMind DB file""" return self._metadata @@ -152,6 +166,35 @@ def get_with_prefix_len( return self._resolve_data_pointer(pointer), prefix_len return None, prefix_len + def __iter__(self, include_aliased_nodes=False): + return self._generate_children(0, 0, 0, include_aliased_nodes) + + def _generate_children(self, node, depth, ip_acc, include_aliased_nodes): + if not include_aliased_nodes and ip_acc != 0 and node == self._ipv4_start: + # Skip nodes aliased to IPv4 + return + + node_count = self._metadata.node_count + if node > node_count: + bits = 128 if self._metadata.ip_version == 6 else 32 + ip_acc <<= bits - depth + if ip_acc <= _IPV4_MAX_NUM and bits == 128: + depth -= 96 + yield ipaddress.ip_network((ip_acc, depth)), self._resolve_data_pointer( + node + ) + elif node < node_count: + left = self._read_node(node, 0) + ip_acc <<= 1 + depth += 1 + yield from self._generate_children( + left, depth, ip_acc, include_aliased_nodes + ) + right = self._read_node(node, 1) + yield from self._generate_children( + right, depth, ip_acc | 1, include_aliased_nodes + ) + def _find_address_in_tree(self, packed: bytearray) -> Tuple[int, int]: bit_count = len(packed) * 8 node = self._start_node(bit_count) @@ -172,21 +215,9 @@ def _find_address_in_tree(self, packed: bytearray) -> Tuple[int, int]: raise InvalidDatabaseError("Invalid node in search tree") def _start_node(self, length: int) -> int: - if self._metadata.ip_version != 6 or length == 128: - return 0 - - # We are looking up an IPv4 address in an IPv6 tree. Skip over the - # first 96 nodes. - if self._ipv4_start: + if self._metadata.ip_version == 6 and length == 32: return self._ipv4_start - - node = 0 - for _ in range(96): - if node >= self._metadata.node_count: - break - node = self._read_node(node, 0) - self._ipv4_start = node - return node + return 0 def _read_node(self, node_number: int, index: int) -> int: base_offset = node_number * self._metadata.node_byte_size diff --git a/tests/data b/tests/data index 86095bd..a75bfb1 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 86095bd9855d6313c501fe0097891a3c6734ae90 +Subproject commit a75bfb17a0e77f576c9eef0cfbf6220909e959e7 diff --git a/tests/reader_test.py b/tests/reader_test.py index e2f1760..f553ae4 100644 --- a/tests/reader_test.py +++ b/tests/reader_test.py @@ -187,6 +187,54 @@ def test_get_with_prefix_len(self): "expected_record for " + test["ip"] + " in " + test["file_name"], ) + def test_iterator(self): + tests = ( + { + "database": "ipv4", + "expected": [ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + ], + }, + { + "database": "ipv6", + "expected": [ + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + ], + }, + { + "database": "mixed", + "expected": [ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + ], + }, + ) + + for record_size in [24, 28, 32]: + for test in tests: + f = f'tests/data/test-data/MaxMind-DB-test-{test["database"]}-{record_size}.mmdb' + reader = open_database(f, self.mode) + networks = [str(n) for (n, _) in reader] + self.assertEqual(networks, test["expected"], f) + def test_decoder(self): reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode