diff --git a/ompi/mca/coll/acoll/coll_acoll.h b/ompi/mca/coll/acoll/coll_acoll.h index aaf636b1bae..91f9a2475fa 100644 --- a/ompi/mca/coll/acoll/coll_acoll.h +++ b/ompi/mca/coll/acoll/coll_acoll.h @@ -22,6 +22,7 @@ #ifdef HAVE_XPMEM_H #include "opal/mca/rcache/base/base.h" +#include "opal/class/opal_hash_table.h" #include #endif @@ -37,12 +38,14 @@ extern int mca_coll_acoll_max_comms; extern int mca_coll_acoll_sg_size; extern int mca_coll_acoll_sg_scale; extern int mca_coll_acoll_node_size; +extern int mca_coll_acoll_force_numa; extern int mca_coll_acoll_use_dynamic_rules; extern int mca_coll_acoll_mnode_enable; extern int mca_coll_acoll_bcast_lin0; extern int mca_coll_acoll_bcast_lin1; extern int mca_coll_acoll_bcast_lin2; extern int mca_coll_acoll_bcast_nonsg; +extern int mca_coll_acoll_bcast_socket; extern int mca_coll_acoll_allgather_lin; extern int mca_coll_acoll_allgather_ring_1; @@ -123,6 +126,7 @@ typedef struct coll_acoll_data { void **xpmem_raddr; mca_rcache_base_module_t **rcache; void *scratch; + opal_hash_table_t **xpmem_reg_tracker_ht; #endif opal_shmem_ds_t *allshmseg_id; void **allshmmmap_sbuf; @@ -160,7 +164,7 @@ typedef struct coll_acoll_subcomms { int numa_root; int socket_ldr_root; int base_root[MCA_COLL_ACOLL_NUM_BASE_LYRS][MCA_COLL_ACOLL_NUM_LAYERS]; - int base_rank[MCA_COLL_ACOLL_NUM_BASE_LYRS]; + int base_rank[MCA_COLL_ACOLL_NUM_BASE_LYRS][MCA_COLL_ACOLL_NUM_LAYERS]; int socket_rank; int subgrp_size; int initialized; @@ -198,12 +202,14 @@ struct mca_coll_acoll_module_t { int log2_sg_cnt; int node_cnt; int log2_node_cnt; + int force_numa; int use_dyn_rules; // Todo: Use substructure for every API related ones int use_mnode; int use_lin0; int use_lin1; int use_lin2; + int use_socket; int mnode_sg_size; int mnode_log2_sg_size; int allg_lin; diff --git a/ompi/mca/coll/acoll/coll_acoll_allgather.c b/ompi/mca/coll/acoll/coll_acoll_allgather.c index 5e4db719277..3fc2167193f 100644 --- a/ompi/mca/coll/acoll/coll_acoll_allgather.c +++ b/ompi/mca/coll/acoll/coll_acoll_allgather.c @@ -344,7 +344,7 @@ static inline int mca_coll_acoll_allgather_intra(const void *sbuf, size_t scount } /* Return if all ranks belong to single subgroup */ - if (num_sgs == 1) { + if (1 == num_sgs) { /* All done */ return err; } @@ -396,7 +396,7 @@ static inline int mca_coll_acoll_allgather_intra(const void *sbuf, size_t scount } /* Now all base ranks have the full data */ /* Do broadcast within subgroups from the base ranks for the extra data */ - if (sg_id == 0) { + if (0 == sg_id) { num_data_blks = 1; data_blk_size[0] = bcount * (num_sgs - 2) + last_subgrp_rcnt; blk_ofst[0] = bcount; @@ -527,7 +527,7 @@ int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_dataty if (num_nodes > 1) { assert(subc->local_r_comm != NULL); } - intra_comm = num_nodes == 1 ? comm : subc->local_r_comm; + intra_comm = 1 == num_nodes ? comm : subc->local_r_comm; } err = mca_coll_acoll_allgather_intra(sbuf, scount, sdtype, local_rbuf, rcount, rdtype, intra_comm, module); @@ -536,7 +536,7 @@ int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_dataty } /* Return if intra-node communicator */ - if ((num_nodes == 1) || (size <= 2)) { + if ((1 == num_nodes) || (size <= 2)) { /* All done */ return err; } @@ -592,7 +592,7 @@ int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_dataty } /* End of if inter leader */ /* Do intra node broadcast */ - if (node_id == 0) { + if (0 == node_id) { num_data_blks = 1; data_blk_size[0] = bcount * (num_nodes - 2) + last_subgrp_rcnt; blk_ofst[0] = bcount; @@ -613,7 +613,7 @@ int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_dataty /* Loop over data blocks */ for (i = 0; i < num_data_blks; i++) { char *buff = (char *) rbuf + (ptrdiff_t) blk_ofst[i] * rext; - err = (comm)->c_coll->coll_bcast(buff, data_blk_size[i], rdtype, 0, subc->local_r_comm, + err = ompi_coll_base_bcast_intra_basic_linear(buff, data_blk_size[i], rdtype, 0, subc->local_r_comm, module); if (MPI_SUCCESS != err) { return err; diff --git a/ompi/mca/coll/acoll/coll_acoll_allreduce.c b/ompi/mca/coll/acoll/coll_acoll_allreduce.c index 46c5554810c..79ef9c4807e 100644 --- a/ompi/mca/coll/acoll/coll_acoll_allreduce.c +++ b/ompi/mca/coll/acoll/coll_acoll_allreduce.c @@ -59,7 +59,7 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si int size; size_t total_dsize, dsize; - coll_acoll_init(module, comm, subc->data, subc); + coll_acoll_init(module, comm, subc->data, subc, 0); coll_acoll_data_t *data = subc->data; if (NULL == data) { return -1; @@ -82,7 +82,7 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si if (!subc->xpmem_use_sr_buf) { tmp_rbuf = (char *) data->scratch; tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; - if ((sbuf == MPI_IN_PLACE)) { + if ((MPI_IN_PLACE == sbuf)) { memcpy(tmp_sbuf, rbuf, total_dsize); } else { memcpy(tmp_sbuf, sbuf, total_dsize); @@ -90,7 +90,7 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si } else { tmp_sbuf = (char *) sbuf; tmp_rbuf = (char *) rbuf; - if (sbuf == MPI_IN_PLACE) { + if (MPI_IN_PLACE == sbuf) { tmp_sbuf = (char *) rbuf; } } @@ -153,7 +153,7 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si my_count_size = (l2_local_rank == (local_size - 1)) ? chunk + (count % local_size) : chunk; - if (l2_local_rank == 0) { + if (0 == l2_local_rank) { for (int i = 1; i < local_size; i++) { ompi_op_reduce(op, (char *) data->xpmem_raddr[l2_gp[i]], (char *) tmp_rbuf, my_count_size, dtype); @@ -192,7 +192,7 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf, int size; size_t total_dsize, dsize; - coll_acoll_init(module, comm, subc->data, subc); + coll_acoll_init(module, comm, subc->data, subc, 0); coll_acoll_data_t *data = subc->data; if (NULL == data) { return -1; @@ -207,7 +207,7 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf, if (!subc->xpmem_use_sr_buf) { tmp_rbuf = (char *) data->scratch; tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; - if ((sbuf == MPI_IN_PLACE)) { + if ((MPI_IN_PLACE == sbuf)) { memcpy(tmp_sbuf, rbuf, total_dsize); } else { memcpy(tmp_sbuf, sbuf, total_dsize); @@ -215,7 +215,7 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf, } else { tmp_sbuf = (char *) sbuf; tmp_rbuf = (char *) rbuf; - if (sbuf == MPI_IN_PLACE) { + if (MPI_IN_PLACE == sbuf) { tmp_sbuf = (char *) rbuf; } } @@ -242,7 +242,7 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf, size_t chunk = count / size; size_t my_count_size = (rank == (size - 1)) ? (count / size) + count % size : count / size; - if (rank == 0) { + if (0 == rank) { if (sbuf != MPI_IN_PLACE) memcpy(tmp_rbuf, sbuf, my_count_size * dsize); } else { @@ -299,7 +299,7 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp opal_atomic_wmb(); int val; - if (up == 1) { + if (1 == up) { val = data->sync[0]; } else { val = data->sync[1]; @@ -346,7 +346,7 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp __ATOMIC_RELAXED); } } - if (up == 1) { + if (1 == up) { data->sync[0] = val; } else { data->sync[1] = val; @@ -361,8 +361,7 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c { size_t dsize; int err = MPI_SUCCESS; - - coll_acoll_init(module, comm, subc->data, subc); + coll_acoll_init(module, comm, subc->data, subc, 0); coll_acoll_data_t *data = subc->data; if (NULL == data) { return -1; @@ -434,7 +433,7 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c } if (intra && (ompi_comm_size(subc->numa_comm) > 1)) { - err = mca_coll_acoll_bcast(rbuf, count, dtype, 0, subc->numa_comm, module); + err = ompi_coll_base_bcast_intra_basic_linear(rbuf, count, dtype, 0, subc->numa_comm, module); } return err; } @@ -451,7 +450,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count, ompi_datatype_type_size(dtype, &dsize); total_dsize = dsize * count; - if (size == 1) { + if (1 == size) { if (MPI_IN_PLACE != sbuf) { memcpy((char *) rbuf, sbuf, total_dsize); } @@ -483,7 +482,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count, alg = coll_allreduce_decision_fixed(size, total_dsize); - if (num_nodes == 1) { + if (1 == num_nodes) { if (total_dsize < 32) { return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, comm, module); @@ -494,10 +493,10 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count, return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, comm, module); } else if (total_dsize < 65536) { - if (alg == 1) { + if (1 == alg) { return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, comm, module); - } else if (alg == 2) { + } else if (2 == alg) { return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, module); } else { /*alg == 3 */ diff --git a/ompi/mca/coll/acoll/coll_acoll_barrier.c b/ompi/mca/coll/acoll/coll_acoll_barrier.c index 8272136ad25..d57db48b91f 100644 --- a/ompi/mca/coll/acoll/coll_acoll_barrier.c +++ b/ompi/mca/coll/acoll/coll_acoll_barrier.c @@ -141,7 +141,7 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base } size = ompi_comm_size(comm); - if (size == 1) { + if (1 == size) { return err; } if (!subc->initialized && size > 1) { diff --git a/ompi/mca/coll/acoll/coll_acoll_bcast.c b/ompi/mca/coll/acoll/coll_acoll_bcast.c index 22103317d22..9a9d447a167 100644 --- a/ompi/mca/coll/acoll/coll_acoll_bcast.c +++ b/ompi/mca/coll/acoll/coll_acoll_bcast.c @@ -126,8 +126,9 @@ static int bcast_flat_tree(void *buff, size_t count, struct ompi_datatype_t *dat *lin_2 = l2; static inline void coll_bcast_decision_fixed(int size, size_t total_dsize, int node_size, - int *sg_cnt, int *use_0, int *use_numa, int *lin_0, - int *lin_1, int *lin_2, + int *sg_cnt, int *use_0, int *use_numa, + int *use_socket, int *lin_0, + int *lin_1, int *lin_2, int num_nodes, mca_coll_acoll_module_t *acoll_module, coll_acoll_subcomms_t *subc) { @@ -135,8 +136,13 @@ static inline void coll_bcast_decision_fixed(int size, size_t total_dsize, int n *use_0 = 0; *lin_0 = 0; *use_numa = 0; + *use_socket = 0; if (size <= node_size) { - if (size <= sg_size) { + if (acoll_module->use_dyn_rules) { + *sg_cnt = (acoll_module->mnode_sg_size == acoll_module->sg_cnt) ? acoll_module->sg_cnt : node_size; + *use_0 = 0; + SET_BCAST_PARAMS(acoll_module->use_lin0, acoll_module->use_lin1, acoll_module->use_lin2) + } else if (size <= sg_size) { *sg_cnt = sg_size; if (total_dsize <= 8192) { SET_BCAST_PARAMS(0, 0, 0) @@ -223,100 +229,109 @@ static inline void coll_bcast_decision_fixed(int size, size_t total_dsize, int n } } else { if (acoll_module->use_dyn_rules) { - *sg_cnt = acoll_module->mnode_sg_size; + *sg_cnt = (acoll_module->mnode_sg_size == acoll_module->sg_cnt) ? acoll_module->sg_cnt : node_size; *use_0 = acoll_module->use_mnode; SET_BCAST_PARAMS(acoll_module->use_lin0, acoll_module->use_lin1, acoll_module->use_lin2) } else { - int derived_node_size = subc->derived_node_size; *use_0 = 1; - if (size <= (derived_node_size << 2)) { - size_t dsize_thresh[2][3] = {{512, 8192, 131072}, {128, 8192, 65536}}; - int thr_ind = (size <= (derived_node_size << 1)) ? 0 : 1; - if (total_dsize <= dsize_thresh[thr_ind][0]) { - *sg_cnt = node_size; - SET_BCAST_PARAMS(0, 0, 0) - } else if (total_dsize <= dsize_thresh[thr_ind][1]) { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(0, 0, 0) - } else if (total_dsize <= dsize_thresh[thr_ind][2]) { - *sg_cnt = sg_size; + *sg_cnt = sg_size; + if (2 == num_nodes) { + SET_BCAST_PARAMS(1, 1, 1) + *use_socket = 1; + *use_numa = (total_dsize <= 2097152) ? 0 : 1; + } else if (num_nodes <= 4) { + if (total_dsize <= 512) { + *use_socket = 1; + SET_BCAST_PARAMS(1, 1, 0) + } else if (total_dsize <= 2097152) { + *use_socket = 1; SET_BCAST_PARAMS(1, 1, 1) } else { - *sg_cnt = node_size; + *use_numa = 1; + *use_socket = (total_dsize <= 4194304) ? 0 : 1; SET_BCAST_PARAMS(1, 1, 1) } - } else if (size <= (derived_node_size << 3)) { - if (total_dsize <= 1024) { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(0, 0, 1) - } else if (total_dsize <= 8192) { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(1, 0, 1) - } else if (total_dsize <= 65536) { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(1, 1, 1) - } else if (total_dsize <= 2097152) { - *sg_cnt = node_size; - SET_BCAST_PARAMS(0, 1, 1) + } else if (num_nodes <= 6) { + SET_BCAST_PARAMS(1, 1, 1) + if (total_dsize <= 524288) { + *use_socket = 1; } else { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(0, 0, 0) + *use_numa = 1; } - } else if (size <= (derived_node_size << 4)) { - if (total_dsize <= 64) { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(0, 1, 1) - } else if (total_dsize <= 8192) { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(0, 0, 1) - } else if (total_dsize <= 32768) { - *sg_cnt = sg_size; + } else if (num_nodes <= 8) { + SET_BCAST_PARAMS(1, 1, 1) + if (total_dsize <= 8192) { + *use_numa = 0; + } else { + *use_numa = 1; + } + } else if (num_nodes <= 10) { + *use_numa = 1; + if (total_dsize <= 32768) { + SET_BCAST_PARAMS(1, 1, 0) + } else { SET_BCAST_PARAMS(1, 1, 1) + } + } else { + *use_numa = 1; + if (total_dsize <= 64) { + SET_BCAST_PARAMS(1, 0, 1) } else if (total_dsize <= 2097152) { - *sg_cnt = node_size; - SET_BCAST_PARAMS(0, 1, 1) + SET_BCAST_PARAMS(1, 1, 1) } else { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(0, 0, 0) + *use_socket = 1; + SET_BCAST_PARAMS(0, 1, 1) } - } else { - *sg_cnt = sg_size; - SET_BCAST_PARAMS(0, 0, 0) } } } + if (-1 != acoll_module->force_numa) { + *use_numa = acoll_module->force_numa; + if (acoll_module->force_numa) { + *sg_cnt = sg_size; + } + } + if (-1 != acoll_module->use_socket) { + *use_socket = acoll_module->use_socket; + } } static inline void coll_acoll_bcast_subcomms(struct ompi_communicator_t *comm, coll_acoll_subcomms_t *subc, struct ompi_communicator_t **subcomms, int *subc_roots, int root, int num_nodes, int use_0, int no_sg, - int use_numa) + int use_numa, int use_socket) { + int lyr_id = use_socket ? MCA_COLL_ACOLL_LYR_SOCKET : MCA_COLL_ACOLL_LYR_NODE; /* Node leaders */ if (use_0) { subcomms[MCA_COLL_ACOLL_NODE_L] = subc->leader_comm; subc_roots[MCA_COLL_ACOLL_NODE_L] = subc->outer_grp_root; } + /* Socket leaders */ + if (use_socket) { + subcomms[MCA_COLL_ACOLL_NODE_L] = subc->socket_ldr_comm; + subc_roots[MCA_COLL_ACOLL_NODE_L] = subc->socket_ldr_root; + } /* Intra comm */ - if ((num_nodes > 1) && use_0) { - subc_roots[MCA_COLL_ACOLL_INTRA] = subc->is_root_node - ? subc->local_root[MCA_COLL_ACOLL_LYR_NODE] - : 0; - subcomms[MCA_COLL_ACOLL_INTRA] = subc->local_comm; + if (((num_nodes > 1) && use_0) || use_socket) { + int is_root = use_socket ? subc->is_root_socket : subc->is_root_node; + subc_roots[MCA_COLL_ACOLL_INTRA] = is_root ? subc->local_root[lyr_id] : 0; + subcomms[MCA_COLL_ACOLL_INTRA] = use_socket ? subc->socket_comm : subc->local_comm; } else { subc_roots[MCA_COLL_ACOLL_INTRA] = root; subcomms[MCA_COLL_ACOLL_INTRA] = comm; } /* Base ranks comm */ + int parent = lyr_id; if (no_sg) { subcomms[MCA_COLL_ACOLL_L3_L] = subcomms[MCA_COLL_ACOLL_INTRA]; subc_roots[MCA_COLL_ACOLL_L3_L] = subc_roots[MCA_COLL_ACOLL_INTRA]; } else { subcomms[MCA_COLL_ACOLL_L3_L] = subc->base_comm[MCA_COLL_ACOLL_L3CACHE] - [MCA_COLL_ACOLL_LYR_NODE]; + [parent]; subc_roots[MCA_COLL_ACOLL_L3_L] = subc->base_root[MCA_COLL_ACOLL_L3CACHE] - [MCA_COLL_ACOLL_LYR_NODE]; + [parent]; } /* Subgroup comm */ subcomms[MCA_COLL_ACOLL_LEAF] = subc->subgrp_comm; @@ -325,9 +340,9 @@ static inline void coll_acoll_bcast_subcomms(struct ompi_communicator_t *comm, /* Override with numa when needed */ if (use_numa) { subcomms[MCA_COLL_ACOLL_L3_L] = subc->base_comm[MCA_COLL_ACOLL_NUMA] - [MCA_COLL_ACOLL_LYR_NODE]; + [parent]; subc_roots[MCA_COLL_ACOLL_L3_L] = subc->base_root[MCA_COLL_ACOLL_NUMA] - [MCA_COLL_ACOLL_LYR_NODE]; + [parent]; subcomms[MCA_COLL_ACOLL_LEAF] = subc->numa_comm; subc_roots[MCA_COLL_ACOLL_LEAF] = subc->numa_root; } @@ -338,7 +353,7 @@ static int mca_coll_acoll_bcast_intra_node(void *buff, size_t count, struct ompi coll_acoll_subcomms_t *subc, struct ompi_communicator_t **subcomms, int *subc_roots, int lin_1, int lin_2, int no_sg, int use_numa, - int world_rank) + int use_socket, int world_rank) { int size; int rank; @@ -363,8 +378,9 @@ static int mca_coll_acoll_bcast_intra_node(void *buff, size_t count, struct ompi if (no_sg) { is_base = 1; } else { - int ind = use_numa ? MCA_COLL_ACOLL_NUMA : MCA_COLL_ACOLL_L3CACHE; - is_base = rank == subc->base_rank[ind] ? 1 : 0; + int ind1 = use_numa ? MCA_COLL_ACOLL_NUMA : MCA_COLL_ACOLL_L3CACHE; + int ind2 = use_socket ? MCA_COLL_ACOLL_LYR_SOCKET : MCA_COLL_ACOLL_LYR_NODE; + is_base = rank == subc->base_rank[ind1][ind2] ? 1 : 0; } /* All base ranks receive from root */ @@ -439,7 +455,7 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat int num_nodes; int use_0 = 0; int lin_0 = 0, lin_1 = 0, lin_2 = 0; - int use_numa = 0; + int use_numa = 0, use_socket = 0; int no_sg; size_t total_dsize, dsize; mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; @@ -448,6 +464,12 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat struct ompi_communicator_t *subcomms[MCA_COLL_ACOLL_NUM_SC] = {NULL}; int subc_roots[MCA_COLL_ACOLL_NUM_SC] = {-1}; + /* For small communicators, use linear bcast */ + size = ompi_comm_size(comm); + if (size < 8) { + return ompi_coll_base_bcast_intra_basic_linear(buff, count, datatype, root, comm, module); + } + /* Obtain the subcomms structure */ err = check_and_create_subc(comm, acoll_module, &subc); /* Fallback to knomial if subcomms is not obtained */ @@ -460,7 +482,6 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat && (root != subc->prev_init_root)) { return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); } - size = ompi_comm_size(comm); if ((!subc->initialized || (root != subc->prev_init_root)) && size > 2) { err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, root); if (MPI_SUCCESS != err) { @@ -482,7 +503,7 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat /* Use knomial for nodes 8 and above and non-large messages */ if ((num_nodes >= 8 && total_dsize <= 65536) - || (num_nodes == 1 && size >= 256 && total_dsize < 16384)) { + || (1 == num_nodes && size >= 256 && total_dsize < 16384)) { return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); } @@ -490,14 +511,14 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat /* sg_cnt determines subgroup based communication */ /* lin_1 and lin_2 indicate whether to use linear or log based sends/receives across and within subgroups respectively. */ - coll_bcast_decision_fixed(size, total_dsize, node_size, &sg_cnt, &use_0, &use_numa, &lin_0, - &lin_1, &lin_2, acoll_module, subc); + coll_bcast_decision_fixed(size, total_dsize, node_size, &sg_cnt, &use_0, &use_numa, &use_socket, &lin_0, + &lin_1, &lin_2, num_nodes, acoll_module, subc); no_sg = (sg_cnt == node_size) ? 1 : 0; if (size <= 2) no_sg = 1; coll_acoll_bcast_subcomms(comm, subc, subcomms, subc_roots, root, num_nodes, use_0, no_sg, - use_numa); + use_numa, use_socket); reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); if (NULL == reqs) { @@ -507,7 +528,7 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat preq = reqs; err = MPI_SUCCESS; - if (use_0) { + if (use_0 || use_socket) { if (subc_roots[MCA_COLL_ACOLL_NODE_L] != -1) { err = bcast_func[lin_0](buff, count, datatype, subc_roots[MCA_COLL_ACOLL_NODE_L], subcomms[MCA_COLL_ACOLL_NODE_L], preq, &nreqs, rank); @@ -528,7 +549,7 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat } err = mca_coll_acoll_bcast_intra_node(buff, count, datatype, module, subc, subcomms, subc_roots, - lin_1, lin_2, no_sg, use_numa, rank); + lin_1, lin_2, no_sg, use_numa, use_socket, rank); if (MPI_SUCCESS != err) { ompi_coll_base_free_reqs(reqs, nreqs); diff --git a/ompi/mca/coll/acoll/coll_acoll_component.c b/ompi/mca/coll/acoll/coll_acoll_component.c index 8f15b6b265c..6a8651fcf81 100644 --- a/ompi/mca/coll/acoll/coll_acoll_component.c +++ b/ompi/mca/coll/acoll/coll_acoll_component.c @@ -29,12 +29,14 @@ int mca_coll_acoll_max_comms = 10; int mca_coll_acoll_sg_size = 8; int mca_coll_acoll_sg_scale = 1; int mca_coll_acoll_node_size = 128; +int mca_coll_acoll_force_numa = -1; int mca_coll_acoll_use_dynamic_rules = 0; int mca_coll_acoll_mnode_enable = 1; int mca_coll_acoll_bcast_lin0 = 0; int mca_coll_acoll_bcast_lin1 = 0; int mca_coll_acoll_bcast_lin2 = 0; int mca_coll_acoll_bcast_nonsg = 0; +int mca_coll_acoll_bcast_socket = -1; int mca_coll_acoll_allgather_lin = 0; int mca_coll_acoll_allgather_ring_1 = 0; int mca_coll_acoll_reserve_memory_for_algo = 0; @@ -112,6 +114,10 @@ static int acoll_register(void) "Size of node for multinode cases", MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_node_size); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "force_numa", + "Force enable/disable NUMA based comm split", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_force_numa); (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "use_dynamic_rules", @@ -140,6 +146,10 @@ static int acoll_register(void) &mca_coll_acoll_component.collm_version, "bcast_nonsg", "Flag to turn on/off subgroup based algorithms for multinode", MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_nonsg); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "bcast_socket", + "Flag to turn on/off socket based algorithms for bcast", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_socket); (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "allgather_lin", "Flag to indicate use of linear allgather for multinode", MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, @@ -230,10 +240,29 @@ static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module) if (NULL != data) { #ifdef HAVE_XPMEM_H for (int j = 0; j < data->comm_size; j++) { + if (ompi_comm_rank(subc->orig_comm) == j) { + continue; + } + // Dereg all rcache regs. + uint64_t key = 0; + uint64_t value = 0; + uint64_t zero_value = 0; + OPAL_HASH_TABLE_FOREACH(key,uint64,value,(data->xpmem_reg_tracker_ht[j])) { + mca_rcache_base_registration_t* reg = + (mca_rcache_base_registration_t*) key; + + for (uint64_t d_i = 0; d_i < value; ++d_i) { + (data->rcache[j])->rcache_deregister(data->rcache[j], reg); + } + opal_hash_table_set_value_uint64(data->xpmem_reg_tracker_ht[j], + key, (void*)(zero_value)); + } xpmem_release(data->all_apid[j]); - xpmem_remove(data->allseg_id[j]); mca_rcache_base_module_destroy(data->rcache[j]); + opal_hash_table_remove_all(data->xpmem_reg_tracker_ht[j]); + OBJ_RELEASE(data->xpmem_reg_tracker_ht[j]); } + xpmem_remove(data->allseg_id[ompi_comm_rank(subc->orig_comm)]); free(data->allseg_id); data->allseg_id = NULL; @@ -249,6 +278,8 @@ static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module) data->xpmem_raddr = NULL; free(data->scratch); data->scratch = NULL; + free(data->xpmem_reg_tracker_ht); + data->xpmem_reg_tracker_ht = NULL; free(data->rcache); data->rcache = NULL; #endif diff --git a/ompi/mca/coll/acoll/coll_acoll_module.c b/ompi/mca/coll/acoll/coll_acoll_module.c index 3d9242226cf..bbab6034132 100644 --- a/ompi/mca/coll/acoll/coll_acoll_module.c +++ b/ompi/mca/coll/acoll/coll_acoll_module.c @@ -132,11 +132,15 @@ mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *co break; } + acoll_module->force_numa = mca_coll_acoll_force_numa; acoll_module->use_dyn_rules = mca_coll_acoll_use_dynamic_rules; acoll_module->use_mnode = mca_coll_acoll_mnode_enable; + /* Value of 0 is currently unsupported for mnode_enable */ + acoll_module->use_mnode = 1; acoll_module->use_lin0 = mca_coll_acoll_bcast_lin0; acoll_module->use_lin1 = mca_coll_acoll_bcast_lin1; acoll_module->use_lin2 = mca_coll_acoll_bcast_lin2; + acoll_module->use_socket = mca_coll_acoll_bcast_socket; if (mca_coll_acoll_bcast_nonsg) { acoll_module->mnode_sg_size = acoll_module->node_cnt; acoll_module->mnode_log2_sg_size = acoll_module->log2_node_cnt; @@ -182,6 +186,11 @@ static int acoll_module_enable(mca_coll_base_module_t *module, struct ompi_commu ACOLL_INSTALL_COLL_API(comm, acoll_module, gather); ACOLL_INSTALL_COLL_API(comm, acoll_module, reduce); + /* Initialize k-nomial tree */ + module->base_data->cached_kmtree = NULL; + module->base_data->cached_kmtree_root = -1; + module->base_data->cached_kmtree_radix = 4; + /* All done */ return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/acoll/coll_acoll_reduce.c b/ompi/mca/coll/acoll/coll_acoll_reduce.c index 505c3da5206..32442fc4889 100644 --- a/ompi/mca/coll/acoll/coll_acoll_reduce.c +++ b/ompi/mca/coll/acoll/coll_acoll_reduce.c @@ -64,14 +64,14 @@ static inline int coll_acoll_reduce_topo(const void *sbuf, void *rbuf, size_t co rank = ompi_comm_rank(comm); tmp_sbuf = (char *) sbuf; - if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + if ((MPI_IN_PLACE == sbuf) && (rank == root)) { tmp_sbuf = (char *) rbuf; } int i; int ind1 = MCA_COLL_ACOLL_L3CACHE; int ind2 = MCA_COLL_ACOLL_LYR_NODE; - int is_base = rank == subc->base_rank[ind1] ? 1 : 0; + int is_base = rank == subc->base_rank[ind1][ind2] ? 1 : 0; int bound = subc->subgrp_size; sz = ompi_comm_size(subc->base_comm[ind1][ind2]); @@ -166,7 +166,7 @@ static inline int mca_coll_acoll_reduce_xpmem(const void *sbuf, void *rbuf, size mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; - coll_acoll_init(module, comm, subc->data, subc); + coll_acoll_init(module, comm, subc->data, subc, 0); coll_acoll_reserve_mem_t *reserve_mem_rbuf_reduce = NULL; if (subc->xpmem_use_sr_buf != 0) { reserve_mem_rbuf_reduce = &(acoll_module->reserve_mem_s); @@ -192,17 +192,17 @@ static inline int mca_coll_acoll_reduce_xpmem(const void *sbuf, void *rbuf, size char *tmp_sbuf = NULL; char *tmp_rbuf = NULL; - if (subc->xpmem_use_sr_buf == 0) { + if (0 == subc->xpmem_use_sr_buf) { tmp_rbuf = (char *) data->scratch; tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; - if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + if ((MPI_IN_PLACE == sbuf) && (rank == root)) { memcpy(tmp_sbuf, rbuf, total_dsize); } else { memcpy(tmp_sbuf, sbuf, total_dsize); } } else { tmp_sbuf = (char *) sbuf; - if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + if ((MPI_IN_PLACE == sbuf) && (rank == root)) { tmp_sbuf = (char *) rbuf; } @@ -270,7 +270,7 @@ static inline int mca_coll_acoll_reduce_xpmem(const void *sbuf, void *rbuf, size chunk = count / local_size; my_count_size = (l2_local_rank == (local_size - 1)) ? chunk + (count % local_size) : chunk; - if (l2_local_rank == 0) { + if (0 == l2_local_rank) { for (int i = 1; i < local_size; i++) { ompi_op_reduce(op, (char *) data->xpmem_raddr[l2_gp[i]], (char *) tmp_rbuf, my_count_size, dtype); @@ -291,7 +291,7 @@ static inline int mca_coll_acoll_reduce_xpmem(const void *sbuf, void *rbuf, size } } ompi_coll_base_barrier_intra_tree(comm, module); - if (subc->xpmem_use_sr_buf == 0) { + if (0 == subc->xpmem_use_sr_buf) { if (rank == root) { memcpy(rbuf, tmp_rbuf, total_dsize); } @@ -353,15 +353,15 @@ int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, size_t count, num_nodes = subc->num_nodes; - if (num_nodes == 1) { + if (1 == num_nodes) { if (total_dsize < 262144) { - if (alg == -1 /* interaction with xpmem implementation causing issues 0*/) { + if (-1 == alg /* interaction with xpmem implementation causing issues 0*/) { return coll_acoll_reduce_topo(sbuf, rbuf, count, dtype, op, root, comm, module, subc); - } else if (alg == 1) { + } else if (1 == alg) { return ompi_coll_base_reduce_intra_basic_linear(sbuf, rbuf, count, dtype, op, root, comm, module); - } else if (alg == 2) { + } else if (2 == alg) { return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, comm, module, 0, 0); } else { /*(alg == 3)*/ @@ -373,7 +373,7 @@ int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, size_t count, if ((((subc->xpmem_use_sr_buf != 0) && (acoll_module->reserve_mem_s).reserve_mem_allocate && ((acoll_module->reserve_mem_s).reserve_mem_size >= total_dsize)) - || ((subc->xpmem_use_sr_buf == 0) && (subc->xpmem_buf_size > 2 * total_dsize))) + || ((0 == subc->xpmem_use_sr_buf) && (subc->xpmem_buf_size > 2 * total_dsize))) && (subc->without_xpmem != 1)) { return mca_coll_acoll_reduce_xpmem(sbuf, rbuf, count, dtype, op, root, comm, module, subc); diff --git a/ompi/mca/coll/acoll/coll_acoll_utils.h b/ompi/mca/coll/acoll/coll_acoll_utils.h index 2ba56275db3..c665ad2babc 100644 --- a/ompi/mca/coll/acoll/coll_acoll_utils.h +++ b/ompi/mca/coll/acoll/coll_acoll_utils.h @@ -230,7 +230,7 @@ static inline int comm_grp_ranks_local(ompi_communicator_t *comm, ompi_communica } static inline int mca_coll_acoll_create_base_comm(ompi_communicator_t **parent_comm, - coll_acoll_subcomms_t *subc, int color, int rank, + coll_acoll_subcomms_t *subc, int color, int *rank, int *root, int base_lyr) { int i; @@ -240,7 +240,7 @@ static inline int mca_coll_acoll_create_base_comm(ompi_communicator_t **parent_c int is_root_node = 0; /* Create base comm */ - err = ompi_comm_split(parent_comm[i], color, rank, &subc->base_comm[base_lyr][i], false); + err = ompi_comm_split(parent_comm[i], color, rank[i], &subc->base_comm[base_lyr][i], false); if (MPI_SUCCESS != err) return err; @@ -340,6 +340,7 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, subc->numa_root = 0; subc->is_root_socket = 0; subc->socket_ldr_root = -1; + subc->is_root_node = 0; if (subc->initialized) { if (subc->num_nodes > 1) { @@ -377,13 +378,14 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, } /* Further subcommunicators based on root */ - if (subc->num_nodes > 1) { + int *subgrp_ranks = NULL, *numa_ranks = NULL, *socket_ranks = NULL; + ompi_communicator_t *parent_comm[MCA_COLL_ACOLL_NUM_LAYERS]; + int parent_rank[MCA_COLL_ACOLL_NUM_LAYERS]; + if (subc->num_nodes > 1) { /* Multinode case */ int local_rank = ompi_comm_rank(subc->local_comm); int color = MPI_UNDEFINED; int is_root_node = 0, is_root_socket = 0; int local_root = 0; - int *subgrp_ranks = NULL, *numa_ranks = NULL, *socket_ranks = NULL; - ompi_communicator_t *parent_comm[MCA_COLL_ACOLL_NUM_LAYERS]; /* Initializations */ subc->local_root[MCA_COLL_ACOLL_LYR_NODE] = 0; @@ -395,7 +397,7 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, /* Create subcommunicator with leader ranks */ color = 1; - if (!subc->is_root_node && (local_rank == 0)) { + if (!subc->is_root_node && (0 == local_rank)) { color = 0; } if (rank == root) { @@ -419,56 +421,59 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, local_root); /* Create subcommunicator with socket leaders */ - subc->socket_rank = subc->is_root_socket == 1 ? local_root : socket_ranks[0]; + subc->socket_rank = 1 == subc->is_root_socket ? local_root : socket_ranks[0]; color = local_rank == subc->socket_rank ? 0 : 1; - err = ompi_comm_split(subc->local_comm, color, local_rank, &subc->socket_ldr_comm, false); + err = ompi_comm_split(comm, color, rank, &subc->socket_ldr_comm, false); if (MPI_SUCCESS != err) return err; /* Find out local rank of root in socket leader comm */ - err = comm_grp_ranks_local(subc->local_comm, subc->socket_ldr_comm, &is_root_socket, - &subc->socket_ldr_root, NULL, local_root); + err = comm_grp_ranks_local(comm, subc->socket_ldr_comm, &is_root_socket, + &subc->socket_ldr_root, NULL, root); /* Find out local rank of root in subgroup comm */ err = comm_grp_ranks_local(subc->local_comm, subc->subgrp_comm, &subc->is_root_sg, &subc->subgrp_root, &subgrp_ranks, local_root); + subc->base_rank[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] = + 1 == subc->is_root_sg ? local_root : subgrp_ranks[0]; + /* Find out socket rank of root in subgroup comm */ + int tmp_root; + err = comm_grp_ranks_local(subc->socket_comm, subc->subgrp_comm, &subc->is_root_sg, + &tmp_root, &subgrp_ranks, + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET]); + subc->base_rank[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_SOCKET] = + 1 == subc->is_root_sg ? subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] : subgrp_ranks[0]; + /* Create subcommunicator with base ranks */ - subc->base_rank[MCA_COLL_ACOLL_L3CACHE] = subc->is_root_sg == 1 ? local_root - : subgrp_ranks[0]; - color = local_rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE] ? 0 : 1; + color = local_rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] ? 0 : 1; parent_comm[MCA_COLL_ACOLL_LYR_NODE] = subc->local_comm; parent_comm[MCA_COLL_ACOLL_LYR_SOCKET] = subc->socket_comm; - err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, local_rank, + parent_rank[MCA_COLL_ACOLL_LYR_NODE] = local_rank; + parent_rank[MCA_COLL_ACOLL_LYR_SOCKET] = ompi_comm_rank(subc->socket_comm); + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, parent_rank, subc->local_root, MCA_COLL_ACOLL_L3CACHE); /* Find out local rank of root in numa comm */ err = comm_grp_ranks_local(subc->local_comm, subc->numa_comm, &subc->is_root_numa, &subc->numa_root, &numa_ranks, local_root); - subc->base_rank[MCA_COLL_ACOLL_NUMA] = subc->is_root_numa == 1 ? local_root : numa_ranks[0]; - color = local_rank == subc->base_rank[MCA_COLL_ACOLL_NUMA] ? 0 : 1; - err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, local_rank, + subc->base_rank[MCA_COLL_ACOLL_NUMA][MCA_COLL_ACOLL_LYR_NODE] = + 1 == subc->is_root_numa ? local_root : numa_ranks[0]; + /* Find out socket rank of root in numa comm */ + err = comm_grp_ranks_local(subc->socket_comm, subc->numa_comm, &subc->is_root_numa, + &tmp_root, &numa_ranks, + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET]); + subc->base_rank[MCA_COLL_ACOLL_NUMA][MCA_COLL_ACOLL_LYR_SOCKET] = + 1 == subc->is_root_numa ? subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] : numa_ranks[0]; + + color = local_rank == subc->base_rank[MCA_COLL_ACOLL_NUMA][MCA_COLL_ACOLL_LYR_NODE] ? 0 : 1; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, parent_rank, subc->local_root, MCA_COLL_ACOLL_NUMA); - - if (socket_ranks != NULL) { - free(socket_ranks); - socket_ranks = NULL; - } - if (subgrp_ranks != NULL) { - free(subgrp_ranks); - subgrp_ranks = NULL; - } - if (numa_ranks != NULL) { - free(numa_ranks); - numa_ranks = NULL; - } } else { /* Intra node case */ int color; int is_root_socket = 0; - int *subgrp_ranks = NULL, *numa_ranks = NULL, *socket_ranks = NULL; - ompi_communicator_t *parent_comm[MCA_COLL_ACOLL_NUM_LAYERS]; /* Initializations */ subc->local_root[MCA_COLL_ACOLL_LYR_NODE] = root; @@ -480,7 +485,7 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, root); /* Create subcommunicator with socket leaders */ - subc->socket_rank = subc->is_root_socket == 1 ? root : socket_ranks[0]; + subc->socket_rank = 1 == subc->is_root_socket ? root : socket_ranks[0]; color = rank == subc->socket_rank ? 0 : 1; err = ompi_comm_split(comm, color, rank, &subc->socket_ldr_comm, false); if (MPI_SUCCESS != err) { @@ -495,40 +500,59 @@ static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, err = comm_grp_ranks_local(comm, subc->subgrp_comm, &subc->is_root_sg, &subc->subgrp_root, &subgrp_ranks, root); + subc->base_rank[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] = + 1 == subc->is_root_sg ? root : subgrp_ranks[0]; + /* Find out socket rank of root in subgroup comm */ + int tmp_root; + err = comm_grp_ranks_local(subc->socket_comm, subc->subgrp_comm, &subc->is_root_sg, + &tmp_root, &subgrp_ranks, + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET]); + subc->base_rank[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_SOCKET] = + 1 == subc->is_root_sg ? subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] : subgrp_ranks[0]; + /* Create subcommunicator with base ranks */ - subc->base_rank[MCA_COLL_ACOLL_L3CACHE] = subc->is_root_sg == 1 ? root : subgrp_ranks[0]; - color = rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE] ? 0 : 1; + color = rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] ? 0 : 1; parent_comm[MCA_COLL_ACOLL_LYR_NODE] = subc->local_comm; parent_comm[MCA_COLL_ACOLL_LYR_SOCKET] = subc->socket_comm; - err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, rank, subc->local_root, + parent_rank[MCA_COLL_ACOLL_LYR_NODE] = rank; + parent_rank[MCA_COLL_ACOLL_LYR_SOCKET] = ompi_comm_rank(subc->socket_comm); + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, parent_rank, subc->local_root, MCA_COLL_ACOLL_L3CACHE); int numa_rank; numa_rank = ompi_comm_rank(subc->numa_comm); - color = (numa_rank == 0) ? 0 : 1; + color = (0 == numa_rank) ? 0 : 1; err = ompi_comm_split(subc->local_comm, color, rank, &subc->numa_comm_ldrs, false); /* Find out local rank of root in numa comm */ err = comm_grp_ranks_local(comm, subc->numa_comm, &subc->is_root_numa, &subc->numa_root, &numa_ranks, root); - subc->base_rank[MCA_COLL_ACOLL_NUMA] = subc->is_root_numa == 1 ? root : numa_ranks[0]; - color = rank == subc->base_rank[MCA_COLL_ACOLL_NUMA] ? 0 : 1; - err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, rank, subc->local_root, + subc->base_rank[MCA_COLL_ACOLL_NUMA][MCA_COLL_ACOLL_LYR_NODE] = + 1 == subc->is_root_numa ? root : numa_ranks[0]; + /* Find out socket rank of root in numa comm */ + err = comm_grp_ranks_local(subc->socket_comm, subc->numa_comm, &subc->is_root_numa, + &tmp_root, &numa_ranks, + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET]); + subc->base_rank[MCA_COLL_ACOLL_NUMA][MCA_COLL_ACOLL_LYR_SOCKET] = + 1 == subc->is_root_numa ? subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] : numa_ranks[0]; + + color = rank == subc->base_rank[MCA_COLL_ACOLL_NUMA][MCA_COLL_ACOLL_LYR_NODE] ? 0 : 1; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, parent_rank, subc->local_root, MCA_COLL_ACOLL_NUMA); + } - if (socket_ranks != NULL) { - free(socket_ranks); - socket_ranks = NULL; - } - if (subgrp_ranks != NULL) { - free(subgrp_ranks); - subgrp_ranks = NULL; - } - if (numa_ranks != NULL) { - free(numa_ranks); - numa_ranks = NULL; - } + if (socket_ranks != NULL) { + free(socket_ranks); + socket_ranks = NULL; + } + if (subgrp_ranks != NULL) { + free(subgrp_ranks); + subgrp_ranks = NULL; + } + if (numa_ranks != NULL) { + free(numa_ranks); + numa_ranks = NULL; } /* Restore originals for local and socket comms */ @@ -595,7 +619,7 @@ static inline int mca_coll_acoll_xpmem_deregister(void *xpmem_apid, #endif static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communicator_t *comm, - coll_acoll_data_t *data, coll_acoll_subcomms_t *subc) + coll_acoll_data_t *data, coll_acoll_subcomms_t *subc, int root) { int size, ret = 0, rank, line; @@ -614,7 +638,7 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica data->comm_size = size; #ifdef HAVE_XPMEM_H - if (subc->xpmem_use_sr_buf == 0) { + if (0 == subc->xpmem_use_sr_buf) { data->scratch = (char *) malloc(subc->xpmem_buf_size); if (NULL == data->scratch) { line = __LINE__; @@ -668,8 +692,16 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica ret = OMPI_ERR_OUT_OF_RESOURCE; goto error_hndl; } + data->xpmem_reg_tracker_ht = NULL; + data->xpmem_reg_tracker_ht = (opal_hash_table_t **) malloc(sizeof(opal_hash_table_t*) * size); + if (NULL == data->xpmem_reg_tracker_ht) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + seg_id = xpmem_make(0, XPMEM_MAXADDR_SIZE, XPMEM_PERMIT_MODE, (void *) 0666); - if (seg_id == -1) { + if (-1 == seg_id) { line = __LINE__; ret = -1; goto error_hndl; @@ -685,12 +717,12 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica if (rank != i) { data->all_apid[i] = xpmem_get(data->allseg_id[i], XPMEM_RDWR, XPMEM_PERMIT_MODE, (void *) 0666); - if (data->all_apid[i] == -1) { + if (-1 == data->all_apid[i]) { line = __LINE__; ret = -1; goto error_hndl; } - if (data->all_apid[i] == -1) { + if (-1 == data->all_apid[i]) { line = __LINE__; ret = -1; goto error_hndl; @@ -704,17 +736,19 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica .deregister_mem = mca_coll_acoll_xpmem_deregister}; data->rcache[i] = mca_rcache_base_module_create("grdma", NULL, &rcache_element); - if (data->rcache[i] == NULL) { + if (NULL == data->rcache[i]) { ret = -1; line = __LINE__; goto error_hndl; } + data->xpmem_reg_tracker_ht[i] = OBJ_NEW(opal_hash_table_t); + opal_hash_table_init(data->xpmem_reg_tracker_ht[i], 2048); } } #endif /* temporary variables */ - int tmp1, tmp2, tmp3 = 0; + int tmp1, tmp2, tmp3 = root; comm_grp_ranks_local(comm, subc->numa_comm, &tmp1, &tmp2, &data->l1_gp, tmp3); data->l1_gp_size = ompi_comm_size(subc->numa_comm); data->l1_local_rank = ompi_comm_rank(subc->numa_comm); @@ -777,6 +811,8 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica } } + data->allshmmmap_sbuf[root] = opal_shmem_segment_attach(&data->allshmseg_id[0]); + int offset = LEADER_SHM_SIZE; memset(((char *) data->allshmmmap_sbuf[data->l1_gp[0]]) + offset + CACHE_LINE_SIZE * rank, 0, CACHE_LINE_SIZE); if (data->l1_gp[0] == rank) { @@ -805,6 +841,8 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica data->xpmem_saddr = NULL; free(data->xpmem_raddr); data->xpmem_raddr = NULL; + free(data->xpmem_reg_tracker_ht); + data->xpmem_reg_tracker_ht = NULL; free(data->rcache); data->rcache = NULL; free(data->scratch); @@ -825,6 +863,25 @@ static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communica } #ifdef HAVE_XPMEM_H +static inline void update_rcache_reg_hashtable_entry + (struct acoll_xpmem_rcache_reg_t *reg, + opal_hash_table_t* ht) +{ + // Converting pointer to uint64 to use as key. + uint64_t key = (uint64_t)reg; + // Converting uint64_t to pointer type to use for value. + uint64_t value = 1; + int ht_ret = opal_hash_table_get_value_uint64(ht, key, (void**)(&value)); + + if (OPAL_ERR_NOT_FOUND == ht_ret) { + value = 1; + opal_hash_table_set_value_uint64(ht, key, (void*)(value)); + } else if (OPAL_SUCCESS == ht_ret) { + value += 1; + opal_hash_table_set_value_uint64(ht, key, (void*)(value)); + } +} + static inline void register_and_cache(int size, size_t total_dsize, int rank, coll_acoll_data_t *data) { @@ -844,6 +901,8 @@ static inline void register_and_cache(int size, size_t total_dsize, int rank, sbuf_reg = NULL; return; } + update_rcache_reg_hashtable_entry(sbuf_reg, data->xpmem_reg_tracker_ht[i]); + data->xpmem_saddr[i] = (void *) ((uintptr_t) sbuf_reg->xpmem_vaddr + ((uintptr_t) data->allshm_sbuf[i] - (uintptr_t) sbuf_reg->base.base)); @@ -858,6 +917,8 @@ static inline void register_and_cache(int size, size_t total_dsize, int rank, rbuf_reg = NULL; return; } + update_rcache_reg_hashtable_entry(rbuf_reg, data->xpmem_reg_tracker_ht[i]); + data->xpmem_raddr[i] = (void *) ((uintptr_t) rbuf_reg->xpmem_vaddr + ((uintptr_t) data->allshm_rbuf[i] - (uintptr_t) rbuf_reg->base.base));