Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dgRMatrix, dgCMatrix and dgTMatrix #9

Merged
merged 15 commits into from
Sep 25, 2018
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ r_packages:
- knitr
- rmarkdown
- pkgKitten
- Matrix
addons:
apt:
sources:
Expand Down
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ Suggests: pkgKitten,
rmarkdown,
covr,
testthat,
inline
inline,
Matrix
VignetteBuilder: knitr
67 changes: 63 additions & 4 deletions inst/include/RcppArrayFireAs.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,35 @@ namespace RcppArrayFire{
template<> struct dtype2cpp<s32>{ typedef int type ; };
template<> struct dtype2cpp<u32>{ typedef unsigned int type ; };

template<af::storage AF_STORAGETYPE> struct af_storage_traits{};
template<> struct af_storage_traits<AF_STORAGE_CSR>{
static constexpr auto row_idx = "p";
static constexpr auto col_idx = "j";
static void check_s4_class(const SEXP &x){
if(!Rf_inherits(x, "dgRMatrix"))
throw std::invalid_argument("Need S4 class dgRMatrix for a typed_array<af::dtype, AF_STORAGE_CSR>");
}
};
template<> struct af_storage_traits<AF_STORAGE_CSC>{
static constexpr auto row_idx = "i";
static constexpr auto col_idx = "p";
static void check_s4_class(const SEXP &x){
if(!Rf_inherits(x, "dgCMatrix"))
throw std::invalid_argument("Need S4 class dgCMatrix for a typed_array<af::dtype, AF_STORAGE_CSC>");
}
};
template<> struct af_storage_traits<AF_STORAGE_COO>{
static constexpr auto row_idx = "i";
static constexpr auto col_idx = "j";
static void check_s4_class(const SEXP &x){
if(!Rf_inherits(x, "dgTMatrix"))
throw std::invalid_argument("Need S4 class dgTMatrix for a typed_array<af::dtype, AF_STORAGE_COO>");
}
};

template<af::dtype AF_DTYPE>
template<
af::dtype AF_DTYPE,
af::storage AF_STORAGETYPE = AF_STORAGE_DENSE>
class typed_array : public af::array{
public:
typed_array() : af::array() {}
Expand Down Expand Up @@ -137,15 +164,15 @@ namespace internal{
namespace traits {

template<af::dtype AF_DTYPE>
class Exporter< ::RcppArrayFire::typed_array<AF_DTYPE> >{
class Exporter< ::RcppArrayFire::typed_array<AF_DTYPE, AF_STORAGE_DENSE> >{
private:
SEXP object ;

public:
Exporter( SEXP x ) : object(x){}
~Exporter(){}

::RcppArrayFire::typed_array<AF_DTYPE> get() {
::RcppArrayFire::typed_array<AF_DTYPE, AF_STORAGE_DENSE> get() {
typedef typename ::RcppArrayFire::dtype2cpp<AF_DTYPE>::type cpp_type ;
//std::vector<cpp_type> buff( Rf_length( object ) );

Expand All @@ -165,10 +192,42 @@ namespace traits {
result = af::array(::Rcpp::as<af::dim4>(dims), buff.data());
}

return ::RcppArrayFire::typed_array<AF_DTYPE>( result );
return ::RcppArrayFire::typed_array<AF_DTYPE, AF_STORAGE_DENSE>( result );
}
};


// Exporter for compressed matrix (dgRMatrix, dgCMatrix, dgTMatrix)
template<af::dtype AF_DTYPE, af::storage AF_STORAGETYPE>
class Exporter< ::RcppArrayFire::typed_array<AF_DTYPE, AF_STORAGETYPE> >{
private:
S4 d_x;
IntegerVector d_dims, d_row, d_col;

public:
Exporter(SEXP x){
::RcppArrayFire::af_storage_traits<AF_STORAGETYPE>::check_s4_class(x);
d_x = x;
d_dims = d_x.slot("Dim");
d_row = d_x.slot(::RcppArrayFire::af_storage_traits<AF_STORAGETYPE>::row_idx);
d_col = d_x.slot(::RcppArrayFire::af_storage_traits<AF_STORAGETYPE>::col_idx);
}
~Exporter(){}

::RcppArrayFire::typed_array<AF_DTYPE, AF_STORAGETYPE> get() {
typedef typename ::RcppArrayFire::dtype2cpp<AF_DTYPE>::type cpp_type ;
::RcppArrayFire::SEXP2CxxPtr<cpp_type> buff(
static_cast<SEXP>(d_x.slot("x")) ) ;

af::array result;
result = af::sparse(
d_dims[0], d_dims[1], buff.size(),
buff.data(), d_row.begin(), d_col.begin(),
AF_DTYPE, AF_STORAGETYPE);

return ::RcppArrayFire::typed_array<AF_DTYPE, AF_STORAGETYPE>( result );
}
};
}
}

Expand Down
8 changes: 6 additions & 2 deletions inst/include/RcppArrayFireForward.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

/* forward declarations */
namespace RcppArrayFire{
template<af::dtype AF_DTYPE> class typed_array;
template<
af::dtype AF_DTYPE,
af::storage AF_STORAGETYPE> class typed_array;
}

namespace Rcpp {
Expand All @@ -43,7 +45,9 @@ namespace Rcpp {

namespace traits {
/* support for as */
template<af::dtype AF_DTYPE> class Exporter<RcppArrayFire::typed_array<AF_DTYPE>>;
template<
af::dtype AF_DTYPE,
af::storage AF_STORAGETYPE> class Exporter<RcppArrayFire::typed_array<AF_DTYPE, AF_STORAGETYPE>>;
}
}

Expand Down
62 changes: 60 additions & 2 deletions inst/include/RcppArrayFireWrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,69 @@ namespace RcppArrayFire{
return ::Rcpp::wrap_extra_steps<T>(x) ;
}

template<typename T> SEXP wrap_array( const af::array& object ){

template<typename T> SEXP wrap_dense_array( const af::array& object ){
return wrap_array_dispatch<T>(object, typename ::Rcpp::traits::r_sexptype_needscast<T>());
}

SEXP af_wrap( const af::array& object ) ;
template<typename T> SEXP wrap_sparse_array( const af::array& object, const af::storage storage_type ){
const int RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype;

std::string major;
switch ( storage_type ) {
case AF_STORAGE_CSR: major = "R"; break;
case AF_STORAGE_CSC: major = "C"; break;
case AF_STORAGE_COO: major = "T"; break;
}

std::string klass;
switch( RTYPE ){
case REALSXP:
klass = std::string("dg") + major + "Matrix";
break;

case LGLSXP:
klass = std::string("lg") + major + "Matrix";
break;

default:
throw std::invalid_argument( "RTYPE not matched in conversion to sparse matrix" );
break;
}

::Rcpp::S4 s(klass);
switch ( storage_type ) {
case AF_STORAGE_CSR:
s.slot("p") = wrap_dense_array<int>( af::sparseGetRowIdx( object ) );
s.slot("j") = wrap_dense_array<int>( af::sparseGetColIdx( object ) );
break;

case AF_STORAGE_CSC:
s.slot("i") = wrap_dense_array<int>( af::sparseGetRowIdx( object ) );
s.slot("p") = wrap_dense_array<int>( af::sparseGetColIdx( object ) );
break;

case AF_STORAGE_COO:
s.slot("i") = wrap_dense_array<int>( af::sparseGetRowIdx( object ) );
s.slot("j") = wrap_dense_array<int>( af::sparseGetColIdx( object ) );
break;
}
s.slot("x") = wrap_dense_array<T>( af::sparseGetValues( object ) );
s.slot("Dim") = ::Rcpp::IntegerVector::create( object.dims(0), object.dims(1) );

return s;
}

template<typename T> SEXP wrap_array( const af::array& object ){
if ( object.issparse() ) {
return wrap_sparse_array<T>( object, af::sparseGetStorage( object ) );
}
else { // if dense array
return wrap_dense_array<T>( object );
}
}

SEXP wrap_af_impl( const af::array& object ) ;
} /* namespace RcppArrayFire */

#endif
72 changes: 24 additions & 48 deletions src/RcppArrayFireWrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,57 +23,33 @@
#include <RcppArrayFireWrap.h>

namespace RcppArrayFire{
SEXP af_wrap( const af::array& object ){
SEXP wrap_af_impl( const af::array& object ){
::Rcpp::RObject x;

switch(object.type()){
case f64:
x = wrap_array<double>( object ) ;
break;
case c64:
x = wrap_array<std::complex<double>>( object ) ;
break;
case f32:
x = wrap_array<float>( object ) ;
break;
case c32:
x = wrap_array<std::complex<float>>( object ) ;
break;
case s32:
x = wrap_array<int>( object ) ;
break;
case u32:
x = wrap_array<unsigned int>( object ) ;
break;
default:
Rcpp::stop("Unsopprted data type");
case f64:
x = wrap_array<double>( object ) ;
break;
case c64:
x = wrap_array<std::complex<double>>( object ) ;
break;
case f32:
x = wrap_array<float>( object ) ;
break;
case c32:
x = wrap_array<std::complex<float>>( object ) ;
break;
case s32:
x = wrap_array<int>( object ) ;
break;
case u32:
x = wrap_array<unsigned int>( object ) ;
break;
default:
Rcpp::stop("Unsopprted data type");
break;
}
//NOTE:there is no af::sparse() in the current open source version of arrayfire
//if(object.issparse() == true){
// const int RTYPE = Rcpp::traits::r_sexptype_traits<T>::rtype;

// IntegerVector dim = IntegerVector::create( object.dims(0), object.dims(1) );

// // copy the data into R objects
// Vector<RTYPE> x(object.device<T>(), object.device<T>() + object.nonzeros() ) ;
// IntegerVector i(/*begin of the row indices of object*/, /*end*/);
// IntegerVector p(/*begin of the col indices of object*/, /*end*/);

// std::string klass ;
// switch( RTYPE ){
// case REALSXP: klass = "dgCMatrix" ; break ;
// // case INTSXP : klass = "igCMatrix" ; break ; class not exported
// case LGLSXP : klass = "lgCMatrix" ; break ;
// default:
// throw std::invalid_argument( "RTYPE not matched in conversion to sparse matrix" ) ;
// }
// S4 s(klass);
// s.slot("i") = i;
// s.slot("p") = p;
// s.slot("x") = x;
// s.slot("Dim") = dim;
// return s;
//}
return x;
}
} /* namespace RcppArrayFire */
Expand All @@ -88,8 +64,8 @@ namespace Rcpp{
}

template <> SEXP wrap (const af::array& data ){
::Rcpp::RObject x = ::RcppArrayFire::af_wrap( data );
if (data.numdims() > 1)
::Rcpp::RObject x = ::RcppArrayFire::wrap_af_impl( data );
if (data.issparse() == false && data.numdims() > 1)
x.attr("dim") = wrap(data.dims());
return x;
}
Expand Down
60 changes: 60 additions & 0 deletions tests/testthat/test-exporter-wrap-sparse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
context("export and wrap sparse array")

library(Matrix)
library(Rcpp)

create.src <- function(dtype, storage){
src.template <- '
af::array asis%s_%s(const RcppArrayFire::typed_array<%s, AF_STORAGE_%s>& x) {
return x;
}'
fill.template <- function(d, s){
sprintf(src.template, s, d, d, s)
}

return(outer(dtype, storage, FUN = fill.template))
}

# NOTE: Matrix package <= 1.2-14 does not support integer and complex matrices
# NOTE: Arrayfire v3.6.1 only supports f64/f32/c64/c32
# TODO: Add some tests for other data types(complex, bool, etc.) when they are supported.
for (src in create.src(c('f64', 'f32'), c('CSR', 'CSC', 'COO'))){
Rcpp::cppFunction(code = src, depends = "RcppArrayFire")
}

test_that("export and wrap CSR array", {
x <- as(matrix(c(1, 0, 0, 2, 3,
0, 0, 1, 0, 2), 2, 5),
'dgRMatrix')
invalid.x <- new('dgCMatrix')

expect_equal(x, asisCSR_f64(x))
expect_equal(x, asisCSR_f32(x))
expect_error(asisCSR_f64(invalid.x))
})


test_that("export and wrap CSC array", {
x <- as(matrix(c(1, 0, 0, 2, 3,
0, 0, 1, 0, 2), 2, 5),
'dgCMatrix')
invalid.x <- new('dgRMatrix')

expect_equal(x, asisCSC_f64(x))
expect_equal(x, asisCSC_f32(x))
expect_error(asisCSC_f64(invalid.x))
})


test_that("export and wrap COO array", {
x <- as(matrix(c(1, 0, 0, 2, 3,
0, 0, 1, 0, 2), 2, 5),
'dgTMatrix')
invalid.x <- new('dgCMatrix')

expect_equal(x, asisCOO_f64(x))
expect_equal(x, asisCOO_f32(x))
expect_error(asisCOO_f64(invalid.x))

})