Skip to content

Commit

Permalink
Python: Complex Types in store_chunk (#915)
Browse files Browse the repository at this point in the history
Add missing support for complex types in the Python store chunk
interfaces of `Record_Component`.
  • Loading branch information
ax3l authored Jan 31, 2021
1 parent 6d4977a commit 8ea0645
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/binding/python/RecordComponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ store_chunk(RecordComponent & r, py::array & a, Offset const & offset, Extent co
else if( dtype == Datatype::LONG_DOUBLE ) store_data( (long double)0 );
else if( dtype == Datatype::DOUBLE ) store_data( double() );
else if( dtype == Datatype::FLOAT ) store_data( float() );
else if( dtype == Datatype::CLONG_DOUBLE ) store_data( std::complex<long double>() );
else if( dtype == Datatype::CDOUBLE ) store_data( std::complex<double>() );
else if( dtype == Datatype::CFLOAT ) store_data( std::complex<float>() );
/* @todo
.value("STRING", Datatype::STRING)
.value("VEC_STRING", Datatype::VEC_STRING)
Expand Down
86 changes: 78 additions & 8 deletions test/python/unittest/API/APITest.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,13 @@ def makeConstantRoundTrip(self, file_ending):
np.dtype('double'))
self.assertTrue(ms["longdouble"][SCALAR].load_chunk(o, e).dtype
== np.dtype('longdouble'))
if file_ending != "json":
self.assertTrue(ms["complex64"][SCALAR].load_chunk(o, e).dtype
== np.dtype('complex64'))
self.assertTrue(ms["complex128"][SCALAR].load_chunk(o, e).dtype
== np.dtype('complex128'))
if file_ending != "bp":
self.assertTrue(ms["clongdouble"][SCALAR].load_chunk(o, e)
.dtype == np.dtype('clongdouble'))
self.assertTrue(ms["complex64"][SCALAR].load_chunk(o, e).dtype
== np.dtype('complex64'))
self.assertTrue(ms["complex128"][SCALAR].load_chunk(o, e).dtype
== np.dtype('complex128'))
if file_ending != "bp":
self.assertTrue(ms["clongdouble"][SCALAR].load_chunk(o, e)
.dtype == np.dtype('clongdouble'))

# FIXME: why does this even work w/o a flush() ?
self.assertEqual(ms["int16"][SCALAR].load_chunk(o, e),
Expand Down Expand Up @@ -567,6 +566,77 @@ def testConstantRecords(self):
for ext in tested_file_extensions:
self.makeConstantRoundTrip(ext)

def makeDataRoundTrip(self, file_ending):
if not found_numpy:
return

# write
series = io.Series(
"unittest_py_data_API." + file_ending,
io.Access.create
)

ms = series.iterations[0].meshes
SCALAR = io.Mesh_Record_Component.SCALAR
DS = io.Dataset

extent = [42, 24, 11]

ms["complex64"][SCALAR].reset_dataset(
DS(np.dtype("complex64"), extent))
ms["complex64"][SCALAR].store_chunk(
np.ones(extent, dtype=np.complex64) *
np.complex64(1.234 + 2.345j))
ms["complex128"][SCALAR].reset_dataset(
DS(np.dtype("complex128"), extent))
ms["complex128"][SCALAR].store_chunk(
np.ones(extent, dtype=np.complex128) *
np.complex128(1.234567 + 2.345678j))
if file_ending != "bp":
ms["clongdouble"][SCALAR].reset_dataset(
DS(np.dtype("clongdouble"), extent))
ms["clongdouble"][SCALAR].store_chunk(
np.ones(extent, dtype=np.clongdouble) *
np.clongdouble(1.23456789 + 2.34567890j))

# flush and close file
del series

# read back
series = io.Series(
"unittest_py_data_API." + file_ending,
io.Access.read_only
)

ms = series.iterations[0].meshes
o = [1, 2, 3]
e = [1, 1, 1]

dc64 = ms["complex64"][SCALAR].load_chunk(o, e)
dc128 = ms["complex128"][SCALAR].load_chunk(o, e)
if file_ending != "bp":
dc256 = ms["clongdouble"][SCALAR].load_chunk(o, e)

self.assertTrue(dc64.dtype == np.dtype('complex64'))
self.assertTrue(dc128.dtype == np.dtype('complex128'))
if file_ending != "bp":
self.assertTrue(
dc256.dtype == np.dtype('clongdouble'))

series.flush()

self.assertEqual(dc64,
np.complex64(1.234 + 2.345j))
self.assertEqual(dc128,
np.complex128(1.234567 + 2.345678j))
if file_ending != "bp":
self.assertEqual(dc256,
np.clongdouble(1.23456789 + 2.34567890j))

def testDataRoundTrip(self):
for ext in io.file_extensions:
self.makeDataRoundTrip(ext)

def makeEmptyRoundTrip(self, file_ending):
# write
series = io.Series(
Expand Down

0 comments on commit 8ea0645

Please sign in to comment.