Skip to content

Commit

Permalink
Add force_offset_64 to LightweightTableCollection.asdict()
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 29, 2021
1 parent 100338a commit 0ef9b51
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 18 deletions.
51 changes: 50 additions & 1 deletion python/lwt_interface/dict_encoding_testlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2018-2020 Tskit Developers
# Copyright (c) 2018-2021 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -633,3 +633,52 @@ def test_del_lwt_and_tables(self, tables):
tables2 = tables.copy()
del tables
assert tskit.TableCollection.fromdict(lwt_dict) == tables2


class TestForceOffset64:
def get_offset_columns(self, dict_encoding):
for table_name, table in dict_encoding.items():
if isinstance(table, dict):
for name, array in table.items():
if name.endswith("_offset"):
yield f"{table_name}/{name}", array

def test_bad_args(self, tables):
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(tables.asdict())
for bad_type in [None, {}, "sdf"]:
with pytest.raises(TypeError):
lwt.asdict(bad_type)

def test_off_by_default(self, tables):
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(tables.asdict())
d = lwt.asdict()
for _, array in self.get_offset_columns(d):
assert array.dtype == np.uint32

def test_types_64(self, tables):
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(tables.asdict())
d = lwt.asdict(force_offset_64=True)
for _, array in self.get_offset_columns(d):
assert array.dtype == np.uint64

def test_types_32(self, tables):
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(tables.asdict())
d = lwt.asdict(force_offset_64=False)
for _, array in self.get_offset_columns(d):
assert array.dtype == np.uint32

def test_values_equal(self, tables):
lwt = lwt_module.LightweightTableCollection()
lwt.fromdict(tables.asdict())
d64 = lwt.asdict(force_offset_64=True)
d32 = lwt.asdict(force_offset_64=False)
offsets_64 = dict(self.get_offset_columns(d64))
offsets_32 = dict(self.get_offset_columns(d32))
for col_name, col_32 in offsets_32.items():
col_64 = offsets_64[col_name]
assert col_64.shape == col_32.shape
assert np.all(col_64 == col_32)
54 changes: 37 additions & 17 deletions python/lwt_interface/tskit_lwt_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -1503,24 +1503,37 @@ write_table_col(const tsklwt_table_col_t *col, PyObject *table_dict)
}

static int
write_ragged_col(const tsklwt_ragged_col_t *col, PyObject *table_dict)
write_ragged_col(
const tsklwt_ragged_col_t *col, PyObject *table_dict, bool force_offset_64)
{
int ret = -1;
char offset_col_name[128];
npy_intp offset_len = col->num_rows + 1;
PyArrayObject *data_array
= (PyArrayObject *) PyArray_EMPTY(1, &col->data_len, col->type, 0);
PyArrayObject *offset_array
= (PyArrayObject *) PyArray_EMPTY(1, &offset_len, NPY_UINT32, 0);

PyArrayObject *data_array = NULL;
PyArrayObject *offset_array = NULL;
bool offset_64 = force_offset_64; // || col->offset[col->num_rows] > UINT32_MAX
int offset_type = offset_64 ? NPY_UINT64 : NPY_UINT32;
/* TODO change this to 32 bit when we flip tsk_size_t over */
uint64_t *dest;
npy_intp j;

data_array = (PyArrayObject *) PyArray_EMPTY(1, &col->data_len, col->type, 0);
offset_array = (PyArrayObject *) PyArray_EMPTY(1, &offset_len, offset_type, 0);
if (data_array == NULL || offset_array == NULL) {
goto out;
}

memcpy(PyArray_DATA(data_array), col->data,
col->data_len * PyArray_ITEMSIZE(data_array));
memcpy(PyArray_DATA(offset_array), col->offset,
offset_len * PyArray_ITEMSIZE(offset_array));
if (offset_64) {
dest = (uint64_t *) PyArray_DATA(offset_array);
for (j = 0; j < offset_len; j++) {
dest[j] = col->offset[j];
}
} else {
memcpy(PyArray_DATA(offset_array), col->offset,
offset_len * PyArray_ITEMSIZE(offset_array));
}

assert(strlen(col->name) + strlen("_offset") + 2 < sizeof(offset_col_name));
strcpy(offset_col_name, col->name);
Expand All @@ -1541,7 +1554,7 @@ write_ragged_col(const tsklwt_ragged_col_t *col, PyObject *table_dict)
}

static PyObject *
write_table_dict(const tsklwt_table_desc_t *table_desc)
write_table_dict(const tsklwt_table_desc_t *table_desc, bool force_offset_64)
{
PyObject *ret = NULL;
PyObject *str = NULL;
Expand All @@ -1563,7 +1576,7 @@ write_table_dict(const tsklwt_table_desc_t *table_desc)
if (table_desc->ragged_cols != NULL) {
for (ragged_col = table_desc->ragged_cols; ragged_col->name != NULL;
ragged_col++) {
if (write_ragged_col(ragged_col, table_dict) != 0) {
if (write_ragged_col(ragged_col, table_dict, force_offset_64) != 0) {
goto out;
}
}
Expand All @@ -1587,7 +1600,8 @@ write_table_dict(const tsklwt_table_desc_t *table_desc)
}

static int
write_table_arrays(tsk_table_collection_t *tables, PyObject *dict)
write_table_arrays(
const tsk_table_collection_t *tables, PyObject *dict, bool force_offset_64)
{
int ret = -1;
PyObject *table_dict = NULL;
Expand Down Expand Up @@ -1759,7 +1773,7 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict)
};

for (j = 0; j < sizeof(table_descs) / sizeof(*table_descs); j++) {
table_dict = write_table_dict(&table_descs[j]);
table_dict = write_table_dict(&table_descs[j], force_offset_64);
if (table_dict == NULL) {
goto out;
}
Expand All @@ -1776,7 +1790,7 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict)

/* Returns a dictionary encoding of the specified table collection */
static PyObject *
dump_tables_dict(tsk_table_collection_t *tables)
dump_tables_dict(tsk_table_collection_t *tables, bool force_offset_64)
{
PyObject *ret = NULL;
PyObject *dict = NULL;
Expand Down Expand Up @@ -1834,7 +1848,7 @@ dump_tables_dict(tsk_table_collection_t *tables)
val = NULL;
}

err = write_table_arrays(tables, dict);
err = write_table_arrays(tables, dict, force_offset_64);
if (err != 0) {
goto out;
}
Expand Down Expand Up @@ -1903,14 +1917,20 @@ LightweightTableCollection_init(
}

static PyObject *
LightweightTableCollection_asdict(LightweightTableCollection *self)
LightweightTableCollection_asdict(
LightweightTableCollection *self, PyObject *args, PyObject *kwds)
{
PyObject *ret = NULL;
static char *kwlist[] = { "force_offset_64", NULL };
int force_offset_64 = 0;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "|i", kwlist, &force_offset_64)) {
goto out;
}
if (LightweightTableCollection_check_state(self) != 0) {
goto out;
}
ret = dump_tables_dict(self->tables);
ret = dump_tables_dict(self->tables, force_offset_64);
out:
return ret;
}
Expand Down Expand Up @@ -1940,7 +1960,7 @@ LightweightTableCollection_fromdict(LightweightTableCollection *self, PyObject *
static PyMethodDef LightweightTableCollection_methods[] = {
{ .ml_name = "asdict",
.ml_meth = (PyCFunction) LightweightTableCollection_asdict,
.ml_flags = METH_NOARGS,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Returns the tables encoded as a dictionary." },
{ .ml_name = "fromdict",
.ml_meth = (PyCFunction) LightweightTableCollection_fromdict,
Expand Down

0 comments on commit 0ef9b51

Please sign in to comment.