Skip to content

Commit

Permalink
Merge pull request open-mpi#1291 from vspetrov/hcoll_derived_datatypes
Browse files Browse the repository at this point in the history
coll/hcoll mpi datatypes support
  • Loading branch information
rhc54 authored Aug 2, 2016
2 parents 640bcf6 + 58473c5 commit 0d6ddc5
Show file tree
Hide file tree
Showing 6 changed files with 375 additions and 190 deletions.
8 changes: 7 additions & 1 deletion ompi/mca/coll/hcoll/coll_hcoll.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ typedef struct mca_coll_hcoll_ops_t {
int (*hcoll_barrier)(void *);
} mca_coll_hcoll_ops_t;

typedef struct {
opal_free_list_item_t super;
dte_data_representation_t type;
} mca_coll_hcoll_dtype_t;
OBJ_CLASS_DECLARATION(mca_coll_hcoll_dtype_t);

struct mca_coll_hcoll_component_t {
/** Base coll component */
Expand Down Expand Up @@ -80,8 +85,9 @@ struct mca_coll_hcoll_component_t {

/* FCA global stuff */
mca_coll_hcoll_ops_t hcoll_ops;

ompi_free_list_t requests;
opal_free_list_t dtypes;
int derived_types_support_enabled;
};
typedef struct mca_coll_hcoll_component_t mca_coll_hcoll_component_t;

Expand Down
14 changes: 11 additions & 3 deletions ompi/mca/coll/hcoll/coll_hcoll_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "coll_hcoll.h"
#include "opal/mca/installdirs/installdirs.h"
#include "coll_hcoll_dtypes.h"

/*
* Public string showing the coll ompi_hcol component version number
Expand Down Expand Up @@ -205,8 +206,15 @@ static int hcoll_register(void)
1,
&mca_coll_hcoll_component.hcoll_datatype_fallback,
0));


#if HCOLL_API >= HCOLL_VERSION(3,6)
CHECK(reg_int("dts",NULL,
"[1|0|] Enable/Disable derived types support",
1,
&mca_coll_hcoll_component.derived_types_support_enabled,
0));
#else
mca_coll_hcoll_component.derived_types_support_enabled = 0;
#endif
return ret;
}

Expand Down Expand Up @@ -258,7 +266,7 @@ static int hcoll_close(void)

HCOL_VERBOSE(5,"HCOLL FINALIZE");
rc = hcoll_finalize();

OBJ_DESTRUCT(&cm->dtypes);
opal_progress_unregister(mca_coll_hcoll_progress);
if (HCOLL_SUCCESS != rc){
HCOL_VERBOSE(1,"Hcol library finalize failed");
Expand Down
134 changes: 118 additions & 16 deletions ompi/mca/coll/hcoll/coll_hcoll_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/

#include "ompi/datatype/ompi_datatype.h"
#include "ompi/datatype/ompi_datatype_internal.h"
#include "ompi/mca/op/op.h"
#include "hcoll/api/hcoll_dte.h"
extern int hcoll_type_attr_keyval;

/*to keep this at hand: Ids of the basic opal_datatypes:
#define OPAL_DATATYPE_INT1 4
Expand All @@ -31,9 +33,7 @@
total 15 types
*/



static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX_PREDEFINED] = {
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OMPI_DATATYPE_MAX_PREDEFINED] = {
&DTE_ZERO, /*OPAL_DATATYPE_LOOP 0 */
&DTE_ZERO, /*OPAL_DATATYPE_END_LOOP 1 */
&DTE_ZERO, /*OPAL_DATATYPE_LB 2 */
Expand All @@ -53,34 +53,113 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX
&DTE_FLOAT64, /*OPAL_DATATYPE_FLOAT8 16 */
&DTE_FLOAT96, /*OPAL_DATATYPE_FLOAT12 17 */
&DTE_FLOAT128, /*OPAL_DATATYPE_FLOAT16 18 */
#if defined(DTE_FLOAT32_COMPLEX) && defined(DTE_FLOAT64_COMPLEX)
#if defined(DTE_FLOAT32_COMPLEX)
&DTE_FLOAT32_COMPLEX, /*OPAL_DATATYPE_COMPLEX8 19 */
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX16 20 */
#else
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX8 19 */
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX16 20 */
&DTE_ZERO,
#endif
#if defined(DTE_FLOAT64_COMPLEX)
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX32 20 */
#else
&DTE_ZERO,
#endif
#if defined(DTE_FLOAT128_COMPLEX)
&DTE_FLOAT128_COMPLEX, /*OPAL_DATATYPE_COMPLEX64 21 */
#else
&DTE_ZERO,
#endif
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX32 21 */
&DTE_ZERO, /*OPAL_DATATYPE_BOOL 22 */
&DTE_ZERO, /*OPAL_DATATYPE_WCHAR 23 */
&DTE_ZERO /*OPAL_DATATYPE_UNAVAILABLE 24 */
};

static dte_data_representation_t ompi_dtype_2_dte_dtype(ompi_datatype_t *dtype){
enum {
TRY_FIND_DERIVED,
NO_DERIVED
};


#if HCOLL_API >= HCOLL_VERSION(3,6)
static inline
int hcoll_map_derived_type(ompi_datatype_t *dtype, dte_data_representation_t *new_dte)
{
int rc;
if (NULL == dtype->args) {
/* predefined type, shouldn't call this */
return OMPI_SUCCESS;
}
rc = hcoll_create_mpi_type((void*)dtype, new_dte);
return rc == HCOLL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR;
}

static dte_data_representation_t find_derived_mapping(ompi_datatype_t *dtype){
dte_data_representation_t dte = DTE_ZERO;
mca_coll_hcoll_dtype_t *hcoll_dtype;
if (mca_coll_hcoll_component.derived_types_support_enabled) {
int map_found = 0;
ompi_attr_get_c(dtype->d_keyhash, hcoll_type_attr_keyval,
(void**)&hcoll_dtype, &map_found);
if (!map_found)
hcoll_map_derived_type(dtype, &dte);
else
dte = hcoll_dtype->type;
}

return dte;
}



static inline dte_data_representation_t
ompi_predefined_derived_2_hcoll(int ompi_id) {
switch(ompi_id) {
case OMPI_DATATYPE_MPI_FLOAT_INT:
return DTE_FLOAT_INT;
case OMPI_DATATYPE_MPI_DOUBLE_INT:
return DTE_DOUBLE_INT;
case OMPI_DATATYPE_MPI_LONG_INT:
return DTE_LONG_INT;
case OMPI_DATATYPE_MPI_SHORT_INT:
return DTE_SHORT_INT;
case OMPI_DATATYPE_MPI_LONG_DOUBLE_INT:
return DTE_LONG_DOUBLE_INT;
case OMPI_DATATYPE_MPI_2INT:
return DTE_2INT;
default:
break;
}
return DTE_ZERO;
}
#endif

static dte_data_representation_t
ompi_dtype_2_hcoll_dtype( ompi_datatype_t *dtype,
const int mode)
{
int ompi_type_id = dtype->id;
int opal_type_id = dtype->super.id;
dte_data_representation_t dte_data_rep;
if (!(dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS)) {
ompi_type_id = -1;
dte_data_representation_t dte_data_rep = DTE_ZERO;

if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED) {
if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED) {
dte_data_rep = *ompi_datatype_2_dte_data_rep[opal_type_id];
}
#if HCOLL_API >= HCOLL_VERSION(3,6)
else if (TRY_FIND_DERIVED == mode){
dte_data_rep = ompi_predefined_derived_2_hcoll(ompi_type_id);
}
} else {
if (TRY_FIND_DERIVED == mode)
dte_data_rep = find_derived_mapping(dtype);
#endif
}
if (OPAL_UNLIKELY( ompi_type_id < 0 ||
ompi_type_id >= OPAL_DATATYPE_MAX_PREDEFINED)){
if (HCOL_DTE_IS_ZERO(dte_data_rep) && TRY_FIND_DERIVED == mode &&
!mca_coll_hcoll_component.hcoll_datatype_fallback) {
dte_data_rep = DTE_ZERO;
dte_data_rep.rep.in_line_rep.data_handle.in_line.in_line = 0;
dte_data_rep.rep.in_line_rep.data_handle.pointer_to_handle = (uint64_t ) &dtype->super;
return dte_data_rep;
}
return *ompi_datatype_2_dte_data_rep[opal_type_id];
return dte_data_rep;
}

static hcoll_dte_op_t* ompi_op_2_hcoll_op[OMPI_OP_BASE_FORTRAN_OP_MAX + 1] = {
Expand Down Expand Up @@ -108,4 +187,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) {
return ompi_op_2_hcoll_op[op->o_f_to_c_index];
}


#if HCOLL_API >= HCOLL_VERSION(3,6)
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
int ret = OMPI_SUCCESS;
mca_coll_hcoll_dtype_t *dtype =
(mca_coll_hcoll_dtype_t*) attr_val;

assert(dtype);
if (HCOLL_SUCCESS != (ret = hcoll_dt_destroy(dtype->type))) {
HCOL_ERROR("failed to delete type attr: hcoll_dte_destroy returned %d",ret);
return OMPI_ERROR;
}
OPAL_FREE_LIST_RETURN(&mca_coll_hcoll_component.dtypes,
&dtype->super);

return OMPI_SUCCESS;
}
#else
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
/*Do nothing - it's an old version of hcoll w/o dtypes support */
return OMPI_SUCCESS;
}
#endif
#endif /* COLL_HCOLL_DTYPES_H */
24 changes: 24 additions & 0 deletions ompi/mca/coll/hcoll/coll_hcoll_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

#include "ompi_config.h"
#include "coll_hcoll.h"
#include "coll_hcoll_dtypes.h"

int hcoll_comm_attr_keyval;
int hcoll_type_attr_keyval;

/*
* Initial query function that is invoked during MPI_INIT, allowing
Expand Down Expand Up @@ -211,6 +213,10 @@ int mca_coll_hcoll_progress(void)
}


OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t,
opal_free_list_item_t,
NULL,NULL);

/*
* Invoked when there's a new communicator that has been created.
* Look at the communicator and decide which set of functions and
Expand Down Expand Up @@ -288,6 +294,24 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
HCOL_ERROR("Hcol comm keyval create failed");
return NULL;
}

if (mca_coll_hcoll_component.derived_types_support_enabled) {
copy_fn.attr_datatype_copy_fn = (MPI_Type_internal_copy_attr_function *) MPI_TYPE_NULL_COPY_FN;
del_fn.attr_datatype_delete_fn = hcoll_type_attr_del_fn;
err = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, &hcoll_type_attr_keyval, NULL ,0, NULL);
if (OMPI_SUCCESS != err) {
cm->hcoll_enable = 0;
hcoll_finalize();
opal_progress_unregister(mca_coll_hcoll_progress);
HCOL_ERROR("Hcol type keyval create failed");
return NULL;
}
}
OBJ_CONSTRUCT(&cm->dtypes, opal_free_list_t);
opal_free_list_init(&cm->dtypes, sizeof(mca_coll_hcoll_dtype_t),
OBJ_CLASS(mca_coll_hcoll_dtype_t),
32, -1, 32);

}

hcoll_module = OBJ_NEW(mca_coll_hcoll_module_t);
Expand Down
Loading

0 comments on commit 0d6ddc5

Please sign in to comment.