Skip to content

Commit

Permalink
Add iterator over Reader object
Browse files Browse the repository at this point in the history
Closes #23.
  • Loading branch information
oschwald committed Oct 12, 2023
1 parent f6f983e commit 9680255
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 18 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
.pyre
.tox
build
compile_flags.txt
core
dist
docs/_build
Expand Down
6 changes: 5 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
History
-------

2.4.1
2.5.0
++++++++++++++++++

* The ``Reader`` class now implements the ``__iter__`` method. This will
return an iterator that iterates over all records in the database,
excluding repeated aliased of the IPv4 network. Requested by
Jean-Baptiste Braun and others. GitHub #23.
* The multiprocessing test now explicitly uses ``fork``. This allows it
to run successfully on macOS. Pull request by Theodore Ni. GitHub #116.

Expand Down
7 changes: 7 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ If you wish to also retrieve the prefix length for the record, use the
``get_with_prefix_len`` method. This returns a tuple containing the record
followed by the network prefix length associated with the record.

You may also iterate over the whole database. The ``Reader`` class implements
the ``__iter__`` method that returns an iterator. This iterator yields a
tuple containing the network and the record.

Example
-------

Expand All @@ -83,6 +87,9 @@ Example
>>>
>>> reader.get_with_prefix_len('152.216.7.110')
({'country': ... }, 24)
>>>
>>> for network, record in reader:
>>> ...
Exceptions
----------
Expand Down
243 changes: 242 additions & 1 deletion extension/maxminddb.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <arpa/inet.h>
#include <maxminddb.h>
Expand All @@ -9,8 +10,10 @@
#include <inttypes.h>

static PyTypeObject Reader_Type;
static PyTypeObject ReaderIter_Type;
static PyTypeObject Metadata_Type;
static PyObject *MaxMindDB_error;
static PyObject *ipaddress_ip_network;

// clang-format off
typedef struct {
Expand All @@ -19,6 +22,22 @@ typedef struct {
PyObject *closed;
} Reader_obj;

typedef struct record record;
struct record{
char ip_packed[16];
int depth;
uint64_t record;
uint8_t type;
MMDB_entry_s entry;
struct record *next;
};

typedef struct {
PyObject_HEAD /* no semicolon */
Reader_obj *reader;
struct record *next;
} ReaderIter_obj;

typedef struct {
PyObject_HEAD /* no semicolon */
PyObject *binary_format_major_version;
Expand Down Expand Up @@ -389,6 +408,199 @@ static void Reader_dealloc(PyObject *self) {
PyObject_Del(self);
}

static PyObject *Reader_iter(PyObject *reader) {
ReaderIter_obj *ri = PyObject_New(ReaderIter_obj, &ReaderIter_Type);
if (ri == NULL) {
return NULL;
}

ri->reader = (Reader_obj *)reader;
if (ri->reader->closed == Py_True) {
PyErr_SetString(PyExc_ValueError,
"Attempt to iterate over a closed MaxMind DB.");
return NULL;
}
Py_INCREF(reader);

// Currently, we are always starting from the 0 node with the 0 IP
ri->next = calloc(1, sizeof(record));
if (ri->next == NULL) {
Py_DECREF(reader);
PyErr_NoMemory();
return NULL;
}

return (PyObject *)ri;
}

static bool is_ipv6(char ip[16]) {
char z = 0;
for (int i = 0; i < 12; i++) {
z |= ip[i];
}
return z;
}

static PyObject *ReaderIter_next(PyObject *self) {
ReaderIter_obj *ri = (ReaderIter_obj *)self;
if (ri->reader->closed == Py_True) {
PyErr_SetString(PyExc_ValueError,
"Attempt to iterate over a closed MaxMind DB.");
return NULL;
}

while (ri->next != NULL) {
record *cur = ri->next;
ri->next = cur->next;

switch (cur->type) {
case MMDB_RECORD_TYPE_INVALID:
PyErr_SetString(MaxMindDB_error,
"Invalid record when reading node");
free(cur);
return NULL;
case MMDB_RECORD_TYPE_SEARCH_NODE: {
if (cur->record ==
ri->reader->mmdb->ipv4_start_node.node_value &&
is_ipv6(cur->ip_packed)) {
// These are aliased networks. Skip them.
break;
}
MMDB_search_node_s node;
int status = MMDB_read_node(
ri->reader->mmdb, (uint32_t)cur->record, &node);
if (status != MMDB_SUCCESS) {
const char *error = MMDB_strerror(status);
PyErr_Format(
MaxMindDB_error, "Error reading node: %s", error);
free(cur);
return NULL;
}
struct record *left = calloc(1, sizeof(record));
if (left == NULL) {
free(cur);
PyErr_NoMemory();
return NULL;
}
memcpy(
left->ip_packed, cur->ip_packed, sizeof(left->ip_packed));
left->depth = cur->depth + 1;
left->record = node.left_record;
left->type = node.left_record_type;
left->entry = node.left_record_entry;

struct record *right = left->next = calloc(1, sizeof(record));
if (right == NULL) {
free(cur);
PyErr_NoMemory();
return NULL;
}
memcpy(
right->ip_packed, cur->ip_packed, sizeof(right->ip_packed));
right->ip_packed[cur->depth / 8] |= 1 << (7 - cur->depth % 8);
right->depth = cur->depth + 1;
right->record = node.right_record;
right->type = node.right_record_type;
right->entry = node.right_record_entry;
right->next = ri->next;

ri->next = left;
break;
}
case MMDB_RECORD_TYPE_EMPTY:
break;
case MMDB_RECORD_TYPE_DATA: {
MMDB_entry_data_list_s *entry_data_list = NULL;
int status =
MMDB_get_entry_data_list(&cur->entry, &entry_data_list);
if (MMDB_SUCCESS != status) {
PyErr_Format(
MaxMindDB_error,
"Error looking up data while iterating over tree: %s",
MMDB_strerror(status));
MMDB_free_entry_data_list(entry_data_list);
free(cur);
return NULL;
}

MMDB_entry_data_list_s *original_entry_data_list =
entry_data_list;
PyObject *record = from_entry_data_list(&entry_data_list);
MMDB_free_entry_data_list(original_entry_data_list);
if (record == NULL) {
free(cur);
return NULL;
}

int ip_start = 0;
int ip_length = 4;
if (ri->reader->mmdb->depth == 128) {
if (is_ipv6(cur->ip_packed)) {
// IPv6 address
ip_length = 16;
} else {
// IPv4 address in IPv6 tree
ip_start = 12;
}
}
PyObject *network_tuple =
Py_BuildValue("(y#i)",
&(cur->ip_packed[ip_start]),
ip_length,
cur->depth - ip_start * 8);
if (network_tuple == NULL) {
Py_DECREF(record);
free(cur);
return NULL;
}
PyObject *args = PyTuple_Pack(1, network_tuple);
Py_DECREF(network_tuple);
if (args == NULL) {
Py_DECREF(record);
free(cur);
return NULL;
}
PyObject *network =
PyObject_CallObject(ipaddress_ip_network, args);
Py_DECREF(args);
if (network == NULL) {
Py_DECREF(record);
free(cur);
return NULL;
}

PyObject *rv = PyTuple_Pack(2, network, record);
Py_DECREF(network);
Py_DECREF(record);

free(cur);
return rv;
}
default:
PyErr_Format(
MaxMindDB_error, "Unknown record type: %u", cur->type);
free(cur);
return NULL;
}
free(cur);
}
return NULL;
}

static void ReaderIter_dealloc(PyObject *self) {
ReaderIter_obj *ri = (ReaderIter_obj *)self;

Py_DECREF(ri->reader);

struct record *next = ri->next;
while (next != NULL) {
struct record *cur = next;
next = cur->next;
free(cur);
}
PyObject_Del(self);
}

static int Metadata_init(PyObject *self, PyObject *args, PyObject *kwds) {

PyObject *binary_format_major_version, *binary_format_minor_version,
Expand Down Expand Up @@ -644,13 +856,30 @@ static PyTypeObject Reader_Type = {
.tp_dealloc = Reader_dealloc,
.tp_doc = "Reader object",
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_iter = Reader_iter,
.tp_methods = Reader_methods,
.tp_members = Reader_members,
.tp_name = "Reader",
.tp_init = Reader_init,
};
// clang-format on

static PyMethodDef ReaderIter_methods[] = {{NULL, NULL, 0, NULL}};

// clang-format off
static PyTypeObject ReaderIter_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
.tp_basicsize = sizeof(ReaderIter_obj),
.tp_dealloc = ReaderIter_dealloc,
.tp_doc = "Iterator for Reader object",
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_iter = PyObject_SelfIter,
.tp_iternext = ReaderIter_next,
.tp_methods = ReaderIter_methods,
.tp_name = "odict_iterator",
};
// clang-format on

static PyMethodDef Metadata_methods[] = {{NULL, NULL, 0, NULL}};

static PyMemberDef Metadata_members[] = {
Expand Down Expand Up @@ -753,9 +982,21 @@ PyMODINIT_FUNC PyInit_extension(void) {
if (MaxMindDB_error == NULL) {
return NULL;
}

Py_INCREF(MaxMindDB_error);

PyObject *ipaddress_mod = PyImport_ImportModule("ipaddress");
if (error_mod == NULL) {
return NULL;
}

ipaddress_ip_network = PyObject_GetAttrString(ipaddress_mod, "ip_network");
Py_DECREF(ipaddress_mod);

if (ipaddress_ip_network == NULL) {
return NULL;
}
Py_INCREF(ipaddress_ip_network);

/* We primarily add it to the module for backwards compatibility */
PyModule_AddObject(m, "InvalidDatabaseError", MaxMindDB_error);

Expand Down
Loading

0 comments on commit 9680255

Please sign in to comment.