diff --git a/LICENSE b/LICENSE index 6626c0ab405..679847e168b 100644 --- a/LICENSE +++ b/LICENSE @@ -64,6 +64,7 @@ Copyright (c) 2020-2021 Cornelis Networks, Inc. All rights reserved. Copyright (c) 2021 Nanook Consulting Copyright (c) 2017-2019 Iowa State University Research Foundation, Inc. All rights reserved. +Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. $COPYRIGHT$ diff --git a/ompi/mca/coll/acoll/Makefile.am b/ompi/mca/coll/acoll/Makefile.am new file mode 100644 index 00000000000..fdbd7edbbd2 --- /dev/null +++ b/ompi/mca/coll/acoll/Makefile.am @@ -0,0 +1,45 @@ +# +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +AM_CPPFLAGS = $(coll_acoll_CPPFLAGS) + +sources = \ + coll_acoll.h \ + coll_acoll_utils.h \ + coll_acoll_allgather.c \ + coll_acoll_bcast.c \ + coll_acoll_gather.c \ + coll_acoll_reduce.c \ + coll_acoll_allreduce.c \ + coll_acoll_barrier.c \ + coll_acoll_component.c \ + coll_acoll_module.c + +# Make the output library in this directory, and name it either +# mca__.la (for DSO builds) or libmca__.la +# (for static builds). + +if MCA_BUILD_ompi_coll_acoll_DSO +component_noinst = +component_install = mca_coll_acoll.la +else +component_noinst = libmca_coll_acoll.la +component_install = +endif + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_coll_acoll_la_SOURCES = $(sources) +mca_coll_acoll_la_LDFLAGS = -module -avoid-version $(coll_acoll_LDFLAGS) +mca_coll_acoll_la_LIBADD = $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la $(coll_acoll_LIBS) + +noinst_LTLIBRARIES = $(component_noinst) +libmca_coll_acoll_la_SOURCES =$(sources) +libmca_coll_acoll_la_LIBADD = $(coll_acoll_LIBS) +libmca_coll_acoll_la_LDFLAGS = -module -avoid-version $(coll_acoll_LDFLAGS) diff --git a/ompi/mca/coll/acoll/README b/ompi/mca/coll/acoll/README new file mode 100644 index 00000000000..d5b5acae8f1 --- /dev/null +++ b/ompi/mca/coll/acoll/README @@ -0,0 +1,16 @@ +Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +$COPYRIGHT$ + +Additional copyrights may follow + +$HEADER$ + +=========================================================================== + +The collective component, AMD Coll (“acoll”), is a high-performant MPI collective component for the OpenMPI library that is optimized for AMD "Zen"-based processors. “acoll” is optimized for communications within a single node of AMD “Zen”-based processors and provides the following commonly used collective algorithms: boardcast (MPI_Bcast), allreduce (MPI_Allreduce), reduce (MPI_Reduce), gather (MPI_Gather), allgather (MPI_Allgather), and barrier (MPI_Barrier). + +At present, “acoll” has been tested with OpenMPI v5.0.2 and can be built as part of OpenMPI. + +To run an application with acoll, use the following command line parameters +- mpirun --mca coll acoll,tuned,libnbc,basic --mca coll_acoll_priority 40 diff --git a/ompi/mca/coll/acoll/coll_acoll.h b/ompi/mca/coll/acoll/coll_acoll.h new file mode 100644 index 00000000000..36769db03fc --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll.h @@ -0,0 +1,225 @@ +/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef MCA_COLL_ACOLL_EXPORT_H +#define MCA_COLL_ACOLL_EXPORT_H + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/communicator/communicator.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/mca.h" +#include "ompi/request/request.h" + +#ifdef HAVE_XPMEM_H +#include "opal/mca/rcache/base/base.h" +#include +#endif + +#include "opal/mca/shmem/base/base.h" +#include "opal/mca/shmem/shmem.h" + +BEGIN_C_DECLS + +/* Globally exported variables */ +OMPI_DECLSPEC extern const mca_coll_base_component_3_0_0_t mca_coll_acoll_component; +extern int mca_coll_acoll_priority; +extern int mca_coll_acoll_sg_size; +extern int mca_coll_acoll_sg_scale; +extern int mca_coll_acoll_node_size; +extern int mca_coll_acoll_use_dynamic_rules; +extern int mca_coll_acoll_mnode_enable; +extern int mca_coll_acoll_bcast_lin0; +extern int mca_coll_acoll_bcast_lin1; +extern int mca_coll_acoll_bcast_lin2; +extern int mca_coll_acoll_bcast_nonsg; +extern int mca_coll_acoll_allgather_lin; +extern int mca_coll_acoll_allgather_ring_1; + +/* API functions */ +int mca_coll_acoll_init_query(bool enable_progress_threads, bool enable_mpi_threads); +mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *comm, int *priority); + +int mca_coll_acoll_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); + +int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_gather_intra(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + +int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +END_C_DECLS + +#define MCA_COLL_ACOLL_MAX_CID 100 +#define MCA_COLL_ACOLL_ROOT_CHANGE_THRESH 10 + +typedef enum MCA_COLL_ACOLL_SG_SIZES { + MCA_COLL_ACOLL_SG_SIZE_1 = 8, + MCA_COLL_ACOLL_SG_SIZE_2 = 16 +} MCA_COLL_ACOLL_SG_SIZES; + +typedef enum MCA_COLL_ACOLL_SG_SCALES { + MCA_COLL_ACOLL_SG_SCALE_1 = 1, + MCA_COLL_ACOLL_SG_SCALE_2 = 2, + MCA_COLL_ACOLL_SG_SCALE_3 = 4, + MCA_COLL_ACOLL_SG_SCALE_4 = 8, + MCA_COLL_ACOLL_SG_SCALE_5 = 16 +} MCA_COLL_ACOLL_SG_SCALES; + +typedef enum MCA_COLL_ACOLL_SUBCOMMS { + MCA_COLL_ACOLL_NODE_L = 0, + MCA_COLL_ACOLL_INTRA, + MCA_COLL_ACOLL_SOCK_L, + MCA_COLL_ACOLL_NUMA_L, + MCA_COLL_ACOLL_L3_L, + MCA_COLL_ACOLL_LEAF, + MCA_COLL_ACOLL_NUM_SC +} MCA_COLL_ACOLL_SUBCOMMS; + +typedef enum MCA_COLL_ACOLL_LAYERS { + MCA_COLL_ACOLL_LYR_NODE = 0, + MCA_COLL_ACOLL_LYR_SOCKET, + MCA_COLL_ACOLL_NUM_LAYERS +} MCA_COLL_ACOLL_LAYERS; + +typedef enum MCA_COLL_ACOLL_BASE_LYRS { + MCA_COLL_ACOLL_L3CACHE = 0, + MCA_COLL_ACOLL_NUMA, + MCA_COLL_ACOLL_NUM_BASE_LYRS +} MCA_COLL_ACOLL_BASE_LYRS; + +typedef struct coll_acoll_data { +#ifdef HAVE_XPMEM_H + xpmem_segid_t *allseg_id; + xpmem_apid_t *all_apid; + void **allshm_sbuf; + void **allshm_rbuf; + void **xpmem_saddr; + void **xpmem_raddr; + mca_rcache_base_module_t **rcache; + void *scratch; +#endif + opal_shmem_ds_t *allshmseg_id; + void **allshmmmap_sbuf; + + int comm_size; + int l1_local_rank; + int l2_local_rank; + int l1_gp_size; + int *l1_gp; + int *l2_gp; + int l2_gp_size; + int offset[4]; + int sync[2]; +} coll_acoll_data_t; + +typedef struct coll_acoll_subcomms { + ompi_communicator_t *local_comm; + ompi_communicator_t *local_r_comm; + ompi_communicator_t *leader_comm; + ompi_communicator_t *subgrp_comm; + ompi_communicator_t *numa_comm; + ompi_communicator_t *base_comm[MCA_COLL_ACOLL_NUM_BASE_LYRS][MCA_COLL_ACOLL_NUM_LAYERS]; + ompi_communicator_t *orig_comm; + ompi_communicator_t *socket_comm; + ompi_communicator_t *socket_ldr_comm; + int num_nodes; + int derived_node_size; + int is_root_node; + int is_root_sg; + int is_root_numa; + int is_root_socket; + int local_root[MCA_COLL_ACOLL_NUM_LAYERS]; + int outer_grp_root; + int subgrp_root; + int numa_root; + int socket_ldr_root; + int base_root[MCA_COLL_ACOLL_NUM_BASE_LYRS][MCA_COLL_ACOLL_NUM_LAYERS]; + int base_rank[MCA_COLL_ACOLL_NUM_BASE_LYRS]; + int socket_rank; + int subgrp_size; + int initialized; + int prev_init_root; + int num_root_change; + + ompi_communicator_t *numa_comm_ldrs; + ompi_communicator_t *node_comm; + ompi_communicator_t *inter_comm; + int cid; + coll_acoll_data_t *data; + bool initialized_data; + bool initialized_shm_data; +#ifdef HAVE_XPMEM_H + uint64_t xpmem_buf_size; + int without_xpmem; + int xpmem_use_sr_buf; +#endif + +} coll_acoll_subcomms_t; + +typedef struct coll_acoll_reserve_mem { + void *reserve_mem; + uint64_t reserve_mem_size; + bool reserve_mem_allocate; + bool reserve_mem_in_use; +} coll_acoll_reserve_mem_t; + +struct mca_coll_acoll_module_t { + mca_coll_base_module_t super; + MCA_COLL_ACOLL_SG_SIZES sg_size; + MCA_COLL_ACOLL_SG_SCALES sg_scale; + int sg_cnt; + // Todo: Remove log2 variables + int log2_sg_cnt; + int node_cnt; + int log2_node_cnt; + int use_dyn_rules; + // Todo: Use substructure for every API related ones + int use_mnode; + int use_lin0; + int use_lin1; + int use_lin2; + int mnode_sg_size; + int mnode_log2_sg_size; + int allg_lin; + int allg_ring; + coll_acoll_subcomms_t subc[MCA_COLL_ACOLL_MAX_CID]; + coll_acoll_reserve_mem_t reserve_mem_s; +}; + +#ifdef HAVE_XPMEM_H +struct acoll_xpmem_rcache_reg_t { + mca_rcache_base_registration_t base; + void *xpmem_vaddr; +}; +#endif + +typedef struct mca_coll_acoll_module_t mca_coll_acoll_module_t; +OMPI_DECLSPEC OBJ_CLASS_DECLARATION(mca_coll_acoll_module_t); + +#endif /* MCA_COLL_ACOLL_EXPORT_H */ diff --git a/ompi/mca/coll/acoll/coll_acoll_allgather.c b/ompi/mca/coll/acoll/coll_acoll_allgather.c new file mode 100644 index 00000000000..26287215de2 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_allgather.c @@ -0,0 +1,625 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_base_util.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +static inline int log_sg_bcast_intra(void *buff, size_t count, struct ompi_datatype_t *datatype, + int rank, int dim, int size, int sg_size, int cur_base, + int sg_start, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **preq, + int *nreqs) +{ + int msb_pos, sub_rank, peer, err; + int i, mask; + int end_sg, end_peer; + + end_sg = sg_start + sg_size - 1; + if (end_sg >= size) { + end_sg = size - 1; + } + end_peer = (end_sg - cur_base) % sg_size; + sub_rank = (rank - cur_base + sg_size) % sg_size; + + msb_pos = opal_hibit(sub_rank, dim); + --dim; + + /* Receive data from parent in the sg tree. */ + if (sub_rank > 0) { + assert(msb_pos >= 0); + peer = (sub_rank & ~(1 << msb_pos)); + if (peer > end_peer) { + peer = (((peer + cur_base - sg_start) % sg_size) + sg_start); + } else { + peer = peer + cur_base; + } + + err = MCA_PML_CALL( + recv(buff, count, datatype, peer, MCA_COLL_BASE_TAG_ALLGATHER, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + for (i = msb_pos + 1, mask = 1 << i; i <= dim; ++i, mask <<= 1) { + peer = sub_rank | mask; + if (peer >= sg_size) { + continue; + } + if (peer >= end_peer) { + peer = (((peer + cur_base - sg_start) % sg_size) + sg_start); + } else { + peer = peer + cur_base; + } + /* Checks to ensure that the sends are limited to the necessary ones. + It also ensures 'preq' not exceeding the max allocated. */ + if ((peer < size) && (peer != rank) && (peer != cur_base)) { + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } + + return err; +} + +static inline int lin_sg_bcast_intra(void *buff, size_t count, struct ompi_datatype_t *datatype, + int rank, int dim, int size, int sg_size, int cur_base, + int sg_start, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **preq, + int *nreqs) +{ + int peer; + int err; + int sg_end; + + sg_end = sg_start + sg_size - 1; + if (sg_end >= size) { + sg_end = size - 1; + } + + if (rank == cur_base) { + for (peer = sg_start; peer <= sg_end; peer++) { + if (peer == cur_base) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + err = MCA_PML_CALL(recv(buff, count, datatype, cur_base, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +/* + * sg_bcast_intra + * + * Function: broadcast operation within a subgroup + * Accepts: Arguments of MPI_Bcast() plus subgroup params + * Returns: MPI_SUCCESS or error code + * + * Description: O(N) or O(log(N)) algorithm based on count. + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +static inline int sg_bcast_intra(void *buff, size_t count, struct ompi_datatype_t *datatype, int rank, + int dim, int size, int sg_size, int cur_base, int sg_start, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module, + ompi_request_t **preq, int *nreqs) +{ + int err; + size_t total_dsize, dsize; + + ompi_datatype_type_size(datatype, &dsize); + total_dsize = dsize * count; + + if (total_dsize <= 8192) { + err = log_sg_bcast_intra(buff, count, datatype, rank, dim, size, sg_size, cur_base, + sg_start, comm, module, preq, nreqs); + } else { + err = lin_sg_bcast_intra(buff, count, datatype, rank, dim, size, sg_size, cur_base, + sg_start, comm, module, preq, nreqs); + } + return err; +} + +/* + * coll_allgather_decision_fixed + * + * Function: Choose optimal allgather algorithm + * + * Description: Based on no. of processes and message size, chooses whether + * or not to use subgroups. If subgroup based algorithm is not, + * chosen, further decides if [ring|lin] allgather is to be used. + * + */ +static inline void coll_allgather_decision_fixed(int size, size_t total_dsize, int sg_size, + int *use_ring, int *use_lin) +{ + *use_ring = 0; + *use_lin = 0; + if (size <= (sg_size << 1)) { + if (total_dsize >= 1048576) { + *use_lin = 1; + } + } else if (size <= (sg_size << 2)) { + if ((total_dsize >= 4096) && (total_dsize < 32768)) { + *use_ring = 1; + } else if (total_dsize >= 1048576) { + *use_lin = 1; + } + } else if (size <= (sg_size << 3)) { + if ((total_dsize >= 4096) && (total_dsize < 32768)) { + *use_ring = 1; + } + } else { + if (total_dsize >= 4096) { + *use_ring = 1; + } + } +} + +/* + * rd_allgather_sub + * + * Function: Uses recursive doubling based allgather for the group. + * Group can be all ranks in a subgroup or base ranks across + * subgroups. + * + * Description: Implementation logic of recursive doubling reused from + * ompi_coll_base_allgather_intra_recursivedoubling(). + * + */ +static inline int rd_allgather_sub(void *rbuf, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, size_t count, int send_blk_loc, + int rank, int virtual_rank, int grp_size, const int across_sg, + int sg_start, int sg_size, ptrdiff_t rext) +{ + int err; + /* At step i, rank r exchanges message with rank (r ^ 2^i) */ + for (int dist = 0x1, i = 0; dist < grp_size; dist <<= 1, i++) { + int remote = virtual_rank ^ dist; + int recv_blk_loc = virtual_rank < remote ? send_blk_loc + dist : send_blk_loc - dist; + size_t sr_cnt = count << i; + char *tmpsend = (char *) rbuf + (ptrdiff_t) send_blk_loc * (ptrdiff_t) count * rext; + char *tmprecv = (char *) rbuf + (ptrdiff_t) recv_blk_loc * (ptrdiff_t) count * rext; + int peer = across_sg ? remote * sg_size : remote + sg_start; + if (virtual_rank >= remote) { + send_blk_loc -= dist; + } + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, sr_cnt, rdtype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + tmprecv, sr_cnt, rdtype, peer, MCA_COLL_BASE_TAG_ALLGATHER, + comm, MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +static inline int mca_coll_acoll_allgather_intra(const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, void *rbuf, + size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int i; + int err; + int size; + int rank, adj_rank; + int sg_id, num_sgs, is_pow2_num_sgs; + int sg_start, sg_end; + int sg_size, log2_sg_size; + int subgrp_size, last_subgrp_size; + ptrdiff_t rlb, rext; + char *tmpsend = NULL, *tmprecv = NULL; + int sendto, recvfrom; + int num_data_blks; + size_t data_blk_size[2] = {0}, blk_ofst[2] = {0}; + size_t bcount; + size_t last_subgrp_rcnt; + int brank, last_brank; + int use_rd_base, use_ring_sg; + int use_ring = 0, use_lin = 0; + int nreqs; + ompi_request_t **preq, **reqs; + size_t dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + err = ompi_datatype_get_extent(rdtype, &rlb, &rext); + if (MPI_SUCCESS != err) { + return err; + } + + ompi_datatype_type_size(rdtype, &dsize); + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + sg_size = acoll_module->sg_cnt; + log2_sg_size = acoll_module->log2_sg_cnt; + + /* Handle non MPI_IN_PLACE */ + tmprecv = (char *) rbuf + (ptrdiff_t) rank * (ptrdiff_t) rcount * rext; + if (MPI_IN_PLACE != sbuf) { + tmpsend = (char *) sbuf; + err = ompi_datatype_sndrcv(tmpsend, scount, sdtype, tmprecv, rcount, rdtype); + if (MPI_SUCCESS != err) { + return err; + } + } + + /* Derive subgroup parameters */ + sg_id = rank >> log2_sg_size; + num_sgs = (size + sg_size - 1) >> log2_sg_size; + sg_start = sg_id << log2_sg_size; + sg_end = sg_start + sg_size; + if (sg_end > size) { + sg_end = size; + } + subgrp_size = sg_end - sg_start; + last_subgrp_size = size - ((num_sgs - 1) << log2_sg_size); + last_subgrp_rcnt = rcount * last_subgrp_size; + use_ring_sg = (subgrp_size != sg_size) ? 1 : 0; + bcount = rcount << log2_sg_size; + + /* Override subgroup params based on data size */ + coll_allgather_decision_fixed(size, dsize * rcount, sg_size, &use_ring, &use_lin); + + if (use_lin) { + err = ompi_coll_base_allgather_intra_basic_linear(sbuf, scount, sdtype, rbuf, rcount, + rdtype, comm, module); + return err; + } + if (use_ring) { + sg_size = sg_end = subgrp_size = size; + num_sgs = 1; + use_ring_sg = 1; + sg_start = 0; + } + + /* Do ring/recursive doubling based allgather within subgroup */ + adj_rank = rank - sg_start; + if (use_ring_sg) { + recvfrom = ((adj_rank - 1 + subgrp_size) % subgrp_size) + sg_start; + sendto = ((adj_rank + 1) % subgrp_size) + sg_start; + + /* Loop over ranks in subgroup */ + for (i = 0; i < (subgrp_size - 1); i++) { + int recv_peer = ((adj_rank - i - 1 + subgrp_size) % subgrp_size) + sg_start; + int send_peer = ((adj_rank - i + subgrp_size) % subgrp_size) + sg_start; + + tmprecv = (char *) rbuf + (ptrdiff_t) recv_peer * (ptrdiff_t) rcount * rext; + tmpsend = (char *) rbuf + (ptrdiff_t) send_peer * (ptrdiff_t) rcount * rext; + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, rcount, rdtype, sendto, + MCA_COLL_BASE_TAG_ALLGATHER, tmprecv, rcount, rdtype, + recvfrom, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + err = rd_allgather_sub(rbuf, rdtype, comm, rcount, rank, rank, adj_rank, sg_size, 0, + sg_start, sg_size, rext); + if (MPI_SUCCESS != err) { + return err; + } + } + + /* Return if all ranks belong to single subgroup */ + if (num_sgs == 1) { + /* All done */ + return err; + } + + /* Do ring/rd based allgather across start ranks of subgroups */ + is_pow2_num_sgs = 0; + if (num_sgs == (1 << opal_hibit(num_sgs, comm->c_cube_dim))) { + is_pow2_num_sgs = 1; + } + use_rd_base = is_pow2_num_sgs ? ((last_subgrp_rcnt == bcount) ? 1 : 0) : 0; + + brank = sg_id; + last_brank = num_sgs - 1; + + /* Use ring for non-power of 2 cases */ + if (!(rank & (sg_size - 1)) && !use_rd_base) { + recvfrom = ((brank - 1 + num_sgs) % num_sgs) << log2_sg_size; + sendto = ((brank + 1) % num_sgs) << log2_sg_size; + + /* Loop over subgroups */ + for (i = 0; i < (num_sgs - 1); i++) { + int recv_peer = ((brank - i - 1 + num_sgs) % num_sgs); + int send_peer = ((brank - i + num_sgs) % num_sgs); + size_t scnt = (send_peer == last_brank) ? last_subgrp_rcnt : bcount; + size_t rcnt = (recv_peer == last_brank) ? last_subgrp_rcnt : bcount; + + tmprecv = (char *) rbuf + (ptrdiff_t) recv_peer * (ptrdiff_t) bcount * rext; + tmpsend = (char *) rbuf + (ptrdiff_t) send_peer * (ptrdiff_t) bcount * rext; + + recv_peer <<= log2_sg_size; + send_peer <<= log2_sg_size; + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, scnt, rdtype, sendto, + MCA_COLL_BASE_TAG_ALLGATHER, tmprecv, rcnt, rdtype, + recvfrom, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + } else if (!(rank & (sg_size - 1))) { + /* Use recursive doubling for power of 2 cases */ + err = rd_allgather_sub(rbuf, rdtype, comm, bcount, brank, rank, brank, num_sgs, 1, sg_start, + sg_size, rext); + if (MPI_SUCCESS != err) { + return err; + } + } + /* Now all base ranks have the full data */ + /* Do broadcast within subgroups from the base ranks for the extra data */ + if (sg_id == 0) { + num_data_blks = 1; + data_blk_size[0] = bcount * (num_sgs - 2) + last_subgrp_rcnt; + blk_ofst[0] = bcount; + } else if (sg_id == num_sgs - 1) { + if (last_subgrp_size < 2) { + return err; + } + num_data_blks = 1; + data_blk_size[0] = bcount * (num_sgs - 1); + blk_ofst[0] = 0; + } else { + num_data_blks = 2; + data_blk_size[0] = bcount * sg_id; + data_blk_size[1] = bcount * (num_sgs - sg_id - 2) + last_subgrp_rcnt; + blk_ofst[0] = 0; + blk_ofst[1] = bcount * (sg_id + 1); + } + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + nreqs = 0; + preq = reqs; + /* Loop over data blocks */ + for (i = 0; i < num_data_blks; i++) { + char *buff = (char *) rbuf + (ptrdiff_t) blk_ofst[i] * rext; + int sg_dim = opal_hibit(subgrp_size - 1, comm->c_cube_dim); + if ((1 << sg_dim) < subgrp_size) { + sg_dim++; + } + /* The size parameters to sg_bcast_intra ensures that the no. of send + requests do not exceed the max allocated. */ + err = sg_bcast_intra(buff, data_blk_size[i], rdtype, rank, sg_dim, size, sg_size, sg_start, + sg_start, comm, module, preq, &nreqs); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + } + } + + /* All done */ + return err; +} + +/* + * mca_coll_acoll_allgather + * + * Function: Allgather operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Allgather() + * Returns: MPI_SUCCESS or error code + * + * Description: Allgather is performed across and within subgroups. + * Subgroups can be 1 or more based on size and count. + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int i; + int err; + int size; + int rank; + int num_nodes, node_start, node_end, node_id; + int node_size, last_node_size; + ptrdiff_t rlb, rext; + char *tmpsend = NULL, *tmprecv = NULL; + int sendto, recvfrom; + int num_data_blks; + size_t data_blk_size[2] = {0}, blk_ofst[2] = {0}; + size_t bcount; + size_t last_subgrp_rcnt; + int brank, last_brank; + int use_rd_base; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + char *local_rbuf; + ompi_communicator_t *intra_comm; + + /* Fallback to ring if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_allgather_intra_ring(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, + module); + } + + subc = &acoll_module->subc[cid]; + size = ompi_comm_size(comm); + if (!subc->initialized && size > 2) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != err) { + return err; + } + } + + err = ompi_datatype_get_extent(rdtype, &rlb, &rext); + if (MPI_SUCCESS != err) { + return err; + } + + rank = ompi_comm_rank(comm); + node_size = size > 2 ? subc->derived_node_size : size; + + /* Derive node parameters */ + num_nodes = (size + node_size - 1) / node_size; + node_id = rank / node_size; + node_start = node_id * node_size; + node_end = node_start + node_size; + if (node_end > size) { + node_end = size; + } + last_node_size = size - (num_nodes - 1) * node_size; + + /* Call intra */ + local_rbuf = (char *) rbuf + (ptrdiff_t) node_start * (ptrdiff_t) rcount * rext; + if (size <= 2) { + intra_comm = comm; + } else { + if (num_nodes > 1) { + assert(subc->local_r_comm != NULL); + } + intra_comm = num_nodes == 1 ? comm : subc->local_r_comm; + } + err = mca_coll_acoll_allgather_intra(sbuf, scount, sdtype, local_rbuf, rcount, rdtype, + intra_comm, module); + if (MPI_SUCCESS != err) { + return err; + } + + /* Return if intra-node communicator */ + if ((num_nodes == 1) || (size <= 2)) { + /* All done */ + return err; + } + + /* Handle inter-case by first doing allgather across node leaders */ + bcount = node_size * rcount; + last_subgrp_rcnt = last_node_size * rcount; + + /* Perform allgather across node leaders */ + if (rank == node_start) { + int is_pow2_num_nodes = 0; + if (num_nodes == (1 << opal_hibit(num_nodes, comm->c_cube_dim))) { + is_pow2_num_nodes = 1; + } + use_rd_base = is_pow2_num_nodes ? ((last_node_size == node_size) ? 1 : 0) : 0; + brank = node_id; + last_brank = num_nodes - 1; + + /* Use ring for non-power of 2 cases */ + if (!use_rd_base) { + recvfrom = ((brank - 1 + num_nodes) % num_nodes) * node_size; + sendto = ((brank + 1) % num_nodes) * node_size; + + /* Loop over nodes */ + for (i = 0; i < (num_nodes - 1); i++) { + int recv_peer = ((brank - i - 1 + num_nodes) % num_nodes); + int send_peer = ((brank - i + num_nodes) % num_nodes); + size_t scnt = (send_peer == last_brank) ? last_subgrp_rcnt : bcount; + size_t rcnt = (recv_peer == last_brank) ? last_subgrp_rcnt : bcount; + + tmprecv = (char *) rbuf + (ptrdiff_t) recv_peer * (ptrdiff_t) bcount * rext; + tmpsend = (char *) rbuf + (ptrdiff_t) send_peer * (ptrdiff_t) bcount * rext; + recv_peer *= node_size; + send_peer *= node_size; + + /* Sendreceive */ + err = ompi_coll_base_sendrecv(tmpsend, scnt, rdtype, sendto, + MCA_COLL_BASE_TAG_ALLGATHER, tmprecv, rcnt, rdtype, + recvfrom, MCA_COLL_BASE_TAG_ALLGATHER, comm, + MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + /* Use recursive doubling for power of 2 cases */ + err = rd_allgather_sub(rbuf, rdtype, comm, bcount, brank, rank, brank, num_nodes, 1, + node_start, node_size, rext); + if (MPI_SUCCESS != err) { + return err; + } + } + } /* End of if inter leader */ + + /* Do intra node broadcast */ + if (node_id == 0) { + num_data_blks = 1; + data_blk_size[0] = bcount * (num_nodes - 2) + last_subgrp_rcnt; + blk_ofst[0] = bcount; + } else if (node_id == num_nodes - 1) { + if (last_node_size < 2) { + return err; + } + num_data_blks = 1; + data_blk_size[0] = bcount * (num_nodes - 1); + blk_ofst[0] = 0; + } else { + num_data_blks = 2; + data_blk_size[0] = bcount * node_id; + data_blk_size[1] = bcount * (num_nodes - node_id - 2) + last_subgrp_rcnt; + blk_ofst[0] = 0; + blk_ofst[1] = bcount * (node_id + 1); + } + /* Loop over data blocks */ + for (i = 0; i < num_data_blks; i++) { + char *buff = (char *) rbuf + (ptrdiff_t) blk_ofst[i] * rext; + err = (comm)->c_coll->coll_bcast(buff, data_blk_size[i], rdtype, 0, subc->local_r_comm, + module); + if (MPI_SUCCESS != err) { + return err; + } + } + + /* All done */ + return err; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_allreduce.c b/ompi/mca/coll/acoll/coll_acoll_allreduce.c new file mode 100644 index 00000000000..6a248a73c5a --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_allreduce.c @@ -0,0 +1,561 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + + +#include "mpi.h" +#include "ompi/communicator/communicator.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "ompi/op/op.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + + +void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp_size, int rank, int up); +int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, int intra); + + +static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size) +{ + int alg = 3; + if (msg_size <= 256) { + alg = 1; + } else if (msg_size <= 1045876) { + alg = 2; + } else if (msg_size <= 4194304) { + alg = 3; + } else if (msg_size <= 8388608) { + alg = 0; + } else { + alg = 3; + } + return alg; +} + +#ifdef HAVE_XPMEM_H +static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int size; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * count; + + int l1_gp_size = data->l1_gp_size; + int *l1_gp = data->l1_gp; + int *l2_gp = data->l2_gp; + int l2_gp_size = data->l2_gp_size; + + int l1_local_rank = data->l1_local_rank; + int l2_local_rank = data->l2_local_rank; + char *tmp_sbuf = NULL; + char *tmp_rbuf = NULL; + if (!subc->xpmem_use_sr_buf) { + tmp_rbuf = (char *) data->scratch; + tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; + if ((sbuf == MPI_IN_PLACE)) { + memcpy(tmp_sbuf, rbuf, total_dsize); + } else { + memcpy(tmp_sbuf, sbuf, total_dsize); + } + } else { + tmp_sbuf = (char *) sbuf; + tmp_rbuf = (char *) rbuf; + if (sbuf == MPI_IN_PLACE) { + tmp_sbuf = (char *) rbuf; + } + } + void *sbuf_vaddr[1] = {tmp_sbuf}; + void *rbuf_vaddr[1] = {tmp_rbuf}; + int err = MPI_SUCCESS; + + err = comm->c_coll->coll_allgather(sbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_sbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (err != MPI_SUCCESS) { + return err; + } + + err = comm->c_coll->coll_allgather(rbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_rbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (err != MPI_SUCCESS) { + return err; + } + + register_and_cache(size, total_dsize, rank, data); + + /* reduce to the local group leader */ + size_t chunk = count / l1_gp_size; + size_t my_count_size = (l1_local_rank == (l1_gp_size - 1)) ? chunk + count % l1_gp_size : chunk; + + if (rank == l1_gp[0]) { + if (sbuf != MPI_IN_PLACE) + memcpy(tmp_rbuf, sbuf, my_count_size * dsize); + + for (int i = 1; i < l1_gp_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) tmp_rbuf + chunk * l1_local_rank * dsize, my_count_size, dtype); + } + } else { + ompi_3buff_op_reduce(op, + (char *) data->xpmem_saddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + (char *) tmp_sbuf + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + for (int i = 1; i < l1_gp_size; i++) { + if (i == l1_local_rank) { + continue; + } + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + } + } + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (err != MPI_SUCCESS) { + return err; + } + + /* perform reduce to 0 */ + int local_size = l2_gp_size; + if ((rank == l1_gp[0]) && (local_size > 1)) { + chunk = count / local_size; + + my_count_size = (l2_local_rank == (local_size - 1)) ? chunk + (count % local_size) : chunk; + + if (l2_local_rank == 0) { + for (int i = 1; i < local_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_raddr[l2_gp[i]], (char *) tmp_rbuf, + my_count_size, dtype); + } + } else { + for (int i = 1; i < local_size; i++) { + if (i == l2_local_rank) { + continue; + } + + ompi_op_reduce(op, + (char *) data->xpmem_raddr[l2_gp[i]] + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + ompi_op_reduce(op, (char *) tmp_rbuf + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + } + + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (!subc->xpmem_use_sr_buf) { + memcpy(rbuf, tmp_rbuf, total_dsize); + } + return err; +} + +static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int size; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + size = ompi_comm_size(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * count; + + char *tmp_sbuf = NULL; + char *tmp_rbuf = NULL; + if (!subc->xpmem_use_sr_buf) { + tmp_rbuf = (char *) data->scratch; + tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; + if ((sbuf == MPI_IN_PLACE)) { + memcpy(tmp_sbuf, rbuf, total_dsize); + } else { + memcpy(tmp_sbuf, sbuf, total_dsize); + } + } else { + tmp_sbuf = (char *) sbuf; + tmp_rbuf = (char *) rbuf; + if (sbuf == MPI_IN_PLACE) { + tmp_sbuf = (char *) rbuf; + } + } + void *sbuf_vaddr[1] = {tmp_sbuf}; + void *rbuf_vaddr[1] = {tmp_rbuf}; + int err = MPI_SUCCESS; + int rank = ompi_comm_rank(comm); + + err = comm->c_coll->coll_allgather(sbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_sbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (err != MPI_SUCCESS) { + return err; + } + err = comm->c_coll->coll_allgather(rbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_rbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + if (err != MPI_SUCCESS) { + return err; + } + + register_and_cache(size, total_dsize, rank, data); + + size_t chunk = count / size; + size_t my_count_size = (rank == (size - 1)) ? (count / size) + count % size : count / size; + if (rank == 0) { + if (sbuf != MPI_IN_PLACE) + memcpy(tmp_rbuf, sbuf, my_count_size * dsize); + } else { + ompi_3buff_op_reduce(op, (char *) data->xpmem_saddr[0] + chunk * rank * dsize, + (char *) tmp_sbuf + chunk * rank * dsize, + (char *) tmp_rbuf + chunk * rank * dsize, my_count_size, dtype); + } + + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (err != MPI_SUCCESS) { + return err; + } + + for (int i = 1; i < size; i++) { + if (rank == i) { + continue; + } + ompi_op_reduce(op, (char *) data->xpmem_saddr[i] + chunk * rank * dsize, + (char *) tmp_rbuf + chunk * rank * dsize, my_count_size, dtype); + } + err = ompi_coll_base_barrier_intra_tree(comm, module); + if (err != MPI_SUCCESS) { + return err; + } + + size_t tmp = chunk * dsize; + for (int i = 0; i < size; i++) { + if (subc->xpmem_use_sr_buf && (rank == i)) { + continue; + } + my_count_size = (i == (size - 1)) ? (count / size) + count % size : count / size; + size_t tmp1 = i * tmp; + char *dst = (char *) rbuf + tmp1; + char *src = (char *) data->xpmem_raddr[i] + tmp1; + memcpy(dst, src, my_count_size * dsize); + } + + err = ompi_coll_base_barrier_intra_tree(comm, module); + + return err; +} +#endif + +void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp_size, int rank, + int up) +{ + volatile int *tmp, tmp0; + tmp = (int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + CACHE_LINE_SIZE * rank); + tmp0 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + CACHE_LINE_SIZE * group[0]), + __ATOMIC_RELAXED); + + opal_atomic_wmb(); + + int val; + if (up == 1) { + val = data->sync[0]; + } else { + val = data->sync[1]; + } + + if (rank == group[0]) { + __atomic_store_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + CACHE_LINE_SIZE * group[0]), + val, __ATOMIC_RELAXED); + } + + while (tmp0 != val) { + tmp0 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + CACHE_LINE_SIZE * group[0]), + __ATOMIC_RELAXED); + } + + if (rank != group[0]) { + val++; + __atomic_store_n(tmp, val, __ATOMIC_RELAXED); + } + opal_atomic_wmb(); + if (rank == group[0]) { + for (int i = 1; i < gp_size; i++) { + volatile int tmp1 = __atomic_load_n( + (int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + CACHE_LINE_SIZE * group[i]), + __ATOMIC_RELAXED); + while (tmp1 == val) { + tmp1 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + CACHE_LINE_SIZE * group[i]), + __ATOMIC_RELAXED); + } + opal_atomic_wmb(); + } + ++val; + __atomic_store_n(tmp, val, __ATOMIC_RELAXED); + } else { + volatile int tmp1 = __atomic_load_n( + (int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + CACHE_LINE_SIZE * group[0]), + __ATOMIC_RELAXED); + while (tmp1 != val) { + tmp1 = __atomic_load_n((int *) ((char *) data->allshmmmap_sbuf[group[0]] + offset + + CACHE_LINE_SIZE * group[0]), + __ATOMIC_RELAXED); + } + } + if (up == 1) { + data->sync[0] = val; + } else { + data->sync[1] = val; + } +} + +int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, int intra) +{ + size_t dsize; + int err = MPI_SUCCESS; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + int rank = ompi_comm_rank(comm); + ompi_datatype_type_size(dtype, &dsize); + + int l1_gp_size = data->l1_gp_size; + int *l1_gp = data->l1_gp; + int *l2_gp = data->l2_gp; + int l2_gp_size = data->l2_gp_size; + + int l1_local_rank = data->l1_local_rank; + int l2_local_rank = data->l2_local_rank; + int comm_id = ompi_comm_get_local_cid(comm); + + int offset1 = data->offset[0]; + int offset2 = data->offset[1]; + int tshm_offset = data->offset[2]; + int shm_offset = data->offset[3]; + const int per_rank_shm_size = 8 * 1024; + + int local_size; + + if (rank == l1_gp[0]) { + if (l2_gp_size > 1) { + mca_coll_acoll_sync(data, offset2, l2_gp, l2_gp_size, rank, 3); + } + } + + if (MPI_IN_PLACE == sbuf) { + memcpy((char *) data->allshmmmap_sbuf[l1_gp[0]] + shm_offset, rbuf, count * dsize); + } else { + memcpy((char *) data->allshmmmap_sbuf[l1_gp[0]] + shm_offset, sbuf, count * dsize); + } + + mca_coll_acoll_sync(data, offset1, l1_gp, l1_gp_size, rank, 1); + + if (rank == l1_gp[0]) { + memcpy((char *) data->allshmmmap_sbuf[l1_gp[l1_local_rank]], + (char *) data->allshmmmap_sbuf[l1_gp[0]] + shm_offset, count * dsize); + for (int i = 1; i < l1_gp_size; i++) { + ompi_op_reduce(op, + (char *) data->allshmmmap_sbuf[l1_gp[0]] + tshm_offset + + l1_gp[i] * per_rank_shm_size, + (char *) data->allshmmmap_sbuf[l1_gp[l1_local_rank]], count, dtype); + } + memcpy(rbuf, data->allshmmmap_sbuf[l1_gp[l1_local_rank]], count * dsize); + } + + if (rank == l1_gp[0]) { + if (l2_gp_size > 1) { + mca_coll_acoll_sync(data, offset2, l2_gp, l2_gp_size, rank, 3); + } + } + + /* perform allreduce across leaders */ + local_size = l2_gp_size; + if (local_size > 1) { + if (rank == l1_gp[0]) { + for (int i = 0; i < local_size; i++) { + if (i == l2_local_rank) { + continue; + } + ompi_op_reduce(op, (char *) data->allshmmmap_sbuf[l2_gp[i]], (char *) rbuf, count, + dtype); + } + } + } + + if (intra && (ompi_comm_size(acoll_module->subc[comm_id].numa_comm) > 1)) { + err = mca_coll_acoll_bcast(rbuf, count, dtype, 0, acoll_module->subc[comm_id].numa_comm, module); + } + return err; +} + +int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size, alg, err; + int num_nodes; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + size = ompi_comm_size(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * count; + + if (size == 1) { + if (MPI_IN_PLACE != sbuf) { + memcpy((char *) rbuf, sbuf, total_dsize); + } + return MPI_SUCCESS; + } + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + + /* Falling back to recursivedoubling for non-commutative operators to be safe */ + if (!ompi_op_is_commute(op)) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, comm, + module); + } + + /* Fallback to knomial if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, + module); + } + + subc = &acoll_module->subc[cid]; + if (!subc->initialized) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != err) + return err; + } + + num_nodes = subc->num_nodes; + + alg = coll_allreduce_decision_fixed(size, total_dsize); + + if (num_nodes == 1) { + if (total_dsize < 32) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, + comm, module); + } else if (total_dsize < 512) { + return mca_coll_acoll_allreduce_small_msgs_h(sbuf, rbuf, count, dtype, op, comm, module, + 1); + } else if (total_dsize <= 2048) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, + comm, module); + } else if (total_dsize < 65536) { + if (alg == 1) { + return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, + op, comm, module); + } else if (alg == 2) { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } else { /*alg == 3 */ + return ompi_coll_base_allreduce_intra_ring_segmented(sbuf, rbuf, count, dtype, op, + comm, module, 0); + } + } else if (total_dsize < 4194304) { +#ifdef HAVE_XPMEM_H + if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) { + return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module); + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } +#else + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, + comm, module); +#endif + } else if (total_dsize <= 16777216) { +#ifdef HAVE_XPMEM_H + if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) { + mca_coll_acoll_reduce_xpmem_h(sbuf, rbuf, count, dtype, op, comm, module); + return mca_coll_acoll_bcast(rbuf, count, dtype, 0, comm, module); + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } +#else + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, + comm, module); +#endif + } else { +#ifdef HAVE_XPMEM_H + if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) { + return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module); + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, + op, comm, module); + } +#else + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, + comm, module); +#endif + } + + } else { + return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, + module); + } + return MPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_barrier.c b/ompi/mca/coll/acoll/coll_acoll_barrier.c new file mode 100644 index 00000000000..a138027f444 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_barrier.c @@ -0,0 +1,223 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +static int mca_coll_acoll_barrier_recv_subc(struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **reqs, + int *nreqs, int root) +{ + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + int err = MPI_SUCCESS; + + if (rank < 0) { + return err; + } + + /* Non-zero ranks receive zero-byte message from rank 0 */ + if (rank != root) { + err = MCA_PML_CALL( + recv(NULL, 0, MPI_BYTE, root, MCA_COLL_BASE_TAG_BARRIER, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } else if (rank == root) { + ompi_request_t **preq = reqs; + *nreqs = 0; + for (int i = 0; i < size; i++) { + if (i == root) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(NULL, 0, MPI_BYTE, i, MCA_COLL_BASE_TAG_BARRIER, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + err = ompi_request_wait_all(*nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +static int mca_coll_acoll_barrier_send_subc(struct ompi_communicator_t *comm, + mca_coll_base_module_t *module, ompi_request_t **reqs, + int *nreqs, int root) +{ + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + int err = MPI_SUCCESS; + + if (rank < 0) { + return err; + } + + /* Non-zero ranks send zero-byte message to rank 0 */ + if (rank != root) { + err = MCA_PML_CALL(send(NULL, 0, MPI_BYTE, root, MCA_COLL_BASE_TAG_BARRIER, + MCA_PML_BASE_SEND_STANDARD, comm)); + if (MPI_SUCCESS != err) { + return err; + } + } else if (rank == root) { + ompi_request_t **preq = reqs; + *nreqs = 0; + for (int i = 0; i < size; i++) { + if (i == root) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL( + irecv(NULL, 0, MPI_BYTE, i, MCA_COLL_BASE_TAG_BARRIER, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + err = ompi_request_wait_all(*nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +/* + * mca_coll_acoll_barrier_intra + * + * Function: Barrier operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Barrier() + * Returns: MPI_SUCCESS or error code + * + * Description: Step 1 - All leaf ranks of a subgroup send to base rank. + * Step 2 - All base ranks send to rank 0. + * Step 3 - Base rank sends to leaf ranks. + * + * Limitations: None + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size, ssize, bsize; + int err = MPI_SUCCESS; + int nreqs = 0; + ompi_request_t **reqs; + int num_nodes; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + + /* Fallback to linear if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_barrier_intra_basic_linear(comm, module); + } + + subc = &acoll_module->subc[cid]; + size = ompi_comm_size(comm); + if (size == 1) { + return err; + } + if (!subc->initialized && size > 1) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != err) { + return err; + } + } + num_nodes = size > 1 ? subc->num_nodes : 1; + + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + ssize = ompi_comm_size(subc->subgrp_comm); + bsize = ompi_comm_size(subc->base_comm[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE]); + + /* Sends from leaf ranks at subgroup level */ + if (ssize > 1) { + err = mca_coll_acoll_barrier_send_subc(subc->subgrp_comm, module, reqs, &nreqs, + subc->subgrp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Sends from leaf ranks at base rank level */ + if ((bsize > 1) && (subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] != -1)) { + err = mca_coll_acoll_barrier_send_subc( + subc->base_comm[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE], module, reqs, &nreqs, + subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE]); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Sends from leaf ranks at node leader level */ + if ((num_nodes > 1) && (subc->outer_grp_root != -1)) { + err = mca_coll_acoll_barrier_send_subc(subc->leader_comm, module, reqs, &nreqs, + subc->outer_grp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + /* Leaf ranks at node leader level receive from root */ + if ((num_nodes > 1) && (subc->outer_grp_root != -1)) { + err = mca_coll_acoll_barrier_recv_subc(subc->leader_comm, module, reqs, &nreqs, + subc->outer_grp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Leaf ranks at base rank level receive from inter leader */ + if ((bsize > 1) && (subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE] != -1)) { + err = mca_coll_acoll_barrier_recv_subc( + subc->base_comm[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE], module, reqs, &nreqs, + subc->base_root[MCA_COLL_ACOLL_L3CACHE][MCA_COLL_ACOLL_LYR_NODE]); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + /* Leaf ranks at subgroup level to receive from base ranks */ + if (ssize > 1) { + err = mca_coll_acoll_barrier_recv_subc(subc->subgrp_comm, module, reqs, &nreqs, + subc->subgrp_root); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + /* All done */ + ompi_coll_base_free_reqs(reqs, nreqs); + return err; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_bcast.c b/ompi/mca/coll/acoll/coll_acoll_bcast.c new file mode 100644 index 00000000000..b423479db21 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_bcast.c @@ -0,0 +1,540 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +typedef int (*bcast_subc_func)(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs, + int world_rank); + +/* + * bcast_binomial + * + * Function: Broadcast operation using balanced binomial tree + * + * Description: Core logic of implementation is derived from that in + * "basic" component. + */ +static int bcast_binomial(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs, + int world_rank) +{ + int msb_pos, sub_rank, peer, err = MPI_SUCCESS; + int size, rank, dim; + int i, mask; + + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + dim = comm->c_cube_dim; + sub_rank = (rank - root + size) % size; + + msb_pos = opal_hibit(sub_rank, dim); + --dim; + + /* Receive data from parent in the subgroup tree. */ + if (sub_rank > 0) { + assert(msb_pos >= 0); + peer = ((sub_rank & ~(1 << msb_pos)) + root) % size; + + err = MCA_PML_CALL( + recv(buff, count, datatype, peer, MCA_COLL_BASE_TAG_BCAST, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + for (i = msb_pos + 1, mask = 1 << i; i <= dim; ++i, mask <<= 1) { + peer = sub_rank | mask; + if (peer < size) { + peer = (peer + root) % size; + *nreqs = *nreqs + 1; + + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_BCAST, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } + + return err; +} + +static int bcast_flat_tree(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, ompi_request_t **preq, int *nreqs, + int world_rank) +{ + int peer; + int err = MPI_SUCCESS; + int rank = ompi_comm_rank(comm); + int size = ompi_comm_size(comm); + + if (rank == root) { + for (peer = 0; peer < size; peer++) { + if (peer == root) { + continue; + } + *nreqs = *nreqs + 1; + err = MCA_PML_CALL(isend(buff, count, datatype, peer, MCA_COLL_BASE_TAG_BCAST, + MCA_PML_BASE_SEND_STANDARD, comm, preq++)); + if (MPI_SUCCESS != err) { + return err; + } + } + } else { + err = MCA_PML_CALL( + recv(buff, count, datatype, root, MCA_COLL_BASE_TAG_BCAST, comm, MPI_STATUS_IGNORE)); + if (MPI_SUCCESS != err) { + return err; + } + } + + return err; +} + +/* + * coll_bcast_decision_fixed + * + * Function: Choose optimal broadcast algorithm + * + * Description: Based on no. of processes and message size, chooses [log|lin] + * broadcast and subgroup size to be used. + * + */ + +#define SET_BCAST_PARAMS(l0, l1, l2) \ + *lin_0 = l0; \ + *lin_1 = l1; \ + *lin_2 = l2; + +static inline void coll_bcast_decision_fixed(int size, size_t total_dsize, int node_size, + int *sg_cnt, int *use_0, int *use_numa, int *lin_0, + int *lin_1, int *lin_2, + mca_coll_acoll_module_t *acoll_module, + coll_acoll_subcomms_t *subc) +{ + int sg_size = *sg_cnt; + *use_0 = 0; + *lin_0 = 0; + *use_numa = 0; + if (size <= node_size) { + if (size <= sg_size) { + *sg_cnt = sg_size; + if (total_dsize <= 8192) { + SET_BCAST_PARAMS(0, 0, 0) + } else { + SET_BCAST_PARAMS(0, 1, 1) + } + } else if (size <= (sg_size << 1)) { + if (total_dsize <= 1024) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 2097152) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else if (size <= (sg_size << 2)) { + if (total_dsize <= 1024) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 32768) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else if (total_dsize <= 4194304) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else if (size <= (sg_size << 3)) { + if (total_dsize <= 1024) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 262144) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } + } else if (size <= (sg_size << 4)) { + if (total_dsize <= 512) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 262144) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } + } else { + if (total_dsize <= 512) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= 262144) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else if (total_dsize <= 16777216) { + *sg_cnt = size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + *use_numa = 1; + SET_BCAST_PARAMS(0, 1, 1) + } + } + } else { + if (acoll_module->use_dyn_rules) { + *sg_cnt = acoll_module->mnode_sg_size; + *use_0 = acoll_module->use_mnode; + SET_BCAST_PARAMS(acoll_module->use_lin0, acoll_module->use_lin1, acoll_module->use_lin2) + } else { + int derived_node_size = subc->derived_node_size; + *use_0 = 1; + if (size <= (derived_node_size << 2)) { + size_t dsize_thresh[2][3] = {{512, 8192, 131072}, {128, 8192, 65536}}; + int thr_ind = (size <= (derived_node_size << 1)) ? 0 : 1; + if (total_dsize <= dsize_thresh[thr_ind][0]) { + *sg_cnt = node_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= dsize_thresh[thr_ind][1]) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } else if (total_dsize <= dsize_thresh[thr_ind][2]) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 1, 1) + } else { + *sg_cnt = node_size; + SET_BCAST_PARAMS(1, 1, 1) + } + } else if (size <= (derived_node_size << 3)) { + if (total_dsize <= 1024) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 1) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 0, 1) + } else if (total_dsize <= 65536) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 1, 1) + } else if (total_dsize <= 2097152) { + *sg_cnt = node_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else if (size <= (derived_node_size << 4)) { + if (total_dsize <= 64) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 1, 1) + } else if (total_dsize <= 8192) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 1) + } else if (total_dsize <= 32768) { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(1, 1, 1) + } else if (total_dsize <= 2097152) { + *sg_cnt = node_size; + SET_BCAST_PARAMS(0, 1, 1) + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } else { + *sg_cnt = sg_size; + SET_BCAST_PARAMS(0, 0, 0) + } + } + } +} + +static inline void coll_acoll_bcast_subcomms(struct ompi_communicator_t *comm, + coll_acoll_subcomms_t *subc, + struct ompi_communicator_t **subcomms, int *subc_roots, + int root, int num_nodes, int use_0, int no_sg, + int use_numa) +{ + /* Node leaders */ + if (use_0) { + subcomms[MCA_COLL_ACOLL_NODE_L] = subc->leader_comm; + subc_roots[MCA_COLL_ACOLL_NODE_L] = subc->outer_grp_root; + } + /* Intra comm */ + if ((num_nodes > 1) && use_0) { + subc_roots[MCA_COLL_ACOLL_INTRA] = subc->is_root_node + ? subc->local_root[MCA_COLL_ACOLL_LYR_NODE] + : 0; + subcomms[MCA_COLL_ACOLL_INTRA] = subc->local_comm; + } else { + subc_roots[MCA_COLL_ACOLL_INTRA] = root; + subcomms[MCA_COLL_ACOLL_INTRA] = comm; + } + /* Base ranks comm */ + if (no_sg) { + subcomms[MCA_COLL_ACOLL_L3_L] = subcomms[MCA_COLL_ACOLL_INTRA]; + subc_roots[MCA_COLL_ACOLL_L3_L] = subc_roots[MCA_COLL_ACOLL_INTRA]; + } else { + subcomms[MCA_COLL_ACOLL_L3_L] = subc->base_comm[MCA_COLL_ACOLL_L3CACHE] + [MCA_COLL_ACOLL_LYR_NODE]; + subc_roots[MCA_COLL_ACOLL_L3_L] = subc->base_root[MCA_COLL_ACOLL_L3CACHE] + [MCA_COLL_ACOLL_LYR_NODE]; + } + /* Subgroup comm */ + subcomms[MCA_COLL_ACOLL_LEAF] = subc->subgrp_comm; + subc_roots[MCA_COLL_ACOLL_LEAF] = subc->subgrp_root; + + /* Override with numa when needed */ + if (use_numa) { + subcomms[MCA_COLL_ACOLL_L3_L] = subc->base_comm[MCA_COLL_ACOLL_NUMA] + [MCA_COLL_ACOLL_LYR_NODE]; + subc_roots[MCA_COLL_ACOLL_L3_L] = subc->base_root[MCA_COLL_ACOLL_NUMA] + [MCA_COLL_ACOLL_LYR_NODE]; + subcomms[MCA_COLL_ACOLL_LEAF] = subc->numa_comm; + subc_roots[MCA_COLL_ACOLL_LEAF] = subc->numa_root; + } +} + +static int mca_coll_acoll_bcast_intra_node(void *buff, size_t count, struct ompi_datatype_t *datatype, + mca_coll_base_module_t *module, + coll_acoll_subcomms_t *subc, + struct ompi_communicator_t **subcomms, int *subc_roots, + int lin_1, int lin_2, int no_sg, int use_numa, + int world_rank) +{ + int size; + int rank; + int err; + int subgrp_size; + int is_base = 0; + int nreqs; + ompi_request_t **preq, **reqs; + struct ompi_communicator_t *comm = subcomms[MCA_COLL_ACOLL_INTRA]; + bcast_subc_func bcast_intra[2] = {&bcast_binomial, &bcast_flat_tree}; + + rank = ompi_comm_rank(comm); + size = ompi_comm_size(comm); + + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + nreqs = 0; + preq = reqs; + err = MPI_SUCCESS; + if (no_sg) { + is_base = 1; + } else { + int ind = use_numa ? MCA_COLL_ACOLL_NUMA : MCA_COLL_ACOLL_L3CACHE; + is_base = rank == subc->base_rank[ind] ? 1 : 0; + } + + /* All base ranks receive from root */ + if (is_base) { + err = bcast_intra[lin_1](buff, count, datatype, subc_roots[MCA_COLL_ACOLL_L3_L], + subcomms[MCA_COLL_ACOLL_L3_L], preq, &nreqs, world_rank); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + } + } + + /* If single stage, return */ + if (no_sg) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + + subgrp_size = use_numa ? ompi_comm_size(subc->numa_comm) : subc->subgrp_size; + /* All leaf ranks receive from the respective base rank */ + if ((subgrp_size > 1) && !no_sg) { + err = bcast_intra[lin_2](buff, count, datatype, subc_roots[MCA_COLL_ACOLL_LEAF], + subcomms[MCA_COLL_ACOLL_LEAF], preq, &nreqs, world_rank); + } + + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + } + } + + /* All done */ + ompi_coll_base_free_reqs(reqs, nreqs); + return err; +} + +/* + * mca_coll_acoll_bcast + * + * Function: Broadcast operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Bcast() + * Returns: MPI_SUCCESS or error code + * + * Description: Broadcast is performed across and within subgroups. + * O(N) or O(log(N)) algorithm within sunbgroup based on count. + * Subgroups can be 1 or more based on size and count. + * + * Limitations: None + * + * Memory: No additional memory requirements beyond user-supplied buffers. + * + */ +int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size; + int rank; + int err; + int nreqs; + ompi_request_t **preq, **reqs; + int sg_cnt, node_size; + int num_nodes; + int use_0 = 0; + int lin_0 = 0, lin_1 = 0, lin_2 = 0; + int use_numa = 0; + int no_sg; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + bcast_subc_func bcast_func[2] = {&bcast_binomial, &bcast_flat_tree}; + coll_acoll_subcomms_t *subc; + struct ompi_communicator_t *subcomms[MCA_COLL_ACOLL_NUM_SC] = {NULL}; + int subc_roots[MCA_COLL_ACOLL_NUM_SC] = {-1}; + int cid = ompi_comm_get_local_cid(comm); + + /* Fallback to knomial if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); + } + + subc = &acoll_module->subc[cid]; + /* Fallback to knomial if no. of root changes is beyond a threshold */ + if (subc->num_root_change > MCA_COLL_ACOLL_ROOT_CHANGE_THRESH) { + return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); + } + size = ompi_comm_size(comm); + if ((!subc->initialized || (root != subc->prev_init_root)) && size > 2) { + err = mca_coll_acoll_comm_split_init(comm, acoll_module, root); + if (MPI_SUCCESS != err) { + return err; + } + } + + ompi_datatype_type_size(datatype, &dsize); + total_dsize = dsize * count; + rank = ompi_comm_rank(comm); + sg_cnt = acoll_module->sg_cnt; + if (size > 2) { + num_nodes = subc->num_nodes; + node_size = ompi_comm_size(subc->local_comm); + } else { + num_nodes = 1; + node_size = size; + } + + /* Use knomial for nodes 8 and above and non-large messages */ + if ((num_nodes >= 8 && total_dsize <= 65536) + || (num_nodes == 1 && size >= 256 && total_dsize < 16384)) { + return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4); + } + + /* Determine the algorithm to be used based on size and count */ + /* sg_cnt determines subgroup based communication */ + /* lin_1 and lin_2 indicate whether to use linear or log based + sends/receives across and within subgroups respectively. */ + coll_bcast_decision_fixed(size, total_dsize, node_size, &sg_cnt, &use_0, &use_numa, &lin_0, + &lin_1, &lin_2, acoll_module, subc); + no_sg = (sg_cnt == node_size) ? 1 : 0; + if (size <= 2) + no_sg = 1; + + coll_acoll_bcast_subcomms(comm, subc, subcomms, subc_roots, root, num_nodes, use_0, no_sg, + use_numa); + + reqs = ompi_coll_base_comm_get_reqs(module->base_data, size); + if (NULL == reqs) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + nreqs = 0; + preq = reqs; + err = MPI_SUCCESS; + + if (use_0) { + if (subc_roots[MCA_COLL_ACOLL_NODE_L] != -1) { + err = bcast_func[lin_0](buff, count, datatype, subc_roots[MCA_COLL_ACOLL_NODE_L], + subcomms[MCA_COLL_ACOLL_NODE_L], preq, &nreqs, rank); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + } + + /* Start and wait on all requests. */ + if (nreqs > 0) { + err = ompi_request_wait_all(nreqs, reqs, MPI_STATUSES_IGNORE); + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + } + + err = mca_coll_acoll_bcast_intra_node(buff, count, datatype, module, subc, subcomms, subc_roots, + lin_1, lin_2, no_sg, use_numa, rank); + + if (MPI_SUCCESS != err) { + ompi_coll_base_free_reqs(reqs, nreqs); + return err; + } + + /* All done */ + ompi_coll_base_free_reqs(reqs, nreqs); + return err; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_component.c b/ompi/mca/coll/acoll/coll_acoll_component.c new file mode 100644 index 00000000000..0214bfc89be --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_component.c @@ -0,0 +1,346 @@ +/* -*- Mode: C; c-acoll-offset:4 ; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + * + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/mca/coll/coll.h" +#include "coll_acoll.h" + +/* + * Public string showing the coll ompi_acoll component version number + */ +const char *mca_coll_acoll_component_version_string + = "Open MPI acoll collective MCA component version " OMPI_VERSION; + +/* + * Global variables + */ +int mca_coll_acoll_priority = 0; +int mca_coll_acoll_sg_size = 8; +int mca_coll_acoll_sg_scale = 1; +int mca_coll_acoll_node_size = 128; +int mca_coll_acoll_use_dynamic_rules = 0; +int mca_coll_acoll_mnode_enable = 1; +int mca_coll_acoll_bcast_lin0 = 0; +int mca_coll_acoll_bcast_lin1 = 0; +int mca_coll_acoll_bcast_lin2 = 0; +int mca_coll_acoll_bcast_nonsg = 0; +int mca_coll_acoll_allgather_lin = 0; +int mca_coll_acoll_allgather_ring_1 = 0; +int mca_coll_acoll_reserve_memory_for_algo = 0; +uint64_t mca_coll_acoll_reserve_memory_size_for_algo = 128 * 32768; // 4 MB +uint64_t mca_coll_acoll_xpmem_buffer_size = 128 * 32768; + +/* By default utilize xpmem based algorithms applicable when built with xpmem. */ +int mca_coll_acoll_without_xpmem = 0; +int mca_coll_acoll_xpmem_use_sr_buf = 1; + +/* + * Local function + */ +static int acoll_register(void); + +/* + * Instantiate the public struct with all of our public information + * and pointers to our public functions in it + */ + +const mca_coll_base_component_3_0_0_t mca_coll_acoll_component = { + + /* First, the mca_component_t struct containing meta information + * about the component itself */ + + .collm_version = { + MCA_COLL_BASE_VERSION_3_0_0, + + /* Component name and version */ + .mca_component_name = "acoll", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + + /* Component open and close functions */ + .mca_register_component_params = acoll_register, + }, + .collm_data = { + /* The component is checkpoint ready */ + MCA_BASE_METADATA_PARAM_CHECKPOINT + }, + + /* Initialization / querying functions */ + + .collm_init_query = mca_coll_acoll_init_query, + .collm_comm_query = mca_coll_acoll_comm_query, +}; + +static int acoll_register(void) +{ + /* Use a low priority, but allow other components to be lower */ + mca_coll_acoll_priority = 0; + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "priority", + "Priority of the acoll coll component", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_priority); + + /* Defaults on topology */ + (void) + mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "sg_size", + "Size of subgroup to be used for subgroup based algorithms", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_sg_size); + + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "sg_scale", + "Scale factor for effective subgroup size for subgroup based algorithms", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_sg_scale); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "node_size", + "Size of node for multinode cases", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_node_size); + (void) + mca_base_component_var_register(&mca_coll_acoll_component.collm_version, + "use_dynamic_rules", + "Use dynamic selection of algorithms for multinode cases", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_use_dynamic_rules); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "mnode_enable", + "Enable separate algorithm for multinode cases", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_mnode_enable); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "bcast_lin0", + "Use lin/log for stage 0 of multinode algorithm", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_lin0); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "bcast_lin1", + "Use lin/log for stage 1 of multinode algorithm", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_lin1); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "bcast_lin2", + "Use lin/log for stage 2 of multinode algorithm", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_lin2); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "bcast_nonsg", + "Flag to turn on/off subgroup based algorithms for multinode", MCA_BASE_VAR_TYPE_INT, NULL, + 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_bcast_nonsg); + (void) mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "allgather_lin", + "Flag to indicate use of linear allgather for multinode", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_allgather_lin); + (void) + mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "allgather_ring_1", + "Flag to indicate use of ring/rd allgather for multinode", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_allgather_ring_1); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "reserve_memory_for_algo", + "Flag to inform the acoll component to reserve/pre-allocate memory" + " for use inside collective algorithms.", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_reserve_memory_for_algo); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "reserve_memory_size_for_algo", + "Size of memory to be allocated by acoll component to use as reserve" + "memory inside collective algorithms.", + MCA_BASE_VAR_TYPE_UINT64_T, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_reserve_memory_size_for_algo); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "without_xpmem", + "By default, xpmem-based algorithms are used when applicable. " + "When this flag is set to 1, xpmem-based algorithms are disabled.", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_without_xpmem); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "xpmem_buffer_size", + "Maximum size of memory that can be used for temporary buffers for " + "xpmem-based algorithms. By default these buffers are not created or " + "used unless xpmem_use_sr_buf is set to 0.", + MCA_BASE_VAR_TYPE_UINT64_T, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_xpmem_buffer_size); + (void) mca_base_component_var_register( + &mca_coll_acoll_component.collm_version, "xpmem_use_sr_buf", + "Uses application provided send/recv buffers during xpmem registration " + "when set to 1 instead of temporary buffers. The send/recv buffers are " + "assumed to persist for the duration of the application.", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_acoll_xpmem_use_sr_buf); + + return OMPI_SUCCESS; +} + +/* + * Module constructor + */ +static void mca_coll_acoll_module_construct(mca_coll_acoll_module_t *module) +{ + for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) { + coll_acoll_subcomms_t *subc = &module->subc[i]; + subc->initialized = 0; + subc->is_root_node = 0; + subc->is_root_sg = 0; + subc->is_root_numa = 0; + subc->outer_grp_root = -1; + subc->subgrp_root = 0; + subc->num_nodes = 1; + subc->prev_init_root = -1; + subc->num_root_change = 0; + subc->numa_root = 0; + subc->socket_ldr_root = -1; + subc->local_comm = NULL; + subc->local_r_comm = NULL; + subc->leader_comm = NULL; + subc->subgrp_comm = NULL; + subc->socket_comm = NULL; + subc->socket_ldr_comm = NULL; + for (int j = 0; j < MCA_COLL_ACOLL_NUM_LAYERS; j++) { + for (int k = 0; k < MCA_COLL_ACOLL_NUM_BASE_LYRS; k++) { + subc->base_comm[k][j] = NULL; + subc->base_root[k][j] = -1; + } + subc->local_root[j] = 0; + } + + subc->numa_comm = NULL; + subc->numa_comm_ldrs = NULL; + subc->node_comm = NULL; + subc->inter_comm = NULL; + subc->cid = -1; + subc->initialized_data = false; + subc->initialized_shm_data = false; + subc->data = NULL; +#ifdef HAVE_XPMEM_H + subc->xpmem_buf_size = mca_coll_acoll_xpmem_buffer_size; + subc->without_xpmem = mca_coll_acoll_without_xpmem; + subc->xpmem_use_sr_buf = mca_coll_acoll_xpmem_use_sr_buf; +#endif + } + + /* Reserve memory init. Lazy allocation of memory when needed. */ + (module->reserve_mem_s).reserve_mem = NULL; + (module->reserve_mem_s).reserve_mem_size = 0; + (module->reserve_mem_s).reserve_mem_allocate = false; + (module->reserve_mem_s).reserve_mem_in_use = false; + if ((0 != mca_coll_acoll_reserve_memory_for_algo) + && (0 < mca_coll_acoll_reserve_memory_size_for_algo) + && (false == ompi_mpi_thread_multiple)) { + (module->reserve_mem_s).reserve_mem_allocate = true; + (module->reserve_mem_s).reserve_mem_size = mca_coll_acoll_reserve_memory_size_for_algo; + } +} + +/* + * Module destructor + */ +static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module) +{ + + for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) { + coll_acoll_subcomms_t *subc = &module->subc[i]; + if (subc->initialized_data) { + if (subc->initialized_shm_data) { + if (subc->orig_comm != NULL) { + opal_shmem_unlink( + &((subc->data)->allshmseg_id[ompi_comm_rank(subc->orig_comm)])); + opal_shmem_segment_detach( + &((subc->data)->allshmseg_id[ompi_comm_rank(subc->orig_comm)])); + } + } + coll_acoll_data_t *data = subc->data; + if (NULL != data) { +#ifdef HAVE_XPMEM_H + for (int j = 0; j < data->comm_size; j++) { + xpmem_release(data->all_apid[j]); + xpmem_remove(data->allseg_id[j]); + mca_rcache_base_module_destroy(data->rcache[j]); + } + + free(data->allseg_id); + data->allseg_id = NULL; + free(data->all_apid); + data->all_apid = NULL; + free(data->allshm_sbuf); + data->allshm_sbuf = NULL; + free(data->allshm_rbuf); + data->allshm_rbuf = NULL; + free(data->xpmem_saddr); + data->xpmem_saddr = NULL; + free(data->xpmem_raddr); + data->xpmem_raddr = NULL; + free(data->scratch); + data->scratch = NULL; + free(data->rcache); + data->rcache = NULL; +#endif + free(data->allshmseg_id); + data->allshmseg_id = NULL; + free(data->allshmmmap_sbuf); + data->allshmmmap_sbuf = NULL; + free(data->l1_gp); + data->l1_gp = NULL; + free(data->l2_gp); + data->l2_gp = NULL; + free(data); + data = NULL; + } + } + + if (subc->local_comm != NULL) { + ompi_comm_free(&(subc->local_comm)); + subc->local_comm = NULL; + } + + if (subc->local_r_comm != NULL) { + ompi_comm_free(&(subc->local_r_comm)); + subc->local_r_comm = NULL; + } + + if (subc->leader_comm != NULL) { + ompi_comm_free(&(subc->leader_comm)); + subc->leader_comm = NULL; + } + + if (subc->subgrp_comm != NULL) { + ompi_comm_free(&(subc->subgrp_comm)); + subc->subgrp_comm = NULL; + } + if (subc->socket_comm != NULL) { + ompi_comm_free(&(subc->socket_comm)); + subc->socket_comm = NULL; + } + + if (subc->socket_ldr_comm != NULL) { + ompi_comm_free(&(subc->socket_ldr_comm)); + subc->socket_ldr_comm = NULL; + } + for (int k = 0; k < MCA_COLL_ACOLL_NUM_BASE_LYRS; k++) { + for (int j = 0; j < MCA_COLL_ACOLL_NUM_LAYERS; j++) { + if (subc->base_comm[k][j] != NULL) { + ompi_comm_free(&(subc->base_comm[k][j])); + subc->base_comm[k][j] = NULL; + } + } + } + subc->initialized = 0; + } + + if ((true == (module->reserve_mem_s).reserve_mem_allocate) + && (NULL != (module->reserve_mem_s).reserve_mem)) { + free((module->reserve_mem_s).reserve_mem); + } +} + +OBJ_CLASS_INSTANCE(mca_coll_acoll_module_t, mca_coll_base_module_t, mca_coll_acoll_module_construct, + mca_coll_acoll_module_destruct); diff --git a/ompi/mca/coll/acoll/coll_acoll_gather.c b/ompi/mca/coll/acoll/coll_acoll_gather.c new file mode 100644 index 00000000000..429b61296aa --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_gather.c @@ -0,0 +1,217 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +/* + * mca_coll_acoll_gather_intra + * + * Function: Gather operation using subgroup based algorithm + * Accepts: Same arguments as MPI_Gather() + * Returns: MPI_SUCCESS or error code + * + * Description: Gather is performed across and within subgroups. + * Subgroups can be 1 or more based on size and count. + * + * Limitations: Current implementation is optimal only for map-by core. + * + * Memory: The base rank of each subgroup may create temporary buffer. + * + */ +int mca_coll_acoll_gather_intra(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int i, err, rank, size; + char *wkg = NULL, *workbuf = NULL; + MPI_Status status; + MPI_Aint sextent, sgap = 0, ssize; + MPI_Aint rextent = 0; + size_t total_recv = 0; + int sg_cnt, node_cnt; + int cur_sg, root_sg; + int cur_node, root_node; + int is_base, is_local_root; + int startr, endr, inc; + int startn, endn; + int num_nodes; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_reserve_mem_t *reserve_mem_gather = &(acoll_module->reserve_mem_s); + + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + + sg_cnt = acoll_module->sg_cnt; + node_cnt = acoll_module->node_cnt; + num_nodes = (size + node_cnt - 1) / node_cnt; + /* For small messages for nodes 8 and above, fall back to normal */ + if (num_nodes >= 8 && (rcount < 262144)) { + node_cnt = size; + sg_cnt = size; + num_nodes = 1; + } + + /* Setup root for receive */ + if (rank == root) { + ompi_datatype_type_extent(rdtype, &rextent); + /* Just use the recv buffer */ + wkg = (char *) rbuf; + if (sbuf != MPI_IN_PLACE) { + MPI_Aint root_ofst = rextent * (ptrdiff_t) (rcount * root); + err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, wkg + (ptrdiff_t) root_ofst, + rcount, rdtype); + if (MPI_SUCCESS != err) { + return err; + } + } + total_recv = rcount; + } + + /* Setup base ranks of non-root subgroups for receive */ + cur_sg = rank / sg_cnt; + root_sg = root / sg_cnt; + is_base = (rank % sg_cnt == 0) && (cur_sg != root_sg); + startr = (rank / sg_cnt) * sg_cnt; + cur_node = rank / node_cnt; + root_node = root / node_cnt; + is_local_root = (rank % node_cnt == 0) && (cur_node != root_node); + startn = (rank / node_cnt) * node_cnt; + + if (is_base) { + size_t buf_size = is_local_root ? (size_t) scount * node_cnt : (size_t) scount * sg_cnt; + ompi_datatype_type_extent(sdtype, &sextent); + ssize = opal_datatype_span(&sdtype->super, buf_size, &sgap); + if (cur_sg != root_sg) { + char *tmprecv = NULL; + workbuf = (char *) coll_acoll_buf_alloc(reserve_mem_gather, ssize + sgap); + if (NULL == workbuf) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + wkg = workbuf - sgap; + tmprecv = wkg + sextent * (ptrdiff_t) (rcount * (rank - startr)); + /* local copy to workbuf */ + err = ompi_datatype_sndrcv((void *) sbuf, scount, sdtype, tmprecv, scount, sdtype); + if (MPI_SUCCESS != err) { + + return err; + } + } + rdtype = sdtype; + rcount = scount; + rextent = sextent; + total_recv = rcount; + } else if (rank != root) { + wkg = (char *) sbuf; + total_recv = scount; + } + + /* All base ranks receive from other ranks in their respective subgroup */ + endr = startr + sg_cnt; + if (endr > size) { + endr = size; + } + inc = (rank == root) ? ((root != 0) ? 0 : 1) : 1; + if (is_base || (rank == root)) { + for (i = startr + inc; i < endr; i++) { + char *tmprecv = NULL; + if (i == root) { + continue; + } + if (rank == root) { + tmprecv = wkg + rextent * (ptrdiff_t) (rcount * i); + } else { + tmprecv = wkg + rextent * (ptrdiff_t) (rcount * (i - startr)); + } + err = MCA_PML_CALL( + recv(tmprecv, rcount, rdtype, i, MCA_COLL_BASE_TAG_GATHER, comm, &status)); + total_recv += rcount; + } + } else { + int peer = (cur_sg == root_sg) ? root : startr; + err = MCA_PML_CALL(send(sbuf, scount, sdtype, peer, MCA_COLL_BASE_TAG_GATHER, + MCA_PML_BASE_SEND_STANDARD, comm)); + return err; + } + + /* All base ranks send to local root */ + endn = startn + node_cnt; + if (endn > size) { + endn = size; + } + if (sg_cnt < size) { + int local_root = (root_node == cur_node) ? root : startn; + for (i = startn; i < endn; i += sg_cnt) { + int i_sg = i / sg_cnt; + if ((rank != local_root) && (rank == i) && is_base) { + err = MCA_PML_CALL(send(workbuf - sgap, total_recv, sdtype, local_root, + MCA_COLL_BASE_TAG_GATHER, MCA_PML_BASE_SEND_STANDARD, + comm)); + } + if ((rank == local_root) && (rank != i) && (i_sg != root_sg)) { + size_t recv_amt = (i + sg_cnt > size) ? rcount * (size - i) : rcount * sg_cnt; + MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * (i - startn)); + + err = MCA_PML_CALL(recv(wkg + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i, + MCA_COLL_BASE_TAG_GATHER, comm, &status)); + total_recv += recv_amt; + } + if (MPI_SUCCESS != err) { + if (NULL != workbuf) { + coll_acoll_buf_free(reserve_mem_gather, workbuf); + } + return err; + } + } + } + + /* All local roots ranks send to root */ + if (node_cnt < size && num_nodes > 1) { + for (i = 0; i < size; i += node_cnt) { + int i_node = i / node_cnt; + if ((rank != root) && (rank == i) && is_base) { + err = MCA_PML_CALL(send(workbuf - sgap, total_recv, sdtype, root, + MCA_COLL_BASE_TAG_GATHER, MCA_PML_BASE_SEND_STANDARD, + comm)); + } + if ((rank == root) && (rank != i) && (i_node != root_node)) { + size_t recv_amt = (i + node_cnt > size) ? rcount * (size - i) : rcount * node_cnt; + MPI_Aint rcv_ofst = rextent * (ptrdiff_t) (rcount * i); + + err = MCA_PML_CALL(recv((char *) rbuf + (ptrdiff_t) rcv_ofst, recv_amt, rdtype, i, + MCA_COLL_BASE_TAG_GATHER, comm, &status)); + total_recv += recv_amt; + } + if (MPI_SUCCESS != err) { + if (NULL != workbuf) { + coll_acoll_buf_free(reserve_mem_gather, workbuf); + } + return err; + } + } + } + + if (NULL != workbuf) { + coll_acoll_buf_free(reserve_mem_gather, workbuf); + } + + /* All done */ + return MPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_module.c b/ompi/mca/coll/acoll/coll_acoll_module.c new file mode 100644 index 00000000000..b3b2afddc8b --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_module.c @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include + +#include "mpi.h" +#include "ompi/mca/coll/base/base.h" +#include "ompi/mca/coll/coll.h" +#include "coll_acoll.h" + + +static int acoll_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); +static int acoll_module_disable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm); + +/* + * Initial query function that is invoked during MPI_INIT, allowing + * this component to disqualify itself if it doesn't support the + * required level of thread support. + */ +int mca_coll_acoll_init_query(bool enable_progress_threads, bool enable_mpi_threads) +{ + /* Nothing to do */ + return OMPI_SUCCESS; +} + + + +#define ACOLL_INSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__module->super.coll_##__api) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "acoll"); \ + } \ + } while (0) + +#define ACOLL_UNINSTALL_COLL_API(__comm, __module, __api) \ + do \ + { \ + if (__comm->c_coll->coll_##__api##_module == &__module->super) \ + { \ + MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "acoll"); \ + } \ + } while (0) + +/* + * Invoked when there's a new communicator that has been created. + * Look at the communicator and decide which set of functions and + * priority we want to return. + */ +mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *comm, int *priority) +{ + mca_coll_acoll_module_t *acoll_module; + + acoll_module = OBJ_NEW(mca_coll_acoll_module_t); + if (NULL == acoll_module) { + return NULL; + } + + if (OMPI_COMM_IS_INTER(comm)) { + *priority = 0; + return NULL; + } + if (OMPI_COMM_IS_INTRA(comm) && ompi_comm_size(comm) < 2) { + *priority = 0; + return NULL; + } + + *priority = mca_coll_acoll_priority; + + /* Set topology params */ + acoll_module->sg_scale = mca_coll_acoll_sg_scale; + acoll_module->sg_size = mca_coll_acoll_sg_size; + acoll_module->sg_cnt = mca_coll_acoll_sg_size / mca_coll_acoll_sg_scale; + acoll_module->node_cnt = mca_coll_acoll_node_size; + if (mca_coll_acoll_sg_size == MCA_COLL_ACOLL_SG_SIZE_1) { + assert((acoll_module->sg_cnt == 1) || (acoll_module->sg_cnt == 2) + || (acoll_module->sg_cnt == 4) || (acoll_module->sg_cnt == 8)); + } + if (mca_coll_acoll_sg_size == MCA_COLL_ACOLL_SG_SIZE_2) { + assert((acoll_module->sg_cnt == 1) || (acoll_module->sg_cnt == 2) + || (acoll_module->sg_cnt == 4) || (acoll_module->sg_cnt == 8) + || (acoll_module->sg_cnt == 16)); + } + + switch (acoll_module->sg_cnt) { + case 1: + acoll_module->log2_sg_cnt = 0; + break; + case 2: + acoll_module->log2_sg_cnt = 1; + break; + case 4: + acoll_module->log2_sg_cnt = 2; + break; + case 8: + acoll_module->log2_sg_cnt = 3; + break; + case 16: + acoll_module->log2_sg_cnt = 4; + break; + default: + assert(0); + break; + } + + switch (acoll_module->node_cnt) { + case 96: + case 128: + acoll_module->log2_node_cnt = 7; + break; + case 192: + acoll_module->log2_node_cnt = 8; + break; + case 64: + acoll_module->log2_node_cnt = 6; + break; + case 32: + acoll_module->log2_node_cnt = 5; + break; + default: + assert(0); + break; + } + + acoll_module->use_dyn_rules = mca_coll_acoll_use_dynamic_rules; + acoll_module->use_mnode = mca_coll_acoll_mnode_enable; + acoll_module->use_lin0 = mca_coll_acoll_bcast_lin0; + acoll_module->use_lin1 = mca_coll_acoll_bcast_lin1; + acoll_module->use_lin2 = mca_coll_acoll_bcast_lin2; + if (mca_coll_acoll_bcast_nonsg) { + acoll_module->mnode_sg_size = acoll_module->node_cnt; + acoll_module->mnode_log2_sg_size = acoll_module->log2_node_cnt; + } else { + acoll_module->mnode_sg_size = acoll_module->sg_cnt; + acoll_module->mnode_log2_sg_size = acoll_module->log2_sg_cnt; + } + acoll_module->allg_lin = mca_coll_acoll_allgather_lin; + acoll_module->allg_ring = mca_coll_acoll_allgather_ring_1; + + /* Choose whether to use [intra|inter], and [subgroup|normal]-based + * algorithms. */ + acoll_module->super.coll_module_enable = acoll_module_enable; + acoll_module->super.coll_module_disable = acoll_module_disable; + + acoll_module->super.coll_allgather = mca_coll_acoll_allgather; + acoll_module->super.coll_allreduce = mca_coll_acoll_allreduce_intra; + acoll_module->super.coll_barrier = mca_coll_acoll_barrier_intra; + acoll_module->super.coll_bcast = mca_coll_acoll_bcast; + acoll_module->super.coll_gather = mca_coll_acoll_gather_intra; + acoll_module->super.coll_reduce = mca_coll_acoll_reduce_intra; + + return &(acoll_module->super); +} + +/* + * Init module on the communicator + */ +static int acoll_module_enable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) +{ + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + /* prepare the placeholder for the array of request* */ + module->base_data = OBJ_NEW(mca_coll_base_comm_t); + if (NULL == module->base_data) { + return OMPI_ERROR; + } + + ACOLL_INSTALL_COLL_API(comm, acoll_module, allgather); + ACOLL_INSTALL_COLL_API(comm, acoll_module, allreduce); + ACOLL_INSTALL_COLL_API(comm, acoll_module, barrier); + ACOLL_INSTALL_COLL_API(comm, acoll_module, bcast); + ACOLL_INSTALL_COLL_API(comm, acoll_module, gather); + ACOLL_INSTALL_COLL_API(comm, acoll_module, reduce); + + /* All done */ + return OMPI_SUCCESS; +} + +static int acoll_module_disable(mca_coll_base_module_t *module, struct ompi_communicator_t *comm) +{ + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, allgather); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, allreduce); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, barrier); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, bcast); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, gather); + ACOLL_UNINSTALL_COLL_API(comm, acoll_module, reduce); + + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_reduce.c b/ompi/mca/coll/acoll/coll_acoll_reduce.c new file mode 100644 index 00000000000..836c8893158 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_reduce.c @@ -0,0 +1,397 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/mca/pml/pml.h" +#include "ompi/op/op.h" +#include "opal/util/bit_ops.h" +#include "coll_acoll.h" +#include "coll_acoll_utils.h" + +static inline int coll_reduce_decision_fixed(int comm_size, size_t msg_size) +{ + /* Set default to topology aware algorithm */ + int alg = 0; + if (comm_size <= 8) { + /* Linear */ + alg = 1; + } else if (msg_size <= 8192) { + alg = 0; + } else if (msg_size <= 262144) { + /* Binomial */ + alg = 2; + } else if (msg_size <= 8388608 && comm_size < 64) { + alg = 1; + } else if (msg_size <= 8388608 && comm_size <= 128) { + /* In order binary */ + alg = 3; + } else { + alg = 2; + } + return alg; +} + +static inline int coll_acoll_reduce_topo(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + int root, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int ret = MPI_SUCCESS, rank, sz; + int cid = ompi_comm_get_local_cid(comm); + + ptrdiff_t dsize, gap = 0; + char *free_buffer = NULL; + char *pml_buffer = NULL; + char *tmp_rbuf = NULL; + char *tmp_sbuf = NULL; + + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc = &acoll_module->subc[cid]; + coll_acoll_reserve_mem_t *reserve_mem_rbuf_reduce = &(acoll_module->reserve_mem_s); + + rank = ompi_comm_rank(comm); + + tmp_sbuf = (char *) sbuf; + if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + tmp_sbuf = (char *) rbuf; + } + + int i; + int ind1 = MCA_COLL_ACOLL_L3CACHE; + int ind2 = MCA_COLL_ACOLL_LYR_NODE; + int is_base = rank == subc->base_rank[ind1] ? 1 : 0; + int bound = subc->subgrp_size; + + sz = ompi_comm_size(subc->base_comm[ind1][ind2]); + dsize = opal_datatype_span(&dtype->super, count, &gap); + if (rank == root) { + tmp_rbuf = rbuf; + } else if (is_base) { + tmp_rbuf = (char *) coll_acoll_buf_alloc(reserve_mem_rbuf_reduce, dsize); + if (NULL == tmp_rbuf) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + } + + if (is_base) { + ret = ompi_datatype_copy_content_same_ddt(dtype, count, (char *) tmp_rbuf, + (char *) tmp_sbuf); + free_buffer = (char *) malloc(dsize); + if (NULL == free_buffer) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + pml_buffer = free_buffer - gap; + } + + /* if not a local root, send the message to the local root */ + if (!is_base) { + ret = MCA_PML_CALL(send(tmp_sbuf, count, dtype, subc->subgrp_root, MCA_COLL_BASE_TAG_REDUCE, + MCA_PML_BASE_SEND_STANDARD, subc->subgrp_comm)); + } + + /* if local root, receive the message from other ranks within that group */ + if (is_base) { + for (i = 0; i < bound; i++) { + if (i == subc->subgrp_root) { + continue; + } + ret = MCA_PML_CALL(recv(pml_buffer, count, dtype, i, MCA_COLL_BASE_TAG_REDUCE, + subc->subgrp_comm, MPI_STATUS_IGNORE)); + ompi_op_reduce(op, pml_buffer, tmp_rbuf, count, dtype); + } + } + /* perform reduction at root */ + if (is_base && (sz > 1)) { + if (rank != root) { + ret = MCA_PML_CALL(send(tmp_rbuf, count, dtype, subc->base_root[ind1][ind2], + MCA_COLL_BASE_TAG_REDUCE, MCA_PML_BASE_SEND_STANDARD, + subc->base_comm[ind1][ind2])); + if (ret != MPI_SUCCESS) { + free(pml_buffer); + if (NULL != tmp_rbuf) { + coll_acoll_buf_free(reserve_mem_rbuf_reduce, tmp_rbuf); + } + return ret; + } + } + if (rank == root) { + for (i = 0; i < sz; i++) { + if (i == subc->base_root[ind1][ind2]) { + continue; + } + ret = MCA_PML_CALL(recv(pml_buffer, count, dtype, i, MCA_COLL_BASE_TAG_REDUCE, + subc->base_comm[ind1][ind2], MPI_STATUS_IGNORE)); + if (ret != MPI_SUCCESS) { + free(pml_buffer); + return ret; + } + ompi_op_reduce(op, pml_buffer, rbuf, count, dtype); + } + } + } + + /* if local root, reduce at root */ + if (is_base && (sz > 1)) { + free(pml_buffer); + if (rank != root && NULL != tmp_rbuf) { + coll_acoll_buf_free(reserve_mem_rbuf_reduce, tmp_rbuf); + } + } + + return ret; +} + +#ifdef HAVE_XPMEM_H +static inline int mca_coll_acoll_reduce_xpmem(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + int root, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + int size; + size_t total_dsize, dsize; + ptrdiff_t gap = 0; + + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + coll_acoll_init(module, comm, subc->data); + coll_acoll_reserve_mem_t *reserve_mem_rbuf_reduce = NULL; + if (subc->xpmem_use_sr_buf != 0) { + reserve_mem_rbuf_reduce = &(acoll_module->reserve_mem_s); + } + coll_acoll_data_t *data = subc->data; + if (NULL == data) { + return -1; + } + + size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + ompi_datatype_type_size(dtype, &dsize); + total_dsize = opal_datatype_span(&dtype->super, count, &gap); + + int l1_gp_size = data->l1_gp_size; + int *l1_gp = data->l1_gp; + int *l2_gp = data->l2_gp; + int l2_gp_size = data->l2_gp_size; + + int l1_local_rank = data->l1_local_rank; + int l2_local_rank = data->l2_local_rank; + + char *tmp_sbuf = NULL; + char *tmp_rbuf = NULL; + + if (subc->xpmem_use_sr_buf == 0) { + tmp_rbuf = (char *) data->scratch; + tmp_sbuf = (char *) data->scratch + (subc->xpmem_buf_size) / 2; + if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + memcpy(tmp_sbuf, rbuf, total_dsize); + } else { + memcpy(tmp_sbuf, sbuf, total_dsize); + } + } else { + tmp_sbuf = (char *) sbuf; + if ((sbuf == MPI_IN_PLACE) && (rank == root)) { + tmp_sbuf = (char *) rbuf; + } + + if (rank == root) { + tmp_rbuf = rbuf; + } else { + tmp_rbuf = (char *) coll_acoll_buf_alloc(reserve_mem_rbuf_reduce, total_dsize); + if (NULL == tmp_rbuf) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + } + } + void *sbuf_vaddr[1] = {tmp_sbuf}; + void *rbuf_vaddr[1] = {tmp_rbuf}; + + int ret; + + ret = comm->c_coll->coll_allgather(sbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_sbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + if (ret != MPI_SUCCESS) { + return ret; + } + ret = comm->c_coll->coll_allgather(rbuf_vaddr, sizeof(void *), MPI_BYTE, data->allshm_rbuf, + sizeof(void *), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + if (ret != MPI_SUCCESS) { + return ret; + } + + register_and_cache(size, total_dsize, rank, data); + + /* reduce to the group leader */ + size_t chunk = count / l1_gp_size; + size_t my_count_size = (l1_local_rank == (l1_gp_size - 1)) ? chunk + count % l1_gp_size : chunk; + + if (rank == l1_gp[0]) { + if (sbuf != MPI_IN_PLACE) + memcpy(tmp_rbuf, sbuf, my_count_size * dsize); + for (int i = 1; i < l1_gp_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) tmp_rbuf + chunk * l1_local_rank * dsize, my_count_size, dtype); + } + } else { + ompi_3buff_op_reduce(op, + (char *) data->xpmem_saddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + (char *) tmp_sbuf + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + for (int i = 1; i < l1_gp_size; i++) { + if (i == l1_local_rank) { + continue; + } + ompi_op_reduce(op, (char *) data->xpmem_saddr[l1_gp[i]] + chunk * l1_local_rank * dsize, + (char *) data->xpmem_raddr[l1_gp[0]] + chunk * l1_local_rank * dsize, + my_count_size, dtype); + } + } + ompi_coll_base_barrier_intra_tree(comm, module); + + /* perform reduce to 0 */ + int local_size = l2_gp_size; + if ((rank == l1_gp[0]) && (local_size > 1)) { + chunk = count / local_size; + my_count_size = (l2_local_rank == (local_size - 1)) ? chunk + (count % local_size) : chunk; + + if (l2_local_rank == 0) { + for (int i = 1; i < local_size; i++) { + ompi_op_reduce(op, (char *) data->xpmem_raddr[l2_gp[i]], (char *) tmp_rbuf, + my_count_size, dtype); + } + } else { + for (int i = 1; i < local_size; i++) { + if (i == l2_local_rank) { + continue; + } + ompi_op_reduce(op, + (char *) data->xpmem_raddr[l2_gp[i]] + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + ompi_op_reduce(op, (char *) tmp_rbuf + chunk * l2_local_rank * dsize, + (char *) data->xpmem_raddr[0] + chunk * l2_local_rank * dsize, + my_count_size, dtype); + } + } + ompi_coll_base_barrier_intra_tree(comm, module); + if (subc->xpmem_use_sr_buf == 0) { + if (rank == root) { + memcpy(rbuf, tmp_rbuf, total_dsize); + } + } else { + if ((rank != root) && (subc->xpmem_use_sr_buf != 0)) { + coll_acoll_buf_free(reserve_mem_rbuf_reduce, tmp_rbuf); + } + } + + return MPI_SUCCESS; +} +#endif + +int mca_coll_acoll_reduce_intra(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + int size, alg; + int num_nodes, ret; + size_t total_dsize, dsize; + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + + size = ompi_comm_size(comm); + if (size < 4) + return ompi_coll_base_reduce_intra_basic_linear(sbuf, rbuf, count, dtype, op, root, comm, + module); + + /* Falling back to inorder binary for non-commutative operators to be safe */ + if (!ompi_op_is_commute(op)) { + return ompi_coll_base_reduce_intra_in_order_binary(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + if (root != 0) { // ToDo: support non-zero root + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + + ompi_datatype_type_size(dtype, &dsize); + total_dsize = dsize * count; + + alg = coll_reduce_decision_fixed(size, total_dsize); + + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + + /* Fallback to knomial if cid is beyond supported limit */ + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + + subc = &acoll_module->subc[cid]; + if (!subc->initialized || (root != subc->prev_init_root)) { + ret = mca_coll_acoll_comm_split_init(comm, acoll_module, 0); + if (MPI_SUCCESS != ret) { + return ret; + } + } + + num_nodes = subc->num_nodes; + + if (num_nodes == 1) { + if (total_dsize < 262144) { + if (alg == -1 /* interaction with xpmem implementation causing issues 0*/) { + return coll_acoll_reduce_topo(sbuf, rbuf, count, dtype, op, root, comm, module); + } else if (alg == 1) { + return ompi_coll_base_reduce_intra_basic_linear(sbuf, rbuf, count, dtype, op, root, + comm, module); + } else if (alg == 2) { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, + comm, module, 0, 0); + } else { /*(alg == 3)*/ + return ompi_coll_base_reduce_intra_in_order_binary(sbuf, rbuf, count, dtype, op, + root, comm, module, 0, 0); + } + } else { +#ifdef HAVE_XPMEM_H + if ((((subc->xpmem_use_sr_buf != 0) + && (acoll_module->reserve_mem_s).reserve_mem_allocate + && ((acoll_module->reserve_mem_s).reserve_mem_size >= total_dsize)) + || ((subc->xpmem_use_sr_buf == 0) && (subc->xpmem_buf_size > 2 * total_dsize))) + && (subc->without_xpmem != 1)) { + return mca_coll_acoll_reduce_xpmem(sbuf, rbuf, count, dtype, op, root, comm, + module); + } else { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, + root, comm, module, 0, 0); + } +#else + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, + comm, module, 0, 0); +#endif + } + } else { + return ompi_coll_base_reduce_intra_binomial(sbuf, rbuf, count, dtype, op, root, comm, + module, 0, 0); + } + return MPI_SUCCESS; +} diff --git a/ompi/mca/coll/acoll/coll_acoll_utils.h b/ompi/mca/coll/acoll/coll_acoll_utils.h new file mode 100644 index 00000000000..4b98a73ccd5 --- /dev/null +++ b/ompi/mca/coll/acoll/coll_acoll_utils.h @@ -0,0 +1,788 @@ +/* -*- Mode: C; indent-tabs-mode:nil -*- */ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/communicator/communicator.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "opal/include/opal/align.h" + +#ifdef HAVE_XPMEM_H +#include "opal/mca/rcache/base/base.h" +#include +#endif + + +/* shared memory structure */ +/* first 16 * 1024 bytes (16KB) are used for the leader */ +/* next 2* 64 * comm_size bytes are used for sync variables */ +/* next 8 * 1024 * comm_size are used for per_rank data (8KB per rank) */ +/* offsets for the shared memory region */ +#define CACHE_LINE_SIZE 64 +#define LEADER_SHM_SIZE 16384 +#define PER_RANK_SHM_SIZE 8192 + + +/* Function to allocate scratch buffer */ +static inline void *coll_acoll_buf_alloc(coll_acoll_reserve_mem_t *reserve_mem_ptr, uint64_t size) +{ + void *temp_ptr = NULL; + /* If requested size is within the pre-allocated range, use the + pre-allocated buffer if not in use. */ + if ((true == reserve_mem_ptr->reserve_mem_allocate) + && (size <= reserve_mem_ptr->reserve_mem_size) + && (false == reserve_mem_ptr->reserve_mem_in_use)) { + if (NULL == reserve_mem_ptr->reserve_mem) { + reserve_mem_ptr->reserve_mem = malloc(reserve_mem_ptr->reserve_mem_size); + } + temp_ptr = reserve_mem_ptr->reserve_mem; + + /* Mark the buffer as "in use" */ + if (NULL != temp_ptr) { + reserve_mem_ptr->reserve_mem_in_use = true; + } + } else { + /* If requested size if greater than that of the pre-allocated + buffer or if the pre-allocated buffer is in use, create new buffer */ + temp_ptr = malloc(size); + } + + return temp_ptr; +} + +/* Function to free scratch buffer */ +static inline void coll_acoll_buf_free(coll_acoll_reserve_mem_t *reserve_mem_ptr, void *ptr) +{ + /* Free the buffer only if it is not the reserved (pre-allocated) one */ + if ((false == reserve_mem_ptr->reserve_mem_allocate) + || (false == reserve_mem_ptr->reserve_mem_in_use)) { + if (NULL != ptr) { + free(ptr); + } + } else if (reserve_mem_ptr->reserve_mem == ptr) { + /* Mark the reserved buffer as free to be used */ + reserve_mem_ptr->reserve_mem_in_use = false; + } +} + +/* Function to compare integer elements */ +static int compare_values(const void *ptra, const void *ptrb) +{ + int a = *((int *) ptra); + int b = *((int *) ptrb); + + if (a < b) { + return -1; + } else if (a > b) { + return 1; + } + + return 0; +} + +/* Function to map ranks from parent communicator to sub-communicator */ +static inline int comm_grp_ranks_local(ompi_communicator_t *comm, ompi_communicator_t *local_comm, + int *is_root_node, int *local_root, int **ranks_buf, + int root) +{ + ompi_group_t *local_grp, *grp; + int local_size = ompi_comm_size(local_comm); + int *ranks = malloc(local_size * sizeof(int)); + int *local_ranks = malloc(local_size * sizeof(int)); + int i, err; + + /* Create parent (comm) and sub-comm (local_comm) groups */ + err = ompi_comm_group(comm, &grp); + err = ompi_comm_group(local_comm, &local_grp); + /* Initialize ranks for sub-communicator (local_comm) */ + for (i = 0; i < local_size; i++) { + local_ranks[i] = i; + } + + /* Translate the ranks among the 2 communicators */ + err = ompi_group_translate_ranks(local_grp, local_size, local_ranks, grp, ranks); + if (ranks_buf != NULL) { + *ranks_buf = malloc(local_size * sizeof(int)); + memcpy(*ranks_buf, ranks, local_size * sizeof(int)); + } + + /* Derive the 'local_root' which is the equivalent rank for 'root' of + 'comm' in 'local_comm' */ + for (i = 0; i < local_size; i++) { + if (ranks[i] == root) { + *is_root_node = 1; + *local_root = i; + break; + } + } + + err = ompi_group_free(&grp); + err = ompi_group_free(&local_grp); + free(ranks); + free(local_ranks); + + return err; +} + +static inline int mca_coll_acoll_create_base_comm(ompi_communicator_t **parent_comm, + coll_acoll_subcomms_t *subc, int color, int rank, + int *root, int base_lyr) +{ + int i; + int err; + + for (i = 0; i < MCA_COLL_ACOLL_NUM_LAYERS; i++) { + int is_root_node = 0; + + /* Create base comm */ + err = ompi_comm_split(parent_comm[i], color, rank, &subc->base_comm[base_lyr][i], false); + if (MPI_SUCCESS != err) + return err; + + /* Find out local rank of root in base comm */ + err = comm_grp_ranks_local(parent_comm[i], subc->base_comm[base_lyr][i], &is_root_node, + &subc->base_root[base_lyr][i], NULL, root[i]); + } + return err; +} + +static inline int mca_coll_acoll_comm_split_init(ompi_communicator_t *comm, + mca_coll_acoll_module_t *acoll_module, int root) +{ + opal_info_t comm_info; + mca_coll_base_module_allreduce_fn_t coll_allreduce_org = (comm)->c_coll->coll_allreduce; + mca_coll_base_module_allgather_fn_t coll_allgather_org = (comm)->c_coll->coll_allgather; + mca_coll_base_module_bcast_fn_t coll_bcast_org = (comm)->c_coll->coll_bcast; + mca_coll_base_module_allreduce_fn_t coll_allreduce_loc, coll_allreduce_soc; + mca_coll_base_module_allgather_fn_t coll_allgather_loc, coll_allgather_soc; + mca_coll_base_module_bcast_fn_t coll_bcast_loc, coll_bcast_soc; + coll_acoll_subcomms_t *subc; + int err; + int size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + int cid = ompi_comm_get_local_cid(comm); + if (cid >= MCA_COLL_ACOLL_MAX_CID) { + return MPI_SUCCESS; + } + + /* Derive subcomm structure */ + subc = &acoll_module->subc[cid]; + subc->cid = cid; + subc->orig_comm = comm; + + (comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (comm)->c_coll->coll_allreduce = ompi_coll_base_allreduce_intra_recursivedoubling; + (comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; + if (!subc->initialized) { + OBJ_CONSTRUCT(&comm_info, opal_info_t); + opal_info_set(&comm_info, "ompi_comm_coll_preference", "libnbc,basic,^acoll"); + /* Create node-level subcommunicator */ + err = ompi_comm_split_type(comm, MPI_COMM_TYPE_SHARED, 0, &comm_info, &(subc->local_comm)); + if (MPI_SUCCESS != err) { + return err; + } + /* Create socket-level subcommunicator */ + err = ompi_comm_split_type(comm, OMPI_COMM_TYPE_SOCKET, 0, &comm_info, + &(subc->socket_comm)); + if (MPI_SUCCESS != err) { + return err; + } + OBJ_DESTRUCT(&comm_info); + OBJ_CONSTRUCT(&comm_info, opal_info_t); + opal_info_set(&comm_info, "ompi_comm_coll_preference", "libnbc,basic,^acoll"); + /* Create subgroup-level subcommunicator */ + err = ompi_comm_split_type(comm, OMPI_COMM_TYPE_L3CACHE, 0, &comm_info, + &(subc->subgrp_comm)); + if (MPI_SUCCESS != err) { + return err; + } + err = ompi_comm_split_type(comm, OMPI_COMM_TYPE_NUMA, 0, &comm_info, &(subc->numa_comm)); + if (MPI_SUCCESS != err) { + return err; + } + subc->subgrp_size = ompi_comm_size(subc->subgrp_comm); + OBJ_DESTRUCT(&comm_info); + + /* Derive the no. of nodes */ + if (size == ompi_comm_size(subc->local_comm)) { + subc->num_nodes = 1; + } else { + int *size_list_buf = (int *) malloc(size * sizeof(int)); + int num_nodes = 0; + int local_size = ompi_comm_size(subc->local_comm); + /* Perform allgather so that all ranks know the sizes of the nodes + to which all other ranks belong */ + err = (comm)->c_coll->coll_allgather(&local_size, 1, MPI_INT, size_list_buf, 1, MPI_INT, + comm, &acoll_module->super); + if (MPI_SUCCESS != err) { + free(size_list_buf); + return err; + } + /* Find the no. of nodes by counting each node only once. + * E.g., if there are 3 nodes with 2, 3 and 4 ranks on each node, + * first sort the size array so that the array elements are + * {2,2,3,3,3,4,4,4,4}. Read the value at the start of the array, + * offset the array by the read value, increment the counter, + * and repeat the process till end of array is reached. */ + qsort(size_list_buf, size, sizeof(int), compare_values); + for (int i = 0; i < size;) { + int ofst = size_list_buf[i]; + num_nodes++; + i += ofst; + } + subc->num_nodes = num_nodes; + free(size_list_buf); + } + } + /* Common initializations */ + { + subc->outer_grp_root = -1; + subc->subgrp_root = 0; + subc->is_root_sg = 0; + subc->is_root_numa = 0; + subc->numa_root = 0; + subc->is_root_socket = 0; + subc->socket_ldr_root = -1; + + if (subc->initialized) { + if (subc->num_nodes > 1) { + ompi_comm_free(&(subc->leader_comm)); + subc->leader_comm = NULL; + } + ompi_comm_free(&(subc->socket_ldr_comm)); + subc->socket_ldr_comm = NULL; + } + for (int i = 0; i < MCA_COLL_ACOLL_NUM_LAYERS; i++) { + if (subc->initialized) { + ompi_comm_free(&(subc->base_comm[MCA_COLL_ACOLL_L3CACHE][i])); + subc->base_comm[MCA_COLL_ACOLL_L3CACHE][i] = NULL; + ompi_comm_free(&(subc->base_comm[MCA_COLL_ACOLL_NUMA][i])); + subc->base_comm[MCA_COLL_ACOLL_NUMA][i] = NULL; + } + subc->base_root[MCA_COLL_ACOLL_L3CACHE][i] = -1; + subc->base_root[MCA_COLL_ACOLL_NUMA][i] = -1; + } + /* Store original collectives for local and socket comms */ + coll_allreduce_loc = (subc->local_comm)->c_coll->coll_allreduce; + coll_allgather_loc = (subc->local_comm)->c_coll->coll_allgather; + coll_bcast_loc = (subc->local_comm)->c_coll->coll_bcast; + (subc->local_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (subc->local_comm)->c_coll->coll_allreduce + = ompi_coll_base_allreduce_intra_recursivedoubling; + (subc->local_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; + coll_allreduce_soc = (subc->socket_comm)->c_coll->coll_allreduce; + coll_allgather_soc = (subc->socket_comm)->c_coll->coll_allgather; + coll_bcast_soc = (subc->socket_comm)->c_coll->coll_bcast; + (subc->socket_comm)->c_coll->coll_allgather = ompi_coll_base_allgather_intra_ring; + (subc->socket_comm)->c_coll->coll_allreduce + = ompi_coll_base_allreduce_intra_recursivedoubling; + (subc->socket_comm)->c_coll->coll_bcast = ompi_coll_base_bcast_intra_basic_linear; + } + + /* Further subcommunicators based on root */ + if (subc->num_nodes > 1) { + int local_rank = ompi_comm_rank(subc->local_comm); + int color = MPI_UNDEFINED; + int is_root_node = 0, is_root_socket = 0; + int local_root = 0; + int *subgrp_ranks = NULL, *numa_ranks = NULL, *socket_ranks = NULL; + ompi_communicator_t *parent_comm[MCA_COLL_ACOLL_NUM_LAYERS]; + + /* Initializations */ + subc->local_root[MCA_COLL_ACOLL_LYR_NODE] = 0; + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] = 0; + + /* Find out the local rank of root */ + err = comm_grp_ranks_local(comm, subc->local_comm, &subc->is_root_node, + &subc->local_root[MCA_COLL_ACOLL_LYR_NODE], NULL, root); + + /* Create subcommunicator with leader ranks */ + color = 1; + if (!subc->is_root_node && (local_rank == 0)) { + color = 0; + } + if (rank == root) { + color = 0; + } + err = ompi_comm_split(comm, color, rank, &subc->leader_comm, false); + if (MPI_SUCCESS != err) { + return err; + } + + /* Find out local rank of root in leader comm */ + err = comm_grp_ranks_local(comm, subc->leader_comm, &is_root_node, &subc->outer_grp_root, + NULL, root); + + /* Find out local rank of root in socket comm */ + if (subc->is_root_node) { + local_root = subc->local_root[MCA_COLL_ACOLL_LYR_NODE]; + } + err = comm_grp_ranks_local(subc->local_comm, subc->socket_comm, &subc->is_root_socket, + &subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET], &socket_ranks, + local_root); + + /* Create subcommunicator with socket leaders */ + subc->socket_rank = subc->is_root_socket == 1 ? local_root : socket_ranks[0]; + color = local_rank == subc->socket_rank ? 0 : 1; + err = ompi_comm_split(subc->local_comm, color, local_rank, &subc->socket_ldr_comm, false); + if (MPI_SUCCESS != err) + return err; + + /* Find out local rank of root in socket leader comm */ + err = comm_grp_ranks_local(subc->local_comm, subc->socket_ldr_comm, &is_root_socket, + &subc->socket_ldr_root, NULL, local_root); + + /* Find out local rank of root in subgroup comm */ + err = comm_grp_ranks_local(subc->local_comm, subc->subgrp_comm, &subc->is_root_sg, + &subc->subgrp_root, &subgrp_ranks, local_root); + + /* Create subcommunicator with base ranks */ + subc->base_rank[MCA_COLL_ACOLL_L3CACHE] = subc->is_root_sg == 1 ? local_root + : subgrp_ranks[0]; + color = local_rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE] ? 0 : 1; + parent_comm[MCA_COLL_ACOLL_LYR_NODE] = subc->local_comm; + parent_comm[MCA_COLL_ACOLL_LYR_SOCKET] = subc->socket_comm; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, local_rank, + subc->local_root, MCA_COLL_ACOLL_L3CACHE); + + /* Find out local rank of root in numa comm */ + err = comm_grp_ranks_local(subc->local_comm, subc->numa_comm, &subc->is_root_numa, + &subc->numa_root, &numa_ranks, local_root); + + subc->base_rank[MCA_COLL_ACOLL_NUMA] = subc->is_root_numa == 1 ? local_root : numa_ranks[0]; + color = local_rank == subc->base_rank[MCA_COLL_ACOLL_NUMA] ? 0 : 1; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, local_rank, + subc->local_root, MCA_COLL_ACOLL_NUMA); + + if (socket_ranks != NULL) { + free(socket_ranks); + socket_ranks = NULL; + } + if (subgrp_ranks != NULL) { + free(subgrp_ranks); + subgrp_ranks = NULL; + } + if (numa_ranks != NULL) { + free(numa_ranks); + numa_ranks = NULL; + } + } else { + /* Intra node case */ + int color; + int is_root_socket = 0; + int *subgrp_ranks = NULL, *numa_ranks = NULL, *socket_ranks = NULL; + ompi_communicator_t *parent_comm[MCA_COLL_ACOLL_NUM_LAYERS]; + + /* Initializations */ + subc->local_root[MCA_COLL_ACOLL_LYR_NODE] = root; + subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET] = 0; + + /* Find out local rank of root in socket comm */ + err = comm_grp_ranks_local(comm, subc->socket_comm, &subc->is_root_socket, + &subc->local_root[MCA_COLL_ACOLL_LYR_SOCKET], &socket_ranks, + root); + + /* Create subcommunicator with socket leaders */ + subc->socket_rank = subc->is_root_socket == 1 ? root : socket_ranks[0]; + color = rank == subc->socket_rank ? 0 : 1; + err = ompi_comm_split(comm, color, rank, &subc->socket_ldr_comm, false); + if (MPI_SUCCESS != err) { + return err; + } + + /* Find out local rank of root in socket leader comm */ + err = comm_grp_ranks_local(comm, subc->socket_ldr_comm, &is_root_socket, + &subc->socket_ldr_root, NULL, root); + + /* Find out local rank of root in subgroup comm */ + err = comm_grp_ranks_local(comm, subc->subgrp_comm, &subc->is_root_sg, &subc->subgrp_root, + &subgrp_ranks, root); + + /* Create subcommunicator with base ranks */ + subc->base_rank[MCA_COLL_ACOLL_L3CACHE] = subc->is_root_sg == 1 ? root : subgrp_ranks[0]; + color = rank == subc->base_rank[MCA_COLL_ACOLL_L3CACHE] ? 0 : 1; + parent_comm[MCA_COLL_ACOLL_LYR_NODE] = subc->local_comm; + parent_comm[MCA_COLL_ACOLL_LYR_SOCKET] = subc->socket_comm; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, rank, subc->local_root, + MCA_COLL_ACOLL_L3CACHE); + + int numa_rank; + numa_rank = ompi_comm_rank(subc->numa_comm); + color = (numa_rank == 0) ? 0 : 1; + err = ompi_comm_split(subc->local_comm, color, rank, &subc->numa_comm_ldrs, false); + + /* Find out local rank of root in numa comm */ + err = comm_grp_ranks_local(comm, subc->numa_comm, &subc->is_root_numa, &subc->numa_root, + &numa_ranks, root); + + subc->base_rank[MCA_COLL_ACOLL_NUMA] = subc->is_root_numa == 1 ? root : numa_ranks[0]; + color = rank == subc->base_rank[MCA_COLL_ACOLL_NUMA] ? 0 : 1; + err = mca_coll_acoll_create_base_comm(parent_comm, subc, color, rank, subc->local_root, + MCA_COLL_ACOLL_NUMA); + + if (socket_ranks != NULL) { + free(socket_ranks); + socket_ranks = NULL; + } + if (subgrp_ranks != NULL) { + free(subgrp_ranks); + subgrp_ranks = NULL; + } + if (numa_ranks != NULL) { + free(numa_ranks); + numa_ranks = NULL; + } + } + + /* Restore originals for local and socket comms */ + (subc->local_comm)->c_coll->coll_allreduce = coll_allreduce_loc; + (subc->local_comm)->c_coll->coll_allgather = coll_allgather_loc; + (subc->local_comm)->c_coll->coll_bcast = coll_bcast_loc; + (subc->socket_comm)->c_coll->coll_allreduce = coll_allreduce_soc; + (subc->socket_comm)->c_coll->coll_allgather = coll_allgather_soc; + (subc->socket_comm)->c_coll->coll_bcast = coll_bcast_soc; + + /* For collectives where order is important (like gather, allgather), + * split based on ranks. This is optimal for global communicators with + * equal split among nodes, but suboptimal for other cases. + */ + if (!subc->initialized) { + if (subc->num_nodes > 1) { + int node_size = (size + subc->num_nodes - 1) / subc->num_nodes; + int color = rank / node_size; + err = ompi_comm_split(comm, color, rank, &subc->local_r_comm, false); + if (MPI_SUCCESS != err) { + return err; + } + } + subc->derived_node_size = (size + subc->num_nodes - 1) / subc->num_nodes; + } + + /* Restore originals */ + (comm)->c_coll->coll_allreduce = coll_allreduce_org; + (comm)->c_coll->coll_allgather = coll_allgather_org; + (comm)->c_coll->coll_bcast = coll_bcast_org; + + /* Init done */ + subc->initialized = 1; + if (root != subc->prev_init_root) { + subc->num_root_change++; + } + subc->prev_init_root = root; + + return err; +} + +#ifdef HAVE_XPMEM_H +static inline int mca_coll_acoll_xpmem_register(void *xpmem_apid, void *base, size_t size, + mca_rcache_base_registration_t *reg) +{ + struct xpmem_addr xpmem_addr; + xpmem_addr.apid = *((xpmem_apid_t *) xpmem_apid); + xpmem_addr.offset = (uintptr_t) base; + struct acoll_xpmem_rcache_reg_t *xpmem_reg = (struct acoll_xpmem_rcache_reg_t *) reg; + xpmem_reg->xpmem_vaddr = xpmem_attach(xpmem_addr, size, NULL); + + if ((void *) -1 == xpmem_reg->xpmem_vaddr) { + return -1; + } + return 0; +} + +static inline int mca_coll_acoll_xpmem_deregister(void *xpmem_apid, + mca_rcache_base_registration_t *reg) +{ + int status = xpmem_detach(((struct acoll_xpmem_rcache_reg_t *) reg)->xpmem_vaddr); + return status; +} +#endif + +static inline int coll_acoll_init(mca_coll_base_module_t *module, ompi_communicator_t *comm, + coll_acoll_data_t *data) +{ + int size, ret = 0, rank, line; + + mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module; + coll_acoll_subcomms_t *subc; + int cid = ompi_comm_get_local_cid(comm); + subc = &acoll_module->subc[cid]; + if (subc->initialized_data) { + return ret; + } + subc->cid = cid; + data = (coll_acoll_data_t *) malloc(sizeof(coll_acoll_data_t)); + if (NULL == data) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + size = ompi_comm_size(comm); + rank = ompi_comm_rank(comm); + data->comm_size = size; + +#ifdef HAVE_XPMEM_H + if (subc->xpmem_use_sr_buf == 0) { + data->scratch = (char *) malloc(subc->xpmem_buf_size); + if (NULL == data->scratch) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + } else { + data->scratch = NULL; + } + + xpmem_segid_t seg_id; + data->allseg_id = (xpmem_segid_t *) malloc(sizeof(xpmem_segid_t) * size); + if (NULL == data->allseg_id) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->all_apid = (xpmem_apid_t *) malloc(sizeof(xpmem_apid_t) * size); + if (NULL == data->all_apid) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->allshm_sbuf = (void **) malloc(sizeof(void *) * size); + if (NULL == data->allshm_sbuf) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->allshm_rbuf = (void **) malloc(sizeof(void *) * size); + if (NULL == data->allshm_rbuf) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->xpmem_saddr = (void **) malloc(sizeof(void *) * size); + if (NULL == data->xpmem_saddr) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->xpmem_raddr = (void **) malloc(sizeof(void *) * size); + if (NULL == data->xpmem_raddr) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + data->rcache = (mca_rcache_base_module_t **) malloc(sizeof(mca_rcache_base_module_t *) * size); + if (NULL == data->rcache) { + line = __LINE__; + ret = OMPI_ERR_OUT_OF_RESOURCE; + goto error_hndl; + } + seg_id = xpmem_make(0, XPMEM_MAXADDR_SIZE, XPMEM_PERMIT_MODE, (void *) 0666); + if (seg_id == -1) { + line = __LINE__; + ret = -1; + goto error_hndl; + } + + ret = comm->c_coll->coll_allgather(&seg_id, sizeof(xpmem_segid_t), MPI_BYTE, data->allseg_id, + sizeof(xpmem_segid_t), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + /* Assuming the length of rcache name is less than 50 characters */ + char rc_name[50]; + for (int i = 0; i < size; i++) { + if (rank != i) { + data->all_apid[i] = xpmem_get(data->allseg_id[i], XPMEM_RDWR, XPMEM_PERMIT_MODE, + (void *) 0666); + if (data->all_apid[i] == -1) { + line = __LINE__; + ret = -1; + goto error_hndl; + } + if (data->all_apid[i] == -1) { + line = __LINE__; + ret = -1; + goto error_hndl; + } + sprintf(rc_name, "acoll_%d_%d_%d", cid, rank, i); + mca_rcache_base_resources_t rcache_element + = {.cache_name = rc_name, + .reg_data = &data->all_apid[i], + .sizeof_reg = sizeof(struct acoll_xpmem_rcache_reg_t), + .register_mem = mca_coll_acoll_xpmem_register, + .deregister_mem = mca_coll_acoll_xpmem_deregister}; + + data->rcache[i] = mca_rcache_base_module_create("grdma", NULL, &rcache_element); + if (data->rcache[i] == NULL) { + ret = -1; + line = __LINE__; + goto error_hndl; + } + } + } +#endif + + /* temporary variables */ + int tmp1, tmp2, tmp3 = 0; + + + + comm_grp_ranks_local(comm, subc->numa_comm, &tmp1, &tmp2, &data->l1_gp, tmp3); + data->l1_gp_size = ompi_comm_size(subc->numa_comm); + data->l1_local_rank = ompi_comm_rank(subc->numa_comm); + + comm_grp_ranks_local(comm, subc->numa_comm_ldrs, &tmp1, &tmp2, &data->l2_gp, tmp3); + data->l2_gp_size = ompi_comm_size(subc->numa_comm_ldrs); + data->l2_local_rank = ompi_comm_rank(subc->numa_comm_ldrs); + data->offset[0] = LEADER_SHM_SIZE; + data->offset[1] = data->offset[0] + size * CACHE_LINE_SIZE; + data->offset[2] = data->offset[1] + size * CACHE_LINE_SIZE; + data->offset[3] = data->offset[2] + rank * PER_RANK_SHM_SIZE; + data->allshmseg_id = (opal_shmem_ds_t *) malloc(sizeof(opal_shmem_ds_t) * size); + data->allshmmmap_sbuf = (void **) malloc(sizeof(void *) * size); + data->sync[0] = 0; + data->sync[1] = 0; + char *shfn; + + /* Only the leaders need to allocate shared memory */ + /* remaining ranks move their data into their leader's shm */ + if (data->l1_gp[0] == rank) { + subc->initialized_shm_data = true; + ret = asprintf(&shfn, "/dev/shm/acoll_coll_shmem_seg.%u.%x.%d:%d-%d", geteuid(), + OPAL_PROC_MY_NAME.jobid, ompi_comm_rank(MPI_COMM_WORLD), + ompi_comm_get_local_cid(comm), ompi_comm_size(comm)); + } + + if (ret < 0) { + line = __LINE__; + goto error_hndl; + } + + opal_shmem_ds_t seg_ds; + if (data->l1_gp[0] == rank) { + /* Assuming cacheline size is 64 */ + long memsize + = (LEADER_SHM_SIZE /* scratch leader */ + CACHE_LINE_SIZE * size /* sync variables l1 group*/ + + CACHE_LINE_SIZE * size /* sync variables l2 group*/ + PER_RANK_SHM_SIZE * size /*data from ranks*/); + ret = opal_shmem_segment_create(&seg_ds, shfn, memsize); + free(shfn); + } + + if (ret != OPAL_SUCCESS) { + opal_output_verbose(MCA_BASE_VERBOSE_ERROR, ompi_coll_base_framework.framework_output, + "coll:acoll: Error: Could not create shared memory segment"); + line = __LINE__; + goto error_hndl; + } + + ret = comm->c_coll->coll_allgather(&seg_ds, sizeof(opal_shmem_ds_t), MPI_BYTE, + data->allshmseg_id, sizeof(opal_shmem_ds_t), MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + + if (data->l1_gp[0] != rank) { + data->allshmmmap_sbuf[data->l1_gp[0]] = opal_shmem_segment_attach( + &data->allshmseg_id[data->l1_gp[0]]); + } else { + for (int i = 0; i < data->l2_gp_size; i++) { + data->allshmmmap_sbuf[data->l2_gp[i]] = opal_shmem_segment_attach( + &data->allshmseg_id[data->l2_gp[i]]); + } + } + + int offset = LEADER_SHM_SIZE; + memset(((char *) data->allshmmmap_sbuf[data->l1_gp[0]]) + offset + CACHE_LINE_SIZE * rank, 0, CACHE_LINE_SIZE); + if (data->l1_gp[0] == rank) { + memset(((char *) data->allshmmmap_sbuf[data->l2_gp[0]]) + (offset + CACHE_LINE_SIZE * size) + CACHE_LINE_SIZE * rank, + 0, CACHE_LINE_SIZE); + } + + subc->initialized_data = true; + subc->data = data; + ompi_coll_base_barrier_intra_tree(comm, module); + + return MPI_SUCCESS; +error_hndl: + (void) line; + if (NULL != data) { +#ifdef HAVE_XPMEM_H + free(data->allseg_id); + data->allseg_id = NULL; + free(data->all_apid); + data->all_apid = NULL; + free(data->allshm_sbuf); + data->allshm_sbuf = NULL; + free(data->allshm_rbuf); + data->allshm_rbuf = NULL; + free(data->xpmem_saddr); + data->xpmem_saddr = NULL; + free(data->xpmem_raddr); + data->xpmem_raddr = NULL; + free(data->rcache); + data->rcache = NULL; + free(data->scratch); + data->scratch = NULL; +#endif + free(data->allshmseg_id); + data->allshmseg_id = NULL; + free(data->allshmmmap_sbuf); + data->allshmmmap_sbuf = NULL; + free(data->l1_gp); + data->l1_gp = NULL; + free(data->l2_gp); + data->l2_gp = NULL; + free(data); + data = NULL; + } + return ret; +} + +#ifdef HAVE_XPMEM_H +static inline void register_and_cache(int size, size_t total_dsize, int rank, + coll_acoll_data_t *data) +{ + uintptr_t base, bound; + for (int i = 0; i < size; i++) { + if (rank != i) { + mca_rcache_base_module_t *rcache_i = data->rcache[i]; + int access_flags = 0; + struct acoll_xpmem_rcache_reg_t *sbuf_reg = NULL, *rbuf_reg = NULL; + base = OPAL_DOWN_ALIGN((uintptr_t) data->allshm_sbuf[i], 4096, uintptr_t); + bound = OPAL_ALIGN((uintptr_t) data->allshm_sbuf[i] + total_dsize, 4096, uintptr_t); + int ret = rcache_i->rcache_register(rcache_i, (void *) base, bound - base, access_flags, + MCA_RCACHE_ACCESS_ANY, + (mca_rcache_base_registration_t **) &sbuf_reg); + + if (ret != 0) { + sbuf_reg = NULL; + return; + } + data->xpmem_saddr[i] = (void *) ((uintptr_t) sbuf_reg->xpmem_vaddr + + ((uintptr_t) data->allshm_sbuf[i] + - (uintptr_t) sbuf_reg->base.base)); + + base = OPAL_DOWN_ALIGN((uintptr_t) data->allshm_rbuf[i], 4096, uintptr_t); + bound = OPAL_ALIGN((uintptr_t) data->allshm_rbuf[i] + total_dsize, 4096, uintptr_t); + ret = rcache_i->rcache_register(rcache_i, (void *) base, bound - base, access_flags, + MCA_RCACHE_ACCESS_ANY, + (mca_rcache_base_registration_t **) &rbuf_reg); + + if (ret != 0) { + rbuf_reg = NULL; + return; + } + data->xpmem_raddr[i] = (void *) ((uintptr_t) rbuf_reg->xpmem_vaddr + + ((uintptr_t) data->allshm_rbuf[i] + - (uintptr_t) rbuf_reg->base.base)); + } else { + data->xpmem_saddr[i] = data->allshm_sbuf[i]; + data->xpmem_raddr[i] = data->allshm_rbuf[i]; + } + } +} +#endif diff --git a/ompi/mca/coll/acoll/configure.m4 b/ompi/mca/coll/acoll/configure.m4 new file mode 100644 index 00000000000..339b34c567c --- /dev/null +++ b/ompi/mca/coll/acoll/configure.m4 @@ -0,0 +1,18 @@ +# +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +AC_DEFUN([MCA_ompi_coll_acoll_CONFIG],[ + AC_CONFIG_FILES([ompi/mca/coll/acoll/Makefile]) + + OPAL_CHECK_XPMEM([coll_acoll], [should_build=1], [should_build=1]) + + AC_SUBST([coll_acoll_CPPFLAGS]) + AC_SUBST([coll_acoll_LDFLAGS]) + AC_SUBST([coll_acoll_LIBS]) +])dnl diff --git a/ompi/mca/coll/acoll/owner.txt b/ompi/mca/coll/acoll/owner.txt new file mode 100644 index 00000000000..6bd0a386110 --- /dev/null +++ b/ompi/mca/coll/acoll/owner.txt @@ -0,0 +1,7 @@ +# +# owner/status file +# owner: institution that is responsible for this package +# status: e.g. active, maintenance, unmaintained +# +owner: AMD +status: active \ No newline at end of file