diff --git a/prov/efa/src/efa_prov_info.c b/prov/efa/src/efa_prov_info.c index 0bcbf4ef65b..4b659b462e3 100644 --- a/prov/efa/src/efa_prov_info.c +++ b/prov/efa/src/efa_prov_info.c @@ -388,12 +388,36 @@ static int efa_prov_info_set_nic_attr(struct fi_info *prov_info, struct efa_devi } #if HAVE_CUDA || HAVE_NEURON || HAVE_SYNAPSEAI -void efa_prov_info_set_hmem_flags(struct fi_info *prov_info) +void efa_prov_info_set_hmem_flags(struct fi_info *prov_info, enum fi_ep_type ep_type) { - if (prov_info->ep_attr->type == FI_EP_RDM && - (ofi_hmem_is_initialized(FI_HMEM_CUDA) || + int i; + enum fi_hmem_iface iface; + struct efa_hmem_info *hmem_info; + bool enable_hmem = false; + + if (ep_type != FI_EP_RDM) + return; + + /* EFA direct only supports HMEM when p2p support is available */ + if ((ofi_hmem_is_initialized(FI_HMEM_CUDA) || ofi_hmem_is_initialized(FI_HMEM_NEURON) || ofi_hmem_is_initialized(FI_HMEM_SYNAPSEAI))) { + EFA_HMEM_IFACE_FOREACH(i) { + iface = efa_hmem_ifaces[i]; + hmem_info = &g_efa_hmem_info[iface]; + if (hmem_info->initialized && !hmem_info->p2p_supported_by_device) { + EFA_INFO(FI_LOG_CORE, + "EFA direct provider was compiled with support for %s HMEM interface " + "but the interface does not support p2p transfers. " + "EFA direct provider does not support HMEM transfers without p2p support. " + "HMEM support will be disabled.\n", fi_tostr(&iface, FI_TYPE_HMEM_IFACE)); + return; + } + } + enable_hmem = true; + } + + if (enable_hmem) { prov_info->caps |= FI_HMEM; prov_info->tx_attr->caps |= FI_HMEM; prov_info->rx_attr->caps |= FI_HMEM; @@ -401,7 +425,7 @@ void efa_prov_info_set_hmem_flags(struct fi_info *prov_info) } } #else -void efa_prov_info_set_hmem_flags(struct fi_info *prov_info) +void efa_prov_info_set_hmem_flags(struct fi_info *prov_info, enum fi_ep_type ep_type) { } #endif @@ -480,7 +504,7 @@ int efa_prov_info_alloc(struct fi_info **prov_info_ptr, goto err_free; } - efa_prov_info_set_hmem_flags(prov_info); + efa_prov_info_set_hmem_flags(prov_info, ep_type); *prov_info_ptr = prov_info; return 0; diff --git a/prov/efa/src/efa_prov_info.h b/prov/efa/src/efa_prov_info.h index c5b3ff93c4a..ad34e8fe03f 100644 --- a/prov/efa/src/efa_prov_info.h +++ b/prov/efa/src/efa_prov_info.h @@ -22,4 +22,6 @@ int efa_prov_info_compare_domain_name(const struct fi_info *hints, int efa_prov_info_compare_pci_bus_id(const struct fi_info *hints, const struct fi_info *info); +void efa_prov_info_set_hmem_flags(struct fi_info *prov_info, enum fi_ep_type ep_type); + #endif diff --git a/prov/efa/test/efa_unit_test_info.c b/prov/efa/test/efa_unit_test_info.c index db52ccd0594..1f2420ffc97 100644 --- a/prov/efa/test/efa_unit_test_info.c +++ b/prov/efa/test/efa_unit_test_info.c @@ -2,6 +2,7 @@ /* SPDX-FileCopyrightText: Copyright Amazon.com, Inc. or its affiliates. All rights reserved. */ #include "efa_unit_tests.h" +#include "efa_prov_info.h" /** * @brief test that when a wrong fi_info was used to open resource, the error is handled @@ -113,6 +114,60 @@ void test_info_direct_attributes() } } +/** + * @brief Verify that efa direct only supports HMEM with p2p + */ +#if HAVE_CUDA || HAVE_NEURON || HAVE_SYNAPSEAI +void test_info_direct_hmem_support_p2p() +{ + struct fi_info *info; + bool hmem_ops_cuda_init; + + info = fi_allocinfo(); + + memset(g_efa_hmem_info, 0, OFI_HMEM_MAX * sizeof(struct efa_hmem_info)); + + /* Save current value of hmem_ops[FI_HMEM_CUDA].initialized to reset later + * hmem_ops is populated in ofi_hmem_init and only runs once + * + * CUDA iface will be initialized on Nvidia GPU platforms but not on others + * Force setting hmem_ops[FI_HMEM_CUDA].initialized allows this test to + * run on all instance types + */ + hmem_ops_cuda_init = hmem_ops[FI_HMEM_CUDA].initialized; + hmem_ops[FI_HMEM_CUDA].initialized = true; + + /* g_efa_hmem_info is populated in efa_hmem_info_initialize which runs on + * every fi_getinfo call. So no need to save and reset these fields + */ + g_efa_hmem_info[FI_HMEM_CUDA].initialized = true; + g_efa_hmem_info[FI_HMEM_CUDA].p2p_supported_by_device = true; + + efa_prov_info_set_hmem_flags(info, FI_EP_RDM); + assert_true(info->caps & FI_HMEM); + assert_true(info->tx_attr->caps & FI_HMEM); + assert_true(info->rx_attr->caps & FI_HMEM); + fi_freeinfo(info); + + info = fi_allocinfo(); + g_efa_hmem_info[FI_HMEM_CUDA].initialized = true; + g_efa_hmem_info[FI_HMEM_CUDA].p2p_supported_by_device = false; + + efa_prov_info_set_hmem_flags(info, FI_EP_RDM); + assert_false(info->caps & FI_HMEM); + assert_false(info->tx_attr->caps & FI_HMEM); + assert_false(info->rx_attr->caps & FI_HMEM); + fi_freeinfo(info); + + /* Reset hmem_ops[FI_HMEM_CUDA].initialized */ + hmem_ops[FI_HMEM_CUDA].initialized = hmem_ops_cuda_init; +} +#else +void test_info_direct_hmem_support_p2p() +{ +} +#endif + /** * @brief Verify info->tx/rx_attr->msg_order is set according to hints. * diff --git a/prov/efa/test/efa_unit_tests.c b/prov/efa/test/efa_unit_tests.c index 3e831fe3a65..74b6c54a777 100644 --- a/prov/efa/test/efa_unit_tests.c +++ b/prov/efa/test/efa_unit_tests.c @@ -145,6 +145,7 @@ int main(void) cmocka_unit_test_setup_teardown(test_info_rdm_attributes, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_dgram_attributes, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_direct_attributes, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), + cmocka_unit_test_setup_teardown(test_info_direct_hmem_support_p2p, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_tx_rx_msg_order_rdm_order_none, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_tx_rx_msg_order_rdm_order_sas, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), cmocka_unit_test_setup_teardown(test_info_tx_rx_msg_order_dgram_order_none, efa_unit_test_mocks_setup, efa_unit_test_mocks_teardown), diff --git a/prov/efa/test/efa_unit_tests.h b/prov/efa/test/efa_unit_tests.h index 5b9379feced..187aea2c4e5 100644 --- a/prov/efa/test/efa_unit_tests.h +++ b/prov/efa/test/efa_unit_tests.h @@ -166,6 +166,7 @@ void test_info_open_ep_with_wrong_info(); void test_info_rdm_attributes(); void test_info_dgram_attributes(); void test_info_direct_attributes(); +void test_info_direct_hmem_support_p2p(); void test_info_tx_rx_msg_order_rdm_order_none(); void test_info_tx_rx_msg_order_rdm_order_sas(); void test_info_tx_rx_msg_order_dgram_order_none();