diff --git a/CMakeLists.txt b/CMakeLists.txt index b25a6d1641e42817a237b7699f0b5b37840baf42..4dabc368efb5b828c5dfdb284fdbaccebf50ef96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,8 +38,8 @@ message(STATUS "USE_EXAMPLES:${USE_EXAMPLES}") message(STATUS "ENABLE_ASCENDC_DUMP:${ENABLE_ASCENDC_DUMP}") option(USE_FUZZ_TEST "USE_FUZZ_TEST" OFF) message(STATUS "USE_FUZZ_TEST:${USE_FUZZ_TEST}") - -set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) +option(SHMEM_RDMA_SUPPORT "SHMEM_RDMA_SUPPORT" OFF) +message(STATUS "SHMEM_RDMA_SUPPORT:${SHMEM_RDMA_SUPPORT}") set(CMAKE_COMPILER bisheng) set(CMAKE_C_COMPILER ${CMAKE_COMPILER}) @@ -85,15 +85,30 @@ include_directories( ${ASCEND_HOME_PATH}/include ${ASCEND_HOME_PATH}/include/experiment/runtime ${ASCEND_HOME_PATH}/include/experiment/msprof - ${ASCEND_DRIVER_PATH}/kernel/inc ) link_directories( ${ASCEND_HOME_PATH}/lib64 - ${ASCEND_DRIVER_PATH}/lib64/driver ) -link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase ascend_hal pthread) +link_libraries(runtime stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread) + +# MF_BACKEND +set(USE_MF "0") + +if ("${USE_MF}" STREQUAL "1") + add_compile_definitions(BACKEND_MF=1) + include_directories( + ${PROJECT_SOURCE_DIR}/include/ + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ + ) + + link_libraries( + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so + ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_hybm_core.so + ) +endif() # 添加子目录 add_subdirectory(src) diff --git a/OWNERS b/OWNERS index e90f9031b13b1c7f93aa3143a99ab27bbe3cea1a..0b91d03669c84c08134cfb0c7ed78dfc6423f681 100644 --- a/OWNERS +++ b/OWNERS @@ -4,6 +4,8 @@ approvers: - nino233 - baoxiaom - victorwaang +- oioring +- gujianxiao reviewers: - git_ray @@ -19,4 +21,5 @@ reviewers: - huangxiaolan - Vector - lenokia -- victorwaang \ No newline at end of file +- victorwaang +- gujianxiao \ No newline at end of file diff --git a/docs/quickstart.md b/docs/quickstart.md index 705ed1f63f2899941256467a660153280ad462fd..5124f318d12b06c0456d2f43d43fbb6036db3715 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -168,4 +168,24 @@ shmem_init_attr_t *attr; int ret = shmem_get_uniqueid(&uid); ret = shmem_set_attr(my_rank, n_ranks, mem_size, nullptr, &attr); // 第4个参数是ip_port,当前场景传入nullptr ret = shmem_set_attr_uniqueid_args(my_rank, n_ranks, &uid, attr); + +## shmem方式 +注:使用unique id的接口初始化,可以手动配置环境变量SHMEM_UID_SESSION_ID或者SHMEM_UID_SOCK_IFNAM,同时配置时只读SHMEM_UID_SESSION_ID,都不配置会自动搜索可用网口。 +SHMEM_UID_SESSION_ID配置示例: +SHMEM_UID_SESSION_ID=127.0.0.1:1234 +SHMEM_UID_SESSION_ID=[6666:6666:6666:6666:6666:6666:6666:6666]:886 +SHMEM_UID_SESSION_ID=[6666:6666:6666:6666:6666:6666:6666:6666%eth]:886 +SHMEM_UID_SOCK_IFNAM配置示例: +SHMEM_UID_SOCK_IFNAM=enpxxxx:inet4 取ipv4 +SHMEM_UID_SOCK_IFNAM=enpxxxx:inet6 取ipv6 +不配置默认取inet4自动搜索可用网口,搜索优先级:非docker、lo>>docker>>lo。 + + +- c++初始化例子 +```cpp +shmemx_uniqueid_t uid; +shmem_init_attr_t *attr; +int ret = shmem_get_uniqueid(&uid); +shmemx_set_attr_uniqueid_args(rank, rank_size, local_mem_size, &uid, &attributes); +status = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attributes); ``` \ No newline at end of file diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 83c2de081bd026a3dc357f1c3d3b39ec10d4da73..617cb916eb159806e3947b354545db1280738977 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -15,19 +15,21 @@ function(shmem_add_fusion_example NAME) target_compile_options(${NAME} PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220) target_include_directories(${NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ ${PROJECT_SOURCE_DIR}/3rdparty/catlass/include ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} ${PROJECT_SOURCE_DIR}/examples/templates/include ${PROJECT_SOURCE_DIR}/examples/utils + ${MPI_INCLUDE_PATH} ) target_link_options(${NAME} PRIVATE --cce-fatobj-link) - target_link_libraries(${NAME} PRIVATE shmem ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + if(DEFINED ENABLE_ASCENDC_DUMP AND ENABLE_ASCENDC_DUMP) target_link_libraries(${NAME} PRIVATE ascend_dump) endif() + target_link_libraries(${NAME} PRIVATE shmem) + target_compile_options(${NAME} PRIVATE ${MPI_CXX_COMPILE_FLAGS}) + endfunction() function(shmem_add_collective_example NAME) @@ -35,26 +37,27 @@ function(shmem_add_collective_example NAME) target_compile_options(${NAME}_kernel PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220-vec) target_include_directories(${NAME}_kernel PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ ${PROJECT_SOURCE_DIR}/3rdparty/catlass/include ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} ${PROJECT_SOURCE_DIR}/examples/utils ) target_link_options(${NAME}_kernel PRIVATE --cce-fatobj-link) - + add_executable(${NAME} main.cpp) target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(${NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ ${PROJECT_SOURCE_DIR}/3rdparty/catlass/examples/common ${PROJECT_SOURCE_DIR}/examples/${NAME} ${PROJECT_SOURCE_DIR}/examples/templates/include ${PROJECT_SOURCE_DIR}/examples/utils ${PROJECT_SOURCE_DIR}/src/host + ${MPI_INCLUDE_PATH} ) - target_link_libraries(${NAME} PRIVATE shmem ${NAME}_kernel ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so) + target_link_libraries(${NAME} PRIVATE shmem ${NAME}_kernel) + target_compile_options(${NAME} PRIVATE ${MPI_CXX_COMPILE_FLAGS}) + endfunction() foreach(EXAMPLE @@ -68,10 +71,17 @@ foreach(EXAMPLE matmul_reduce_scatter matmul_reduce_scatter_padding dynamic_tiling - rdma_perftest - rdma_demo - rdma_handlewait_test/unuse_handlewait - rdma_handlewait_test/use_handlewait ) add_subdirectory(${EXAMPLE}) -endforeach() \ No newline at end of file +endforeach() + +if(SHMEM_RDMA_SUPPORT) + foreach(EXAMPLE + rdma_perftest + rdma_demo + rdma_handlewait_test/unuse_handlewait + rdma_handlewait_test/use_handlewait + ) + add_subdirectory(${EXAMPLE}) +endforeach() +endif() \ No newline at end of file diff --git a/examples/allgather/README.md b/examples/allgather/README.md index 039fb7b357978b7b11c56d6190edf44035cbb01d..1f291184bc823259c90b9a94d9d0a2733760f9c3 100644 --- a/examples/allgather/README.md +++ b/examples/allgather/README.md @@ -6,4 +6,16 @@ # 完成RANKS卡下的allgather同时验证精度,性能数据会输出在result.csv中。 # RANKS : [2, 4, 8] # TYPES : [int, int32_t, float16_t, bfloat16_t] - bash run.sh -ranks ${RANKS} -type ${TYPES} \ No newline at end of file + bash run.sh -ranks ${RANKS} -type ${TYPES} + +跨机使用方式: +1.在shmem/目录编译: + bash scripts/build.sh + +2.在两台机器上shmem/examples/allgather目录中分别生成golden数据: + rm -rf ./golden + mkdir -p golden + python3 ./scripts/data_gen.py 8 "int" + +3. 在其中一台机器上shmem/examples/allgather执行(ip_host1和ip_host2为各自机器的ip地址, PROJECT_ROOT为shmem/目录) + mpirun -host ip_host1:4,ip_host2:4 -np 8 ${PROJECT_ROOT}/build/bin/allgather \ No newline at end of file diff --git a/examples/allgather/main.cpp b/examples/allgather/main.cpp index 3028b592581aa16ef8b5ded646a5cd937d7c132f..99dc8cbea72113379cf639c967fdd25ad94d7631 100644 --- a/examples/allgather/main.cpp +++ b/examples/allgather/main.cpp @@ -36,10 +36,10 @@ using bfloat16 = op::bfloat16; #include "allgather_kernel.h" int g_npus = 8; -const char *ipport; +const char *ipport = "tcp://127.0.0.1:8998"; int f_rank = 0; int f_npu = 0; -const char *data_type; +const char *data_type = "int"; constexpr int64_t SYNC_FLAG_INTERVAL = 16; constexpr int64_t UB_DMA_MAX_SIZE = 190 * 1024; @@ -49,22 +49,14 @@ constexpr uint32_t DATA_SIZE_THRESHOLD = 2097152; constexpr uint32_t BLOCK_NUM_SMALL_DATA = 8; constexpr uint32_t BLOCK_NUM_LARGE_DATA = 16; -template -int test_shmem_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size) +template +int test_shmem_all_gather(int rank_id, int n_ranks) { - // 初始化ACL和SHMEM - int32_t device_id = rank_id % g_npus + f_npu; + // ACLStream init int status = 0; aclrtStream stream = nullptr; - - status = aclInit(nullptr); - status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(attributes); - // Prepare FFTS address uint64_t fftsAddr = shmemx_get_ffts_config(); @@ -163,12 +155,8 @@ int test_shmem_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size) } outFile.close(); - - status = shmem_finalize(); status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); - return 0; + return status; } int main(int argc, char *argv[]) @@ -181,23 +169,33 @@ int main(int argc, char *argv[]) f_rank = atoi(argv[INDEX5]); f_npu = atoi(argv[INDEX6]); data_type = argv[INDEX7]; + + // Acl && Shmem init + int32_t device_id = rank_id % g_npus + f_npu; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + + shmem_init_attr_t *attributes; uint64_t local_mem_size = 1024UL * 1024UL * 1024; - int32_t ret = shmem_set_conf_store_tls(false, nullptr, 0); - std::cout << "init shmem tls result:" << ret << std::endl; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + if (std::string(data_type) == "int") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks); } else if (std::string(data_type) == "int32_t") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks); } else if (std::string(data_type) == "float16_t") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks); } else if (std::string(data_type) == "bfloat16_t") { - status = test_shmem_all_gather(rank_id, n_ranks, local_mem_size); + status = test_shmem_all_gather(rank_id, n_ranks); } + status = shmem_finalize(); + status = aclrtResetDevice(device_id); + status = aclFinalize(); if (status) { std::exit(EXIT_FAILURE); } std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; - return 0; } \ No newline at end of file diff --git a/examples/allgather/run.sh b/examples/allgather/run.sh index 678f79d127d6d60f7cc8d4cc01d6b616ea8f5d68..1d8bc8135acb8404d3b8f0c9a2b64d981ebae098 100644 --- a/examples/allgather/run.sh +++ b/examples/allgather/run.sh @@ -95,6 +95,7 @@ python3 ./scripts/data_gen.py $RANK_SIZE $TEST_TYPE # Kernel test rm -rf ./output +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/install/memfabric_hybrid/lib/:${ASCEND_HOME_PATH}/lib64:$LD_LIBRARY_PATH pids=() for (( idx =0; idx < ${GNPU_NUM}; idx = idx + 1 )); do diff --git a/examples/allgather/scripts/data_gen.py b/examples/allgather/scripts/data_gen.py index c88eeac262b4cf8efc28947bcdfc3204995d78ff..9642d2a2b029c4bebec6b571d394d1e818ec8103 100644 --- a/examples/allgather/scripts/data_gen.py +++ b/examples/allgather/scripts/data_gen.py @@ -12,6 +12,8 @@ import numpy as np from ml_dtypes import bfloat16 +# Set seed for multi-node situation +np.random.seed(42) def gen_random_data(size, dtype): return np.random.uniform(low=0.0, high=10.0, size=size).astype(dtype) diff --git a/examples/allgather_matmul/main.cpp b/examples/allgather_matmul/main.cpp index 2a6ae52648637e6132f97633b717aa184dffcdc9..c7302f9c79189e2feb9cfb26b5071933111606a9 100644 --- a/examples/allgather_matmul/main.cpp +++ b/examples/allgather_matmul/main.cpp @@ -232,48 +232,54 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; - std::string ipPort = options.ipPort; + + int n_ranks = options.rankSize; + int rank_id = options.rankId; + std::string ipport = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << - " input_ip: " << ipPort << std::endl; + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; + + // status = shmem_set_conf_store_tls(false, nullptr, 0); size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); - size_t cSize = static_cast(m) * rankSize * n * sizeof(__fp16); + size_t cSize = static_cast(m) * n_ranks * n * sizeof(__fp16); uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *cDevice; @@ -309,7 +315,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), cHost, cSize); - if (rankId == 0) { + if (rank_id == 0) { std::printf("test finished\n"); } @@ -322,11 +328,15 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(cDevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } \ No newline at end of file diff --git a/examples/allgather_matmul/scripts/run.sh b/examples/allgather_matmul/scripts/run.sh index 18d95cf52f8f7d560a6a3e1c612b660db9df285e..f461369227b4fc7e5f4c1e8b883d7ab524c26823 100644 --- a/examples/allgather_matmul/scripts/run.sh +++ b/examples/allgather_matmul/scripts/run.sh @@ -37,9 +37,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do python3 ${UTILS_PATH}/gen_data.py 1 1 ${RANK_SIZE} ${M} ${N} ${K} 0 0 ${DATA_DIR} # Set necessary parameters - IPPORT="tcp://127.0.0.1:8788" + IPPORT="tcp://127.0.0.1:8899" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & done diff --git a/examples/allgather_matmul_padding/main.cpp b/examples/allgather_matmul_padding/main.cpp index 6dd72dd75bf7c50435bce076e25266f5b86a4cbb..7a6207774bc597efd2460f76355ce45e8b702663 100644 --- a/examples/allgather_matmul_padding/main.cpp +++ b/examples/allgather_matmul_padding/main.cpp @@ -219,31 +219,35 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; + + int n_ranks = options.rankSize; + int rank_id = options.rankId; std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << - " rank_id:" << rankId << " input_ip: " << ipPort << std::endl; + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipPort.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; LayoutB layoutB{k, n}; constexpr uint32_t alignByByte = 512; @@ -261,20 +265,20 @@ int main(int argc, char **argv) size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); - size_t cSize = static_cast(m) * rankSize * n * sizeof(__fp16); + size_t cSize = static_cast(m) * n_ranks * n * sizeof(__fp16); uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *cDevice; @@ -318,7 +322,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), cHost, cSize); - if (rankId == 0) { + if (rank_id == 0) { std::printf("test finished\n"); } @@ -334,11 +338,15 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(workspaceDevice)); } - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } \ No newline at end of file diff --git a/examples/allgather_matmul_padding/scripts/run.sh b/examples/allgather_matmul_padding/scripts/run.sh index 4ea64b7055fdf11be888ab8a7acb1bd376f4a040..8a4f8f1aa256575ba75c084a84eb1d42ba16f658 100644 --- a/examples/allgather_matmul_padding/scripts/run.sh +++ b/examples/allgather_matmul_padding/scripts/run.sh @@ -37,9 +37,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do python3 ${UTILS_PATH}/gen_data.py 1 1 ${RANK_SIZE} ${M} ${N} ${K} 0 0 ${DATA_DIR} # Set necessary parameters - IPPORT="tcp://127.0.0.1:8788" + IPPORT="tcp://127.0.0.1:8899" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & done diff --git a/examples/allgather_matmul_with_gather_result/main.cpp b/examples/allgather_matmul_with_gather_result/main.cpp index ea5d177aebe372eafefb7fd228ba851209d10d68..62974c5c11b55ca17c341d0d4401dcacf7266e00 100644 --- a/examples/allgather_matmul_with_gather_result/main.cpp +++ b/examples/allgather_matmul_with_gather_result/main.cpp @@ -247,49 +247,54 @@ struct Options { int main(int argc, char **argv) { int status = SHMEM_SUCCESS; + + // Kernel-need params parse Options options; if (options.Parse(argc, argv) != 0) { std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; + + int n_ranks = options.rankSize; + int rank_id = options.rankId; std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << - " rank_id:" << rankId << " input_ip: " << ipPort << std::endl; + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipPort.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); + + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << std::endl; size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); - size_t cSize = static_cast(m) * rankSize * n * sizeof(__fp16); - size_t gatherASize = static_cast(m) * rankSize * k * sizeof(__fp16); + size_t cSize = static_cast(m) * n_ranks * n * sizeof(__fp16); + size_t gatherASize = static_cast(m) * n_ranks * k * sizeof(__fp16); uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *cDevice; @@ -318,7 +323,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtSynchronizeStream(stream)); std::cout << "After calling AG_MM kernel " << std::endl; - if (rankId == 0) { + if (rank_id == 0) { ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(gatherAHost, gatherASize, gatherADevice, gatherASize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), cHost, cSize); @@ -337,11 +342,15 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(cDevice)); ACL_CHECK(aclrtFree(gatherADevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } diff --git a/examples/allgather_matmul_with_gather_result/scripts/run.sh b/examples/allgather_matmul_with_gather_result/scripts/run.sh index 41f8b33d38667b88babcc36cf2f60665dca6c4fd..c2ed538c879ededa490b91c39e268d2e7b4b7fcb 100644 --- a/examples/allgather_matmul_with_gather_result/scripts/run.sh +++ b/examples/allgather_matmul_with_gather_result/scripts/run.sh @@ -37,9 +37,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do python3 ${UTILS_PATH}/gen_data.py 4 1 ${RANK_SIZE} ${M} ${N} ${K} 0 0 ${DATA_DIR} # Set necessary parameters - IPPORT="tcp://127.0.0.1:8788" + IPPORT="tcp://127.0.0.1:8899" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & done diff --git a/examples/dispatch_gmm_combine/main.cpp b/examples/dispatch_gmm_combine/main.cpp index 31b56eb49ecea4aea4bc21870eb8085290e0e750..6dbc67723ed7fa6ce839f703eb1af910fbe01116 100644 --- a/examples/dispatch_gmm_combine/main.cpp +++ b/examples/dispatch_gmm_combine/main.cpp @@ -292,25 +292,27 @@ void InitData(uint8_t **hostPtr, uint8_t **devicePtr, size_t aSize, std::string int main(int argc, char **argv) { int status = SHMEM_SUCCESS; - int rankSize = atoi(argv[1]); - int rankId = atoi(argv[2]); + int n_ranks = atoi(argv[1]); + int rank_id = atoi(argv[2]); std::string ipport = argv[3]; + // Acl && Shmem init ACL_CHECK(aclInit(nullptr)); - int32_t deviceId = atoi(argv[4]) + rankId % gNpuNum; + int32_t deviceId = atoi(argv[4]) + rank_id % gNpuNum; ACL_CHECK(aclrtSetDevice(deviceId)); aclrtStream stream = nullptr; ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); + + // status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, gNpuMallocSpace, ipport.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint32_t m = atoi(argv[5]); uint32_t k = atoi(argv[6]); uint32_t n = atoi(argv[7]); - uint32_t EP = rankSize; + uint32_t EP = n_ranks; uint32_t expertPerRank = atoi(argv[8]); uint32_t dataType = atoi(argv[9]); uint32_t weightNz = atoi(argv[10]); @@ -369,13 +371,13 @@ int main(int argc, char **argv) "_" + std::to_string(dataType) + "_1_" + std::to_string(m) + "_" + std::to_string(k) + "_" + std::to_string(n) + "_" + std::to_string(expertPerRank) + "_" + std::to_string(EP) + "_1.bin"; - InitData(&b1Host, &b1Device, b1Size, filePrefix + "matrix_b1_" + std::to_string(rankId) + fileSuffix); - InitData(&b2Host, &b2Device, b2Size, filePrefix + "matrix_b2_" + std::to_string(rankId) + fileSuffix); + InitData(&b1Host, &b1Device, b1Size, filePrefix + "matrix_b1_" + std::to_string(rank_id) + fileSuffix); + InitData(&b2Host, &b2Device, b2Size, filePrefix + "matrix_b2_" + std::to_string(rank_id) + fileSuffix); InitData(&cHost, &cDevice, cSize); InitData(&scale1Host, &scale1Device, dequantScale1Size, - filePrefix + "matrix_dequant_scale1_" + std::to_string(rankId) + fileSuffix); + filePrefix + "matrix_dequant_scale1_" + std::to_string(rank_id) + fileSuffix); InitData(&scale2Host, &scale2Device, dequantScale2Size, - filePrefix + "matrix_dequant_scale2_" + std::to_string(rankId) + fileSuffix); + filePrefix + "matrix_dequant_scale2_" + std::to_string(rank_id) + fileSuffix); InitData(&probsHost, &probsDevice, probsSize, filePrefix + "probs" + fileSuffix); uint8_t *expertIdx; @@ -395,9 +397,9 @@ int main(int argc, char **argv) int64_t quantMode = 1; std::string dispatchFileSuffix = ""; InitData(&aHost, &aDevice, m * k * sizeof(float16_t), - filePrefix + "matrix_a_" + std::to_string(rankId) + fileSuffix); + filePrefix + "matrix_a_" + std::to_string(rank_id) + fileSuffix); InitData(&expertIdxHost, &expertIdx, m * topK * sizeof(int32_t), - filePrefix + "expert_idx_" + std::to_string(rankId) + fileSuffix); + filePrefix + "expert_idx_" + std::to_string(rank_id) + fileSuffix); moeInitRoutingQuantV2Scale = nullptr; moeInitRoutingQuantV2Offset = nullptr; @@ -414,7 +416,7 @@ int main(int argc, char **argv) size_t initRoutingWorkspace = moeInitRoutingQuantV2TilingBase.workspaceSize_; workspaceSize += initRoutingWorkspace; printf("!!!!!!!!!! initRoutingQuantTilingKey %lu\n\n", initRoutingQuantTilingKey); - if (rankId == 0) { + if (rank_id == 0) { moeInitRoutingQuantV2TilingBase.ShowTilingData(); } @@ -455,8 +457,8 @@ int main(int argc, char **argv) ACL_CHECK(aclrtSynchronizeStream(stream)); ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile("./out/output_" + std::to_string(rankId) + ".bin", cHost, cSize); - if (rankId == 0) { + WriteFile("./out/output_" + std::to_string(rank_id) + ".bin", cHost, cSize); + if (rank_id == 0) { std::printf("\ntest finished\n"); } shmem_free(symmPtr); @@ -469,11 +471,15 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFreeHost(expertIdxHost)); ACL_CHECK(aclrtFree(expertIdx)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(deviceId); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } diff --git a/examples/dispatch_gmm_combine/scripts/run.sh b/examples/dispatch_gmm_combine/scripts/run.sh index b611efd31e91b4512dd864fbd23d13cce146ec97..d690241eeb79017edc87d72f296631f6a625fc88 100644 --- a/examples/dispatch_gmm_combine/scripts/run.sh +++ b/examples/dispatch_gmm_combine/scripts/run.sh @@ -121,6 +121,7 @@ if [[ $? -ne 0 ]]; then fi echo "Test Case, M: ${M}, K: ${K}, N: ${N}, expertPerRank: ${expertPerRank}" +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 export LD_LIBRARY_PATH=${PROJECT_ROOT}/install/shmem/lib:${ASCEND_HOME_PATH}/lib64:${PROJECT_ROOT}/install/memfabric_hybrid/lib:$LD_LIBRARY_PATH for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do export INPUT_PATH=${EXAMPLE_DIR}/utils/test_data/ diff --git a/examples/dispatch_gmm_combine/utils/gen_data.py b/examples/dispatch_gmm_combine/utils/gen_data.py index 38d628af52c0c3d465a0405487cce93689d3d771..140dd639bf0286fd710b10aaa51bb56654337f88 100755 --- a/examples/dispatch_gmm_combine/utils/gen_data.py +++ b/examples/dispatch_gmm_combine/utils/gen_data.py @@ -542,7 +542,6 @@ class MoeTestDate: matrix_a_block_list[src_ep].append(src_offset - src_offset_old) return matrix_a_i_list, matrix_a_block_list - @staticmethod def convert_nd_to_nz(self, coc_dtype_desc, input_tensor): split_tensors = torch.unbind(input_tensor, dim=0) split_tensors = [t.unsqueeze(0) for t in split_tensors] @@ -556,14 +555,12 @@ class MoeTestDate: output_tensor = torch.cat(processed_tensors, dim=0) return output_tensor - @staticmethod def swiglu(self, x: torch.Tensor) -> torch.Tensor: x0, gate = x.chunk(2, dim=-1) swish = x0 * torch.sigmoid(x0) y = swish * gate return y - @staticmethod def quant(self, x: torch.Tensor): x_row_max = torch.max(torch.abs(x), dim=-1).values quant_result = x * 127. / x_row_max[:, None] @@ -571,7 +568,6 @@ class MoeTestDate: scale = (x_row_max / 127.).to(torch.float32) return y, scale - @staticmethod def unpermute(self, permuted_tokens, origin_sorted_indices, probs): orgin_dtype = permuted_tokens.dtype permuted_tokens = permuted_tokens.to(torch.float).cpu() diff --git a/examples/dynamic_tiling/CMakeLists.txt b/examples/dynamic_tiling/CMakeLists.txt index e31dff88df45e4e560a21555236a10d27963a4af..4f50cf1b86edc7a831a5dd53d28b4579be0afd20 100644 --- a/examples/dynamic_tiling/CMakeLists.txt +++ b/examples/dynamic_tiling/CMakeLists.txt @@ -2,7 +2,7 @@ add_custom_target(lib_impl) function(add_impl_share_lib NAME) add_library(${NAME} SHARED ${ARGN}) - target_compile_options(${NAME} PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220) + target_compile_options(${NAME} PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} ${MPI_CXX_COMPILE_FLAGS} --cce-aicore-arch=dav-c220) target_include_directories(${NAME} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/memfabric_hybrid/src/smem/include/host/ @@ -50,6 +50,6 @@ target_include_directories(dynamic_tiling PRIVATE ) target_link_options(dynamic_tiling PRIVATE --cce-fatobj-link) target_link_libraries(dynamic_tiling PRIVATE tiling_lib shmem ${SHARE_LIB_LINK}) -target_compile_options(dynamic_tiling PRIVATE -O3) +target_compile_options(dynamic_tiling PRIVATE -O3 ${MPI_CXX_COMPILE_FLAGS}) add_dependencies(dynamic_tiling lib_impl tiling_lib) diff --git a/examples/dynamic_tiling/main.cpp b/examples/dynamic_tiling/main.cpp index 981bc0bdc91a199bf199c16a5481a25bd9681ee9..5f380fed1fba6fc6cfa162c07cd66bf1a1e37a99 100644 --- a/examples/dynamic_tiling/main.cpp +++ b/examples/dynamic_tiling/main.cpp @@ -200,27 +200,27 @@ int main(int argc, char **argv) options.Parse(argc, argv); CocCommType commType = options.commType; CocDataType dataType = options.dataType; - int rankSize = options.rankSize; - int rankId = options.rankId; + int n_ranks = options.rankSize; + int rank_id = options.rankId; std::string ipPort = options.ipPort; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t deviceId = options.deviceIdList[rank_id]; std::string data_file = options.data_file; if (data_file.empty()) { return -1; } const std::vector> shapes = InitTestShapes(options); - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id: " << rankId << " input_ip: " << ipPort << "\n"; + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id: " << rank_id << " input_ip: " << ipPort << "\n"; aclrtStream stream = nullptr; ACL_CHECK(aclInit(nullptr)); ACL_CHECK(aclrtSetDevice(deviceId)); ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); + // status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, SHMEM_MALLOC_MAX_SIZE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipPort.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint64_t fftsAddr{0}; uint32_t fftsLen{0}; @@ -229,7 +229,7 @@ int main(int argc, char **argv) std::string currentTime = GetCurrentTime(); std::string currentDir = options.parentPath; std::string tilingFileName = currentDir + "/output/tiling/tilingData_" + currentTime + ".csv"; - if (rankId == 0) { + if (rank_id == 0) { CreateTilingFile(tilingFileName); } @@ -254,19 +254,19 @@ int main(int argc, char **argv) cocTiling.commNpuSplit = COMM_NPU_SPLIT; cocTiling.commDataSplit = COMM_DATA_SPLIT; cocTiling.commBlockM = COMM_BLOCK_M; - cocTiling.rankSize = rankSize; + cocTiling.rankSize = n_ranks; size_t aSize = static_cast(m) * k * sizeof(half); size_t bSize = static_cast(k) * n * sizeof(half); size_t cSize = static_cast(m) * n * sizeof(half); size_t cSizePerRank; - size_t gatherASize = aSize * rankSize; + size_t gatherASize = aSize * n_ranks; size_t wASize = 0; size_t wBSize = 0; if (commType == MATMUL_REDUCE_SCATTER) { - cSizePerRank = cSize / rankSize; + cSizePerRank = cSize / n_ranks; } else if (commType == MATMUL_REDUCE_SCATTER_PADDING) { - cSizePerRank = cSize / rankSize; + cSizePerRank = cSize / n_ranks; bool isNeedPaddingA = IsNeedPadding(m, k, transA); bool isNeedPaddingB = IsNeedPadding(k, n, transB); @@ -284,9 +284,9 @@ int main(int argc, char **argv) kernelType = MATMUL_REDUCE_SCATTER; } } else if (commType == ALLGATHER_MATMUL || commType == ALLGATHER_MATMUL_WITH_GATHER_RESULT) { - cSizePerRank = cSize * rankSize; + cSizePerRank = cSize * n_ranks; } else if (commType == ALLGATHER_MATMUL_PADDING) { - cSizePerRank = cSize * rankSize; + cSizePerRank = cSize * n_ranks; bool isNeedPaddingB = IsNeedPadding(k, n, transB); if (isNeedPaddingB) { @@ -306,7 +306,7 @@ int main(int argc, char **argv) uint8_t *aHost; if (data_file != "") { ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(data_file + "/rank_" + std::to_string(rankId) + "_a.bin", aHost, aSize); + ReadFile(data_file + "/rank_" + std::to_string(rank_id) + "_a.bin", aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); } else { std::vector matrixA(m * k, 1); @@ -318,7 +318,7 @@ int main(int argc, char **argv) uint8_t *bHost; if (data_file != "") { ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(data_file + "/rank_" + std::to_string(rankId) + "_b.bin", bHost, bSize); + ReadFile(data_file + "/rank_" + std::to_string(rank_id) + "_b.bin", bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); } else { std::vector matrixB(k * n, 1); @@ -368,7 +368,7 @@ int main(int argc, char **argv) } else { if (searchparams == 1) { // 搜索 tiling - GetTilings(cocTilings, cocTiling, commType, rankSize); + GetTilings(cocTilings, cocTiling, commType, n_ranks); } else { ApplyLookupTable(info, commType, rankSize, cocTiling); cocTilings.push_back(cocTiling); @@ -404,22 +404,22 @@ int main(int argc, char **argv) } if (commType == MATMUL_ALLREDUCE) { - if (rankId == 0) { + if (rank_id == 0) { WriteFile(data_file + "/output.bin", cHost, cSizePerRank); } } else if (commType == ALLGATHER_MATMUL || commType == ALLGATHER_MATMUL_PADDING || commType == ALLGATHER_MATMUL_WITH_GATHER_RESULT) { - if (rankId == 0) { + if (rank_id == 0) { WriteFile(data_file + "/output.bin", cHost, cSizePerRank); if (commType == ALLGATHER_MATMUL_WITH_GATHER_RESULT) { WriteFile(data_file + "/output_gather_a.bin", gatherAHost, gatherASize); } } } else if (commType == MATMUL_REDUCE_SCATTER || commType == MATMUL_REDUCE_SCATTER_PADDING) { - WriteFile(data_file + "/output.bin", cHost, cSizePerRank, rankId * cSizePerRank); + WriteFile(data_file + "/output.bin", cHost, cSizePerRank, rank_id * cSizePerRank); } - if (rankId == 0) { + if (rank_id == 0) { WriteTilingInfos(opName, cocTilings, tilingFileName, transA, transB); std::printf("M: %d, K: %d, N: %d aclrtSynchronizeStream success!\n", cocTiling.m, cocTiling.k, cocTiling.n); } @@ -441,11 +441,16 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(cDevice)); } - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(deviceId); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } \ No newline at end of file diff --git a/examples/dynamic_tiling/scripts/run.sh b/examples/dynamic_tiling/scripts/run.sh index 78ae948ead799cdc9e711b7c7870703dc3db1e95..66577e95efdc2f9456e781e780c49874a60b3d78 100644 --- a/examples/dynamic_tiling/scripts/run.sh +++ b/examples/dynamic_tiling/scripts/run.sh @@ -79,9 +79,10 @@ if [ "$TEST_TYPE" = "0" ]; then python3 ${UTILS_PATH}/gen_data.py ${COMM_TYPE} ${DATA_TYPE} ${RANK_SIZE} ${M} ${N} ${K} ${TA} ${TB} ${DATA_PATH} # Set necessary parameters - IPPORT="tcp://127.0.0.1:27008" + IPPORT="tcp://127.0.0.1:8899" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR $DATA_PATH" ${APP}& @@ -115,11 +116,12 @@ else echo "Processing test case: M=${M}, K=${K}, N=${N}, TransA=${TA}, TransB=${TB}" # Set necessary parameters - IPPORT="tcp://127.0.0.1:27009" + IPPORT="tcp://127.0.0.1:8899" OUTPUT_PATH="./output/msprof/start_line${IDX}_run_rows${TEST_COLLECT_ROWS}/" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do APP="$EXEC_BIN $COMM_TYPE $DATA_TYPE $RANK_SIZE $idx $IPPORT $M $N $K $TEST_START_LINE $TEST_COLLECT_ROWS $PARENT_PATH $CSV_FILE $DEVICE_ID_STR" msprof --application="${APP}" --output="${OUTPUT_PATH}"& diff --git a/examples/kv_shuffle/main.cpp b/examples/kv_shuffle/main.cpp index 50e85b757aed2fb1425e581e797c269a287d2f4b..267c904d7e2ee8eea0ed612e263ef0295fcc9d7a 100644 --- a/examples/kv_shuffle/main.cpp +++ b/examples/kv_shuffle/main.cpp @@ -47,21 +47,13 @@ constexpr int64_t max_block_nums = MAX_SEQLEN * MAX_BATCH / page_size; constexpr int64_t kv_head_num = 8; constexpr int64_t head_dim = 128; -int test_shmem_kv_shuffle(int rank_id, int n_ranks, uint64_t local_mem_size) +int test_shmem_kv_shuffle(int rank_id, int n_ranks) { - // 初始化ACL和SHMEM - int32_t device_id = rank_id % g_npus + f_npu; + // ACLStream init int status = 0; aclrtStream stream = nullptr; - - status = aclInit(nullptr); - status = aclrtSetDevice(device_id); status = aclrtCreateStream(&stream); - shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); - status = shmem_init_attr(attributes); - uint32_t BLOCK_NUM = 16; int64_t kv_cache_size = max_block_nums * kv_head_num * page_size * head_dim * sizeof(int8_t); @@ -185,11 +177,8 @@ int test_shmem_kv_shuffle(int rank_id, int n_ranks, uint64_t local_mem_size) status = aclrtFreeHost(k_output_host); status = aclrtFreeHost(v_output_host); - status = shmem_finalize(); status = aclrtDestroyStream(stream); - status = aclrtResetDevice(device_id); - status = aclFinalize(); - return 0; + return status; } int main(int argc, char *argv[]) @@ -198,12 +187,27 @@ int main(int argc, char *argv[]) int n_ranks = atoi(argv[INDEX1]); int rank_id = atoi(argv[INDEX2]); ipport = argv[INDEX3]; + // int32_t ret = shmem_set_conf_store_tls(false, nullptr, 0); + + // Acl && Shmem init + int32_t device_id = rank_id % g_npus + f_npu; + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); + + shmem_init_attr_t *attributes; uint64_t local_mem_size = 1024UL * 1024UL * 1024; - int32_t ret = shmem_set_conf_store_tls(false, nullptr, 0); + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); - status = test_shmem_kv_shuffle(rank_id, n_ranks, local_mem_size); + status = test_shmem_kv_shuffle(rank_id, n_ranks); - std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; + status = shmem_finalize(); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } diff --git a/examples/kv_shuffle/scripts/run.sh b/examples/kv_shuffle/scripts/run.sh index 12e8f6df4339bcf21e677574681dfc3fd36ea08a..8f6d40dd97f5c76c82d28fb69582068b40d9f8ef 100644 --- a/examples/kv_shuffle/scripts/run.sh +++ b/examples/kv_shuffle/scripts/run.sh @@ -23,6 +23,7 @@ rm -rf scripts/output/*.bin python3 scripts/golden.py $RANK_SIZE # Start Process +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx =0; idx < ${RANK_SIZE}; idx = idx + 1 )); do APP="$EXEC_BIN $RANK_SIZE $idx $IPPORT" ${APP}& diff --git a/examples/matmul_allreduce/main.cpp b/examples/matmul_allreduce/main.cpp index f257d0c1fe15659bed8d01bd4eaef278e79c2f54..60ce7c7c3c68cb13ba487d140265aff83d0ecad4 100644 --- a/examples/matmul_allreduce/main.cpp +++ b/examples/matmul_allreduce/main.cpp @@ -236,25 +236,28 @@ int main(int argc, char **argv) std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; + int n_ranks = options.rankSize; + int rank_id = options.rankId; std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << "\n"; + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << " input_ip: " << ipPort << "\n"; + + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, NPU_MALLOC_SPACE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipPort.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); @@ -264,14 +267,14 @@ int main(int argc, char **argv) ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *dDevice; @@ -297,7 +300,7 @@ int main(int argc, char **argv) ACL_CHECK(aclrtSynchronizeStream(stream)); std::cout << "After calling MM_AR kernel " << std::endl; - if (rankId == 0) { + if (rank_id == 0) { ACL_CHECK(aclrtMemcpy(dHost, dSize, dDevice, dSize, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSize); std::printf("test finished\n"); @@ -312,11 +315,15 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(dDevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } diff --git a/examples/matmul_allreduce/scripts/run.sh b/examples/matmul_allreduce/scripts/run.sh index aecb3139b9d29c3c42bc624ddda185fae1ff8fe4..bf06e19d46c07c4eafd5a11fc6b749bcdbba6f76 100644 --- a/examples/matmul_allreduce/scripts/run.sh +++ b/examples/matmul_allreduce/scripts/run.sh @@ -36,9 +36,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do python3 ${UTILS_PATH}/gen_data.py 0 1 ${RANK_SIZE} ${M} ${N} ${K} 0 0 ${DATA_DIR} # Set necessary parameters - IPPORT="tcp://127.0.0.1:8788" + IPPORT="tcp://127.0.0.1:8899" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & done diff --git a/examples/matmul_reduce_scatter/main.cpp b/examples/matmul_reduce_scatter/main.cpp index ba170fbe318cfcf91f54d04f8705ddf7b5a8f32d..7d8b3108c468320a9cc02b8b0507cc087e3e3683 100644 --- a/examples/matmul_reduce_scatter/main.cpp +++ b/examples/matmul_reduce_scatter/main.cpp @@ -223,43 +223,46 @@ int main(int argc, char **argv) std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; + int n_ranks = options.rankSize; + int rank_id = options.rankId; std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << "\n"; + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << " input_ip: " << ipPort << "\n"; + + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, NPU_MALLOC_SPACE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipPort.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); size_t dSize = static_cast(m) * n * sizeof(__fp16); - size_t dSizeScatter = dSize / options.rankSize; + size_t dSizeScatter = dSize / n_ranks; uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *dDevice; @@ -284,8 +287,8 @@ int main(int argc, char **argv) std::cout << "After calling MM_RS kernel " << std::endl; ACL_CHECK(aclrtMemcpy(dHost, dSizeScatter, dDevice, dSizeScatter, ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rankId * dSizeScatter); - if (rankId == 0) { + WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rank_id * dSizeScatter); + if (rank_id == 0) { std::printf("test finished\n"); } @@ -298,11 +301,15 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(bDevice)); ACL_CHECK(aclrtFree(dDevice)); - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } diff --git a/examples/matmul_reduce_scatter/scripts/run.sh b/examples/matmul_reduce_scatter/scripts/run.sh index 808f1cf1ec7826bcd9cd01406067b3eef1b9eb3e..7fe5f6ebcf515a88cf373f3f72b30ef7407ff857 100644 --- a/examples/matmul_reduce_scatter/scripts/run.sh +++ b/examples/matmul_reduce_scatter/scripts/run.sh @@ -36,9 +36,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do python3 ${UTILS_PATH}/gen_data.py 2 1 ${RANK_SIZE} ${M} ${N} ${K} 0 0 ${DATA_DIR} # Set necessary parameters - IPPORT="tcp://127.0.0.1:8788" + IPPORT="tcp://127.0.0.1:8899" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & done diff --git a/examples/matmul_reduce_scatter_padding/main.cpp b/examples/matmul_reduce_scatter_padding/main.cpp index d975ce5f886818cf9937db80ca9198eb7a31d6c5..512a9dad67546c66390fb9eac87ed6d6ce752ba2 100644 --- a/examples/matmul_reduce_scatter_padding/main.cpp +++ b/examples/matmul_reduce_scatter_padding/main.cpp @@ -235,25 +235,28 @@ int main(int argc, char **argv) std::cerr << "Invalid arguments\n"; return 1; } - int rankSize = options.rankSize; - int rankId = options.rankId; + int n_ranks = options.rankSize; + int rank_id = options.rankId; std::string ipPort = options.ipPort; uint32_t m = options.m; uint32_t n = options.n; uint32_t k = options.k; - int32_t deviceId = options.deviceIdList[rankId]; + int32_t device_id = options.deviceIdList[rank_id]; - std::cout << "[TEST] input rank_size: " << rankSize << " rank_id:" << rankId << " input_ip: " << ipPort << "\n"; + std::cout << "[TEST] input rank_size: " << n_ranks << " rank_id:" << rank_id << " input_ip: " << ipPort << "\n"; + + // Acl && Shmem init + status = aclInit(nullptr); + status = aclrtSetDevice(device_id); - aclrtStream stream = nullptr; - ACL_CHECK(aclInit(nullptr)); - ACL_CHECK(aclrtSetDevice(deviceId)); - ACL_CHECK(aclrtCreateStream(&stream)); - status = shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; - status = shmem_set_attr(rankId, rankSize, NPU_MALLOC_SPACE, ipPort.c_str(), &attributes); - status = shmem_init_attr(attributes); - status = shmem_init_status(); + uint64_t local_mem_size = 1024UL * 1024UL * 1024; + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipPort.c_str(), &attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + // ACLStream init + aclrtStream stream = nullptr; + status = aclrtCreateStream(&stream); LayoutA layoutA{m, k}; LayoutB layoutB{k, n}; @@ -271,20 +274,20 @@ int main(int argc, char **argv) size_t aSize = static_cast(m) * k * sizeof(__fp16); size_t bSize = static_cast(k) * n * sizeof(__fp16); size_t dSize = static_cast(m) * n * sizeof(__fp16); - size_t dSizeScatter = dSize / options.rankSize; + size_t dSizeScatter = dSize / n_ranks; uint8_t *aDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&aDevice), aSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *aHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&aHost), aSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_a.bin"), aHost, aSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_a.bin"), aHost, aSize); ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *bDevice; ACL_CHECK(aclrtMalloc(reinterpret_cast(&bDevice), bSize, ACL_MEM_MALLOC_HUGE_FIRST)); uint8_t *bHost; ACL_CHECK(aclrtMallocHost(reinterpret_cast(&bHost), bSize)); - ReadFile(options.GetDataPath("rank_" + std::to_string(rankId) + "_b.bin"), bHost, bSize); + ReadFile(options.GetDataPath("rank_" + std::to_string(rank_id) + "_b.bin"), bHost, bSize); ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); uint8_t *dDevice; @@ -349,8 +352,8 @@ int main(int argc, char **argv) std::cout << "After calling MM_RS kernel " << std::endl; ACL_CHECK(aclrtMemcpy(dHost, dSizeScatter, dDevice, dSizeScatter, ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rankId * dSizeScatter); - if (rankId == 0) { + WriteFile(options.GetDataPath("shmem_output.bin"), dHost, dSizeScatter, rank_id * dSizeScatter); + if (rank_id == 0) { std::printf("test finished\n"); } @@ -369,11 +372,15 @@ int main(int argc, char **argv) ACL_CHECK(aclrtFree(wbDevice)); } - std::cout << "[TEST] begin to exit...... rankId: " << rankId << std::endl; + status = aclrtDestroyStream(stream); + status = shmem_finalize(); - ACL_CHECK(aclrtDestroyStream(stream)); - ACL_CHECK(aclrtResetDevice(deviceId)); - ACL_CHECK(aclFinalize()); + status = aclrtResetDevice(device_id); + status = aclFinalize(); + if (status) { + std::exit(EXIT_FAILURE); + } + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } diff --git a/examples/matmul_reduce_scatter_padding/scripts/run.sh b/examples/matmul_reduce_scatter_padding/scripts/run.sh index b2743dbf8a7af18d65b2c5715e7c0b7cc7619c75..04d85b3d4d63edfc1a4ff4a866cf198c6ef47c92 100644 --- a/examples/matmul_reduce_scatter_padding/scripts/run.sh +++ b/examples/matmul_reduce_scatter_padding/scripts/run.sh @@ -36,9 +36,10 @@ tail -n +2 "$CSV_FILE" | while IFS=',' read -r M K N; do python3 ${UTILS_PATH}/gen_data.py 2 1 ${RANK_SIZE} ${M} ${N} ${K} 0 0 ${DATA_DIR} # Set necessary parameters - IPPORT="tcp://127.0.0.1:8788" + IPPORT="tcp://127.0.0.1:8899" # Start Process + export SHMEM_UID_SESSION_ID=127.0.0.1:8899 for (( idx = 0; idx < ${RANK_SIZE}; idx = idx + 1 )); do ${EXEC_BIN} "$RANK_SIZE" "$idx" "$IPPORT" "$M" "$N" "$K" "${DATA_DIR}" "$1" & done diff --git a/examples/rdma_demo/README.md b/examples/rdma_demo/README.md index b986fd66bbbd05529e6c8998cc3b51958e0f5061..8b59b171a8fbff15af72410e6380c6dccac72039 100644 --- a/examples/rdma_demo/README.md +++ b/examples/rdma_demo/README.md @@ -7,15 +7,15 @@ bash scripts/build.sh ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 ./build/bin/rdma_demo 2 0 tcp://127.0.0.1:8765 2 0 0 & # rank 0 ./build/bin/rdma_demo 2 1 tcp://127.0.0.1:8765 2 0 0 & # rank 1 ``` 3.命令行参数说明 - ./rdma_demo + ./rdma_demo -- n_ranks: 全局Rank数量,只支持2个Rank。 -- rank_id: 当前进程的Rank号。 +- n_ranks: 全局Rank数量。 - ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。如果执行跨机测试,需要讲IP设为rank0所在Host的IP。 - g_npus: 当前卡上启动的NPU数量。 - f_rank: 当前卡上使用的第一个Rank号。 diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index d923ac8308735d7a4c5444515e5d255de025ad2e..57fc9424527c0ef8c626f13d682cd9e0dbbf7d57 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -36,10 +36,9 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size status |= aclrtCreateStream(&stream); shmem_init_attr_t *attributes; - status |= shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status |= shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint8_t *ptr = static_cast(shmem_malloc(1024)); @@ -59,6 +58,7 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size handle.team_id = SHMEM_TEAM_WORLD; shmem_handle_wait(handle, stream); status |= aclrtSynchronizeStream(stream); + shmemi_control_barrier_all(); // 结果校验打印 int32_t *y_host; @@ -99,7 +99,7 @@ int main(int argc, char *argv[]) f_npu = atoi(argv[argIdx++]); uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); - std::cout << "demo run finished in rank " << rank_id << " with status " << status << std::endl; + std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; return 0; } \ No newline at end of file diff --git a/examples/rdma_demo/run.sh b/examples/rdma_demo/run.sh index 8089b34702cdff51ac194a8c296374673698a880..652ade38a4dd660544cdcf2f0b76c2af5ccb4369 100644 --- a/examples/rdma_demo/run.sh +++ b/examples/rdma_demo/run.sh @@ -4,13 +4,15 @@ script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" project_root="$(cd ${script_dir}/../../ && pwd)" export PROJECT_ROOT=${project_root} export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 cd ${PROJECT_ROOT} pids=() -./build/bin/rdma_demo 2 0 tcp://127.0.0.1:8765 2 0 0 & # rank 0 +./build/bin/rdma_demo 2 0 tcp://127.0.0.1:8899 2 0 0 & # rank 0 pid=$! pids+=("$pid") -./build/bin/rdma_demo 2 1 tcp://127.0.0.1:8765 2 0 0 & # rank 1 +./build/bin/rdma_demo 2 1 tcp://127.0.0.1:8899 2 0 0 & # rank 1 pid=$! pids+=("$pid") @@ -23,4 +25,4 @@ for pid in ${pids[@]}; do fi echo "wait $pid finished" done -exit $ret \ No newline at end of file +exit $ret diff --git a/examples/rdma_handlewait_test/unuse_handlewait/README.md b/examples/rdma_handlewait_test/unuse_handlewait/README.md index 065cb99600800a58435cedcca0bdda6487745756..63e83f3157b5836d518b1cc6a5f8b7f27dedf176 100644 --- a/examples/rdma_handlewait_test/unuse_handlewait/README.md +++ b/examples/rdma_handlewait_test/unuse_handlewait/README.md @@ -7,8 +7,9 @@ bash scripts/build.sh -examples ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH -./build/bin/unuse_handlewait 2 0 tcp://127.0.0.1:8765 2 0 0 # rank 0 -./build/bin/unuse_handlewait 2 1 tcp://127.0.0.1:8765 2 0 0 # rank 1 +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +./build/bin/unuse_handlewait 2 0 tcp://127.0.0.1:8899 2 0 0 # rank 0 +./build/bin/unuse_handlewait 2 1 tcp://127.0.0.1:8899 2 0 0 # rank 1 ``` 3.命令行参数说明 diff --git a/examples/rdma_handlewait_test/unuse_handlewait/main.cpp b/examples/rdma_handlewait_test/unuse_handlewait/main.cpp index 1fd3248e701b7853557744175bd72f9407c6e122..517380a56036e3d6f08fbf59560ef2e859c0bd9d 100644 --- a/examples/rdma_handlewait_test/unuse_handlewait/main.cpp +++ b/examples/rdma_handlewait_test/unuse_handlewait/main.cpp @@ -45,8 +45,7 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint8_t *ptr = static_cast(shmem_malloc(mem_size)); uint8_t *ptr_A = ptr + half_mem_size; diff --git a/examples/rdma_handlewait_test/unuse_handlewait/run.sh b/examples/rdma_handlewait_test/unuse_handlewait/run.sh index a8ef5a88e113511b1356092b8060e5df8d72ae69..8bb1738b03f9649d31c29b080711b1a132a7d11b 100644 --- a/examples/rdma_handlewait_test/unuse_handlewait/run.sh +++ b/examples/rdma_handlewait_test/unuse_handlewait/run.sh @@ -4,14 +4,16 @@ script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" project_root="$(cd ${script_dir}/../../../ && pwd)" export PROJECT_ROOT=${project_root} export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 cd $project_root pids=() -./build/bin/unuse_handlewait 2 0 tcp://127.0.0.1:8765 2 0 0 & # rank 0 +./build/bin/unuse_handlewait 2 0 tcp://127.0.0.1:8899 2 0 0 & # rank 0 pid=$! pids+=("$pid") -./build/bin/unuse_handlewait 2 1 tcp://127.0.0.1:8765 2 0 0 & # rank 1 +./build/bin/unuse_handlewait 2 1 tcp://127.0.0.1:8899 2 0 0 & # rank 1 pid=$! pids+=("$pid") @@ -24,4 +26,4 @@ for pid in ${pids[@]}; do fi echo "wait $pid finished" done -exit $ret \ No newline at end of file +exit $ret diff --git a/examples/rdma_handlewait_test/use_handlewait/README.md b/examples/rdma_handlewait_test/use_handlewait/README.md index e0bdcd8c47ae25106fdfbb10a1d2c5262adda7d1..1263a91b4b037301f7cbe54cf1758698bdc2cef6 100644 --- a/examples/rdma_handlewait_test/use_handlewait/README.md +++ b/examples/rdma_handlewait_test/use_handlewait/README.md @@ -7,6 +7,7 @@ bash scripts/build.sh -examples ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 ./build/bin/use_handlewait 2 0 tcp://127.0.0.1:8765 2 0 0 # rank 0 ./build/bin/use_handlewait 2 1 tcp://127.0.0.1:8765 2 0 0 # rank 1 ``` diff --git a/examples/rdma_handlewait_test/use_handlewait/main.cpp b/examples/rdma_handlewait_test/use_handlewait/main.cpp index 9109b99647d23cef8a89712025ebaaf647f06ca2..6e859d7d605fc9e80044fec05a1cdf2c84bf9a4f 100644 --- a/examples/rdma_handlewait_test/use_handlewait/main.cpp +++ b/examples/rdma_handlewait_test/use_handlewait/main.cpp @@ -45,8 +45,7 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size shmem_init_attr_t *attributes; status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint8_t *ptr = static_cast(shmem_malloc(mem_size)); uint8_t *ptr_A = ptr + half_mem_size; diff --git a/examples/rdma_handlewait_test/use_handlewait/run.sh b/examples/rdma_handlewait_test/use_handlewait/run.sh index 7f802d4d87f503b4dce4c06c406003e393b12efd..e80b85e0e46591b6d5b92e7f39a78d53b1dd982b 100644 --- a/examples/rdma_handlewait_test/use_handlewait/run.sh +++ b/examples/rdma_handlewait_test/use_handlewait/run.sh @@ -4,13 +4,15 @@ script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" project_root="$(cd ${script_dir}/../../../ && pwd)" export PROJECT_ROOT=${project_root} export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 cd $PROJECT_ROOT pids=() -./build/bin/use_handlewait 2 0 tcp://127.0.0.1:8765 2 0 0 & # rank 0 +./build/bin/use_handlewait 2 0 tcp://127.0.0.1:8899 2 0 0 & # rank 0 pid=$! pids+=("$pid") -./build/bin/use_handlewait 2 1 tcp://127.0.0.1:8765 2 0 0 & # rank 1 +./build/bin/use_handlewait 2 1 tcp://127.0.0.1:8899 2 0 0 & # rank 1 pid=$! pids+=("$pid") @@ -23,4 +25,4 @@ for pid in ${pids[@]}; do fi echo "wait $pid finished" done -exit $ret \ No newline at end of file +exit $ret diff --git a/examples/rdma_perftest/README.md b/examples/rdma_perftest/README.md index 1c73d65e46ebc43faf2227f8d410f60d8fa01038..98cff8fe57158602c10327660c7417b1abf19e4f 100644 --- a/examples/rdma_perftest/README.md +++ b/examples/rdma_perftest/README.md @@ -7,15 +7,14 @@ bash scripts/build.sh ```bash export PROJECT_ROOT= export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/smem/lib64:${PROJECT_ROOT}/3rdparty/memfabric_hybrid/output/hybm/lib64:$LD_LIBRARY_PATH -./build/bin/rdma_perftest 2 0 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 # rank 0 -./build/bin/rdma_perftest 2 1 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 # rank 1 +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 +./build/bin/rdma_perftest 2 0 tcp://127.0.0.1:8899 2 0 0 highlevel_put_pingpong_latency 64 # rank 0 +./build/bin/rdma_perftest 2 1 tcp://127.0.0.1:8899 2 0 0 highlevel_put_pingpong_latency 64 # rank 1 ``` 3.命令行参数说明 - ./rdma_perftest + ./rdma_perftest -- n_ranks: 全局Rank数量,只支持2个Rank。 -- rank_id: 当前进程的Rank号。 - ipport: SHMEM初始化需要的IP及端口号,格式为tcp://:<端口号>。如果执行跨机测试,需要讲IP设为rank0所在Host的IP。 - g_npus: 当前卡上启动的NPU数量。 - f_rank: 当前卡上使用的第一个Rank号。 diff --git a/examples/rdma_perftest/main.cpp b/examples/rdma_perftest/main.cpp index a36d0b652491a9728935607f6948852bac87d952..5328ee85a386516b55f68da61303081a4f2997bb 100644 --- a/examples/rdma_perftest/main.cpp +++ b/examples/rdma_perftest/main.cpp @@ -31,7 +31,7 @@ extern void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fft extern void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t cfg, uint8_t* gva, int len); extern void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t cfg, uint8_t* gva, int len, int64_t iter); -int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uint64_t mem_size, int message_length) +int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uint64_t local_mem_size, int message_length) { uint32_t iteration = 1; int32_t device_id = rank_id % g_npus + f_npu; @@ -47,10 +47,10 @@ int test_shmem_rdma_highlevel_put_pingpong_latency(int rank_id, int n_ranks, uin status = aclrtCreateStream(&stream); shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, n_ranks, mem_size, ipport, &attributes); + status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t *gva = static_cast(shmem_malloc(size6M)); @@ -103,7 +103,7 @@ int test_shmem_rdma_postsend_cost(int rank_id, int n_ranks, uint64_t local_mem_s status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t *gva = static_cast(shmem_malloc(size6M)); @@ -153,7 +153,7 @@ int test_shmem_rdma_highlevel_put_bw(int rank_id, int n_ranks, uint64_t local_me status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t *gva = static_cast(shmem_malloc(size6M)); @@ -200,8 +200,9 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size status = shmem_set_attr(rank_id, n_ranks, local_mem_size, ipport, &attributes); attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - shmem_mte_set_ub_params(0, size128K, 0); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + shmem_mte_set_ub_params(0, 128 * 1024, 0); uint64_t fftsConfig = shmemx_get_ffts_config(); uint8_t *gva = static_cast(shmem_malloc(size32M)); @@ -229,7 +230,7 @@ int test_shmem_rdma_mte_put_bw(int rank_id, int n_ranks, uint64_t local_mem_size inHost[i + (rank_id + n_ranks) * message_length / sizeof(int64_t)] = rank_id + iterRange + iter; } aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); rdma_mte_put_bw_do(1, stream, fftsConfig, gva, message_length, iter); aclrtSynchronizeStream(stream); if (rank_id == 0 && iter >= iterRange) { @@ -291,6 +292,5 @@ int main(int argc, char *argv[]) } std::cout << "[SUCCESS] demo run success in rank " << rank_id << std::endl; - return 0; } \ No newline at end of file diff --git a/examples/rdma_perftest/rdma_perftest_kernel.cpp b/examples/rdma_perftest/rdma_perftest_kernel.cpp index 3df0df732258053dba3fd030be9972ecb06d4f35..a1eefdf090e16f74b0580e3b8be3e7b91e41d44d 100644 --- a/examples/rdma_perftest/rdma_perftest_kernel.cpp +++ b/examples/rdma_perftest/rdma_perftest_kernel.cpp @@ -12,10 +12,10 @@ #include "shmem_api.h" constexpr uint32_t MAGIC_VAL = 10; -constexpr uint32_t WARMUP_MSG_LEN = 32; +constexpr uint32_t WARMUP_MESSAGE_LENGTH = 32; -extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64_t cfg, GM_ADDR gva, int msg_len) { - shmemx_set_ffts_config(cfg); +extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64_t fftsConfig, GM_ADDR gva, int message_length) { + shmemx_set_ffts_config(fftsConfig); if (AscendC::GetSubBlockIdx() != 0) { return; } @@ -24,57 +24,56 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_pingpong_latency(uint64 pipe.InitBuffer(buf, UB_ALIGN_SIZE); AscendC::LocalTensor ubLocalRead = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Warm up - GM_ADDR warm_addr = gva + rank_size * msg_len + WARMUP_MSG_LEN * (rank + 1); + GM_ADDR warm_addr = gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (rank + 1); if (rank == 0) { peer = 1; - shmem_put_uint8_mem_nbi(warm_addr, warm_addr, WARMUP_MSG_LEN, peer); - while (*(__gm__ uint32_t*)(gva + rank_size * msg_len + WARMUP_MSG_LEN * (peer + 1)) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + rank_size * msg_len + WARMUP_MSG_LEN * (peer + 1), sizeof(uint32_t)); + shmem_put_uint8_mem_nbi(warm_addr, warm_addr, WARMUP_MESSAGE_LENGTH, peer); + while (*(__gm__ uint32_t*)(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1)) != peer + MAGIC_VAL) { + dcci_cachelines(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); AscendC::GetSystemCycle(); } } else { peer = 0; - while (*(__gm__ uint32_t*)(gva + rank_size * msg_len + WARMUP_MSG_LEN * (peer + 1)) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + rank_size * msg_len + WARMUP_MSG_LEN * (peer + 1), sizeof(uint32_t)); + while (*(__gm__ uint32_t*)(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1)) != peer + MAGIC_VAL) { + dcci_cachelines(gva + rank_size * message_length + WARMUP_MESSAGE_LENGTH * (peer + 1), sizeof(uint32_t)); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); - shmem_put_uint8_mem_nbi(warm_addr, warm_addr, WARMUP_MSG_LEN, peer); + shmem_put_uint8_mem_nbi(warm_addr, warm_addr, WARMUP_MESSAGE_LENGTH, peer); } AscendC::PipeBarrier(); // Actual test - GM_ADDR src_addr = gva + rank * msg_len; + GM_ADDR src_addr = gva + rank * message_length; if (rank == 0) { peer = 1; int64_t start = AscendC::GetSystemCycle(); - shmem_put_uint8_mem_nbi(src_addr, src_addr, msg_len, peer); - while (*(__gm__ uint32_t*)(gva + msg_len * 2 - 8) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + msg_len * 2 - 8, 8); + shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); + while (*(__gm__ uint32_t*)(gva + message_length * 2 - 8) != peer + MAGIC_VAL) { + dcci_cachelines(gva + message_length * 2 - 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); int64_t end = AscendC::GetSystemCycle(); - *(__gm__ int64_t*)(gva + msg_len * 2) = end - start; + *(__gm__ int64_t*)(gva + message_length * 2) = end - start; } else { peer = 0; - while (*(__gm__ uint32_t*)(gva + msg_len * 1 - 8) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + msg_len * 1 - 8, 8); + while (*(__gm__ uint32_t*)(gva + message_length * 1 - 8) != peer + MAGIC_VAL) { + dcci_cachelines(gva + message_length * 1 - 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); - shmem_put_uint8_mem_nbi(src_addr, src_addr, msg_len, peer); + shmem_put_uint8_mem_nbi(src_addr, src_addr, message_length, peer); } } -void rdma_highlevel_put_pingpong_latency_do(uint32_t block_dim, void* stream, uint64_t cfg, uint8_t* gva, int len) -{ - rdma_highlevel_put_pingpong_latency<<<1, nullptr, stream>>>(cfg, gva, len); +void rdma_highlevel_put_pingpong_latency_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { + rdma_highlevel_put_pingpong_latency<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); } extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM_ADDR gva, int message_length) { @@ -86,19 +85,18 @@ extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM AscendC::TBuf buf; pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - AscendC::LocalTensor ubLocal64 = - buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Actual test GM_ADDR src_addr = gva + rank * message_length; - + if (rank == 0) { peer = 1; - GM_ADDR dest_addr = (GM_ADDR)(shmem_ptr(src_addr, peer)); + GM_ADDR dest_addr = (GM_ADDR)(shmem_roce_ptr(src_addr, peer)); int64_t start = AscendC::GetSystemCycle(); for (uint32_t i = 0; i < 500; i++) { shmemi_roce_write(dest_addr, src_addr, peer, 0, message_length, ubLocal64, ubLocal32); @@ -109,8 +107,7 @@ extern "C" __global__ __aicore__ void rdma_postsend_cost(uint64_t fftsConfig, GM } } -void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) -{ +void rdma_postsend_cost_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { rdma_postsend_cost<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); } @@ -123,11 +120,10 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, AscendC::TBuf buf; pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); AscendC::LocalTensor ubLocal32 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint32_t), 0); - AscendC::LocalTensor ubLocal64 = - buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); + AscendC::LocalTensor ubLocal64 = buf.GetWithOffset(UB_ALIGN_SIZE / sizeof(uint64_t), UB_ALIGN_SIZE); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Actual test @@ -141,7 +137,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); shmem_put_uint8_mem_nbi(gva + rank_size * message_length + 8, src_addr, sizeof(uint32_t), peer); while (*(__gm__ uint32_t*)(gva + message_length * rank_size + 16) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + message_length * rank_size + 16, 8); + dcci_cachelines(gva + message_length * rank_size + 16, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -150,7 +146,7 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, } else { peer = 0; while (*(__gm__ uint32_t*)(gva + rank_size * message_length + 8) != peer + MAGIC_VAL) { - cacheWriteThrough(gva + rank_size * message_length + 8, 8); + dcci_cachelines(gva + rank_size * message_length + 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -158,13 +154,12 @@ extern "C" __global__ __aicore__ void rdma_highlevel_put_bw(uint64_t fftsConfig, } } -void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) -{ +void rdma_highlevel_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length) { rdma_highlevel_put_bw<<<1, nullptr, stream>>>(fftsConfig, gva, message_length); } -extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t cfg, GM_ADDR gva, int message_length, int64_t iter) { - shmemx_set_ffts_config(cfg); +extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t fftsConfig, GM_ADDR gva, int message_length, int64_t iter) { + shmemx_set_ffts_config(fftsConfig); AscendC::LocalTensor ubLocal32; ubLocal32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ubLocal32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); @@ -174,8 +169,8 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t cfg, GM_ADDR gva, ubLocal64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); ubLocal64.address_.dataLen = UB_ALIGN_SIZE; - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); uint32_t peer; // Core 0, RDMA @@ -185,14 +180,12 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t cfg, GM_ADDR gva, peer = 1; int64_t start = AscendC::GetSystemCycle(); for (int i = 0; i < 10000; i++) { - shmemi_roce_write((GM_ADDR)shmem_ptr(src_addr, peer), src_addr, peer, 0, - message_length, ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(src_addr, peer), src_addr, peer, 0, message_length, ubLocal64, ubLocal32); } shmemi_roce_quiet(peer, 0, ubLocal64, ubLocal32); - shmemi_roce_write((GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 8, peer), - src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(gva + rank_size * message_length * 2 + 8, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); while (*(__gm__ int64_t*)(gva + message_length * rank_size * 2 + 16) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + message_length * rank_size * 2 + 16, 8); + dcci_cachelines(gva + message_length * rank_size * 2 + 16, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -201,12 +194,11 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t cfg, GM_ADDR gva, } else { peer = 0; while (*(__gm__ int64_t*)(gva + rank_size * message_length * 2 + 8) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + rank_size * message_length * 2 + 8, 8); + dcci_cachelines(gva + rank_size * message_length * 2 + 8, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); - shmemi_roce_write((GM_ADDR)shmem_ptr(gva + rank_size * message_length * 2 + 16, peer), - src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); + shmemi_roce_write((GM_ADDR)shmem_roce_ptr(gva + rank_size * message_length * 2 + 16, peer), src_addr, peer, 0, sizeof(int64_t), ubLocal64, ubLocal32); } } else { // core 1, MTE GM_ADDR src_addr = gva + (rank + rank_size) * message_length; @@ -219,14 +211,12 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t cfg, GM_ADDR gva, peer = 1; int64_t start = AscendC::GetSystemCycle(); for (int i = 0; i < 10000; i++) { - shmem_mte_put_mem_nbi(src_addr, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), - copy_ub_size, message_length, peer, copy_event_id); + shmem_mte_put_mem_nbi(src_addr, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, message_length, peer, copy_event_id); } AscendC::PipeBarrier(); - shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 24, src_addr, - reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); + shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 24, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); while (*(__gm__ uint32_t*)(gva + message_length * rank_size * 2 + 32) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + message_length * rank_size * 2 + 32, 8); + dcci_cachelines(gva + message_length * rank_size * 2 + 32, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); @@ -235,17 +225,15 @@ extern "C" __global__ __aicore__ void rdma_mte_put_bw(uint64_t cfg, GM_ADDR gva, } else { peer = 0; while (*(__gm__ uint32_t*)(gva + rank_size * message_length * 2 + 24) != peer + MAGIC_VAL + iter) { - cacheWriteThrough(gva + rank_size * message_length * 2 + 24, 8); + dcci_cachelines(gva + rank_size * message_length * 2 + 24, 8); AscendC::GetSystemCycle(); } AscendC::PipeBarrier(); - shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 32, src_addr, - reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); + shmem_mte_put_mem_nbi(gva + rank_size * message_length * 2 + 32, src_addr, reinterpret_cast<__ubuf__ uint8_t*>(copy_ub), copy_ub_size, sizeof(uint32_t), peer, copy_event_id); } } } -void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int len, int64_t iter) -{ - rdma_mte_put_bw<<<2, nullptr, stream>>>(fftsConfig, gva, len, iter); +void rdma_mte_put_bw_do(uint32_t block_dim, void* stream, uint64_t fftsConfig, uint8_t* gva, int message_length, int64_t iter) { + rdma_mte_put_bw<<<2, nullptr, stream>>>(fftsConfig, gva, message_length, iter); } \ No newline at end of file diff --git a/examples/rdma_perftest/run.sh b/examples/rdma_perftest/run.sh index c418b0183740937c333e65f45432dc1fd87cca9b..b3b0abac2d5302854652ea6260e282becaf2cc5a 100644 --- a/examples/rdma_perftest/run.sh +++ b/examples/rdma_perftest/run.sh @@ -4,13 +4,15 @@ script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" project_root="$(cd ${script_dir}/../../ && pwd)" export PROJECT_ROOT=${project_root} export LD_LIBRARY_PATH=${PROJECT_ROOT}/build/lib:${PROJECT_ROOT}/src/memfabric_hybrid/output/smem/lib64/:${PROJECT_ROOT}/src/memfabric_hybrid/output/hybm/lib64/:$LD_LIBRARY_PATH + +export SHMEM_UID_SESSION_ID=127.0.0.1:8899 cd $PROJECT_ROOT pids=() -./build/bin/rdma_perftest 2 0 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 & # rank 0 +./build/bin/rdma_perftest 2 0 tcp://127.0.0.1:8899 2 0 0 highlevel_put_pingpong_latency 64 & # rank 0 pid=$! pids+=("$pid") -./build/bin/rdma_perftest 2 1 tcp://127.0.0.1:8765 2 0 0 highlevel_put_pingpong_latency 64 & # rank 1 +./build/bin/rdma_perftest 2 1 tcp://127.0.0.1:8899 2 0 0 highlevel_put_pingpong_latency 64 & # rank 1 pid=$! pids+=("$pid") @@ -23,4 +25,4 @@ for pid in ${pids[@]}; do fi echo "wait $pid finished" done -exit $ret \ No newline at end of file +exit $ret diff --git a/include/device/low_level/shmem_device_low_level_rma.h b/include/device/low_level/shmem_device_low_level_rma.h index 611667df38106bce801f516c2c74a6e0bb5113fa..0a5343c48068ac9a426426f73b27f81aef78adf2 100644 --- a/include/device/low_level/shmem_device_low_level_rma.h +++ b/include/device/low_level/shmem_device_low_level_rma.h @@ -34,14 +34,188 @@ SHMEM_DEVICE __gm__ void *shmem_ptr(__gm__ void *ptr, int pe) uint64_t offset = reinterpret_cast(ptr) - reinterpret_cast(device_state->heap_base); // Address translate - uint64_t remote_ptr = reinterpret_cast(device_state->p2p_heap_device_base[pe]) + offset; + uint64_t remote_ptr = reinterpret_cast(device_state->device_p2p_heap_base[pe]) + offset; return reinterpret_cast<__gm__ void *>(remote_ptr); } /** - * @brief Asynchronous interface. Copy contiguous data on symmetric memory from the specified - * PE to address on the local device. + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] Pointer on Symmetric memory of the destination data. + * @param srcUb [in] Pointer on local UB of the source data. + * @param elem_size [in] Byte Size of data in the destination and source arrays. + * @param toL2Cache [in] Enable L2Cache or not. False means disable L2Cache. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, + uint32_t size, bool toL2Cache = true) +{ + ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); + + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); + if (!toL2Cache) { + gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + AscendC::DataCopyPad(gmTensor, ubTensor, dataCopyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] GlobalTensor on Symmetric memory of the destination data. + * @param srcUb [in] LocalTensor on local UB of the source data. + * @param size [in] Byte Size of data in the destination and source arrays. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, + const AscendC::LocalTensor &srcUb, uint32_t size) +{ + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + AscendC::DataCopyPad(dstGva, srcUb, dataCopyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] Pointer on Symmetric memory of the destination data. + * @param srcUb [in] Pointer on local UB of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(__gm__ T* dstGva, __ubuf__ T* srcUb, + AscendC::DataCopyExtParams ©Params) +{ + ASCENDC_ASSERT((dstGva != nullptr), "input gva is null"); + + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(srcUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGva)); + AscendC::DataCopyPad(gmTensor, ubTensor, copyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstGva [in] GlobalTensor on Symmetric memory of the destination data. + * @param srcUb [in] LocalTensor on local UB of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_ub2gm(const AscendC::GlobalTensor &dstGva, + const AscendC::LocalTensor &srcUb, AscendC::DataCopyExtParams ©Params) +{ + AscendC::DataCopyPad(dstGva, srcUb, copyParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on symmetric memory from the specified PE to local UB. + * + * @param dstUb [in] Pointer on local UB of the destination data. + * @param srcGva [in] Pointer on Symmetric memory of the source data. + * @param size [in] Byte Size of data in the destination and source arrays. + * @param toL2Cache [in] Enable L2Cache or not. False means disable L2Cache. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, + uint32_t size, bool toL2Cache = true) +{ + ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); + if (!toL2Cache) { + gmTensor.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on symmetric memory from the specified PE to local UB. + * + * @param dstUb [in] LocalTensor on local UB of the destination data. + * @param srcGva [in] GlobalTensor on Symmetric memory of the source data. + * @param size [in] Byte Size of data in the destination and source arrays. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, + const AscendC::GlobalTensor &srcGva, uint32_t size) +{ + AscendC::DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(dstUb, srcGva, dataCopyParams, padParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstUb [in] Pointer on local UB of the destination data. + * @param srcGva [in] Pointer on Symmetric memory of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(__ubuf__ T* dstUb, __gm__ T* srcGva, + AscendC::DataCopyExtParams ©Params) +{ + ASCENDC_ASSERT((srcGva != nullptr), "input gva is null"); + AscendC::LocalTensor ubTensor; + AscendC::GlobalTensor gmTensor; + ubTensor.address_.logicPos = static_cast(AscendC::TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(dstUb); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGva)); + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(ubTensor, gmTensor, copyParams, padParams); +} + +/** + * @brief Simple Copy interface. Copy contiguous data on local UB memory to symmetric address on the specified PE. + * + * @param dstUb [in] LocalTensor on local UB of the destination data. + * @param srcGva [in] GlobalTensor on Symmetric memory of the source data. + * @param copyParams [in] Describe non-contiguous data copy. + */ +template +SHMEM_DEVICE void shmemi_copy_gm2ub(const AscendC::LocalTensor &dstUb, + const AscendC::GlobalTensor &srcGva, AscendC::DataCopyExtParams ©Params) +{ + AscendC::DataCopyPadExtParams padParams; + AscendC::DataCopyPad(dstUb, srcGva, copyParams, padParams); +} + +/** + * @brief Translate an local symmetric address to remote symmetric address on the specified PE used by RDMA. + * + * @param ptr [in] Symmetric address on local PE. + * @param pe [in] The number of the remote PE. + * @return A remote symmetric address on the specified PE that can be accessed using memory loads and stores. + */ +SHMEM_DEVICE __gm__ void *shmem_roce_ptr(__gm__ void *ptr, int pe) +{ + // Get Global State + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + + // Back to root address + uint64_t offset = reinterpret_cast(ptr) - reinterpret_cast(device_state->heap_base); + + // Address translate + uint64_t remote_ptr = reinterpret_cast(device_state->device_rdma_heap_base[pe]) + offset; + + return reinterpret_cast<__gm__ void *>(remote_ptr); +} + +/** + * @brief Asynchronous interface. Copy contiguous data on symmetric memory from the specified PE to address on the local device. * * @param dst [in] Pointer on local device of the destination data. * @param src [in] Pointer on Symmetric memory of the source data. @@ -66,20 +240,20 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size); + shmemi_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + i * repeat_elem, buf, block_size); + shmemi_copy_ub2gm(dst + i * repeat_elem, buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain); + shmemi_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain); + shmemi_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain); } } @@ -98,7 +272,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T template SHMEM_DEVICE void shmem_roce_get_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ T* buf, uint32_t elem_size, int pe) { - auto ptr = shmem_ptr(src, pe); + auto ptr = shmem_roce_ptr(src, pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf); @@ -145,7 +319,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); + shmemi_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -153,7 +327,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); } /** @@ -185,20 +359,20 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, remote_buff[i * repeat_elem], block_size); + shmemi_copy_gm2ub(buf, remote_buff[i * repeat_elem], block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst[i * repeat_elem], buf, block_size); + shmemi_copy_ub2gm(dst[i * repeat_elem], buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, remote_buff[repeat_times * repeat_elem], remain); + shmemi_copy_gm2ub(buf, remote_buff[repeat_times * repeat_elem], remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst[repeat_times * repeat_elem], buf, remain); + shmemi_copy_ub2gm(dst[repeat_times * repeat_elem], buf, remain); } } @@ -218,7 +392,7 @@ template SHMEM_DEVICE void shmem_roce_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, AscendC::LocalTensor buf, uint32_t elem_size, int pe) { - auto ptr = shmem_ptr((__gm__ void *)src.GetPhyAddr(), pe); + auto ptr = shmem_roce_ptr((__gm__ void *)src.GetPhyAddr(), pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()); @@ -257,7 +431,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(buf, remote_buff, data_copy_params_gm2ub); + shmemi_copy_gm2ub(buf, remote_buff, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -265,7 +439,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst, buf, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst, buf, data_copy_params_ub2gm); } /** @@ -294,20 +468,20 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, src + i * repeat_elem, block_size); + shmemi_copy_gm2ub(buf, src + i * repeat_elem, block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size); + shmemi_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain); + shmemi_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain); + shmemi_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain); } } @@ -325,7 +499,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T template SHMEM_DEVICE void shmem_roce_put_mem_nbi(__gm__ T* dst, __gm__ T* src, __ubuf__ T* buf, uint32_t elem_size, int pe) { - auto ptr = shmem_ptr(dst, pe); + auto ptr = shmem_roce_ptr(dst, pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf); @@ -371,7 +545,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); + shmemi_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -379,7 +553,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __gm__ T *src, __ubuf__ T AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); } /** @@ -410,20 +584,20 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G uint64_t repeat_elem = block_size / sizeof(T); uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, src[i * repeat_elem], block_size); + shmemi_copy_gm2ub(buf, src[i * repeat_elem], block_size); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_buff[i * repeat_elem], buf, block_size); + shmemi_copy_ub2gm(remote_buff[i * repeat_elem], buf, block_size); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, src[repeat_times * repeat_elem], remain); + shmemi_copy_gm2ub(buf, src[repeat_times * repeat_elem], remain); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_buff[repeat_times * repeat_elem], buf, remain); + shmemi_copy_ub2gm(remote_buff[repeat_times * repeat_elem], buf, remain); } } @@ -442,7 +616,7 @@ template SHMEM_DEVICE void shmem_roce_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::GlobalTensor src, AscendC::LocalTensor buf, uint32_t elem_size, int pe, AscendC::TEventID EVENT_ID) { - auto ptr = shmem_ptr((__gm__ void *)dst.GetPhyAddr(), pe); + auto ptr = shmem_roce_ptr((__gm__ void *)dst.GetPhyAddr(), pe); AscendC::LocalTensor ub_tensor_32; ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); ub_tensor_32.address_.bufferAddr = reinterpret_cast(buf.GetPhyAddr()); @@ -481,7 +655,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(buf, src, data_copy_params_gm2ub); + shmemi_copy_gm2ub(buf, src, data_copy_params_gm2ub); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); @@ -489,7 +663,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::G AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (ub_stride - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(remote_buff, buf, data_copy_params_ub2gm); + shmemi_copy_ub2gm(remote_buff, buf, data_copy_params_ub2gm); } /** @@ -516,7 +690,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__ubuf__ T *dst, __gm__ T *src, uint32_t __gm__ T *remote_ptr = reinterpret_cast<__gm__ T *>(ptr); - smem_shm_copy_gm2ub(dst, remote_ptr, elem_size * sizeof(T)); + shmemi_copy_gm2ub(dst, remote_ptr, elem_size * sizeof(T)); } /** @@ -545,7 +719,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::LocalTensor dst, AscendC::Gl AscendC::GlobalTensor remote_buff; remote_buff.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(ptr)); - smem_shm_copy_gm2ub(dst, remote_buff, elem_size * sizeof(T)); + shmemi_copy_gm2ub(dst, remote_buff, elem_size * sizeof(T)); } /** @@ -582,7 +756,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(__ubuf__ T *dst, __gm__ T *src, const no AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (copy_params.dst_ld - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); + shmemi_copy_gm2ub(ub_tensor, src_tensor, data_copy_params_gm2ub); } /** @@ -616,7 +790,7 @@ SHMEM_DEVICE void shmem_mte_get_mem_nbi(AscendC::LocalTensor dst, AscendC::Gl AscendC::DataCopyExtParams data_copy_params_gm2ub(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) * sizeof(T), (copy_params.dst_ld - copy_params.length) / ELE_NUM_PER_UNIT, 0); - smem_shm_copy_gm2ub(dst, remote_buff, data_copy_params_gm2ub); + shmemi_copy_gm2ub(dst, remote_buff, data_copy_params_gm2ub); } /** @@ -642,7 +816,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __ubuf__ T *src, uint32_t __gm__ T *remote_ptr = reinterpret_cast<__gm__ T *>(ptr); - smem_shm_copy_ub2gm(remote_ptr, src, elem_size * sizeof(T)); + shmemi_copy_ub2gm(remote_ptr, src, elem_size * sizeof(T)); } /** @@ -670,7 +844,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::L AscendC::GlobalTensor remote_buff; remote_buff.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(ptr)); - smem_shm_copy_ub2gm(remote_buff, src, elem_size * sizeof(T)); + shmemi_copy_ub2gm(remote_buff, src, elem_size * sizeof(T)); } /** @@ -708,7 +882,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(__gm__ T *dst, __ubuf__ T *src, const no AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); + shmemi_copy_ub2gm(dst_tensor, ub_tensor, data_copy_params_ub2gm); } /** @@ -742,7 +916,7 @@ SHMEM_DEVICE void shmem_mte_put_mem_nbi(AscendC::GlobalTensor dst, AscendC::L AscendC::DataCopyExtParams data_copy_params_ub2gm(copy_params.repeat, copy_params.length * sizeof(T), (copy_params.src_ld - copy_params.length) / ELE_NUM_PER_UNIT, (copy_params.dst_ld - copy_params.length) * sizeof(T), 0); - smem_shm_copy_ub2gm(remote_buff, src, data_copy_params_ub2gm); + shmemi_copy_ub2gm(remote_buff, src, data_copy_params_ub2gm); } #endif \ No newline at end of file diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index b5cc49a7e1e35d339eee1e2745831721a866259e..b02ae17a319d5531c864af74c8c3a9a275a4f181 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -102,6 +102,19 @@ struct SHMEMHybmDeviceMeta { uint64_t reserved[12]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE }; +SHMEM_DEVICE __gm__ SHMEMAIVRDMAInfo* shmemi_qp_info_fetch() +{ +#ifdef BACKEND_MF + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)( + SMEM_SHM_DEVICE_META_ADDR + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); +#else + __gm__ shmemi_device_host_state_t *device_state = shmemi_get_state(); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(device_state->qp_info); +#endif + return RDMAInfo; +} + SHMEM_DEVICE void shmemi_roce_poll_cq_update_info(AscendC::LocalTensor &ubLocal64, AscendC::LocalTensor &ubLocal32, uint32_t &curTail, uint32_t &rRankId, uint32_t &qpIdx); SHMEM_DEVICE void shmemi_rdma_post_send_update_info(AscendC::LocalTensor &ubLocal64, @@ -123,10 +136,8 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, if (idx == 0) { return 0; } + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); @@ -172,9 +183,8 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, SHMEM_DEVICE void shmemi_roce_poll_cq_update_info(AscendC::LocalTensor &ubLocal64, AscendC::LocalTensor &ubLocal32, uint32_t &curTail, uint32_t &remoteRankId, uint32_t &qpIdx) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); @@ -236,9 +246,8 @@ SHMEM_DEVICE void shmemi_rdma_post_send(__gm__ uint8_t* remoteAddr, __gm__ uint8 AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); @@ -377,9 +386,8 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + uint32_t qpNum = RDMAInfo->qpNum; __gm__ SHMEMWQCtx* qpCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (remoteRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); @@ -391,9 +399,8 @@ SHMEM_DEVICE void shmemi_roce_quiet(uint32_t remoteRankId, uint32_t qpIdx, SHMEM_DEVICE void shmemi_roce_qpinfo_test(__gm__ uint8_t* gva, uint32_t destRankId, uint32_t qpIdx) { - __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + - SMEM_SHM_DEVICE_GLOBAL_META_SIZE); - __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + *(__gm__ uint64_t*)(gva) = (uint64_t)RDMAInfo; uint32_t qpNum = RDMAInfo->qpNum; *(__gm__ uint64_t*)(gva + 8) = (uint64_t)qpNum; @@ -435,4 +442,88 @@ SHMEM_DEVICE void shmemi_roce_qpinfo_test(__gm__ uint8_t* gva, uint32_t destRank AscendC::PipeBarrier(); } +template +SHMEM_DEVICE void shmemi_roce_pollcq_test(__gm__ T* srcDmaAddr, __gm__ T* destDmaAddr, uint32_t destRankId, + uint32_t qpIdx, uint64_t messageLen, + AscendC::LocalTensor ubLocal64, + AscendC::LocalTensor ubLocal32, __gm__ uint8_t* gva) +{ + shmemi_rdma_post_send(destDmaAddr, srcDmaAddr, destRankId, qpIdx, SHMEMAIVOPCODE::OP_RDMA_WRITE, + messageLen, ubLocal64, ubLocal32); + uint32_t idx = 1; + __gm__ SHMEMAIVRDMAInfo* RDMAInfo = shmemi_qp_info_fetch(); + + uint32_t qpNum = RDMAInfo->qpNum; + __gm__ SHMEMCQCtx* cqCtxEntry = (__gm__ SHMEMCQCtx*)(RDMAInfo->scqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMCQCtx)); + *(__gm__ uint64_t*)(gva) = (uint64_t)cqCtxEntry; + auto cqBaseAddr = cqCtxEntry->bufAddr; + auto cqeSize = cqCtxEntry->cqeSize; + auto depth = cqCtxEntry->depth; + *(__gm__ uint64_t*)(gva + 8) = (uint64_t)cqBaseAddr; + *(__gm__ uint64_t*)(gva + 16) = (uint64_t)cqeSize; + *(__gm__ uint64_t*)(gva + 24) = (uint64_t)depth; + auto curHardwareTailAddr = cqCtxEntry->tailAddr; + *(__gm__ uint64_t*)(gva + 32) = (uint64_t)curHardwareTailAddr; + dcci_cachelines((__gm__ uint8_t*)curHardwareTailAddr, 8); + uint32_t curTail = *(__gm__ uint32_t*)(curHardwareTailAddr); + *(__gm__ uint64_t*)(gva + 40) = (uint64_t)curTail; + + AscendC::DataCopyExtParams copyParamsTail{1, 1 * sizeof(uint32_t), 0, 0, 0}; + + __gm__ SHMEMcqeCtx* cqeAddr = (__gm__ SHMEMcqeCtx*)(cqBaseAddr + cqeSize * (curTail & (depth - 1))); + uint32_t cqeByte4 = *(__gm__ uint32_t*)cqeAddr; + while (!(cqeByte4 & (1 << 7))) { + int64_t tmp = AscendC::GetSystemCycle(); + dcci_cachelines((__gm__ uint8_t*)cqeAddr, 32); + cqeByte4 = *(__gm__ uint32_t*)cqeAddr; + } + *(__gm__ uint64_t*)(gva + 56) = (uint64_t)(cqeAddr->byte4); + *(__gm__ uint64_t*)(gva + 64) = (uint64_t)(cqeAddr->immtdata); + *(__gm__ uint64_t*)(gva + 72) = (uint64_t)(cqeAddr->byte12); + *(__gm__ uint64_t*)(gva + 80) = (uint64_t)(cqeAddr->byte16); + *(__gm__ uint64_t*)(gva + 88) = (uint64_t)(cqeAddr->byteCnt); + *(__gm__ uint64_t*)(gva + 96) = (uint64_t)(cqeAddr->smac); + curTail++; + // Process each CQE, and update WQ tail + uint32_t wqn = cqeAddr->byte16 & 0xFFFFFF; + __gm__ SHMEMWQCtx* wqCtxEntry = (__gm__ SHMEMWQCtx*)(RDMAInfo->sqPtr + (destRankId * qpNum + qpIdx) * sizeof(SHMEMWQCtx)); + *(__gm__ uint64_t*)(gva + 104) = (uint64_t)(wqCtxEntry->wqn == wqn); + auto curWQTailAddr = wqCtxEntry->tailAddr; + dcci_cachelines((__gm__ uint8_t*)curWQTailAddr, 8); + uint32_t curWQTail = *(__gm__ uint32_t*)(curWQTailAddr); + ubLocal32.SetValue(0, curWQTail + 1); + AscendC::GlobalTensor WQTailGlobalTensor; + WQTailGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)curWQTailAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(WQTailGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + dcci_cachelines((__gm__ uint8_t*)curWQTailAddr, 8); + + // Check CQE status + uint32_t status = (cqeAddr->byte4 >> 8) & 0xFF; + *(__gm__ uint64_t*)(gva + 112) = status; + if (status) { + return; + } + + // Update tail + ubLocal32.SetValue(0, (uint32_t)curTail); + AscendC::GlobalTensor TailGlobalTensor; + TailGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)curHardwareTailAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(TailGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + dcci_cachelines((__gm__ uint8_t*)curHardwareTailAddr, 8); + + // Ring CQ Doorbell + auto cqDBAddr = cqCtxEntry->dbAddr; + ubLocal32.SetValue(0, (uint32_t)(curTail & 0xFFFFFF)); + AscendC::GlobalTensor CQDBGlobalTensor; + CQDBGlobalTensor.SetGlobalBuffer((__gm__ uint32_t*)cqDBAddr); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(CQDBGlobalTensor, ubLocal32, copyParamsTail); + AscendC::PipeBarrier(); + dcci_cachelines((__gm__ uint8_t*)cqDBAddr, 8); +} + #endif // SHMEM_DEVICE_LOW_LEVEL_ROCE_H \ No newline at end of file diff --git a/include/device/low_level/shmemx_device_low_level_rma.h b/include/device/low_level/shmemx_device_low_level_rma.h index 95a86915e9a7c0ca83fae973788a106ed700f9ea..7d31cb3948da347d0fccef5e8a08af6bfc97f801 100644 --- a/include/device/low_level/shmemx_device_low_level_rma.h +++ b/include/device/low_level/shmemx_device_low_level_rma.h @@ -40,20 +40,20 @@ SHMEM_DEVICE void shmemx_mte_get_mem_nbi_low_level(__gm__ int8_t* dst, __gm__ in uint64_t repeat_elem = block_size; uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size, enable_L2); + shmemi_copy_gm2ub(buf, remote_ptr + i * repeat_elem, block_size, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + i * repeat_elem, buf, block_size, enable_L2); + shmemi_copy_ub2gm(dst + i * repeat_elem, buf, block_size, enable_L2); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain, enable_L2); + shmemi_copy_gm2ub(buf, remote_ptr + repeat_times * repeat_elem, remain, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain, enable_L2); + shmemi_copy_ub2gm(dst + repeat_times * repeat_elem, buf, remain, enable_L2); } } @@ -83,20 +83,20 @@ SHMEM_DEVICE void shmemx_mte_put_mem_nbi_low_level(__gm__ int8_t* dst, __gm__ in uint64_t repeat_elem = block_size; uint64_t loop_times = remain > 0 ? repeat_times + 1 : repeat_times; for (uint64_t i = 0; i < repeat_times; i++) { - smem_shm_copy_gm2ub(buf, src + i * repeat_elem, block_size, enable_L2); + shmemi_copy_gm2ub(buf, src + i * repeat_elem, block_size, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size, enable_L2); + shmemi_copy_ub2gm(remote_ptr + i * repeat_elem, buf, block_size, enable_L2); if (i != loop_times - 1) { // Last PIPE Sync Should be done outside AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); } } if (remain > 0) { - smem_shm_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain, enable_L2); + shmemi_copy_gm2ub(buf, src + repeat_times * repeat_elem, remain, enable_L2); AscendC::SetFlag(EVENT_ID); AscendC::WaitFlag(EVENT_ID); - smem_shm_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain, enable_L2); + shmemi_copy_ub2gm(remote_ptr + repeat_times * repeat_elem, buf, remain, enable_L2); } } diff --git a/include/device/shmem_device_atomic.h b/include/device/shmem_device_atomic.h index 3bdb6a9e8e05f76c63638fc8b8524bf93661aaa3..a0a1e5cc7d01c4deb62b76fbeface80e22ed7bf4 100644 --- a/include/device/shmem_device_atomic.h +++ b/include/device/shmem_device_atomic.h @@ -54,7 +54,7 @@ dcci_atomic(); \ dsb_all(); \ set_st_atomic_cfg(ATOMIC_TYPE, ATOMIC_SUM); \ - st_atomic(value, (__gm__ TYPE *)shmemi_ptr(dst, pe)); \ + st_atomic(value, (__gm__ TYPE *)shmem_ptr(dst, pe)); \ dcci_atomic(); \ } @@ -77,7 +77,7 @@ SHMEM_TYPE_FUNC_ATOMIC_INT(SHMEM_ATOMIC_ADD_TYPENAME); dcci_atomic(); \ dsb_all(); \ set_st_atomic_cfg(ATOMIC_TYPE, ATOMIC_SUM); \ - st_atomic(value, (__gm__ TYPE *)shmemi_ptr(dst, pe)); \ + st_atomic(value, (__gm__ TYPE *)shmem_ptr(dst, pe)); \ dcci_atomic(); \ } diff --git a/include/device/shmem_device_rma.h b/include/device/shmem_device_rma.h index 8299ee18029053809c2699bba3e1e27cbec68e85..e53d9a90febee3493f510ac69667b34f72a1a29a 100644 --- a/include/device/shmem_device_rma.h +++ b/include/device/shmem_device_rma.h @@ -257,7 +257,7 @@ SHMEM_DEVICE void shmem_getmem_nbi(__gm__ void *dst, __gm__ void *src, uint32_t copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr(src, pe); \ + auto ptr = shmem_roce_ptr(src, pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ @@ -334,7 +334,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_DETAILED_NBI); shmem_mte_get_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr((__gm__ void *)src.GetPhyAddr(), pe); \ + auto ptr = shmem_roce_ptr((__gm__ void *)src.GetPhyAddr(), pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ @@ -409,7 +409,7 @@ SHMEM_TYPE_FUNC(SHMEM_GET_TYPENAME_MEM_TENSOR_DETAILED_NBI); copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr(dst, pe); \ + auto ptr = shmem_roce_ptr(dst, pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ @@ -485,7 +485,7 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_DETAILED_NBI); shmem_mte_put_mem_nbi(dst, src, ub_tensor, elem_size, pe, copy_event_id); \ } else if (device_state->topo_list[pe] & SHMEM_TRANSPORT_ROCE) { \ /* RoCE */ \ - auto ptr = shmem_ptr((__gm__ void *)dst.GetPhyAddr(), pe); \ + auto ptr = shmem_roce_ptr((__gm__ void *)dst.GetPhyAddr(), pe); \ if (ptr == nullptr) return; \ /* Create LocalTensor */ \ AscendC::LocalTensor ub_tensor_32; \ diff --git a/include/host/shmem_host_def.h b/include/host/shmem_host_def.h index 000766693e466c6e4c2380081e377283f5ff4cfa..a3a227f1d4180142b91c9da895a483ed8d0464bd 100644 --- a/include/host/shmem_host_def.h +++ b/include/host/shmem_host_def.h @@ -10,6 +10,7 @@ #ifndef SHMEM_HOST_DEF_H #define SHMEM_HOST_DEF_H #include +#include #include "host_device/shmem_types.h" #ifdef __cplusplus @@ -79,6 +80,17 @@ enum shmem_error_code_t : int { SHMEM_SMEM_ERROR = -3, ///< There is a problem with SMEM. SHMEM_INNER_ERROR = -4, ///< This is a problem caused by an internal error. SHMEM_NOT_INITED = -5, ///< This is a problem caused by an uninitialization. + SHMEM_BOOTSTRAP_ERROR = -6,///< This is a problem with BOOTSTRAP. + SHMEM_TIMEOUT_ERROR = -7, ///< This is a problem caused by TIMEOUT. +}; + +/** + * @brief init flags +*/ +enum shmemx_bootstrap_t : int { + SHMEMX_INIT_WITH_UNIQUEID = 1, + SHMEMX_INIT_WITH_MPI = 1 << 1, + SHMEMX_INIT_WITH_DEFAULT = 1 << 2, }; /** @@ -143,6 +155,7 @@ typedef struct { char ip_port[SHMEM_MAX_IP_PORT_LEN]; uint64_t local_mem_size; shmem_init_optional_attr_t option_attr; + void *comm_args; } shmem_init_attr_t; /** @@ -156,14 +169,14 @@ typedef struct { typedef int (*shmem_decrypt_handler)(const char *cipherText, size_t cipherTextLen, char *plainText, size_t &plainTextLen); -constexpr uint16_t SHMEM_UNIQUE_ID_INNER_LEN = 60; +constexpr uint16_t SHMEM_UNIQUE_ID_INNER_LEN = 124; typedef struct { int32_t version; char internal[SHMEM_UNIQUE_ID_INNER_LEN]; -} shmem_uniqueid_t; +} shmemx_uniqueid_t; -constexpr int32_t SHMEM_UNIQUEID_VERSION = (1 << 16) + sizeof(shmem_uniqueid_t); +constexpr int32_t SHMEM_UNIQUEID_VERSION = (1 << 16) + sizeof(shmemx_uniqueid_t); #define SHMEM_UNIQUEID_INITIALIZER \ { \ diff --git a/include/host/shmem_host_init.h b/include/host/shmem_host_init.h index 4e19de10572d7e7ec29acf8be4d5680304a12f39..ec29de84c3a3eaf4ff73f48a27f36a5c35bf3998 100644 --- a/include/host/shmem_host_init.h +++ b/include/host/shmem_host_init.h @@ -64,75 +64,22 @@ SHMEM_HOST_API int shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, */ SHMEM_HOST_API int shmem_set_timeout(shmem_init_attr_t *attributes, uint32_t value); -/** - * @brief get the unique id and return it by intput argument uid. This function need run with PTA. - * - * @param uid [out] a ptr to uid generate by shmem - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int shmem_get_uniqueid(shmem_uniqueid_t *uid); - -/** - * @brief init process with unique id. This function need run with PTA. - * - * @param rank_id [in] current rank id - * @param nranks [in] total ranks - * @param uid [in] a ptr to uid, generated by shmem_get_uniqueid - * @param attr [out] a ptr to shmem_init_attr_t - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int shmem_set_attr_uniqueid_args(int rank_id, int nranks, - const shmem_uniqueid_t *uid, shmem_init_attr_t *attr); +SHMEM_HOST_API int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid); +SHMEM_HOST_API int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const int64_t local_mem_size, + const shmemx_uniqueid_t *uid, + shmem_init_attr_t **shmem_attr); /** * @brief Initialize the resources required for SHMEM task based on attributes. * Attributes can be created by users or obtained by calling shmem_set_attr(). * if the self-created attr structure is incorrect, the initialization will fail. * It is recommended to build the attributes by shmem_set_attr(). * + * @param bootstrap_flags [in] bootstrap_flags for init. * @param attributes [in] Pointer to the user-defined attributes. * @return Returns 0 on success or an error code on failure */ -SHMEM_HOST_API int shmem_init_attr(shmem_init_attr_t *attributes); - -/** - * @brief Set the TLS private key and password, and register a decrypt key password handler. - * - * @param tls_pk the content of tls private key - * @param tls_pk_len length of tls private key - * @param tls_pk_pw the content of tls private key password - * @param tls_pk_pw_len length of tls private key password - * @param decrypt_handler decrypt function pointer - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_set_config_store_tls_key(const char *tls_pk, const uint32_t tls_pk_len, - const char *tls_pk_pw, const uint32_t tls_pk_pw_len, const shmem_decrypt_handler decrypt_handler); - -/** - * @brief Set the log print function for the SHMEM library. - * - * @param func the logging function, takes level and msg as parameter - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)); - -/** - * @brief Set the logging level. - * - * @param level the logging level. 0-debug, 1-info, 2-warn, 3-error - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_set_log_level(int level); - -/** - * @brief Initialize the config store tls info. - * - * @param enable whether to enable tls - * @param tls_info the format describle in memfabric SECURITYNOTE.md, if disabled tls_info won't be use - * @param tls_info_len length of tls_info, if disabled tls_info_len won't be use - * @return Returns 0 on success or an error code on failure - */ -SHMEM_HOST_API int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len); +SHMEM_HOST_API int shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *attributes); /** * @brief Release all resources used by the SHMEM library. @@ -158,11 +105,22 @@ SHMEM_HOST_API void shmem_info_get_version(int *major, int *minor); SHMEM_HOST_API void shmem_info_get_name(char *name); /** - * @brief exit all ranks. + * @brief Set the logging level. * - * @param status [IN] name + * @param level the logging level. 0-debug, 1-info, 2-warn, 3-error + * @return Returns 0 on success or an error code on failure */ +SHMEM_HOST_API int32_t shmem_set_log_level(int level); + +SHMEM_HOST_API int32_t shmem_set_config_store_tls_key(const char *tls_pk, const uint32_t tls_pk_len, + const char *tls_pk_pw, const uint32_t tls_pk_pw_len, const shmem_decrypt_handler decrypt_handler); + +SHMEM_HOST_API int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)); + SHMEM_HOST_API void shmem_global_exit(int status); + +SHMEM_HOST_API int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len); + #ifdef __cplusplus } #endif diff --git a/include/internal/device/shmemi_device_common.h b/include/internal/device/shmemi_device_common.h index 6cdb71e95315c8fc6a6a559f72a3487d64759af0..5ea1906f668b255cbd9a97d2eee29287fc8b8da6 100644 --- a/include/internal/device/shmemi_device_common.h +++ b/include/internal/device/shmemi_device_common.h @@ -13,14 +13,32 @@ #include "shmemi_device_arch.h" #include "shmemi_device_def.h" -#include "smem_shm_aicore_base_api.h" - constexpr int ub_limit = 192 * 1024; -SHMEM_DEVICE __gm__ shmemi_device_host_state_t *shmemi_get_state() -{ +#ifdef BACKEND_MF +#include "smem_shm_aicore_base_api.h" + +SHMEM_DEVICE __gm__ shmemi_device_host_state_t *shmemi_get_state() { return reinterpret_cast<__gm__ shmemi_device_host_state_t *>(smem_shm_get_extra_context_addr()); } +#else + +// rdma +constexpr uint64_t SMEM_SHM_DEVICE_PRE_META_SIZE = 128UL; // 128B +constexpr uint64_t SMEM_SHM_DEVICE_GLOBAL_META_SIZE = SMEM_SHM_DEVICE_PRE_META_SIZE; // 128B +constexpr uint64_t SMEM_OBJECT_NUM_MAX = 511UL; // entity最大数量 +constexpr uint64_t SMEM_SHM_DEVICE_META_SIZE = SMEM_SHM_DEVICE_PRE_META_SIZE * SMEM_OBJECT_NUM_MAX + + SMEM_SHM_DEVICE_GLOBAL_META_SIZE; // 64K + +constexpr uint64_t SMEM_SHM_DEVICE_USER_CONTEXT_PRE_SIZE = 64UL * 1024UL; // 64K +constexpr uint64_t SMEM_SHM_DEVICE_INFO_SIZE = SMEM_SHM_DEVICE_USER_CONTEXT_PRE_SIZE * SMEM_OBJECT_NUM_MAX + + SMEM_SHM_DEVICE_META_SIZE; // 元数据+用户context,总大小32M, 对齐2M +constexpr uint64_t SMEM_SHM_DEVICE_META_ADDR = SVM_END_ADDR - SMEM_SHM_DEVICE_INFO_SIZE; + +SHMEM_DEVICE __gm__ shmemi_device_host_state_t *shmemi_get_state() { + return reinterpret_cast<__gm__ shmemi_device_host_state_t *>((__gm__ void*)(SVM_END_ADDR - GLOBAL_STATE_SIZE)); +} +#endif SHMEM_DEVICE int shmemi_get_my_pe() { @@ -48,14 +66,4 @@ SHMEM_DEVICE T shmemi_load(__gm__ T *cache) { return *((__gm__ T *)cache); } - -template -SHMEM_DEVICE __gm__ T *shmemi_ptr(__gm__ T *local, int pe) -{ - uint64_t shm_size = shmemi_get_heap_size(); - int my_pe = shmemi_get_my_pe(); - - uint64_t remote = reinterpret_cast(local) + shm_size * (pe - my_pe); - return reinterpret_cast<__gm__ T*>(remote); -} #endif diff --git a/include/internal/device/sync/shmemi_device_barrier.h b/include/internal/device/sync/shmemi_device_barrier.h index 2b9caa629654459ac8b70ded4d752929b8b8b42c..2cae28247e86d79c66c7c3c20f58760b23ca8eb8 100644 --- a/include/internal/device/sync/shmemi_device_barrier.h +++ b/include/internal/device/sync/shmemi_device_barrier.h @@ -282,7 +282,7 @@ SHMEM_DEVICE void shmemi_barrier_npu_v3(shmemi_team_t *team) } else { // read remote int remote_pe = start + i * stride; - shmemi_signal_wait_until_eq_for_barrier((__gm__ int32_t *)shmemi_ptr(sync_array, remote_pe), count); + shmemi_signal_wait_until_eq_for_barrier((__gm__ int32_t *)shmem_ptr(sync_array, remote_pe), count); } } diff --git a/include/internal/device/sync/shmemi_device_p2p.h b/include/internal/device/sync/shmemi_device_p2p.h index 5d04477a6f75927d013424c7447624db7f3d8bd2..de0de0cf783ca5afac229fbb4aef5365c2670fe3 100644 --- a/include/internal/device/sync/shmemi_device_p2p.h +++ b/include/internal/device/sync/shmemi_device_p2p.h @@ -22,7 +22,7 @@ SHMEM_DEVICE void shmemi_signal_set(__gm__ int32_t *addr, int32_t val) SHMEM_DEVICE void shmemi_signal_set(__gm__ int32_t *addr, int pe, int32_t val) { - shmemi_signal_set(shmemi_ptr(addr, pe), val); + shmemi_signal_set((__gm__ int32_t *)shmem_ptr(addr, pe), val); } SHMEM_DEVICE void shmemi_highlevel_signal_set(__gm__ int32_t *dst, __gm__ int32_t *src, int pe) @@ -36,7 +36,7 @@ SHMEM_DEVICE void shmemi_highlevel_signal_set(__gm__ int32_t *dst, __gm__ int32_ ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - shmemi_roce_write((__gm__ uint8_t*)shmem_ptr(dst, pe), (__gm__ uint8_t*)src, pe, 0, sizeof(int32_t), + shmemi_roce_write((__gm__ uint8_t*)shmem_roce_ptr(dst, pe), (__gm__ uint8_t*)src, pe, 0, sizeof(int32_t), ub_tensor_64, ub_tensor_32); shmemi_roce_quiet(pe, 0, ub_tensor_64, ub_tensor_32); } @@ -49,7 +49,7 @@ SHMEM_DEVICE void shmemi_signal_add(__gm__ int32_t *addr, int pe, int32_t val) // atomic add set_st_atomic_cfg(ATOMIC_S32, ATOMIC_SUM); - st_atomic(val, shmemi_ptr(addr, pe)); + st_atomic(val, (__gm__ int32_t *)shmem_ptr(addr, pe)); dcci_atomic(); } diff --git a/include/internal/device/sync/shmemi_device_partial_barrier.h b/include/internal/device/sync/shmemi_device_partial_barrier.h index 555f59eb62e1b36ab96c906798c4d0e1c7ba39d2..970945ed75b79930eb9d17ac368ddc0e85ae5259 100644 --- a/include/internal/device/sync/shmemi_device_partial_barrier.h +++ b/include/internal/device/sync/shmemi_device_partial_barrier.h @@ -68,7 +68,7 @@ SHMEM_DEVICE void shmemi_partial_barrier_npu_v3(shmemi_team_t *team, shmemi_signal_set(slot_base, 1); } else { shmemi_signal_wait_until_eq_for_barrier( - (__gm__ int32_t *)shmemi_ptr(slot_base, (int)(remote_pe * stride + start)), 1); + (__gm__ int32_t *)shmem_ptr(slot_base, (int)(remote_pe * stride + start)), 1); } } } diff --git a/include/internal/host/shmemi_host_def.h b/include/internal/host/shmemi_host_def.h index 9d0eaec946ab8c1f9a9528d5c535b54b21b6229d..22adde7e9be2ef364ab0ee4d42bdd63ed8a8e283 100644 --- a/include/internal/host/shmemi_host_def.h +++ b/include/internal/host/shmemi_host_def.h @@ -17,21 +17,24 @@ typedef enum { ADDR_IPv4, ADDR_IPv6 -} shmem_addr_type_t; +} addr_type_t; +// shmem unique id typedef struct { union { - struct sockaddr_in addr4; - struct sockaddr_in6 addr6; + struct sockaddr sa; + struct sockaddr_in addr4; // IPv4地址(含端口) + struct sockaddr_in6 addr6; // IPv6地址(含端口) } addr; - shmem_addr_type_t type; -} shmem_sockaddr_t; + addr_type_t type; +} sockaddr_t; typedef struct { int32_t version; - int32_t inner_sockFd; - shmem_sockaddr_t addr; + int32_t inner_sockFd; // for mf backend + sockaddr_t addr; // 动态传入的地址(含端口) uint64_t magic; -} shmem_uniqueid_inner_t; - + int rank; + int nranks; +} shmemx_bootstrap_uid_state_t; #endif // SHMEMI_HOST_DEF_H \ No newline at end of file diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index 141531aa661eaacdf167fbcab67eeaa9d0774e8c..e033eae7666b5b51f6653f55c418d656c9185a31 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -52,6 +52,11 @@ extern "C" { #define SHMEM_EXTRA_SIZE_UNALIGHED (SYNC_POOL_SIZE + SHMEM_PARTIAL_BARRIER_POOL_SIZE) #define SHMEM_EXTRA_SIZE ALIGH_TO(SHMEM_EXTRA_SIZE_UNALIGHED, SHMEM_PAGE_SIZE) +// global_state +constexpr uint64_t DEVMM_SVM_MEM_START = 0x100000000000ULL; +constexpr uint64_t SVM_END_ADDR = 0x100000000000ULL + 0x80000000000ULL - (1UL << 30UL); // svm end +constexpr uint64_t GLOBAL_STATE_SIZE = 4UL * 1024UL * 1024UL; // global_state fixed length + // synchronization typedef int32_t shmemi_sync_bit[SHMEMI_SYNCBIT_SIZE / sizeof(int32_t)]; @@ -78,12 +83,16 @@ typedef struct { int npes; void *heap_base; - void **p2p_heap_host_base; - void **sdma_heap_host_base; - void **roce_heap_host_base; - void **p2p_heap_device_base; - void **sdma_heap_device_base; - void **roce_heap_device_base; + // Store All Devices' heap_base in Host. + void **host_p2p_heap_base; + void **host_rdma_heap_base; + void **host_sdma_heap_base; + + // Store All Devices' heap_base in Device. + void **device_p2p_heap_base; + void **device_rdma_heap_base; + void **device_sdma_heap_base; + uint8_t topo_list[SHMEM_MAX_RANKS]; size_t heap_size; @@ -104,6 +113,7 @@ typedef struct { bool is_shmem_created; shmemi_mte_config_t mte_config; + uint64_t qp_info; } shmemi_device_host_state_t; // host only state diff --git a/scripts/build.sh b/scripts/build.sh index 7bc8555c223422fb8088856e49eb0797479f4683..4aaf124c48eb4c8f5ff585688f6bdbd8b30fa406 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -316,6 +316,10 @@ while [[ $# -gt 0 ]]; do COMPILE_OPTIONS="${COMPILE_OPTIONS} -DUSE_EXAMPLES=ON" shift ;; + -enable_rdma) + COMPILE_OPTIONS="${COMPILE_OPTIONS} -DSHMEM_RDMA_SUPPORT=ON" + shift + ;; -python_extension) PYEXPAND_TYPE=ON shift diff --git a/scripts/run.sh b/scripts/run.sh index 7db18eaba9009dc845f1a068f245b975293189ac..8372b582059bb554b57016655c4c4fc5779849be 100644 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -23,6 +23,7 @@ rm -rf "$COVERAGE_PATH" set -e RANK_SIZE="8" IPPORT="tcp://127.0.0.1:8666" +SESSION_ID="127.0.0.1:8766" GNPU_NUM="8" FIRST_NPU="0" FIRST_RANK="0" @@ -67,7 +68,9 @@ while [[ $# -gt 0 ]]; do -ipport) if [ -n "$2" ]; then if [[ "$2" =~ ^[a-zA-z0-9.:/_-]+$ ]]; then - IPPORT="$2" + IPPORT="tcp://${2}" + SESSION_ID="${2}" + export SHMEM_UID_SESSION_ID=$SESSION_ID shift 2 else echo "Error: Invalid -ipport format, only alphanumeric and :/_- allowed" diff --git a/scripts/set_env.sh b/scripts/set_env.sh index 546f467f162b4e13fd35fabd6d2bc66fd112e25a..f8cd6e4e38f9c5a6f5e07d7e503ced704cd7842a 100644 --- a/scripts/set_env.sh +++ b/scripts/set_env.sh @@ -13,6 +13,7 @@ if [[ -f "$set_env_path" ]] && [[ "$(basename "$set_env_path")" == "set_env.sh" shmem_path=$(cd $(dirname $set_env_path); pwd) export SHMEM_HOME_PATH="$shmem_path" export LD_LIBRARY_PATH=$SHMEM_HOME_PATH/shmem/lib:$SHMEM_HOME_PATH/memfabric_hybrid/lib:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:$LD_LIBRARY_PATH export PATH=$SHMEM_HOME_PATH/bin:$PATH fi # 是否有python扩展 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 77bdee943ed61d27399cc0d527f953751584b3f7..424eec30ee03dd364c702d778b59d44a5067326a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,23 +10,34 @@ if (BUILD_PYTHON STREQUAL "ON") add_subdirectory(host/python_wrapper) endif () +set(SHMEM_MPI_SUPPORT OFF) + +if(SHMEM_MPI_SUPPORT) + find_package(MPI REQUIRED) +else() + find_package(MPI) + if(MPI_FOUND) + set(SHMEM_MPI_SUPPORT ON) + endif() +endif() + file(GLOB_RECURSE SHMEM_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) add_library(shmem_device OBJECT ${SHMEM_KERNEL_FILES}) target_compile_options(shmem_device PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=dav-c220) target_include_directories(shmem_device PUBLIC ${PROJECT_SOURCE_DIR}/include/ - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/device/ ) file(GLOB_RECURSE SHMEM_HOST_FILES ${CMAKE_CURRENT_SOURCE_DIR}/host/*.cpp) list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "python_wrapper") +list(FILTER SHMEM_HOST_FILES EXCLUDE REGEX "modules") + add_library(shmem_host OBJECT ${SHMEM_HOST_FILES}) target_compile_options(shmem_host PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) target_include_directories(shmem_host PUBLIC ${PROJECT_SOURCE_DIR}/include/ - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/include/smem/host/ ${PROJECT_SOURCE_DIR}/src/host ${PROJECT_SOURCE_DIR}/src/device ) @@ -34,11 +45,82 @@ target_include_directories(shmem_host add_library(shmem SHARED $ $) target_link_options(shmem PRIVATE --cce-fatobj-link) -target_link_libraries(shmem - PUBLIC - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_smem.so - ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib/libmf_hybm_core.so +set(SHMEM_MTE_SUPPORT ON) +if(SHMEM_MTE_SUPPORT) + add_library(shmem_transport_mte SHARED) + + target_sources(shmem_transport_mte PRIVATE + modules/transport/shmemi_mte.cpp) + + target_include_directories(shmem_transport_mte PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/host) + + set_target_properties(shmem_transport_mte PROPERTIES PREFIX "") + + install(TARGETS shmem_transport_mte + LIBRARY DESTINATION lib) +endif() + +# MPI +if(SHMEM_MPI_SUPPORT) + separate_arguments(SHMEM_CXX_LINK_FLAGS NATIVE_COMMAND "${MPI_CXX_LINK_FLAGS}") + target_link_options(shmem INTERFACE ${SHMEM_CXX_LINK_FLAGS}) + target_compile_definitions(shmem INTERFACE ${MPI_CXX_COMPILE_DEFINITIONS}) + target_compile_options(shmem INTERFACE ${MPI_CXX_COMPILE_OPTIONS}) + + add_library( + shmem_bootstrap_mpi SHARED + ) + target_sources(shmem_bootstrap_mpi PRIVATE modules/bootstrap/shmemi_bootstrap_mpi.cpp) + target_link_libraries(shmem_bootstrap_mpi PRIVATE MPI::MPI_CXX) + target_include_directories(shmem_bootstrap_mpi + PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/host + ) + set_target_properties(shmem_bootstrap_mpi PROPERTIES PREFIX "") + install(TARGETS shmem_bootstrap_mpi + LIBRARY DESTINATION lib + ) + +endif() +# UID +add_library(shmem_bootstrap_uid SHARED) + +target_sources(shmem_bootstrap_uid PRIVATE modules/bootstrap/socket/uid_socket.cpp + modules/bootstrap/shmemi_bootstrap_uid.cpp) +target_include_directories(shmem_bootstrap_uid + PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/host ) +set_target_properties(shmem_bootstrap_uid PROPERTIES PREFIX "") +install(TARGETS shmem_bootstrap_uid + LIBRARY DESTINATION lib) + + +if(SHMEM_RDMA_SUPPORT) + add_library( + shmem_transport_rdma SHARED + ) + target_sources(shmem_transport_rdma PRIVATE + modules/transport/shmemi_rdma.cpp + modules/transport/rdma/device_qp_manager.cpp + modules/transport/rdma/dl_hccp_api.cpp + ) + target_link_libraries(shmem_transport_rdma PRIVATE MPI::MPI_CXX) + target_include_directories(shmem_transport_rdma + PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/host + ${PROJECT_SOURCE_DIR}/src/modules + ) + set_target_properties(shmem_transport_rdma PROPERTIES PREFIX "") + install(TARGETS shmem_transport_rdma + LIBRARY DESTINATION lib + ) +endif() # 安装配置 install(TARGETS shmem diff --git a/src/host/bootstrap/shmemi_bootstrap.cpp b/src/host/bootstrap/shmemi_bootstrap.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e5ed2469f5fbf2d4240062464d42fbf46c68799 --- /dev/null +++ b/src/host/bootstrap/shmemi_bootstrap.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "shmemi_host_common.h" +#include "dlfcn.h" + +#define BOOTSTRAP_MODULE_MPI "shmem_bootstrap_mpi.so" +#define BOOTSTRAP_MODULE_UID "shmem_bootstrap_uid.so" + +#define BOOTSTRAP_PLUGIN_INIT_FUNC "shmemi_bootstrap_plugin_init" +#define BOOTSTRAP_PLUGIN_PREINIT_FUNC "shmemi_bootstrap_plugin_pre_init" + +shmemi_bootstrap_handle_t g_boot_handle; + +static void *plugin_hdl = nullptr; +static char *plugin_name = nullptr; + +int bootstrap_loader_finalize(shmemi_bootstrap_handle_t *handle) +{ + int status = handle->finalize(handle); + + if (status != 0) + SHM_LOG_ERROR("Bootstrap plugin finalize failed for " << plugin_name); + + dlclose(plugin_hdl); + plugin_hdl = nullptr; + free(plugin_name); + plugin_name = nullptr; + + return 0; +} + + + +void shmemi_bootstrap_loader() +{ + dlerror(); + if (plugin_hdl == nullptr) { + + plugin_hdl = dlopen(plugin_name, RTLD_NOW); + } + dlerror(); +} + +void shmemi_bootstrap_free() +{ + if (plugin_hdl != nullptr) { + dlclose(plugin_hdl); + plugin_hdl = nullptr; + } + + if (plugin_name != nullptr) { + free(plugin_name); + plugin_name = nullptr; + } +} + +// rank0 requires preloading uid.so to obtain the getuid capability +int32_t shmemi_bootstrap_pre_init(int flags, shmemi_bootstrap_handle_t *handle) { + int32_t status = SHMEM_SUCCESS; + + if (flags & SHMEMX_INIT_WITH_MPI) { + SHM_LOG_ERROR("Unsupport Type for bootstrap preinit."); + return SHMEM_INVALID_PARAM; + } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { + plugin_name = BOOTSTRAP_MODULE_UID; + } else { + SHM_LOG_ERROR("Unknown Type for bootstrap"); + status = SHMEM_INVALID_PARAM; + } + shmemi_bootstrap_loader(); + + if (!plugin_hdl) { + SHM_LOG_ERROR("Bootstrap unable to load " << plugin_name << ", err is: " << stderr); + shmemi_bootstrap_free(); + return SHMEM_INVALID_VALUE; + } + int (*plugin_pre_init)(shmemi_bootstrap_handle_t *); + *((void **)&plugin_pre_init) = dlsym(plugin_hdl, BOOTSTRAP_PLUGIN_PREINIT_FUNC); + if (!plugin_pre_init) { + SHM_LOG_ERROR("Bootstrap plugin init func dlsym failed"); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } + status = plugin_pre_init(&g_boot_handle); + if (status != 0) { + SHM_LOG_ERROR("Bootstrap plugin init failed for " << plugin_name); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } + return status; +} + +void remove_tcp_prefix_and_copy(const char* input, char* output, size_t output_len) { + memset(output, 0, output_len); + if (output_len == 0) return; + + if (input == nullptr || strlen(input) == 0) { + return; + } + + const char* prefix_tcp = "tcp://"; + const char* prefix_tcp6 = "tcp6://"; + size_t len_tcp = strlen(prefix_tcp); + size_t len_tcp6 = strlen(prefix_tcp6); + const char* result_ptr = input; + + if (strncmp(input, prefix_tcp, len_tcp) == 0) { + result_ptr = input + len_tcp; + } + else if (strncmp(input, prefix_tcp6, len_tcp6) == 0) { + result_ptr = input + len_tcp6; + } + + strncpy(output, result_ptr, output_len - 1); + output[output_len - 1] = '\0'; +} + +int32_t shmemi_bootstrap_init(int flags, shmem_init_attr_t *attr) { + int32_t status = SHMEM_SUCCESS; + void *arg; + g_boot_handle.use_attr_ipport= false; + if (flags & SHMEMX_INIT_WITH_DEFAULT){ + SHM_LOG_INFO("SHMEMX_INIT_WITH_DEFAULT"); + g_boot_handle.use_attr_ipport= true; + remove_tcp_prefix_and_copy(attr->ip_port, + g_boot_handle.ipport, + sizeof(g_boot_handle.ipport)); + plugin_name = BOOTSTRAP_MODULE_UID; + arg = (attr != NULL) ? attr->comm_args : NULL; + } else if (flags & SHMEMX_INIT_WITH_MPI) { + SHM_LOG_INFO("SHMEMX_INIT_WITH_MPI"); + plugin_name = BOOTSTRAP_MODULE_MPI; + arg = (attr != NULL) ? attr->comm_args : NULL; + } else if (flags & SHMEMX_INIT_WITH_UNIQUEID) { + SHM_LOG_INFO("SHMEMX_INIT_WITH_UNIQUEID"); + plugin_name = BOOTSTRAP_MODULE_UID; + arg = (attr != NULL) ? attr->comm_args : NULL; + } else { + SHM_LOG_ERROR("Unknown Type for bootstrap"); + status = SHMEM_INVALID_PARAM; + } + shmemi_bootstrap_loader(); + + if (!plugin_hdl) { + SHM_LOG_ERROR("Bootstrap unable to load " << plugin_name << ", err is: " << stderr); + shmemi_bootstrap_free(); + return SHMEM_INVALID_VALUE; + } + + int (*plugin_init)(void *, shmemi_bootstrap_handle_t *); + *((void **)&plugin_init) = dlsym(plugin_hdl, BOOTSTRAP_PLUGIN_INIT_FUNC); + if (!plugin_init) { + SHM_LOG_ERROR("Bootstrap plugin init func dlsym failed"); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } + SHM_LOG_INFO("plugin_init"); + status = plugin_init(arg, &g_boot_handle); + if (status != 0) { + SHM_LOG_ERROR("Bootstrap plugin init failed for " << plugin_name); + shmemi_bootstrap_free(); + return SHMEM_INNER_ERROR; + } + g_boot_handle.is_bootstraped = true; + return status; +} + +void shmemi_bootstrap_finalize() { + g_boot_handle.finalize(&g_boot_handle); + g_boot_handle.is_bootstraped = false; + dlclose(plugin_hdl); +} diff --git a/src/host/bootstrap/shmemi_bootstrap.h b/src/host/bootstrap/shmemi_bootstrap.h new file mode 100644 index 0000000000000000000000000000000000000000..c5e56e56802d7d70503d6cc5016abba2b17f5266 --- /dev/null +++ b/src/host/bootstrap/shmemi_bootstrap.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef SHMEMI_BOOTSTRAP_H +#define SHMEMI_BOOTSTRAP_H +#include "shmem_api.h" +#ifdef __cplusplus +extern "C" { +#endif + +int32_t shmemi_bootstrap_plugin_pre_init(shmemi_bootstrap_handle_t *handle); + +int32_t shmemi_bootstrap_pre_init(int flags, shmemi_bootstrap_handle_t *handle); + +int32_t shmemi_bootstrap_init(int flags, shmem_init_attr_t *attr); + +void shmemi_bootstrap_finalize(); + +int shmemi_bootstrap_plugin_init(void *mpi_comm, shmemi_bootstrap_handle_t *handle); + +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/src/host/common/shmemi_functions.h b/src/host/common/shmemi_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..0d6b850d4666b05b9cf2cfb1b31aef24b48299e2 --- /dev/null +++ b/src/host/common/shmemi_functions.h @@ -0,0 +1,93 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEM_SHM_FUNCTION_H +#define SHMEM_SHM_FUNCTION_H + +#include +#include "shmemi_logger.h" + +namespace shm { +class funci { +public: + /** + * @brief Get real path + * + * @param path [in/out] input path, converted realpath + * @return true if successful + */ + static bool get_real_path(std::string &path); + + /** + * @brief Get real path of a library and check if exists + * + * @param lib_dir_path [in] dir path of the library + * @param lib_name [in] library name + * @param real_path [out] realpath of the library + * @return true if successful + */ + static bool get_library_real_path(const std::string &lib_dir_path, const std::string &lib_name, + std::string &real_path); +}; + +inline bool funci::get_real_path(std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + SHM_LOG_ERROR("Failed to get realpath, path is invalid"); + return false; + } + + /* It will allocate memory to store path */ + char *real_path = realpath(path.c_str(), nullptr); + if (real_path == nullptr) { + SHM_LOG_ERROR("Failed to get realpath, error " << errno); + return false; + } + + path = real_path; + free(real_path); + real_path = nullptr; + return true; +} + +inline bool funci::get_library_real_path(const std::string &lib_dir_path, const std::string &lib_name, + std::string &real_path) +{ + std::string tmp_full_path = lib_dir_path; + if (!get_real_path(tmp_full_path)) { + return false; + } + + if (tmp_full_path.back() != '/') { + tmp_full_path.push_back('/'); + } + + tmp_full_path.append(lib_name); + auto ret = ::access(tmp_full_path.c_str(), F_OK); + if (ret != 0) { + SHM_LOG_ERROR(tmp_full_path << " cannot be accessed, ret: " << ret); + return false; + } + + real_path = tmp_full_path; + return true; +} + +#define DL_LOAD_SYM(TARGET_FUNC_VAR, TARGET_FUNC_TYPE, FILE_HANDLE, SYMBOL_NAME) \ + do { \ + (TARGET_FUNC_VAR) = (TARGET_FUNC_TYPE)dlsym((FILE_HANDLE), (SYMBOL_NAME)); \ + if ((TARGET_FUNC_VAR) == nullptr) { \ + SHM_LOG_ERROR("Failed to call dlsym to load SYMBOL_NAME, error" << dlerror()); \ + dlclose((FILE_HANDLE)); \ + return false; \ + } \ + } while (0) +} // namespace shm + +#endif // SHMEM_SHM_FUNCTION_H diff --git a/src/host/common/shmemi_host_types.h b/src/host/common/shmemi_host_types.h new file mode 100644 index 0000000000000000000000000000000000000000..45ef8aebf57a3b9b518d6702bc9c2664cd13704e --- /dev/null +++ b/src/host/common/shmemi_host_types.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEMI_HOST_TYPES_H +#define SHMEMI_HOST_TYPES_H + +#define SHMEM_MAX_TRANSPORT_NUM 16 + +#include "internal/host_device/shmemi_types.h" +#define SHMEM_MAX_HANDLE_IP_PORT_LEN 64 +typedef struct shmemi_bootstrap_attr { + shmemi_bootstrap_attr() : initialize_mf(0), mpi_comm(NULL), uid_args(NULL) + {} + int initialize_mf; + void *mpi_comm; + void *mete_data; + void *uid_args; +} shmemi_bootstrap_attr_t; + +typedef struct shmemi_bootstrap_init_ops { + void *cookie; + int (*get_unique_id)(void *cookit); + int (*get_unique_id_static_magic)(void *uid_info, bool is_root); +} shmemi_bootstrap_init_ops_t; + +typedef struct shmemi_bootstrap_handle { + int32_t mype, npes; + void *bootstrap_state; + + int (*finalize)(shmemi_bootstrap_handle *boot_handle); + int (*allgather)(const void *sendbuf, void *recvbuf, int size, shmemi_bootstrap_handle *boot_handle); + int (*barrier)(shmemi_bootstrap_handle *boot_handle); + int (*alltoall)(const void *sendbuf, void *recvbuf, int size, shmemi_bootstrap_handle *boot_handle); + void (*global_exit)(int status); + shmemi_bootstrap_init_ops_t *pre_init_ops; + bool is_bootstraped = false; + char ipport[SHMEM_MAX_HANDLE_IP_PORT_LEN]; + bool use_attr_ipport = false; +} shmemi_bootstrap_handle_t; + +typedef struct shmemi_bootstrap_mpi_options { + // TBD +} shmemi_bootstrap_mpi_options_t; + +typedef struct shmemi_bootstrap_uid_options { + // TBD +} shmemi_bootstrap_uid_options_t; + +typedef struct shmemi_transport_pe_info { + int32_t pe; + int32_t dev_id; + int64_t server_id; + int64_t superpod_id; +} shmemi_transport_pe_info_t; + +typedef struct shmemi_transport { + // control plane + int (*can_access_peer)(int *access, shmemi_transport_pe_info_t *peer_info, + shmemi_transport_pe_info_t *my_info, struct shmemi_transport *t); + int (*connect_peers)(struct shmemi_transport *t, int *selected_dev_ids, + int num_selected_devs, shmemi_device_host_state_t *g_state); + int (*finalize)(struct shmemi_transport *t, + shmemi_device_host_state_t *g_state); + + // data plane, TBD + void (*rma)(struct shmemi_transport *t, int32_t type, void *dst, void *src, size_t size, int32_t pe); + void (*amo)(struct shmemi_transport *t, int32_t type, void *dst, void *src, size_t size, int32_t pe); + void (*quiet)(struct shmemi_transport *t); + void (*fence)(struct shmemi_transport *t); + int32_t logical_dev_id; + int32_t dev_id; +} shmemi_transport_t; + +typedef struct { + int32_t pe, npes; + + shmemi_bootstrap_mpi_options_t mpi_options; + shmemi_bootstrap_uid_options_t uid_options; + + // other options + bool rdma_enabled; +} shmemi_options_t; + +// host only state +typedef struct { + // typedef void *aclrtStream; as in https://www.hiascend.com/document/detail/zh/canncommercial/80RC3/apiref/appdevgapi/aclcppdevg_03_1355.html + void *default_stream; + // using TEventID = int8_t; as in https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/800alpha003/apiref/ascendcopapi/atlasascendc_api_07_0181.html + int8_t default_event_id; + uint32_t default_block_num; + + // topo + int32_t *transport_map; /* npes * npes, 2D-Array, point-to-connectivity. */ + shmemi_transport_pe_info *pe_info; /* All pe's host info, need to build transports. */ + + shmemi_options_t options; + + shmemi_bootstrap_handle_t *boot_handle; + shmemi_transport_t choosen_transports[SHMEM_MAX_TRANSPORT_NUM]; + int32_t num_choosen_transport; +} shmemi_host_trans_state_t; + +extern shmemi_bootstrap_handle_t g_boot_handle; +extern shmemi_host_trans_state_t g_host_state; +#endif // SHMEMI_HOST_TYPES_H \ No newline at end of file diff --git a/src/host/common/shmemi_logger.h b/src/host/common/shmemi_logger.h index 3bbe0c4c0592ecf1965a4897d5a6ba841fe76f04..49488c18eddef48072ed56bf0a28a4d5c668ca8f 100644 --- a/src/host/common/shmemi_logger.h +++ b/src/host/common/shmemi_logger.h @@ -119,11 +119,11 @@ private: #ifndef SHM_LOG_FILENAME_SHORT #define SHM_LOG_FILENAME_SHORT (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) #endif -#define SHM_OUT_LOG(LEVEL, ARGS) \ - do { \ - std::ostringstream oss; \ - oss << "[SHM_SHMEM " << SHM_LOG_FILENAME_SHORT << ":" << __LINE__ << "] " << ARGS; \ - shm::shm_out_logger::Instance().log(LEVEL, oss); \ +#define SHM_OUT_LOG(LEVEL, ARGS) \ + do { \ + std::ostringstream oss; \ + oss << "[SHM_SHMEM " << SHM_LOG_FILENAME_SHORT << ":" << __LINE__ << "] " << ARGS; \ + shm::shm_out_logger::Instance().log(LEVEL, oss); \ } while (0) #define SHM_LOG_DEBUG(ARGS) SHM_OUT_LOG(shm::DEBUG_LEVEL, ARGS) @@ -173,18 +173,46 @@ private: } \ } while (0) -#define SHMEM_CHECK_RET(x, ...) \ +#define SHMEM_CHECK(x) \ do { \ int32_t check_ret = x; \ if (check_ret != 0) { \ - if (sizeof(#__VA_ARGS__) > 1) { \ - SHM_LOG_ERROR(" return shmem error: " << check_ret << " - " \ - << #__VA_ARGS__ << " failed. More error information can be found in plog"); \ - } else { \ - SHM_LOG_ERROR(" return shmem error: " << check_ret); \ - } \ + SHM_LOG_ERROR(" return shmem error: " << check_ret << " - " << #x << " failed."); \ + return ; \ + } \ + } while (0) + +#define SHMEM_CHECK_RET(...) \ + _SHMEM_CHECK_RET_HELPER(__VA_ARGS__, _SHMEM_CHECK_RET_WITH_LOG_AND_ERR_CODE, _SHMEM_CHECK_RET_WITH_LOG, _SHMEM_CHECK_RET)(__VA_ARGS__) + +#define _SHMEM_CHECK_RET(x) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" return shmem error: " << check_ret << " - " << #x << " failed."); \ return check_ret; \ } \ } while (0) +#define _SHMEM_CHECK_RET_WITH_LOG(x, LOG_STR) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << LOG_STR << " return shmem error: " << check_ret); \ + return check_ret; \ + } \ + } while (0) + +#define _SHMEM_CHECK_RET_WITH_LOG_AND_ERR_CODE(x, LOG_STR, ERR_CODE) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << LOG_STR << " return shmem error: " << ERR_CODE); \ + return ERR_CODE; \ + } \ + } while (0) + +#define _SHMEM_CHECK_RET_HELPER(_1, _2, _3, FUNC, ...) FUNC + + #endif // SHMEM_SHM_OUT_LOGGER_H diff --git a/src/host/init/init_backends/default/shmemi_init_default.cpp b/src/host/init/init_backends/default/shmemi_init_default.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7241aad97285cbd2bd32de8f9cae897c7a20aeb6 --- /dev/null +++ b/src/host/init/init_backends/default/shmemi_init_default.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "shmemi_init_default.h" +#include "common/shmemi_logger.h" + +shmemi_init_default::shmemi_init_default(shmem_init_attr_t *attr, shmemi_device_host_state_t *global_state) +{ + mype = attr->my_rank; + npes = attr->n_ranks; + option_attr_ = attr->option_attr; + g_state = global_state; + + auto status = aclrtGetDevice(&device_id); + if (status != 0) { + SHM_LOG_ERROR("Get Device_id error"); + } +} + +shmemi_init_default::~shmemi_init_default() +{} + +int shmemi_init_default::init_device_state() +{ + global_state_d = new global_state_reigister(device_id); + if (global_state_d->get_init_status() != 0) { + SHM_LOG_ERROR("global_state reigister error"); + } + return SHMEM_SUCCESS; +} + +int shmemi_init_default::finalize_device_state() +{ + delete global_state_d; + return SHMEM_SUCCESS; +} + +int shmemi_init_default::update_device_state(void* host_ptr, size_t size) +{ + int32_t ptr_size = npes * sizeof(void *); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_p2p_heap_base, ptr_size, g_state->host_p2p_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_rdma_heap_base, ptr_size, g_state->host_rdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_sdma_heap_base, ptr_size, g_state->host_sdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + + SHMEM_CHECK_RET(aclrtMemcpy(global_state_d->get_ptr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE)); + return SHMEM_SUCCESS; +} + +int shmemi_init_default::reserve_heap() +{ + heap_obj = new shmem_symmetric_heap(mype, npes, device_id); + + SHMEM_CHECK_RET(heap_obj->reserve_heap(g_state->heap_size)); + + g_state->heap_base = heap_obj->get_heap_base(); + + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_p2p_heap_base, npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_rdma_heap_base, npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_sdma_heap_base, npes * sizeof(void *))); + + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_p2p_heap_base, npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_rdma_heap_base, npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_sdma_heap_base, npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + return SHMEM_SUCCESS; +} + +int shmemi_init_default::setup_heap() +{ + SHMEM_CHECK_RET(heap_obj->setup_heap()); + + for (int32_t i = 0; i < g_state->npes; i++) { + g_state->host_p2p_heap_base[i] = heap_obj->get_peer_heap_base_p2p(i); + } + g_state->is_shmem_created = true; + + return SHMEM_SUCCESS; +} + +int shmemi_init_default::remove_heap() +{ + SHMEM_CHECK_RET(heap_obj->remove_heap()); + return SHMEM_SUCCESS; +} + +int shmemi_init_default::release_heap() +{ + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_p2p_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_rdma_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_sdma_heap_base)); + } + if (g_state->device_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_p2p_heap_base)); + } + if (g_state->device_rdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_rdma_heap_base)); + } + if (g_state->device_sdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_sdma_heap_base)); + } + SHMEM_CHECK_RET(heap_obj->unreserve_heap()); + return SHMEM_SUCCESS; +} + +int shmemi_init_default::transport_init() +{ + SHMEM_CHECK_RET(shmemi_transport_init(g_state, option_attr_)); // mte init && rdma init + SHMEM_CHECK_RET(shmemi_build_transport_map(g_state)); // build transport_map + SHMEM_CHECK_RET(shmemi_transport_setup_connections(g_state)); // connect_endpoints by transpost_map + return SHMEM_SUCCESS; +} + +int shmemi_init_default::transport_finalize() +{ + SHMEM_CHECK_RET(shmemi_transport_finalize(g_state)); + return SHMEM_SUCCESS; +} + +int32_t shmemi_control_barrier_all_default(shmemi_bootstrap_handle_t boot_handle) +{ + SHMEM_CHECK_RET((boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + return boot_handle.barrier(&boot_handle); +} \ No newline at end of file diff --git a/src/host/init/init_backends/default/shmemi_init_default.h b/src/host/init/init_backends/default/shmemi_init_default.h new file mode 100644 index 0000000000000000000000000000000000000000..16b294519643f020e026b3a50093b727d10fd9c2 --- /dev/null +++ b/src/host/init/init_backends/default/shmemi_init_default.h @@ -0,0 +1,60 @@ + +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEMI_INIT_NORMAL_H +#define SHMEMI_INIT_NORMAL_H + +#include + +#include "init/init_backends/shmemi_init_base.h" + +#include "host/shmem_host_def.h" +#include "internal/host_device/shmemi_types.h" + +#include "mem/shmemi_global_state.h" +#include "mem/shmemi_heap.h" + +#include "bootstrap/shmemi_bootstrap.h" + +#include "transport/shmemi_transport.h" + +class shmemi_init_default: public shmemi_init_base { +public: + shmemi_init_default(shmem_init_attr_t *attr, shmemi_device_host_state_t *global_state); + ~shmemi_init_default(); + + int init_device_state() override; + int finalize_device_state() override; + int update_device_state(void* host_ptr, size_t size) override; + + int reserve_heap() override; + int setup_heap() override; + int remove_heap() override; + int release_heap() override; + + int transport_init() override; + int transport_finalize() override; +private: + int mype; + int npes; + int device_id; + shmemi_device_host_state_t *g_state; + + // global_state + global_state_reigister *global_state_d = nullptr; + + // heap_obj + shmem_symmetric_heap *heap_obj = nullptr; + shmem_init_optional_attr_t option_attr_; +}; + +int32_t shmemi_control_barrier_all_default(shmemi_bootstrap_handle_t boot_handle); + +#endif // SHMEMI_INIT_NORMAL_H \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.cpp b/src/host/init/init_backends/mf/shmemi_init_mf.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3652ebfb5f5c7cdaef7ad904d3f0b7625344fd78 --- /dev/null +++ b/src/host/init/init_backends/mf/shmemi_init_mf.cpp @@ -0,0 +1,590 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "shmemi_init_mf.h" + +#ifdef BACKEND_MF +#include +#include +#include +#include +#include "internal/host/shmemi_host_def.h" + + +constexpr int MIN_PORT = 1024; +constexpr int MAX_PORT = 65536; +constexpr int MAX_ATTEMPTS = 1000; +constexpr int MAX_IFCONFIG_LENGTH = 23; +constexpr int MAX_IP = 48; +constexpr int DEFAULT_IFNAME_LNEGTH = 4; + +constexpr int DEFAULT_FLAG = 0; +constexpr int DEFAULT_ID = 0; +constexpr int DEFAULT_TIMEOUT = 120; +constexpr int DEFAULT_TEVENT = 0; +constexpr int DEFAULT_BLOCK_NUM = 1; + +// smem need +static smem_shm_t g_smem_handle = nullptr; +static char g_ipport[SHMEM_MAX_IP_PORT_LEN] = {0}; + +shmemi_init_mf::shmemi_init_mf(shmem_init_attr_t *attr, char *ipport, shmemi_device_host_state_t *global_state) +{ + attributes = attr; + g_ipport = ipport; + g_state = global_state; + + aclrtGetDevice(&device_id); + smem_set_conf_store_tls(false, nullptr, 0); + + int32_t status = smem_init(DEFAULT_FLAG); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_init Failed"); + } +} + +shmemi_init_mf::~shmemi_init_mf() +{} + +int shmemi_init_mf::init_device_state() +{ + int32_t status = SHMEM_SUCCESS; + smem_shm_config_t config; + status = smem_shm_config_init(&config); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_config_init Failed"); + return SHMEM_SMEM_ERROR; + } + config.sockFd = attributes->option_attr.sockFd; + status = smem_shm_init(attributes->ip_port, attributes->n_ranks, attributes->my_rank, device_id, &config); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_init Failed"); + return SHMEM_SMEM_ERROR; + } + + config.shmInitTimeout = attributes->option_attr.shm_init_timeout; + config.shmCreateTimeout = attributes->option_attr.shm_create_timeout; + config.controlOperationTimeout = attributes->option_attr.control_operation_timeout; + + return SHMEM_SUCCESS; +} + +int shmemi_init_mf::update_device_state(void* host_ptr, size_t size) +{ + if (g_smem_handle == nullptr) { + SHM_LOG_ERROR("smem_shm_create Not Success, update_device_state Failed"); + return SHMEM_SMEM_ERROR; + } + int32_t ptr_size = g_state->npes * sizeof(void *); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_p2p_heap_base, ptr_size, g_state->host_p2p_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_rdma_heap_base, ptr_size, g_state->host_rdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + SHMEM_CHECK_RET(aclrtMemcpy(g_state->device_sdma_heap_base, ptr_size, g_state->host_sdma_heap_base, ptr_size, ACL_MEMCPY_HOST_TO_DEVICE)); + + SHMEM_CHECK_RET(smem_shm_set_extra_context(g_smem_handle, host_ptr, size)); + return SHMEM_SUCCESS; +} + +int shmemi_init_mf::finalize_device_state() +{ + // dummy function + return SHMEM_SUCCESS; +} + +int shmemi_init_mf::reserve_heap() +{ + int32_t status = SHMEM_SUCCESS; + void *gva = nullptr; + + g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state->heap_size, + static_cast(attributes->option_attr.data_op_engine_type), + DEFAULT_FLAG, &gva); + + if (g_smem_handle == nullptr || gva == nullptr) { + SHM_LOG_ERROR("smem_shm_create Failed"); + return SHMEM_SMEM_ERROR; + } + g_state->heap_base = (void *)((uintptr_t)gva + g_state->heap_size * attributes->my_rank); + + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_p2p_heap_base, g_state->npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_rdma_heap_base, g_state->npes * sizeof(void *))); + SHMEM_CHECK_RET(aclrtMallocHost((void **)&g_state->host_sdma_heap_base, g_state->npes * sizeof(void *))); + + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_p2p_heap_base, g_state->npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_rdma_heap_base, g_state->npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + SHMEM_CHECK_RET(aclrtMalloc((void **)&g_state->device_sdma_heap_base, g_state->npes * sizeof(void *), ACL_MEM_MALLOC_HUGE_FIRST)); + + uint32_t reach_info = 0; + for (int32_t i = 0; i < g_state->npes; i++) { + status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_topology_can_reach failed"); + } + g_state->host_p2p_heap_base[i] = (void *)((uintptr_t)gva + g_state->heap_size * i); + if (reach_info & SMEMS_DATA_OP_MTE) { + g_state->topo_list[i] |= SHMEM_TRANSPORT_MTE; + } + if (reach_info & SMEMS_DATA_OP_SDMA) { + g_state->host_sdma_heap_base[i] = (void *)((uintptr_t)gva + g_state->heap_size * i); + } else { + g_state->host_sdma_heap_base[i] = NULL; + } + if (reach_info & SMEMS_DATA_OP_RDMA) { + g_state->host_rdma_heap_base[i] = g_state->host_p2p_heap_base[i]; + g_state->topo_list[i] |= SHMEM_TRANSPORT_ROCE; + } + } + if (g_ipport[0] != '\0') { + g_ipport[0] = '\0'; + bzero(attributes->ip_port, sizeof(attributes->ip_port)); + } else { + SHM_LOG_WARN("my_rank:" << attributes->my_rank << " g_ipport is released in advance!"); + bzero(attributes->ip_port, sizeof(attributes->ip_port)); + } + g_state->is_shmem_created = true; + return status; +} + +int shmemi_init_mf::setup_heap() +{ + int32_t status = SHMEM_SUCCESS; + return status; +} + +int shmemi_init_mf::remove_heap() +{ + int32_t status = SHMEM_SUCCESS; + return status; +} + +int shmemi_init_mf::release_heap() +{ + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_p2p_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_rdma_heap_base)); + } + if (g_state->host_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFreeHost(g_state->host_sdma_heap_base)); + } + if (g_state->device_p2p_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_p2p_heap_base)); + } + if (g_state->device_rdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_rdma_heap_base)); + } + if (g_state->device_sdma_heap_base != nullptr) { + SHMEM_CHECK_RET(aclrtFree(g_state->device_sdma_heap_base)); + } + + if (g_smem_handle != nullptr) { + int32_t status = smem_shm_destroy(g_smem_handle, 0); + if (status != SHMEM_SUCCESS) { + SHM_LOG_ERROR("smem_shm_destroy Failed"); + return SHMEM_SMEM_ERROR; + } + g_smem_handle = nullptr; + } + smem_shm_uninit(0); + smem_uninit(); + return SHMEM_SUCCESS; +} + +int shmemi_init_mf::transport_init() +{ + return SHMEM_SUCCESS; +} + +int shmemi_init_mf::transport_finalize() +{ + return SHMEM_SUCCESS; +} + +int32_t shmem_get_uid_magic(shmemx_bootstrap_uid_state_t *innerUId) +{ + std::ifstream urandom("/dev/urandom", std::ios::binary); + if (!urandom) { + SHM_LOG_ERROR("open random failed"); + return SHMEM_INNER_ERROR; + } + + urandom.read(reinterpret_cast(&innerUId->magic), sizeof(innerUId->magic)); + if (urandom.fail()) { + SHM_LOG_ERROR("read random failed."); + return SHMEM_INNER_ERROR; + } + SHM_LOG_DEBUG("init magic id to " << innerUId->magic); + return SHMEM_SUCCESS; +} + +int32_t bind_tcp_port_v4(int &sockfd, int port, shmemx_bootstrap_uid_state_t *innerUId, char *ip_str) +{ + sockfd = ::socket(AF_INET, SOCK_STREAM, 0); + if (sockfd != -1) { + int on_v4 = 1; + if (::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on_v4, sizeof(on_v4)) == 0) { + innerUId->addr.addr.addr4.sin_port = htons(port); + sockaddr *cur_addr = reinterpret_cast(&innerUId->addr.addr.addr4); + if (::bind(sockfd, cur_addr, sizeof(innerUId->addr.addr.addr4)) == 0) { + SHM_LOG_INFO("bind ipv4 success " << ", fd:" << sockfd << ", " << ip_str << ":" << port); + return 0; + } else { + SHM_LOG_ERROR("bind socket fail:" << errno << "," << ip_str << ":" << port); + } + } else { + SHM_LOG_ERROR("set socket opt fail:" << errno << "," << ip_str << ":" << port); + } + close(sockfd); + sockfd = -1; + } else { + SHM_LOG_ERROR("create socket fail:" << errno << ", " << ip_str << ":" << port); + } + return -1; +} + +int32_t bind_tcp_port_v6(int &sockfd, int port, shmemx_bootstrap_uid_state_t *innerUId, char *ip_str) +{ + sockfd = ::socket(AF_INET6, SOCK_STREAM, 0); + if (sockfd != -1) { + int on_v6 = 1; + if (::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on_v6, sizeof(on_v6)) == 0) { + innerUId->addr.addr.addr6.sin6_port = htons(port); + sockaddr *cur_addr = reinterpret_cast(&innerUId->addr.addr.addr6); + if (::bind(sockfd, cur_addr, sizeof(innerUId->addr.addr.addr6)) == 0) { + SHM_LOG_INFO("bind ipv6 success " << ", fd:" << sockfd << ", " << ip_str << ":" << port); + return 0; + } else { + SHM_LOG_ERROR("bind socket6 fail:" << errno << "," << ip_str << ":" << port); + } + } else { + SHM_LOG_ERROR("set socket6 opt fail:" << errno << "," << ip_str << ":" << port); + } + close(sockfd); + sockfd = -1; + } else { + SHM_LOG_ERROR("create socket6 fail:" << errno << "," << ip_str << ":" << port); + } + return -1; +} + +int32_t shmem_get_port_magic(shmemx_bootstrap_uid_state_t *innerUId, char *ip_str) +{ + static std::random_device rd; + const int min_port = MIN_PORT; + const int max_port = MAX_PORT; + const int max_attempts = MAX_ATTEMPTS; + const int offset_bit = 32; + uint64_t seed = 1; + seed |= static_cast(getpid()) << offset_bit; + seed |= static_cast(std::chrono::system_clock::now().time_since_epoch().count() & 0xFFFFFFFF); + static std::mt19937_64 gen(seed); + std::uniform_int_distribution<> dis(min_port, max_port); + + int sockfd = -1; + int32_t ret; + for (int attempt = 0; attempt < max_attempts; ++attempt) { + int port = dis(gen); + if (innerUId->addr.type == ADDR_IPv4) { + ret = bind_tcp_port_v4(sockfd, port, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } else { + ret = bind_tcp_port_v6(sockfd, port, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } + } + SHM_LOG_ERROR("Not find a available tcp port"); + return -1; +} + +int32_t shmem_using_env_port(shmemx_bootstrap_uid_state_t *innerUId, char *ip_str, uint16_t envPort) +{ + if (envPort < MIN_PORT) { // envPort > MAX_PORT always false + SHM_LOG_ERROR("env port is invalid. " << envPort); + return SHMEM_INVALID_PARAM; + } + + int sockfd = -1; + int32_t ret; + if (innerUId->addr.type == ADDR_IPv4) { + ret = bind_tcp_port_v4(sockfd, envPort, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } else { + ret = bind_tcp_port_v6(sockfd, envPort, innerUId, ip_str); + if (ret == 0) { + innerUId->inner_sockFd = sockfd; + return 0; + } + } + SHM_LOG_ERROR("init with env port fialed " << envPort << ", ret=" << ret); + return ret; +} + +int32_t ParseInterfaceWithType(const char *ipInfo, char *IP, sa_family_t &sockType, bool &flag) +{ + const char *delim = ":"; + const char *sep = strchr(ipInfo, delim[0]); + if (sep != nullptr) { + size_t leftLen = sep - ipInfo; + if (leftLen >= MAX_IFCONFIG_LENGTH - 1 || leftLen == 0) { + return SHMEM_INVALID_VALUE; + } + strncpy(IP, ipInfo, leftLen); + IP[leftLen] = '\0'; + sockType = (strcmp(sep + 1, "inet6") != 0) ? AF_INET : AF_INET6; + flag = true; + } + return SHMEM_SUCCESS; +} + +int32_t shmem_auto_get_ip(struct sockaddr *ifaAddr, char *local, sa_family_t &sockType) +{ + sockType = ifaAddr->sa_family; + if (sockType == AF_INET) { + auto localIp = reinterpret_cast(ifaAddr)->sin_addr; + if (inet_ntop(sockType, &localIp, local, MAX_IP) == nullptr) { + SHM_LOG_ERROR("convert local ipv4 to string failed. "); + return SHMEM_INVALID_PARAM; + } + return SHMEM_SUCCESS; + } else if (sockType == AF_INET6) { + auto localIp = reinterpret_cast(ifaAddr)->sin6_addr; + if (inet_ntop(sockType, &localIp, local, MAX_IP) == nullptr) { + SHM_LOG_ERROR("convert local ipv6 to string failed. "); + return SHMEM_INVALID_PARAM; + } + return SHMEM_SUCCESS; + } + return SHMEM_INVALID_PARAM; +} + +bool shmem_check_ifa(struct ifaddrs *ifa, sa_family_t sockType, bool flag, char *ifaName, size_t ifaLen) +{ + if (ifa->ifa_addr == nullptr || ifa->ifa_netmask == nullptr || ifa->ifa_name == nullptr) { + SHM_LOG_DEBUG("loop ifa_addr/ifa_netmask/ifa_name is nullptr"); + return false; + } + + // socket type match and input env ifa valid + if (ifa->ifa_addr->sa_family != sockType && flag) { + SHM_LOG_DEBUG("sa family is not match, get " << ifa->ifa_addr->sa_family << ", expect " << sockType); + return false; + } + + // prefix match with input ifa name + if (strncmp(ifa->ifa_name, ifaName, ifaLen) != 0) { + SHM_LOG_DEBUG("ifa name prefix un-match, get " << ifa->ifa_name << ", expect " << ifaName); + return false; + } + + // ignore ifa which is down or loopback or not running + if ((ifa->ifa_flags & IFF_LOOPBACK) || !(ifa->ifa_flags & IFF_RUNNING) || !(ifa->ifa_flags & IFF_UP)) { + SHM_LOG_DEBUG("ifa flag un-match, flag=" << ifa->ifa_flags); + return false; + } + + if (sockType == AF_INET6) { + struct sockaddr_in6 *sa6 = reinterpret_cast(ifa->ifa_addr); + if (IN6_IS_ADDR_LINKLOCAL(&sa6->sin6_addr)) { + SHM_LOG_DEBUG("ifa is scope link addr " << ifaName); + return false; + } + } + return true; +} + +int32_t shmem_get_ip_from_ifa(char *local, sa_family_t &sockType, const char *ipInfo) +{ + struct ifaddrs *ifaddr; + char ifaName[MAX_IFCONFIG_LENGTH]; + sockType = AF_INET; + bool flag = false; + if (ipInfo == nullptr) { + strncpy(ifaName, "eth", DEFAULT_IFNAME_LNEGTH); + ifaName[DEFAULT_IFNAME_LNEGTH - 1] = '\0'; + SHM_LOG_INFO("use default if to find IP:" << ifaName); + } else if (ParseInterfaceWithType(ipInfo, ifaName, sockType, flag) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("IP size set in SHMEM_CONF_STORE_MASTER_IF format has wrong length"); + return SHMEM_INVALID_PARAM; + } + if (getifaddrs(&ifaddr) == -1) { + SHM_LOG_ERROR("get local net interfaces failed: " << errno); + return SHMEM_INVALID_PARAM; + } + int32_t result = SHMEM_INVALID_PARAM; + for (auto ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { + if (!shmem_check_ifa(ifa, sockType, flag, ifaName, strlen(ifaName))) { + continue; + } + if (sockType == AF_INET && flag) { + auto localIp = reinterpret_cast(ifa->ifa_addr)->sin_addr; + if (inet_ntop(sockType, &localIp, local, 64) == nullptr) { + SHM_LOG_ERROR("convert local ipv4 to string failed. "); + continue; + } + result = SHMEM_SUCCESS; + break; + } else if (sockType == AF_INET6 && flag) { + auto localIp = reinterpret_cast(ifa->ifa_addr)->sin6_addr; + if (inet_ntop(sockType, &localIp, local, 64) == nullptr) { + SHM_LOG_ERROR("convert local ipv6 to string failed. "); + continue; + } + result = SHMEM_SUCCESS; + break; + } else { + auto ret = shmem_auto_get_ip(ifa->ifa_addr, local, sockType); + if (ret != SHMEM_SUCCESS) { + continue; + } + result = SHMEM_SUCCESS; + break; + } + } + freeifaddrs(ifaddr); + return result; +} + +int32_t shmem_get_ip_from_env(char *ip, uint16_t &port, sa_family_t &sockType, const char *ipPort) +{ + if (ipPort != nullptr) { + SHM_LOG_DEBUG("get env SHMEM_UID_SESSION_ID value:" << ipPort); + std::string ipPortStr = ipPort; + + if (ipPort[0] == '[') { + sockType = AF_INET6; + size_t found = ipPortStr.find_last_of(']'); + if (found == std::string::npos || ipPortStr.length() - found <= 1) { + SHM_LOG_ERROR("get env SHMEM_UID_SESSION_ID is invalid"); + return SHMEM_INVALID_PARAM; + } + std::string ipStr = ipPortStr.substr(1, found - 1); + std::string portStr = ipPortStr.substr(found + 2); + + std::snprintf(ip, MAX_IP, "%s", ipStr.c_str()); + + port = std::stoi(portStr); + } else { + sockType = AF_INET; + size_t found = ipPortStr.find_last_of(':'); + if (found == std::string::npos || ipPortStr.length() - found <= 1) { + SHM_LOG_ERROR("get env SHMEM_UID_SESSION_ID is invalid"); + return SHMEM_INVALID_PARAM; + } + std::string ipStr = ipPortStr.substr(0, found); + std::string portStr = ipPortStr.substr(found + 1); + + std::snprintf(ip, MAX_IP, "%s", ipStr.c_str()); + + port = std::stoi(portStr); + } + return SHMEM_SUCCESS; + } + return SHMEM_INVALID_PARAM; +} + +int32_t shmem_set_ip_info(shmemx_uniqueid_t *uid, sa_family_t &sockType, char *pta_env_ip, uint16_t pta_env_port, + bool is_from_ifa) +{ + // init default uid + SHM_ASSERT_RETURN(uid != nullptr, SHMEM_INVALID_PARAM); + *uid = SHMEM_UNIQUEID_INITIALIZER; + shmemx_bootstrap_uid_state_t *innerUID = reinterpret_cast(uid); + if (sockType == AF_INET) { + innerUID->addr.addr.addr4.sin_family = AF_INET; + if (inet_pton(AF_INET, pta_env_ip, &(innerUID->addr.addr.addr4.sin_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv4 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.type = ADDR_IPv4; + } else if (sockType == AF_INET6) { + innerUID->addr.addr.addr6.sin6_family = AF_INET6; + if (inet_pton(AF_INET6, pta_env_ip, &(innerUID->addr.addr.addr6.sin6_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv6 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.type = ADDR_IPv6; + } else { + SHM_LOG_ERROR("IP Type is not IPv4 or IPv6"); + return SHMEM_INVALID_PARAM; + } + + // fill ip port as part of uid + if (is_from_ifa) { + int32_t ret = shmem_get_port_magic(innerUID, pta_env_ip); + if (ret != 0) { + SHM_LOG_ERROR("get available port failed."); + return SHMEM_INVALID_PARAM; + } + } else { + int32_t ret = shmem_using_env_port(innerUID, pta_env_ip, pta_env_port); + if (ret != 0) { + SHM_LOG_ERROR("using env port failed."); + return SHMEM_INVALID_PARAM; + } + } + + SHM_LOG_INFO("gen unique id success."); + return SHMEM_SUCCESS; +} + +int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid) +{ + if (shmem_set_log_level(shm::WARN_LEVEL) != 0) { + SHM_LOG_ERROR("failed to set log level"); + return SHMEM_INNER_ERROR; + } + char pta_env_ip[MAX_IP]; + uint16_t pta_env_port; + sa_family_t sockType; + const char *ipPort = std::getenv("SHMEM_UID_SESSION_ID"); + const char *ipInfo = std::getenv("SHMEM_UID_SOCK_IFNAM"); + bool is_from_ifa = false; + if (ipPort != nullptr) { + if (shmem_get_ip_from_env(pta_env_ip, pta_env_port, sockType, ipPort) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("cant get pta master addr."); + return SHMEM_INVALID_PARAM; + } + } else { + is_from_ifa = true; + if (shmem_get_ip_from_ifa(pta_env_ip, sockType, ipInfo) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("cant get available ip port."); + return SHMEM_INVALID_PARAM; + } + } + SHM_LOG_INFO("get master IP value:" << pta_env_ip); + return shmem_set_ip_info(uid, sockType, pta_env_ip, pta_env_port, is_from_ifa); +} + +int32_t shmemi_control_barrier_all_mf() +{ + SHM_ASSERT_RETURN(g_smem_handle != nullptr, SHMEM_INVALID_PARAM); + auto ret = smem_shm_control_barrier(g_smem_handle); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("Barrier failed"); + return ret; + } + return SHMEM_SUCCESS; +} + +void shmemi_global_exit_mf(int status) +{ + smem_shm_global_exit(g_smem_handle, status); +} + +#endif \ No newline at end of file diff --git a/src/host/init/init_backends/mf/shmemi_init_mf.h b/src/host/init/init_backends/mf/shmemi_init_mf.h new file mode 100644 index 0000000000000000000000000000000000000000..f099977152b0119920146bd40d95ae0963e7edae --- /dev/null +++ b/src/host/init/init_backends/mf/shmemi_init_mf.h @@ -0,0 +1,54 @@ + +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEMI_INIT_MF_H +#define SHMEMI_INIT_MF_H + +#include + +#include "init/init_backends/shmemi_init_base.h" +#include "shmemi_host_common.h" +#include "internal/host_device/shmemi_types.h" +#ifdef BACKEND_MF +// smem api +#include +#include +#include +#endif + +class shmemi_init_mf: public shmemi_init_base { +public: + shmemi_init_mf(shmem_init_attr_t *attr, char *ipport, shmemi_device_host_state_t *g_state); + ~shmemi_init_mf(); + + int init_device_state() override; + int finalize_device_state() override; + int update_device_state(void* host_ptr, size_t size) override; + + int reserve_heap() override; + int setup_heap() override; + int remove_heap() override; + int release_heap() override; + + int transport_init() override; + int transport_finalize() override; +private: + int32_t device_id; + + shmem_init_attr_t *attributes; + char *g_ipport = nullptr; + shmemi_device_host_state_t *g_state; +}; + +int32_t shmem_get_uniqueid_mf(shmemx_uniqueid_t *uid); +int32_t shmemi_control_barrier_all_mf(); +void shmemi_global_exit_mf(int status); + +#endif // SHMEMI_INIT_MF_H \ No newline at end of file diff --git a/src/host/init/init_backends/shmemi_init_base.h b/src/host/init/init_backends/shmemi_init_base.h new file mode 100644 index 0000000000000000000000000000000000000000..f1599cb87dc99ebf969438dfdc8c544c176046fb --- /dev/null +++ b/src/host/init/init_backends/shmemi_init_base.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEMI_INIT_BASE_H +#define SHMEMI_INIT_BASE_H + +#include + +#include "acl/acl.h" +#include "internal/host_device/shmemi_types.h" + +class shmemi_init_base { +public: + virtual int init_device_state() = 0; + virtual int finalize_device_state() = 0; + virtual int update_device_state(void* host_ptr, size_t size) = 0; + + virtual int reserve_heap() = 0; + virtual int setup_heap() = 0; + virtual int remove_heap() = 0; + virtual int release_heap() = 0; + + virtual int transport_init() = 0; + virtual int transport_finalize() = 0; + + virtual ~shmemi_init_base() {} +}; + +#endif // SHMEMI_INIT_BASE_H \ No newline at end of file diff --git a/src/host/init/shmem_init.cpp b/src/host/init/shmem_init.cpp index a552f345d370fe9cc5782e3f82ed689928f15639..acbb3df731e96febe0e7b6bc6df210d0cfee015b 100644 --- a/src/host/init/shmem_init.cpp +++ b/src/host/init/shmem_init.cpp @@ -13,64 +13,57 @@ #include #include #include -#include -#include +#include #include -#include -#include -#include +#include + #include "acl/acl.h" #include "shmemi_host_common.h" #include "internal/host/shmemi_host_def.h" using namespace std; -namespace shm { -constexpr uint64_t MIN_PORT = 1024; -constexpr uint64_t MAX_PORT = 65536; -constexpr uint64_t MAX_ATTEMPTS = 1000; -constexpr uint64_t MAX_IFCONFIG_LENGTH = 23; -constexpr uint64_t MAX_IP = 48; -constexpr int DEFAULT_MY_PE = -1; -constexpr int DEFAULT_N_PES = -1; +#define DEFAULT_MY_PE (-1) +#define DEFAULT_N_PES (-1) constexpr int DEFAULT_FLAG = 0; constexpr int DEFAULT_ID = 0; constexpr int DEFAULT_TIMEOUT = 120; constexpr int DEFAULT_TEVENT = 0; constexpr int DEFAULT_BLOCK_NUM = 1; -constexpr int DEFAULT_IFNAME_LNEGTH = 4; // initializer -#define SHMEM_DEVICE_HOST_STATE_INITIALIZER \ - { \ - (1 << 16) + sizeof(shmemi_device_host_state_t), /* version */ \ - (DEFAULT_MY_PE), /* mype */ \ - (DEFAULT_N_PES), /* npes */ \ - NULL, /* heap_base */ \ - NULL, /* p2p_heap_host_base */ \ - NULL, /* sdma_heap_host_base */ \ - NULL, /* roce_heap_host_base */ \ - NULL, /* p2p_heap_device_base */ \ - NULL, /* sdma_heap_device_base */ \ - NULL, /* roce_heap_device_base */ \ - {}, /* topo_list */ \ - SIZE_MAX, /* heap_size */ \ - {NULL}, /* team_pools */ \ - 0, /* sync_pool */ \ - 0, /* sync_counter */ \ - 0, /* core_sync_pool */ \ - 0, /* core_sync_counter */ \ - 0, /* partial_barrier_pool */ \ - false, /* shmem_is_shmem_initialized */ \ - false, /* shmem_is_shmem_created */ \ - {0, 16 * 1024, 0}, /* shmem_mte_config */ \ +#define SHMEM_DEVICE_HOST_STATE_INITIALIZER \ + { \ + (1 << 16) + sizeof(shmemi_device_host_state_t), /* version */ \ + (DEFAULT_MY_PE), /* mype */ \ + (DEFAULT_N_PES), /* npes */ \ + NULL, /* heap_base */ \ + NULL, /* host_p2p_heap_base */ \ + NULL, /* host_rdma_heap_base */ \ + NULL, /* host_sdma_heap_base */ \ + NULL, /* p2p_heap_device_base */ \ + NULL, /* sdma_heap_device_base */ \ + NULL, /* roce_heap_device_base */ \ + {}, /* topo_list */ \ + SIZE_MAX, /* heap_size */ \ + {NULL}, /* team_pools */ \ + 0, /* sync_pool */ \ + 0, /* sync_counter */ \ + 0, /* core_sync_pool */ \ + 0, /* core_sync_counter */ \ + 0, /* partial_barrier_pool */ \ + false, /* shmem_is_shmem_initialized */\ + false, /* shmem_is_shmem_created */ \ + {0, 16 * 1024, 0}, /* shmem_mte_config */ \ + 0, /* qp_info */ \ } shmemi_device_host_state_t g_state = SHMEM_DEVICE_HOST_STATE_INITIALIZER; shmemi_host_state_t g_state_host = {nullptr, DEFAULT_TEVENT, DEFAULT_BLOCK_NUM}; shmem_init_attr_t g_attr; -static smem_shm_t g_smem_handle = nullptr; +shmemx_uniqueid_t default_flag_uid; + static bool g_attr_init = false; static char g_ipport[SHMEM_MAX_IP_PORT_LEN] = {0}; @@ -80,56 +73,6 @@ int32_t version_compatible() return status; } -int32_t bind_tcp_port_v4(int &sockfd, int port, shmem_uniqueid_inner_t *innerUId, char *ip_str) -{ - sockfd = ::socket(AF_INET, SOCK_STREAM, 0); - if (sockfd != -1) { - int on_v4 = 1; - if (::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on_v4, sizeof(on_v4)) == 0) { - innerUId->addr.addr.addr4.sin_port = htons(port); - sockaddr *cur_addr = reinterpret_cast(&innerUId->addr.addr.addr4); - if (::bind(sockfd, cur_addr, sizeof(innerUId->addr.addr.addr4)) == 0) { - SHM_LOG_INFO("bind ipv4 success " << ", fd:" << sockfd << ", " << ip_str << ":" << port); - return 0; - } else { - SHM_LOG_ERROR("bind socket fail:" << errno << "," << ip_str << ":" << port); - } - } else { - SHM_LOG_ERROR("set socket opt fail:" << errno << "," << ip_str << ":" << port); - } - close(sockfd); - sockfd = -1; - } else { - SHM_LOG_ERROR("create socket fail:" << errno << ", " << ip_str << ":" << port); - } - return -1; -} - -int32_t bind_tcp_port_v6(int &sockfd, int port, shmem_uniqueid_inner_t *innerUId, char *ip_str) -{ - sockfd = ::socket(AF_INET6, SOCK_STREAM, 0); - if (sockfd != -1) { - int on_v6 = 1; - if (::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on_v6, sizeof(on_v6)) == 0) { - innerUId->addr.addr.addr6.sin6_port = htons(port); - sockaddr *cur_addr = reinterpret_cast(&innerUId->addr.addr.addr6); - if (::bind(sockfd, cur_addr, sizeof(innerUId->addr.addr.addr6)) == 0) { - SHM_LOG_INFO("bind ipv6 success " << ", fd:" << sockfd << ", " << ip_str << ":" << port); - return 0; - } else { - SHM_LOG_ERROR("bind socket6 fail:" << errno << "," << ip_str << ":" << port); - } - } else { - SHM_LOG_ERROR("set socket6 opt fail:" << errno << "," << ip_str << ":" << port); - } - close(sockfd); - sockfd = -1; - } else { - SHM_LOG_ERROR("create socket6 fail:" << errno << "," << ip_str << ":" << port); - } - return -1; -} - int32_t shmemi_options_init() { int32_t status = SHMEM_SUCCESS; @@ -144,132 +87,13 @@ int32_t shmemi_state_init_attr(shmem_init_attr_t *attributes) g_state.heap_size = attributes->local_mem_size + SHMEM_EXTRA_SIZE; aclrtStream stream = nullptr; - SHMEM_CHECK_RET(aclrtCreateStream(&stream), aclrtCreateStream); + SHMEM_CHECK_RET(aclrtCreateStream(&stream)); g_state_host.default_stream = stream; g_state_host.default_event_id = DEFAULT_TEVENT; g_state_host.default_block_num = DEFAULT_BLOCK_NUM; return status; } -void shmemi_reach_info_init(void *&gva) -{ - uint32_t reach_info = 0; - int32_t status = SHMEM_SUCCESS; - for (int32_t i = 0; i < g_state.npes; i++) { - status = smem_shm_topology_can_reach(g_smem_handle, i, &reach_info); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_topology_can_reach failed"); - } - g_state.p2p_heap_host_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * static_cast(i)); - if (reach_info & SMEMS_DATA_OP_MTE) { - g_state.topo_list[i] |= SHMEM_TRANSPORT_MTE; - } - if (reach_info & SMEMS_DATA_OP_SDMA) { - g_state.sdma_heap_host_base[i] = (void *)((uintptr_t)gva + g_state.heap_size * static_cast(i)); - } else { - g_state.sdma_heap_host_base[i] = NULL; - } - if (reach_info & SMEMS_DATA_OP_RDMA) { - g_state.topo_list[i] |= SHMEM_TRANSPORT_ROCE; - } - } -} - -int32_t shmemi_heap_init(shmem_init_attr_t *attributes) -{ - void *gva = nullptr; - int32_t status = SHMEM_SUCCESS; - int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id), aclrtGetDevice); - - status = smem_init(DEFAULT_FLAG); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_init Failed"); - return SHMEM_SMEM_ERROR; - } - smem_shm_config_t config; - status = smem_shm_config_init(&config); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_config_init Failed"); - return SHMEM_SMEM_ERROR; - } - // set config.sockFd value - config.sockFd = attributes->option_attr.sockFd; - status = smem_shm_init(attributes->ip_port, attributes->n_ranks, attributes->my_rank, device_id, &config); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_init Failed"); - return SHMEM_SMEM_ERROR; - } - - config.shmInitTimeout = attributes->option_attr.shm_init_timeout; - config.shmCreateTimeout = attributes->option_attr.shm_create_timeout; - config.controlOperationTimeout = attributes->option_attr.control_operation_timeout; - - g_smem_handle = smem_shm_create(DEFAULT_ID, attributes->n_ranks, attributes->my_rank, g_state.heap_size, - static_cast(attributes->option_attr.data_op_engine_type), - DEFAULT_FLAG, &gva); - if (g_smem_handle == nullptr || gva == nullptr) { - SHM_LOG_ERROR("smem_shm_create Failed"); - return SHMEM_SMEM_ERROR; - } - SHMEM_CHECK_RET( - aclrtMallocHost(((void **)&g_state.p2p_heap_host_base), g_state.npes * sizeof(void *))); - SHMEM_CHECK_RET( - aclrtMallocHost(((void **)&g_state.sdma_heap_host_base), g_state.npes * sizeof(void *))); - SHMEM_CHECK_RET( - aclrtMallocHost(((void **)&g_state.roce_heap_host_base), g_state.npes * sizeof(void *))); - - SHMEM_CHECK_RET(aclrtMalloc(((void **)&g_state.p2p_heap_device_base), g_state.npes * sizeof(void *), - ACL_MEM_MALLOC_HUGE_FIRST)); - SHMEM_CHECK_RET(aclrtMalloc(((void **)&g_state.sdma_heap_device_base), g_state.npes * sizeof(void *), - ACL_MEM_MALLOC_HUGE_FIRST)); - SHMEM_CHECK_RET(aclrtMalloc(((void **)&g_state.roce_heap_device_base), g_state.npes * sizeof(void *), - ACL_MEM_MALLOC_HUGE_FIRST)); - - g_state.heap_base = (void *)((uintptr_t)gva + g_state.heap_size * static_cast(attributes->my_rank)); - shmemi_reach_info_init(gva); - if (shm::g_ipport[0] != '\0') { - g_ipport[0] = '\0'; - bzero(attributes->ip_port, sizeof(attributes->ip_port)); - } else { - SHM_LOG_WARN("my_rank:" << attributes->my_rank << " shm::g_ipport is released in advance!"); - bzero(attributes->ip_port, sizeof(attributes->ip_port)); - } - g_state.is_shmem_created = true; - return status; -} - -int32_t shmemi_control_barrier_all() -{ - SHM_ASSERT_RETURN(g_smem_handle != nullptr, SHMEM_INVALID_PARAM); - auto ret = smem_shm_control_barrier(g_smem_handle); - if (ret != SHMEM_SUCCESS) { - SHM_LOG_ERROR("Barrier failed"); - return ret; - } - return SHMEM_SUCCESS; -} - -int32_t update_device_state() -{ - if (!g_state.is_shmem_created) { - return SHMEM_NOT_INITED; - } - - SHMEM_CHECK_RET(aclrtMemcpy(g_state.p2p_heap_device_base, g_state.npes * sizeof(void *), - g_state.p2p_heap_host_base, g_state.npes * sizeof(void *), ACL_MEMCPY_HOST_TO_DEVICE), aclrtMemcpy); - SHMEM_CHECK_RET(aclrtMemcpy(g_state.sdma_heap_device_base, g_state.npes * sizeof(void *), - g_state.sdma_heap_host_base, g_state.npes * sizeof(void *), ACL_MEMCPY_HOST_TO_DEVICE), aclrtMemcpy); - SHMEM_CHECK_RET(aclrtMemcpy(g_state.roce_heap_device_base, g_state.npes * sizeof(void *), - g_state.roce_heap_host_base, g_state.npes * sizeof(void *), ACL_MEMCPY_HOST_TO_DEVICE), aclrtMemcpy); - auto ret = smem_shm_set_extra_context(g_smem_handle, (void *)&g_state, sizeof(shmemi_device_host_state_t)); - if (ret != SHMEM_SUCCESS) { - SHM_LOG_ERROR("Failed to attach extra context to segment"); - return ret; - } - return SHMEM_SUCCESS; -} - int32_t check_attr(shmem_init_attr_t *attributes) { if ((attributes->my_rank < 0) || (attributes->n_ranks <= 0)) { @@ -289,7 +113,21 @@ int32_t check_attr(shmem_init_attr_t *attributes) return SHMEM_SUCCESS; } -} // namespace shm +shmemi_init_base* init_manager; + +int32_t shmemi_control_barrier_all() +{ +#ifdef BACKEND_MF + return shmemi_control_barrier_all_mf(); +#else + return shmemi_control_barrier_all_default(g_boot_handle); +#endif +} + +int32_t update_device_state() +{ + return init_manager->update_device_state((void *)&g_state, sizeof(shmemi_device_host_state_t)); +} int32_t shmem_set_data_op_engine_type(shmem_init_attr_t *attributes, data_op_engine_type_t value) { @@ -313,474 +151,220 @@ int32_t shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size SHM_ASSERT_RETURN(local_mem_size <= SHMEM_MAX_LOCAL_SIZE, SHMEM_INVALID_VALUE); SHM_ASSERT_RETURN(n_ranks <= SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); SHM_ASSERT_RETURN(my_rank < SHMEM_MAX_RANKS, SHMEM_INVALID_VALUE); - *attributes = &shm::g_attr; + *attributes = &g_attr; size_t ip_len = 0; if (ip_port != nullptr) { - ip_len = std::min(strlen(ip_port), sizeof(shm::g_ipport) - 1); + ip_len = std::min(strlen(ip_port), sizeof(g_ipport) - 1); - std::copy_n(ip_port, ip_len, shm::g_ipport); - shm::g_ipport[ip_len] = '\0'; - std::copy_n(shm::g_ipport, ip_len, shm::g_attr.ip_port); - if (shm::g_ipport[0] == '\0') { - SHM_LOG_ERROR("my_rank:" << my_rank << " shm::g_ipport is nullptr!"); + std::copy_n(ip_port, ip_len, g_ipport); + g_ipport[ip_len] = '\0'; + std::copy_n(g_ipport, ip_len, g_attr.ip_port); + if (g_ipport[0] == '\0') { + SHM_LOG_ERROR("my_rank:" << my_rank << " g_ipport is nullptr!"); return SHMEM_INVALID_VALUE; } } else { SHM_LOG_WARN("init with my_rank:" << my_rank << " ip_port is nullptr!"); } - int attr_version = static_cast((1 << 16) + sizeof(shmem_init_attr_t)); - shm::g_attr.my_rank = my_rank; - shm::g_attr.n_ranks = n_ranks; - shm::g_attr.ip_port[ip_len] = '\0'; - shm::g_attr.local_mem_size = local_mem_size; - shm::g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, shm::DEFAULT_TIMEOUT, - shm::DEFAULT_TIMEOUT, shm::DEFAULT_TIMEOUT, 0}; - shm::g_attr_init = true; + int attr_version = (1 << 16) + sizeof(shmem_init_attr_t); + g_attr.my_rank = my_rank; + g_attr.n_ranks = n_ranks; + g_attr.ip_port[ip_len] = '\0'; + g_attr.local_mem_size = local_mem_size; + g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, DEFAULT_TIMEOUT, + DEFAULT_TIMEOUT, DEFAULT_TIMEOUT}; + g_attr.comm_args = reinterpret_cast(&default_flag_uid); + shmemx_bootstrap_uid_state_t *uid_args = (shmemx_bootstrap_uid_state_t *)(g_attr.comm_args); + uid_args->rank = my_rank; + uid_args->nranks = n_ranks; + g_attr_init = true; return SHMEM_SUCCESS; } -int32_t shmem_get_uid_magic(shmem_uniqueid_inner_t *innerUId) -{ - std::ifstream urandom("/dev/urandom", std::ios::binary); - if (!urandom) { - SHM_LOG_ERROR("open random failed"); - return SHMEM_INNER_ERROR; - } - - urandom.read(reinterpret_cast(&innerUId->magic), sizeof(innerUId->magic)); - if (urandom.fail()) { - SHM_LOG_ERROR("read random failed."); - return SHMEM_INNER_ERROR; - } - SHM_LOG_DEBUG("init magic id to " << innerUId->magic); - return SHMEM_SUCCESS; -} - -int32_t shmem_get_port_magic(shmem_uniqueid_inner_t *innerUId, char *ip_str) +int32_t shmem_init_status(void) { - static std::random_device rd; - const int min_port = shm::MIN_PORT; - const int max_port = shm::MAX_PORT; - const int max_attempts = shm::MAX_ATTEMPTS; - const int offset_bit = 32; - uint64_t seed = 1; - seed |= static_cast(getpid()) << offset_bit; - seed |= static_cast(static_cast(std::chrono::system_clock::now().time_since_epoch().count()) - & 0xFFFFFFFF); - static std::mt19937_64 gen(seed); - std::uniform_int_distribution<> dis(min_port, max_port); - - int sockfd = -1; - int32_t ret; - for (int attempt = 0; attempt < max_attempts; ++attempt) { - int port = dis(gen); - if (innerUId->addr.type == ADDR_IPv4) { - ret = shm::bind_tcp_port_v4(sockfd, port, innerUId, ip_str); - if (ret == 0) { - innerUId->inner_sockFd = sockfd; - return 0; - } - } else { - ret = shm::bind_tcp_port_v6(sockfd, port, innerUId, ip_str); - if (ret == 0) { - innerUId->inner_sockFd = sockfd; - return 0; - } - } - } - SHM_LOG_ERROR("Not find a available tcp port"); - return -1; + if (!g_state.is_shmem_created) + return SHMEM_STATUS_NOT_INITIALIZED; + else if (!g_state.is_shmem_initialized) + return SHMEM_STATUS_SHM_CREATED; + else if (g_state.is_shmem_initialized) + return SHMEM_STATUS_IS_INITIALIZED; + else + return SHMEM_STATUS_INVALID; } -int32_t shmem_using_env_port(shmem_uniqueid_inner_t *innerUId, char *ip_str, uint16_t envPort) -{ - if (envPort < shm::MIN_PORT) { // envPort > MAX_PORT always false - SHM_LOG_ERROR("env port is invalid. " << envPort); - return SHMEM_INVALID_PARAM; - } - - int sockfd = -1; - int32_t ret; - if (innerUId->addr.type == ADDR_IPv4) { - ret = shm::bind_tcp_port_v4(sockfd, envPort, innerUId, ip_str); - if (ret == 0) { - innerUId->inner_sockFd = sockfd; - return 0; +int shmemx_set_attr_uniqueid_args(const int my_rank, const int n_ranks, const int64_t local_mem_size, + const shmemx_uniqueid_t *uid, + shmem_init_attr_t **shmem_attr) { + /* Save to uid_args */ + *shmem_attr = &g_attr; + shmemx_bootstrap_uid_state_t *uid_args = (shmemx_bootstrap_uid_state_t *)(uid); + uid_args->rank = my_rank; + uid_args->nranks = n_ranks; + void * comm_args = reinterpret_cast(uid_args); + g_attr.comm_args = comm_args; + g_attr.my_rank = my_rank; + g_attr.n_ranks = n_ranks; + g_attr.local_mem_size = local_mem_size; +#ifdef BACKEND_MF + std::string ipPort; + if (uid_args->addr.type == ADDR_IPv6) { + char ipStr[INET6_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET6, &(uid_args->addr.addr.addr6.sin6_addr), ipStr, sizeof(ipStr)) == nullptr) { + SHM_LOG_ERROR("inet_ntop failed for IPv6"); + return SHMEM_INNER_ERROR; } + uint16_t port = ntohs(uid_args->addr.addr.addr6.sin6_port); + ipPort = "tcp6://[" + std::string(ipStr) + "]:" + std::to_string(port); } else { - ret = shm::bind_tcp_port_v6(sockfd, envPort, innerUId, ip_str); - if (ret == 0) { - innerUId->inner_sockFd = sockfd; - return 0; - } - } - SHM_LOG_ERROR("init with env port fialed " << envPort << ", ret=" << ret); - return ret; -} - -int32_t ParseInterfaceWithType(const char *ipInfo, char *IP, sa_family_t &sockType, bool &flag) -{ - const char *delim = ":"; - const char *sep = strchr(ipInfo, delim[0]); - if (sep != nullptr) { - size_t leftLen = sep - ipInfo; - if (leftLen >= shm::MAX_IFCONFIG_LENGTH - 1 || leftLen == 0) { - return SHMEM_INVALID_VALUE; + char ipStr[INET_ADDRSTRLEN] = {0}; + if (inet_ntop(AF_INET, &(uid_args->addr.addr.addr4.sin_addr), ipStr, sizeof(ipStr)) == nullptr) { + SHM_LOG_ERROR("inet_ntop failed for IPv4"); + return SHMEM_INNER_ERROR; } - std::copy_n(ipInfo, leftLen, IP); - IP[leftLen] = '\0'; - sockType = (strcmp(sep + 1, "inet6") != 0) ? AF_INET : AF_INET6; - flag = true; + uint16_t port = ntohs(uid_args->addr.addr.addr4.sin_port); + ipPort = "tcp://" + std::string(ipStr) + ":" + std::to_string(port); } + std::copy(ipPort.begin(), ipPort.end(), g_ipport); + std::copy(ipPort.begin(), ipPort.end(), g_attr.ip_port); + g_ipport[ipPort.size()] = '\0'; + g_attr.ip_port[ipPort.size()] = '\0'; + int attr_version = static_cast((1 << 16) + sizeof(shmem_init_attr_t)); + g_attr.option_attr = {attr_version, SHMEM_DATA_OP_MTE, DEFAULT_TIMEOUT, + DEFAULT_TIMEOUT, DEFAULT_TIMEOUT, 0}; + g_attr.option_attr.sockFd = uid_args->inner_sockFd; + SHM_LOG_INFO("extract ip port:" << ipPort); +#endif + g_attr_init = true; return SHMEM_SUCCESS; } -int32_t shmem_auto_get_ip(struct sockaddr *ifaAddr, char *local, sa_family_t &sockType) +int32_t shmem_init_attr(shmemx_bootstrap_t bootstrap_flags, shmem_init_attr_t *attributes) { - sockType = ifaAddr->sa_family; - if (sockType == AF_INET) { - auto localIp = reinterpret_cast(ifaAddr)->sin_addr; - if (inet_ntop(sockType, &localIp, local, shm::MAX_IP) == nullptr) { - SHM_LOG_ERROR("convert local ipv4 to string failed. "); - return SHMEM_INVALID_PARAM; - } - return SHMEM_SUCCESS; - } else if (sockType == AF_INET6) { - auto localIp = reinterpret_cast(ifaAddr)->sin6_addr; - if (inet_ntop(sockType, &localIp, local, shm::MAX_IP) == nullptr) { - SHM_LOG_ERROR("convert local ipv6 to string failed. "); - return SHMEM_INVALID_PARAM; - } - return SHMEM_SUCCESS; - } - return SHMEM_INVALID_PARAM; + int32_t ret; + SHMEM_CHECK_RET(shmem_set_log_level(shm::ERROR_LEVEL)); + + // config init + SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); + SHMEM_CHECK_RET(check_attr(attributes), "An error occurred while checking the initialization attributes. Please check the initialization parameters."); + SHMEM_CHECK_RET(version_compatible(), "SHMEM Version mismatch."); + SHMEM_CHECK_RET(shmemi_options_init()); + + // shmem basic init +#ifdef BACKEND_MF + SHM_LOG_INFO("The current backend is MF."); + SHMEM_CHECK_RET(bootstrap_flags != SHMEMX_INIT_WITH_DEFAULT, "The current backend is MF, and the value of bootstrap_flags only supports SHMEMX_INIT_WITH_DEFAULT.", SHMEM_INVALID_PARAM); + init_manager = new shmemi_init_mf(attributes, g_ipport, &g_state); +#else + SHM_LOG_INFO("The current backend is SHMEM default."); + // bootstrap init + SHMEM_CHECK_RET(shmemi_bootstrap_init(bootstrap_flags, attributes)); + init_manager = new shmemi_init_default(attributes, &g_state); +#endif + SHMEM_CHECK_RET(shmemi_state_init_attr(attributes)); + SHMEM_CHECK_RET(init_manager->init_device_state()); + SHMEM_CHECK_RET(init_manager->reserve_heap()); + SHMEM_CHECK_RET(init_manager->transport_init()); + SHMEM_CHECK_RET(init_manager->setup_heap()); + + // shmem submodules init + SHMEM_CHECK_RET(memory_manager_initialize(g_state.heap_base, g_state.heap_size)); + SHMEM_CHECK_RET(shmemi_team_init(g_state.mype, g_state.npes)); + SHMEM_CHECK_RET(shmemi_sync_init()); + g_state.is_shmem_initialized = true; + SHMEM_CHECK_RET(update_device_state()); + SHMEM_CHECK_RET(shmemi_control_barrier_all()); + SHM_LOG_INFO("SHMEM init success."); + return SHMEM_SUCCESS; } -bool shmem_check_ifa(struct ifaddrs *ifa, sa_family_t sockType, bool flag, char *ifaName, size_t ifaLen) +int32_t shmem_finalize() { - if (ifa->ifa_addr == nullptr || ifa->ifa_netmask == nullptr || ifa->ifa_name == nullptr) { - SHM_LOG_DEBUG("loop ifa_addr/ifa_netmask/ifa_name is nullptr"); - return false; - } - - // socket type match and input env ifa valid - if (ifa->ifa_addr->sa_family != sockType && flag) { - SHM_LOG_DEBUG("sa family is not match, get " << ifa->ifa_addr->sa_family << ", expect " << sockType); - return false; - } + SHM_LOG_INFO("The pe: " << shmem_my_pe() << " begins to finalize."); + // shmem submodules finalize + SHMEM_CHECK_RET(shmemi_team_finalize()); - // prefix match with input ifa name - if (strncmp(ifa->ifa_name, ifaName, ifaLen) != 0) { - SHM_LOG_DEBUG("ifa name prefix un-match, get " << ifa->ifa_name << ", expect " << ifaName); - return false; - } + // shmem basic finalize + SHMEM_CHECK_RET(init_manager->remove_heap()); + SHMEM_CHECK_RET(init_manager->transport_finalize()); + SHMEM_CHECK_RET(init_manager->release_heap()); + SHMEM_CHECK_RET(init_manager->finalize_device_state()); + delete init_manager; - // ignore ifa which is down or loopback or not running - if ((ifa->ifa_flags & IFF_LOOPBACK) || !(ifa->ifa_flags & IFF_RUNNING) || !(ifa->ifa_flags & IFF_UP)) { - SHM_LOG_DEBUG("ifa flag un-match, flag=" << ifa->ifa_flags); - return false; - } +#ifdef BACKEND_MF - if (sockType == AF_INET6) { - struct sockaddr_in6 *sa6 = reinterpret_cast(ifa->ifa_addr); - if (IN6_IS_ADDR_LINKLOCAL(&sa6->sin6_addr)) { - SHM_LOG_DEBUG("ifa is scope link addr " << ifaName); - return false; - } - } - return true; +#else + shmemi_bootstrap_finalize(); +#endif + SHM_LOG_INFO("The pe: " << shmem_my_pe() << " finalize success."); + return SHMEM_SUCCESS; } -int32_t shmem_get_ip_from_ifa(char *local, sa_family_t &sockType, const string ipInfo) +void shmem_info_get_version(int *major, int *minor) { - struct ifaddrs *ifaddr; - char ifaName[shm::MAX_IFCONFIG_LENGTH]; - sockType = AF_INET; - bool flag = false; - if (ipInfo.empty()) { - std::copy_n("eth", shm::DEFAULT_IFNAME_LNEGTH, ifaName); - ifaName[shm::DEFAULT_IFNAME_LNEGTH - 1] = '\0'; - SHM_LOG_INFO("use default if to find IP:" << ifaName); - } else if (ParseInterfaceWithType(ipInfo.c_str(), ifaName, sockType, flag) != SHMEM_SUCCESS) { - SHM_LOG_ERROR("IP size set in SHMEM_CONF_STORE_MASTER_IF format has wrong length"); - return SHMEM_INVALID_PARAM; - } - if (getifaddrs(&ifaddr) == -1) { - SHM_LOG_ERROR("get local net interfaces failed: " << errno); - return SHMEM_INVALID_PARAM; - } - int32_t result = SHMEM_INVALID_PARAM; - const int IP_STR_BUFFER_SIZE = 64; - for (auto ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { - if (!shmem_check_ifa(ifa, sockType, flag, ifaName, strlen(ifaName))) { - continue; - } - if (sockType == AF_INET && flag) { - auto localIp = reinterpret_cast(ifa->ifa_addr)->sin_addr; - if (inet_ntop(sockType, &localIp, local, IP_STR_BUFFER_SIZE) == nullptr) { - SHM_LOG_ERROR("convert local ipv4 to string failed. "); - continue; - } - result = SHMEM_SUCCESS; - break; - } else if (sockType == AF_INET6 && flag) { - auto localIp = reinterpret_cast(ifa->ifa_addr)->sin6_addr; - if (inet_ntop(sockType, &localIp, local, IP_STR_BUFFER_SIZE) == nullptr) { - SHM_LOG_ERROR("convert local ipv6 to string failed. "); - continue; - } - result = SHMEM_SUCCESS; - break; - } else { - auto ret = shmem_auto_get_ip(ifa->ifa_addr, local, sockType); - if (ret != SHMEM_SUCCESS) { - continue; - } - result = SHMEM_SUCCESS; - break; - } - } - freeifaddrs(ifaddr); - return result; + SHM_ASSERT_RET_VOID(major != nullptr && minor != nullptr); + *major = SHMEM_MAJOR_VERSION; + *minor = SHMEM_MINOR_VERSION; } -int32_t shmem_get_ip_from_env(char *ip, uint16_t &port, sa_family_t &sockType, const string ipPort) +void shmem_info_get_name(char *name) { - if (!ipPort.empty()) { - SHM_LOG_DEBUG("get env SHMEM_UID_SESSION_ID value:" << ipPort); - std::string ipPortStr = ipPort; - - if (ipPort[0] == '[') { - sockType = AF_INET6; - size_t found = ipPortStr.find_last_of(']'); - if (found == std::string::npos || ipPortStr.length() - found <= 1) { - SHM_LOG_ERROR("get env SHMEM_UID_SESSION_ID is invalid"); - return SHMEM_INVALID_PARAM; - } - std::string ipStr = ipPortStr.substr(1, found - 1); - std::string portStr = ipPortStr.substr(found + 2); - - std::string result = ipStr; - if (result.length() >= shm::MAX_IP) { - SHM_LOG_ERROR("IP address is too long"); - return SHMEM_INVALID_PARAM; - } - std::copy(result.begin(), result.end(), ip); - ip[result.length()] = '\0'; - - port = std::stoi(portStr); - } else { - sockType = AF_INET; - size_t found = ipPortStr.find_last_of(':'); - if (found == std::string::npos || ipPortStr.length() - found <= 1) { - SHM_LOG_ERROR("get env SHMEM_UID_SESSION_ID is invalid"); - return SHMEM_INVALID_PARAM; - } - std::string ipStr = ipPortStr.substr(0, found); - std::string portStr = ipPortStr.substr(found + 1); - - std::string result = ipStr; - if (result.length() >= shm::MAX_IP) { - SHM_LOG_ERROR("IP address is too long"); - return SHMEM_INVALID_PARAM; - } - std::copy(result.begin(), result.end(), ip); - ip[result.length()] = '\0'; - - port = std::stoi(portStr); - } - return SHMEM_SUCCESS; + SHM_ASSERT_RET_VOID(name != nullptr); + std::ostringstream oss; + oss << "SHMEM v" << SHMEM_VENDOR_MAJOR_VER << "." << SHMEM_VENDOR_MINOR_VER << "." << SHMEM_VENDOR_PATCH_VER; + auto version_str = oss.str(); + size_t i; + for (i = 0; i < SHMEM_MAX_NAME_LEN - 1 && version_str[i] != '\0'; i++) { + name[i] = version_str[i]; } - return SHMEM_INVALID_PARAM; + name[i] = '\0'; } -int32_t shmem_set_ip_info(shmem_uniqueid_t *uid, sa_family_t &sockType, char *pta_env_ip, uint16_t pta_env_port, - bool is_from_ifa) +int32_t shmem_get_uniqueid_default(shmemx_uniqueid_t *uid) { - // init default uid - SHM_ASSERT_RETURN(uid != nullptr, SHMEM_INVALID_PARAM); - *uid = SHMEM_UNIQUEID_INITIALIZER; - shmem_uniqueid_inner_t *innerUID = reinterpret_cast(uid); - if (sockType == AF_INET) { - innerUID->addr.addr.addr4.sin_family = AF_INET; - if (inet_pton(AF_INET, pta_env_ip, &(innerUID->addr.addr.addr4.sin_addr)) <= 0) { - SHM_LOG_ERROR("inet_pton IPv4 failed"); - return SHMEM_NOT_INITED; - } - innerUID->addr.type = ADDR_IPv4; - } else if (sockType == AF_INET6) { - innerUID->addr.addr.addr6.sin6_family = AF_INET6; - if (inet_pton(AF_INET6, pta_env_ip, &(innerUID->addr.addr.addr6.sin6_addr)) <= 0) { - SHM_LOG_ERROR("inet_pton IPv6 failed"); - return SHMEM_NOT_INITED; - } - innerUID->addr.type = ADDR_IPv6; - } else { - SHM_LOG_ERROR("IP Type is not IPv4 or IPv6"); - return SHMEM_INVALID_PARAM; - } + int status = 0; + SHMEM_CHECK_RET(shmemi_options_init(), "Bootstrap failed during the preloading step."); + SHMEM_CHECK_RET(shmemi_bootstrap_pre_init(SHMEMX_INIT_WITH_UNIQUEID, &g_boot_handle), "Get uniqueid failed during the bootstrap preloading step."); - // fill ip port as part of uid - if (is_from_ifa) { - int32_t ret = shmem_get_port_magic(innerUID, pta_env_ip); - if (ret != 0) { - SHM_LOG_ERROR("get available port failed."); - return SHMEM_INVALID_PARAM; - } + if (g_boot_handle.pre_init_ops) { + SHMEM_CHECK_RET(g_boot_handle.pre_init_ops->get_unique_id((void *)uid), "Get uniqueid failed during the get uniqueid step."); } else { - int32_t ret = shmem_using_env_port(innerUID, pta_env_ip, pta_env_port); - if (ret != 0) { - SHM_LOG_ERROR("using env port failed."); - return SHMEM_INVALID_PARAM; - } + SHM_LOG_ERROR("Pre_init_ops is empty, unique_id cannot be obtained."); + status = SHMEM_INVALID_PARAM; } - SHM_LOG_INFO("gen unique id success."); - return SHMEM_SUCCESS; + return (status); } -int32_t shmem_get_uniqueid(shmem_uniqueid_t *uid) -{ - if (shmem_set_log_level(shm::WARN_LEVEL) != 0) { - SHM_LOG_ERROR("failed to set log level"); - return SHMEM_INNER_ERROR; - } - char pta_env_ip[shm::MAX_IP]; - uint16_t pta_env_port{}; - sa_family_t sockType; - const char *ipPortInput = std::getenv("SHMEM_UID_SESSION_ID"); - const char *ipInfoInput = std::getenv("SHMEM_UID_SOCK_IFNAM"); - const string ipPort = ipPortInput ? ipPortInput : ""; - const string ipInfo = ipInfoInput ? ipInfoInput : ""; - bool is_from_ifa = false; - if (!ipPort.empty()) { - if (shmem_get_ip_from_env(pta_env_ip, pta_env_port, sockType, ipPort) != SHMEM_SUCCESS) { - SHM_LOG_ERROR("cant get pta master addr."); - return SHMEM_INVALID_PARAM; - } - } else { - is_from_ifa = true; - if (shmem_get_ip_from_ifa(pta_env_ip, sockType, ipInfo) != SHMEM_SUCCESS) { - SHM_LOG_ERROR("cant get available ip port."); - return SHMEM_INVALID_PARAM; - } - } - SHM_LOG_INFO("get master IP value:" << pta_env_ip); - return shmem_set_ip_info(uid, sockType, pta_env_ip, pta_env_port, is_from_ifa); +int32_t shmem_get_uniqueid(shmemx_uniqueid_t *uid){ + shmem_set_log_level(shm::ERROR_LEVEL); + *uid = SHMEM_UNIQUEID_INITIALIZER; +#ifdef BACKEND_MF + SHMEM_CHECK_RET(shmem_get_uniqueid_mf(uid), "shmem_get_uniqueid failed, backend: mf"); + return SHMEM_SUCCESS; +#else + SHMEM_CHECK_RET(shmem_get_uniqueid_default(uid), "shmem_get_uniqueid failed, backend: default"); + return SHMEM_SUCCESS; +#endif } -int32_t shmem_set_attr_uniqueid_args(int rank_id, int nranks, const shmem_uniqueid_t *uid, shmem_init_attr_t *attr) -{ - if (attr == nullptr || uid == nullptr) { - SHM_LOG_ERROR("set unique id attr/uid is null"); - return SHMEM_INVALID_PARAM; - } - - if (rank_id != shm::g_attr.my_rank || nranks != shm::g_attr.n_ranks) { - SHM_LOG_ERROR("rankid/nranks invalid, maybe call shmem_set_attr firstly."); - return SHMEM_INVALID_PARAM; - } - - if (uid->version != SHMEM_UNIQUEID_VERSION) { - SHM_LOG_ERROR("uid version invalid, init unique id with shmem_get_uniqueid firstly."); - return SHMEM_INVALID_PARAM; - } - - // extract ip port from inner unique id - shmem_uniqueid_inner_t *innerUID = reinterpret_cast(const_cast(uid)); +int32_t shmemi_get_uniqueid_static_magic(shmemx_uniqueid_t *uid, bool is_root) { + shmem_set_log_level(shm::ERROR_LEVEL); + *uid = SHMEM_UNIQUEID_INITIALIZER; + int status = 0; + SHMEM_CHECK_RET(shmemi_options_init(), "Bootstrap failed during the preloading step."); + SHMEM_CHECK_RET(shmemi_bootstrap_pre_init(SHMEMX_INIT_WITH_UNIQUEID, &g_boot_handle), "Get uniqueid failed during the bootstrap preloading step."); - // compatibility with shmem_init_attr, init ip_port from unique id - std::string ipPort; - if (innerUID->addr.type == ADDR_IPv6) { - char ipStr[INET6_ADDRSTRLEN] = {0}; - if (inet_ntop(AF_INET6, &(innerUID->addr.addr.addr6.sin6_addr), ipStr, sizeof(ipStr)) == nullptr) { - SHM_LOG_ERROR("inet_ntop failed for IPv6"); - return SHMEM_INNER_ERROR; - } - uint16_t port = ntohs(innerUID->addr.addr.addr6.sin6_port); - ipPort = "tcp6://[" + std::string(ipStr) + "]:" + std::to_string(port); + if (g_boot_handle.pre_init_ops) { + SHMEM_CHECK_RET(g_boot_handle.pre_init_ops->get_unique_id_static_magic((void *)uid, is_root), "Get uniqueid failed during the get uniqueid step."); } else { - char ipStr[INET_ADDRSTRLEN] = {0}; - if (inet_ntop(AF_INET, &(innerUID->addr.addr.addr4.sin_addr), ipStr, sizeof(ipStr)) == nullptr) { - SHM_LOG_ERROR("inet_ntop failed for IPv4"); - return SHMEM_INNER_ERROR; - } - uint16_t port = ntohs(innerUID->addr.addr.addr4.sin_port); - ipPort = "tcp://" + std::string(ipStr) + ":" + std::to_string(port); + SHM_LOG_ERROR("Pre_init_ops is empty, unique_id cannot be obtained."); + status = SHMEM_INVALID_PARAM; } - std::copy(ipPort.begin(), ipPort.end(), shm::g_ipport); - std::copy(ipPort.begin(), ipPort.end(), shm::g_attr.ip_port); - std::copy(ipPort.begin(), ipPort.end(), attr->ip_port); - shm::g_ipport[ipPort.size()] = '\0'; - shm::g_attr.ip_port[ipPort.size()] = '\0'; - attr->ip_port[ipPort.size()] = '\0'; - attr->option_attr.sockFd = innerUID->inner_sockFd; - SHM_LOG_INFO("extract ip port:" << ipPort); - - int32_t status = shmem_init_attr(attr); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("shmem_init_attr failed"); - return status; - } - return SHMEM_SUCCESS; -} - -int32_t shmem_init_status(void) -{ - if (!shm::g_state.is_shmem_created) - return SHMEM_STATUS_NOT_INITIALIZED; - else if (!shm::g_state.is_shmem_initialized) - return SHMEM_STATUS_SHM_CREATED; - else if (shm::g_state.is_shmem_initialized) - return SHMEM_STATUS_IS_INITIALIZED; - else - return SHMEM_STATUS_INVALID; -} - -void shmem_rank_exit(int status) -{ - SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); - exit(status); -} - -int32_t shmem_init_attr(shmem_init_attr_t *attributes) -{ - int32_t ret; - - SHM_ASSERT_RETURN(attributes != nullptr, SHMEM_INVALID_PARAM); - SHMEM_CHECK_RET(shmem_set_log_level(shm::WARN_LEVEL), shmem_set_log_level); - SHMEM_CHECK_RET(shm::check_attr(attributes), check_attr); - SHMEM_CHECK_RET(shm::version_compatible(), version_compatible); - SHMEM_CHECK_RET(shm::shmemi_options_init(), shmemi_options_init); - - SHMEM_CHECK_RET(shm::shmemi_state_init_attr(attributes), shmemi_state_init_attr); - SHMEM_CHECK_RET(shm::shmemi_heap_init(attributes), shmemi_heap_init); - SHMEM_CHECK_RET(shm::update_device_state(), update_device_state); - - SHMEM_CHECK_RET(shm::memory_manager_initialize(shm::g_state.heap_base, shm::g_state.heap_size), - memory_manager_initialize); - SHMEM_CHECK_RET(shm::shmemi_team_init(shm::g_state.mype, shm::g_state.npes), shmemi_team_init); - SHMEM_CHECK_RET(shm::update_device_state(), update_device_state); - SHMEM_CHECK_RET(shm::shmemi_sync_init(), shmemi_sync_init); - SHMEM_CHECK_RET(smem_shm_register_exit(shm::g_smem_handle, &shmem_rank_exit), smem_shm_register_exit); - shm::g_state.is_shmem_initialized = true; - SHMEM_CHECK_RET(shm::shmemi_control_barrier_all(), shmemi_control_barrier_all); return SHMEM_SUCCESS; } -int32_t shmem_set_config_store_tls_key(const char *tls_pk, const uint32_t tls_pk_len, - const char *tls_pk_pw, const uint32_t tls_pk_pw_len, const shmem_decrypt_handler decrypt_handler) -{ - return smem_set_config_store_tls_key(tls_pk, tls_pk_len, tls_pk_pw, tls_pk_pw_len, decrypt_handler); -} - -int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)) -{ - SHM_ASSERT_RETURN(func != nullptr, SHMEM_INVALID_PARAM); - shm::shm_out_logger::Instance().set_extern_log_func(func, true); - return smem_set_extern_logger(func); -} int32_t shmem_set_log_level(int level) { @@ -800,76 +384,53 @@ int32_t shmem_set_log_level(int level) level = shm::FATAL_LEVEL; } } - shm::shm_out_logger::Instance().set_log_level(static_cast(level)); - if (smem_set_log_level(level) != SHMEM_SUCCESS) { - SHM_LOG_ERROR("Failed to set ock::mf::OutLogger level"); - } - return SHMEM_SUCCESS; + #ifdef BACKEND_MF + smem_set_log_level(level); + #endif + + return shm::shm_out_logger::Instance().set_log_level(static_cast(level)); } int32_t shmem_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len) { +#ifdef BACKEND_MF return smem_set_conf_store_tls(enable, tls_info, tls_info_len); +#else + return SHMEM_SUCCESS; +#endif } -int32_t shmem_finalize(void) +void shmem_rank_exit(int status) { - SHMEM_CHECK_RET(shm::shmemi_team_finalize()); - - if (shm::g_state.p2p_heap_host_base != nullptr) { - aclrtFree(shm::g_state.p2p_heap_host_base); - } - if (shm::g_state.sdma_heap_host_base != nullptr) { - aclrtFree(shm::g_state.sdma_heap_host_base); - } - if (shm::g_state.roce_heap_host_base != nullptr) { - aclrtFree(shm::g_state.roce_heap_host_base); - } - - if (shm::g_state.p2p_heap_device_base != nullptr) { - aclrtFree(shm::g_state.p2p_heap_device_base); - } - if (shm::g_state.sdma_heap_device_base != nullptr) { - aclrtFree(shm::g_state.sdma_heap_device_base); - } - if (shm::g_state.roce_heap_device_base != nullptr) { - aclrtFree(shm::g_state.roce_heap_device_base); - } - - if (shm::g_smem_handle != nullptr) { - int32_t status = smem_shm_destroy(shm::g_smem_handle, 0); - if (status != SHMEM_SUCCESS) { - SHM_LOG_ERROR("smem_shm_destroy Failed"); - return SHMEM_SMEM_ERROR; - } - shm::g_smem_handle = nullptr; - } - smem_shm_uninit(0); - smem_uninit(); - return SHMEM_SUCCESS; + SHM_LOG_DEBUG("shmem_rank_exit is work ,status: " << status); + exit(status); } -void shmem_info_get_version(int *major, int *minor) +int32_t shmem_set_config_store_tls_key(const char *tls_pk, const uint32_t tls_pk_len, + const char *tls_pk_pw, const uint32_t tls_pk_pw_len, const shmem_decrypt_handler decrypt_handler) { - SHM_ASSERT_RET_VOID(major != nullptr && minor != nullptr); - *major = SHMEM_MAJOR_VERSION; - *minor = SHMEM_MINOR_VERSION; +#ifdef BACKEND_MF + return smem_set_config_store_tls_key(tls_pk, tls_pk_len, tls_pk_pw, tls_pk_pw_len, decrypt_handler); +#else + return SHMEM_SUCCESS; +#endif } -void shmem_info_get_name(char *name) +int32_t shmem_set_extern_logger(void (*func)(int level, const char *msg)) { - SHM_ASSERT_RET_VOID(name != nullptr); - std::ostringstream oss; - oss << "SHMEM v" << SHMEM_VENDOR_MAJOR_VER << "." << SHMEM_VENDOR_MINOR_VER << "." << SHMEM_VENDOR_PATCH_VER; - auto version_str = oss.str(); - size_t i; - for (i = 0; i < SHMEM_MAX_NAME_LEN - 1 && version_str[i] != '\0'; i++) { - name[i] = version_str[i]; - } - name[i] = '\0'; +#ifdef BACKEND_MF + SHM_ASSERT_RETURN(func != nullptr, SHMEM_INVALID_PARAM); + shm::shm_out_logger::Instance().set_extern_log_func(func, true); + return smem_set_extern_logger(func); +#else + return SHMEM_SUCCESS; +#endif } void shmem_global_exit(int status) { - smem_shm_global_exit(shm::g_smem_handle, status); -} +#ifdef BACKEND_MF + shmemi_global_exit_mf(status); +#else +#endif +} \ No newline at end of file diff --git a/src/host/init/shmemi_init.h b/src/host/init/shmemi_init.h index 44908ba7e09e57badc2954358c6da9a723c63a99..b1b46967abe6ea8685284ccb0a3576c7e7763cf2 100644 --- a/src/host/init/shmemi_init.h +++ b/src/host/init/shmemi_init.h @@ -13,14 +13,19 @@ #include "stdint.h" #include "internal/host_device/shmemi_types.h" -namespace shm { +#ifdef BACKEND_MF +#include "init/init_backends/mf/shmemi_init_mf.h" +#else +#include "init/init_backends/default/shmemi_init_default.h" +#endif + extern shmemi_device_host_state_t g_state; extern shmemi_host_state_t g_state_host; -int32_t update_device_state(void); - int32_t shmemi_control_barrier_all(); -} // namespace shm +int32_t update_device_state(void); + +int32_t shmemi_get_uniqueid_static_magic(shmemx_uniqueid_t *uid, bool is_root); #endif // SHMEMI_INIT_H diff --git a/src/host/mem/shmem_mm.cpp b/src/host/mem/shmem_mm.cpp index ffe3e26448770619079da0043db29675f6ec673a..ce39da45083c14b38c66a9461ff1a4dc2f1a136f 100644 --- a/src/host/mem/shmem_mm.cpp +++ b/src/host/mem/shmem_mm.cpp @@ -10,17 +10,307 @@ #include #include "acl/acl.h" #include "shmemi_host_common.h" -#include "shmemi_mm_heap.h" -namespace shm { +bool range_size_first_comparator::operator()(const memory_range &mr1, const memory_range &mr2) const noexcept +{ + if (mr1.size != mr2.size) { + return mr1.size < mr2.size; + } + + return mr1.offset < mr2.offset; +} + +memory_heap::memory_heap(void *base, uint64_t size) noexcept : base_{reinterpret_cast(base)}, size_{size} +{ + pthread_spin_init(&spinlock_, 0); + address_idle_tree_[0] = size; + size_idle_tree_.insert({0, size}); +} + +memory_heap::~memory_heap() noexcept +{ + pthread_spin_destroy(&spinlock_); +} + +void *memory_heap::allocate(uint64_t size) noexcept +{ + if (size == 0 || size > g_state.heap_size) { + SHM_LOG_ERROR("cannot allocate with size " << size); + return nullptr; + } + + auto aligned_size = allocated_size_align_up(size); + memory_range anchor{0, aligned_size}; + + pthread_spin_lock(&spinlock_); + auto size_pos = size_idle_tree_.lower_bound(anchor); + if (size_pos == size_idle_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("cannot allocate with size: " << size); + return nullptr; + } + + auto target_offset = size_pos->offset; + auto target_size = size_pos->size; + auto addr_pos = address_idle_tree_.find(target_offset); + if (addr_pos == address_idle_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("offset(" << target_offset << ") size(" << target_size << ") in size tree, not in address tree."); + return nullptr; + } + + size_idle_tree_.erase(size_pos); + address_idle_tree_.erase(addr_pos); + address_used_tree_.emplace(target_offset, aligned_size); + if (target_size > aligned_size) { + memory_range left{target_offset + aligned_size, target_size - aligned_size}; + address_idle_tree_.emplace(left.offset, left.size); + size_idle_tree_.emplace(left); + } + pthread_spin_unlock(&spinlock_); + + return base_ + target_offset; +} + +void *memory_heap::aligned_allocate(uint64_t alignment, uint64_t size) noexcept +{ + if (size == 0 || alignment == 0 || size > g_state.heap_size) { + SHM_LOG_ERROR("invalid input, align=" << alignment << ", size=" << size); + return nullptr; + } + + if ((alignment & (alignment - 1UL)) != 0) { + SHM_LOG_ERROR("alignment should be power of 2, but real " << alignment); + return nullptr; + } + + uint64_t head_skip = 0; + auto aligned_size = allocated_size_align_up(size); + memory_range anchor{0, aligned_size}; + + pthread_spin_lock(&spinlock_); + auto size_pos = size_idle_tree_.lower_bound(anchor); + while (size_pos != size_idle_tree_.end() && !alignment_matches(*size_pos, alignment, aligned_size, head_skip)) { + ++size_pos; + } + + if (size_pos == size_idle_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("cannot allocate with size: " << size << ", alignment: " << alignment); + return nullptr; + } + + auto target_offset = size_pos->offset; + auto target_size = size_pos->size; + memory_range result_range{size_pos->offset + head_skip, aligned_size}; + size_idle_tree_.erase(size_pos); + + if (head_skip > 0) { + size_idle_tree_.emplace(memory_range{target_offset, head_skip}); + address_idle_tree_.emplace(target_offset, head_skip); + } + + if (head_skip + aligned_size < target_size) { + memory_range leftMR{target_offset + head_skip + aligned_size, target_size - head_skip - aligned_size}; + size_idle_tree_.emplace(leftMR); + address_idle_tree_.emplace(leftMR.offset, leftMR.size); + } + + address_used_tree_.emplace(result_range.offset, result_range.size); + pthread_spin_unlock(&spinlock_); + + return base_ + result_range.offset; +} + +bool memory_heap::change_size(void *address, uint64_t size) noexcept +{ + auto u8a = reinterpret_cast(address); + if (u8a < base_ || u8a >= base_ + size_) { + SHM_LOG_ERROR("release invalid address " << address); + return false; + } + + if (size == 0) { + release(address); + return true; + } + + auto offset = u8a - base_; + pthread_spin_lock(&spinlock_); + auto pos = address_used_tree_.find(offset); + if (pos == address_used_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("change size for address " << address << " not allocated."); + return false; + } + + // size不变 + if (pos->second == size) { + pthread_spin_unlock(&spinlock_); + return true; + } + + // 缩小size + if (pos->second > size) { + reduce_size_in_lock(pos, size); + pthread_spin_unlock(&spinlock_); + return true; + } + + // 扩大size + auto success = expend_size_in_lock(pos, size); + pthread_spin_unlock(&spinlock_); + + return success; +} + +int32_t memory_heap::release(void *address) noexcept +{ + auto u8a = reinterpret_cast(address); + if (u8a < base_ || u8a >= base_ + size_) { + SHM_LOG_ERROR("release invalid address " << address); + return -1; + } + + auto offset = u8a - base_; + pthread_spin_lock(&spinlock_); + auto pos = address_used_tree_.find(offset); + if (pos == address_used_tree_.end()) { + pthread_spin_unlock(&spinlock_); + SHM_LOG_ERROR("release address " << address << " not allocated."); + return -1; + } + + auto size = pos->second; + uint64_t final_offset = static_cast(offset); + uint64_t final_size = size; + address_used_tree_.erase(pos); + + auto prev_addr_pos = address_idle_tree_.lower_bound(offset); + if (prev_addr_pos != address_idle_tree_.begin()) { + --prev_addr_pos; + if (prev_addr_pos != address_idle_tree_.end() && + prev_addr_pos->first + prev_addr_pos->second == static_cast(offset)) { + // 合并前一个range + final_offset = prev_addr_pos->first; + final_size += prev_addr_pos->second; + + auto prev_addr_range = *prev_addr_pos; + address_idle_tree_.erase(prev_addr_pos); + size_idle_tree_.erase(memory_range{prev_addr_range.first, prev_addr_range.second}); + } + } + + auto next_addr_pos = address_idle_tree_.find(offset + size); + if (next_addr_pos != address_idle_tree_.end()) { // 合并后一个range + uint64_t next_addr = next_addr_pos->first; + uint64_t next_size = next_addr_pos->second; + final_size += next_size; + address_idle_tree_.erase(next_addr_pos); + size_idle_tree_.erase(memory_range{next_addr, next_size}); + } + address_idle_tree_.emplace(final_offset, final_size); + size_idle_tree_.emplace(memory_range{final_offset, final_size}); + pthread_spin_unlock(&spinlock_); + + return 0; +} + +bool memory_heap::allocated_size(void *address, uint64_t &size) const noexcept +{ + auto u8a = reinterpret_cast(address); + if (u8a < base_ || u8a >= base_ + size_) { + SHM_LOG_ERROR("release invalid address " << address); + return false; + } + + auto offset = u8a - base_; + bool exist = false; + pthread_spin_lock(&spinlock_); + auto pos = address_used_tree_.find(offset); + if (pos != address_used_tree_.end()) { + exist = true; + size = pos->second; + } + pthread_spin_unlock(&spinlock_); + + return exist; +} + +uint64_t memory_heap::allocated_size_align_up(uint64_t input_size) noexcept +{ + constexpr uint64_t align_size = 16UL; + constexpr uint64_t align_size_mask = ~(align_size - 1UL); + return (input_size + align_size - 1UL) & align_size_mask; +} + +bool memory_heap::alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, + uint64_t &head_skip) noexcept +{ + if (mr.size < size) { + return false; + } + + if ((mr.offset & (alignment - 1UL)) == 0UL) { + head_skip = 0; + return true; + } + + auto aligned_offset = ((mr.offset + alignment - 1UL) & (~(alignment - 1UL))); + head_skip = aligned_offset - mr.offset; + return mr.size >= size + head_skip; +} + +void memory_heap::reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept +{ + auto offset = pos->first; + auto old_size = pos->second; + pos->second = new_size; + auto next_addr_pos = address_idle_tree_.find(offset + old_size); + if (next_addr_pos == address_idle_tree_.end()) { + address_idle_tree_.emplace(offset + new_size, old_size - new_size); + size_idle_tree_.emplace(memory_range{offset + new_size, old_size - new_size}); + } else { + auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); + size_idle_tree_.erase(next_size_pos); + next_addr_pos->second += (old_size - new_size); + size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); + } +} + +bool memory_heap::expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept +{ + auto offset = pos->first; + auto old_size = pos->second; + auto delta = new_size - old_size; + + auto next_addr_pos = address_idle_tree_.find(offset + old_size); + if (next_addr_pos == address_idle_tree_.end() || next_addr_pos->second < delta) { + return false; + } + + pos->second = new_size; + auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); + if (next_addr_pos->second == delta) { + size_idle_tree_.erase(next_size_pos); + address_idle_tree_.erase(next_addr_pos); + } else { + size_idle_tree_.erase(next_size_pos); + next_addr_pos->second -= delta; + size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); + } + + return true; +} + namespace { -std::shared_ptr shm_memory_heap; +std::shared_ptr shmemi_memory_manager; } int32_t memory_manager_initialize(void *base, uint64_t size) { - shm_memory_heap = std::make_shared(base, size); - if (shm_memory_heap == nullptr) { + shmemi_memory_manager = std::make_shared(base, size); + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Failed to initialize shared memory heap"); return SHMEM_INNER_ERROR; } @@ -29,24 +319,23 @@ int32_t memory_manager_initialize(void *base, uint64_t size) void memory_manager_destroy() { - shm_memory_heap.reset(); + shmemi_memory_manager.reset(); } -} // namespace shm void *shmem_malloc(size_t size) { - if (shm::shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - void *ptr = shm::shm_memory_heap->allocate(size); + void *ptr = shmemi_memory_manager->allocate(size); SHM_LOG_DEBUG("shmem_malloc(" << size << ")"); - auto ret = shm::shmemi_control_barrier_all(); + auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("malloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { - shm::shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } @@ -55,28 +344,28 @@ void *shmem_malloc(size_t size) void *shmem_calloc(size_t nmemb, size_t size) { - if (shm::shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - SHM_ASSERT_MULTIPLY_OVERFLOW(nmemb, size, shm::g_state.heap_size, nullptr); + SHM_ASSERT_MULTIPLY_OVERFLOW(nmemb, size, g_state.heap_size, nullptr); auto total_size = nmemb * size; - auto ptr = shm::shm_memory_heap->allocate(total_size); + auto ptr = shmemi_memory_manager->allocate(total_size); if (ptr != nullptr) { auto ret = aclrtMemset(ptr, size, 0, size); if (ret != 0) { SHM_LOG_ERROR("shmem_calloc(" << nmemb << ", " << size << ") memset failed: " << ret); - shm::shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } - auto ret = shm::shmemi_control_barrier_all(); + auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("calloc mem barrier failed, ret: " << ret); if (ptr != nullptr) { - shm::shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } @@ -87,17 +376,17 @@ void *shmem_calloc(size_t nmemb, size_t size) void *shmem_align(size_t alignment, size_t size) { - if (shm::shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return nullptr; } - auto ptr = shm::shm_memory_heap->aligned_allocate(alignment, size); - auto ret = shm::shmemi_control_barrier_all(); + auto ptr = shmemi_memory_manager->aligned_allocate(alignment, size); + auto ret = shmemi_control_barrier_all(); if (ret != 0) { SHM_LOG_ERROR("shmem_align barrier failed, ret: " << ret); if (ptr != nullptr) { - shm::shm_memory_heap->release(ptr); + shmemi_memory_manager->release(ptr); ptr = nullptr; } } @@ -107,7 +396,7 @@ void *shmem_align(size_t alignment, size_t size) void shmem_free(void *ptr) { - if (shm::shm_memory_heap == nullptr) { + if (shmemi_memory_manager == nullptr) { SHM_LOG_ERROR("Memory Heap Not Initialized."); return; } @@ -115,7 +404,7 @@ void shmem_free(void *ptr) return; } - auto ret = shm::shm_memory_heap->release(ptr); + auto ret = shmemi_memory_manager->release(ptr); if (ret != 0) { SHM_LOG_ERROR("release failed: " << ret); } diff --git a/src/host/mem/shmem_rma.cpp b/src/host/mem/shmem_rma.cpp index 40d5ed67f6abca37aa63d0e1fe3cc226721530bb..9893320cb9403aae4f336ac778d4f0185f31d728 100644 --- a/src/host/mem/shmem_rma.cpp +++ b/src/host/mem/shmem_rma.cpp @@ -21,31 +21,31 @@ void *shmem_ptr(void *ptr, int32_t pe) SHM_LOG_ERROR("shmem_ptr Failed. PE: " << shmem_my_pe() << " Got Ilegal PE !!"); return nullptr; } - uint64_t lower_bound = (uint64_t)shm::g_state.heap_base; - uint64_t upper_bound = lower_bound + shm::g_state.heap_size; + uint64_t lower_bound = (uint64_t)g_state.heap_base; + uint64_t upper_bound = lower_bound + g_state.heap_size; if (uint64_t(ptr) < lower_bound || uint64_t(ptr) >= upper_bound) { SHM_LOG_ERROR("shmem_ptr Failed. PE: " << shmem_my_pe() << " Got Ilegal Address !!"); return nullptr; } - uint64_t offset = (uint64_t)ptr - (uint64_t)shm::g_state.heap_base; - void *symm_ptr = shm::g_state.p2p_heap_host_base[pe]; + uint64_t offset = (uint64_t)ptr - (uint64_t)g_state.heap_base; + void *symm_ptr = g_state.host_p2p_heap_base[pe]; if (symm_ptr != nullptr) { symm_ptr = reinterpret_cast(reinterpret_cast(symm_ptr) + offset); return symm_ptr; } SHM_LOG_ERROR("shmem_ptr Failed. PE: " << shmem_my_pe() - << " g_state.p2p_heap_host_base contains nullptr, Please Check Init Status!!"); + << " g_state.host_p2p_heap_base contains nullptr, Please Check Init Status!!"); return nullptr; } // Set Memcpy Interfaces necessary UB Buffer. int32_t shmem_mte_set_ub_params(uint64_t offset, uint32_t ub_size, uint32_t event_id) { - shm::g_state.mte_config.shmem_ub = static_cast(offset); - shm::g_state.mte_config.ub_size = ub_size; - shm::g_state.mte_config.event_id = event_id; - SHMEM_CHECK_RET(shm::update_device_state(), update_device_state); + g_state.mte_config.shmem_ub = offset; + g_state.mte_config.ub_size = ub_size; + g_state.mte_config.event_id = event_id; + SHMEM_CHECK_RET(update_device_state()); return SHMEM_SUCCESS; } @@ -62,7 +62,7 @@ int32_t shmem_mte_set_ub_params(uint64_t offset, uint32_t ub_size, uint32_t even { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem", SHMEMI_OP_PUT, NO_NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -84,7 +84,7 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_PUT) { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem_nbi", SHMEMI_OP_PUT, NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -107,7 +107,7 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_PUT_NBI) { \ int ret = shmemi_prepare_and_post_rma("shmem_get_" #NAME "_mem", SHMEMI_OP_GET, NO_NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -130,7 +130,7 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_GET) { \ int ret = shmemi_prepare_and_post_rma("shmem_get_" #NAME "_mem_nbi", SHMEMI_OP_GET, NBI, (uint8_t *)dest, \ (uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -157,8 +157,8 @@ SHMEM_TYPE_FUNC(SHMEM_TYPE_GET_NBI) { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem_signal", SHMEMI_OP_PUT_SIGNAL, NO_NBI, \ (uint8_t *)dst, (uint8_t *)src, elem_size, sizeof(TYPE), pe, sig_addr, \ - signal, sig_op, 1, 1, shm::g_state_host.default_stream, \ - shm::g_state_host.default_block_num); \ + signal, sig_op, 1, 1, g_state_host.default_stream, \ + g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -185,8 +185,8 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_SIGNAL) { \ int ret = shmemi_prepare_and_post_rma("shmem_put_" #NAME "_mem_signal_nbi", SHMEMI_OP_PUT_SIGNAL, NBI, \ (uint8_t *)dst, (uint8_t *)src, elem_size, sizeof(TYPE), pe, sig_addr, \ - signal, sig_op, 1, 1, shm::g_state_host.default_stream, \ - shm::g_state_host.default_block_num); \ + signal, sig_op, 1, 1, g_state_host.default_stream, \ + g_state_host.default_block_num); \ if (ret < 0) { \ SHM_LOG_ERROR("device calling transfer failed"); \ } \ @@ -206,7 +206,7 @@ SHMEM_TYPE_FUNC(SHMEM_PUT_TYPENAME_MEM_SIGNAL_NBI) SHMEM_HOST_API void shmem_##NAME##_p(TYPE *dst, const TYPE value, int pe) \ { \ shmemi_prepare_and_post_rma_##NAME##_p("shmem_" #NAME "_p", (uint8_t *)dst, value, pe, \ - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); \ + g_state_host.default_stream, g_state_host.default_block_num); \ } SHMEM_TYPE_FUNC(SHMEM_TYPENAME_P) @@ -242,8 +242,8 @@ SHMEM_TYPE_FUNC(SHMEM_TYPENAME_G) void shmem_putmem(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem putmem", SHMEMI_OP_PUT, NO_NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_putmem failed"); } @@ -252,8 +252,8 @@ void shmem_putmem(void *dst, void *src, size_t elem_size, int32_t pe) void shmem_getmem(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem getmem", SHMEMI_OP_GET, NO_NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_getmem failed"); } @@ -262,8 +262,8 @@ void shmem_getmem(void *dst, void *src, size_t elem_size, int32_t pe) void shmem_putmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem_putmem_nbi", SHMEMI_OP_PUT, NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_putmem_nbi failed"); } @@ -272,8 +272,8 @@ void shmem_putmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe) void shmem_getmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe) { int ret = shmemi_prepare_and_post_rma("shmem_getmem_nbi", SHMEMI_OP_GET, NBI, (uint8_t *)dst, (uint8_t *)src, - elem_size, 1, pe, nullptr, 0, 0, 1, 1, shm::g_state_host.default_stream, - shm::g_state_host.default_block_num); + elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream, + g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("shmem_getmem_nbi failed"); } @@ -283,7 +283,7 @@ void shmem_putmem_signal_nbi(void *dst, void *src, size_t elem_size, void *sig_a { int ret = shmemi_prepare_and_post_rma("shmem_putmem_signal_nbi", SHMEMI_OP_PUT_SIGNAL, NBI, (uint8_t *)dst, (uint8_t *)src, elem_size, 1, pe, (uint8_t *)sig_addr, signal, sig_op, 1, 1, - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); + g_state_host.default_stream, g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("device calling transfer failed"); } @@ -293,7 +293,7 @@ void shmem_putmem_signal(void *dst, void *src, size_t elem_size, void *sig_addr, { int ret = shmemi_prepare_and_post_rma("shmem_putmem_signal", SHMEMI_OP_PUT_SIGNAL, NO_NBI, (uint8_t *)dst, (uint8_t *)src, elem_size, 1, pe, (uint8_t *)sig_addr, signal, sig_op, 1, 1, - shm::g_state_host.default_stream, shm::g_state_host.default_block_num); + g_state_host.default_stream, g_state_host.default_block_num); if (ret < 0) { SHM_LOG_ERROR("device calling transfer failed"); } diff --git a/src/host/mem/shmemi_global_state.cpp b/src/host/mem/shmemi_global_state.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34f24f05f3f051b017a984899474f27a637fa110 --- /dev/null +++ b/src/host/mem/shmemi_global_state.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include "shmemi_global_state.h" +#include "host/shmem_host_def.h" +#include "common/shmemi_host_types.h" +#include "common/shmemi_logger.h" +#include "utils.h" + +#define LOAD_SYM(TARGET_FUNC, FILE_HANDLE, SYMBOL_NAME) \ + dlerror(); \ + *((void **)&TARGET_FUNC) = dlsym(FILE_HANDLE, SYMBOL_NAME); \ + error = dlerror(); \ + if (error != NULL) { \ + fprintf(stderr, "dlsym failed: %s\n", error); \ + dlclose(hal_handle); \ + } + +std::mutex g_mutex; + +bool g_hal_loaded = false; +static void *hal_handle; +const char *g_hal_lib_name = "libascend_hal.so"; + +int (*halMemAddressReserveFunc)(void **ptr, size_t size, size_t alignment, void *addr, uint64_t flag); +int (*halMemAddressFreeFunc)(void *ptr); +int (*halMemCreateFunc)(drv_mem_handle_t **handle, size_t size, const struct drv_mem_prop *prop, uint64_t flag); +int (*halMemReleaseFunc)(drv_mem_handle_t *handle); +int (*halMemMapFunc)(void *ptr, size_t size, size_t offset, drv_mem_handle_t *handle, uint64_t flag); +int (*halMemUnmapFunc)(void *ptr); + +int32_t load_hal_library() +{ + char *error; + std::lock_guard guard(g_mutex); + if (g_hal_loaded) { + return 0; + } + + dlerror(); + + hal_handle = dlopen(g_hal_lib_name, RTLD_NOW); + if (!hal_handle) { + fprintf(stderr, "dlopen failed: %s\n", dlerror()); + return 1; + } + + LOAD_SYM(halMemAddressReserveFunc, hal_handle, "halMemAddressReserve"); + LOAD_SYM(halMemAddressFreeFunc, hal_handle, "halMemAddressFree"); + LOAD_SYM(halMemCreateFunc, hal_handle, "halMemCreate"); + LOAD_SYM(halMemReleaseFunc, hal_handle, "halMemRelease"); + LOAD_SYM(halMemMapFunc, hal_handle, "halMemMap"); + LOAD_SYM(halMemUnmapFunc, hal_handle, "halMemUnmap"); + + g_hal_loaded = true; + return 0; +} + +global_state_reigister::global_state_reigister(int device_id): device_id_{device_id} +{ + SHMEM_CHECK(load_hal_library()); + + SHMEM_CHECK(halMemAddressReserveFunc(&device_ptr_, GLOBAL_STATE_SIZE, 0, (void *)(SVM_END_ADDR - GLOBAL_STATE_SIZE), 1)); + + int32_t logicDeviceId = -1; + rtLibLoader& loader = rtLibLoader::getInstance(); + if (loader.isLoaded()) { + loader.getLogicDevId(device_id_, &logicDeviceId); + } + + drv_mem_prop memprop; + memprop.side = 1; + memprop.devid = logicDeviceId; + memprop.module_id = 0; + memprop.pg_type = 0; + memprop.mem_type = 0; + memprop.reserve = 0; + + SHMEM_CHECK(halMemCreateFunc(&alloc_handle, GLOBAL_STATE_SIZE, &memprop, 0)); + + SHMEM_CHECK(halMemMapFunc(device_ptr_, GLOBAL_STATE_SIZE, 0, alloc_handle, 0)); + + // init success + init_status_ = 0; +} + +global_state_reigister::~global_state_reigister() +{ + SHMEM_CHECK(halMemUnmapFunc(device_ptr_)); + + SHMEM_CHECK(halMemReleaseFunc(alloc_handle)); + + SHMEM_CHECK(halMemAddressFreeFunc(device_ptr_)); + + if (hal_handle != nullptr) + dlclose(hal_handle); +} + +void *global_state_reigister::get_ptr() +{ + return device_ptr_; +} + +int global_state_reigister::get_init_status() +{ + return init_status_; +} diff --git a/src/host/mem/shmemi_global_state.h b/src/host/mem/shmemi_global_state.h new file mode 100644 index 0000000000000000000000000000000000000000..1fb56cfc8a491638fb7096bcb8af3adb5f567954 --- /dev/null +++ b/src/host/mem/shmemi_global_state.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEMI_GLOBAL_STATE_H +#define SHMEMI_GLOBAL_STATE_H + +#include +#include + +#include + +#include "internal/host_device/shmemi_types.h" + +typedef struct drv_mem_handle drv_mem_handle_t; + +struct drv_mem_prop { + uint32_t side; + uint32_t devid; + uint32_t module_id; + + uint32_t pg_type; + uint32_t mem_type; + uint64_t reserve; +}; + +class global_state_reigister { +public: + global_state_reigister(); + global_state_reigister(int device_id); + + ~global_state_reigister(); + + void *get_ptr(); + int get_init_status(); +private: + void *device_ptr_ = nullptr; + + drv_mem_handle_t *alloc_handle; + + int device_id_; + + // 1 means no-init + int init_status_ = 1; +}; + + +#endif // SHMEMI_GLOBAL_STATE_H \ No newline at end of file diff --git a/src/host/mem/shmemi_heap.cpp b/src/host/mem/shmemi_heap.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d5c175663f5e5a7fdceb8ed3dd8124f6f2273ade --- /dev/null +++ b/src/host/mem/shmemi_heap.cpp @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "shmemi_heap.h" +#include "host/shmem_host_def.h" +#include "common/shmemi_host_types.h" +#include "common/shmemi_logger.h" + +shmem_symmetric_heap::shmem_symmetric_heap(int pe_id, int pe_size, int dev_id): mype(pe_id), npes(pe_size), device_id(dev_id) +{ + pid_list.resize(pe_size); + sdid_list.resize(pe_size); + + memprop.handleType = ACL_MEM_HANDLE_TYPE_NONE; + memprop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED; + memprop.memAttr = ACL_HBM_MEM_HUGE; + memprop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; + memprop.location.id = dev_id; + memprop.reserve = 0; +} + +int shmem_symmetric_heap::reserve_heap(size_t size) +{ + peer_heap_base_p2p_ = (void **)std::calloc(npes, sizeof(void *)); + + // reserve virtual ptrs + for (int i = 0; i < npes; i++) { + peer_heap_base_p2p_[i] = NULL; + } + + // reserve local heap_base_ + SHMEM_CHECK_RET(aclrtReserveMemAddress(&(peer_heap_base_p2p_[mype]), size, 0, nullptr, 1)); + heap_base_ = peer_heap_base_p2p_[mype]; + + // alloc local physical memory + SHMEM_CHECK_RET(aclrtMallocPhysical(&local_handle, size, &memprop, 0)); + + alloc_size = size; + SHMEM_CHECK_RET(aclrtMapMem(peer_heap_base_p2p_[mype], alloc_size, 0, local_handle, 0)); + + return SHMEM_SUCCESS; +} + +int shmem_symmetric_heap::export_memory() +{ + // Get memory_name + char memoryName[IPC_NAME_SIZE] = {}; + SHMEM_CHECK_RET(rtIpcSetMemoryName(peer_heap_base_p2p_[mype], alloc_size, memoryName, IPC_NAME_SIZE)); + + memory_name = memoryName; + return SHMEM_SUCCESS; +} + +int shmem_symmetric_heap::export_pid() +{ + // Get local pid + SHMEM_CHECK_RET(aclrtDeviceGetBareTgid(&my_pid)); + + // Get Sdid + const int rtModuleTypeSystem = 0; + const int infoTypeSdid = 26; + SHMEM_CHECK_RET(rtGetDeviceInfo(device_id, rtModuleTypeSystem, infoTypeSdid, &my_sdid)); + + return SHMEM_SUCCESS; +} + +int shmem_symmetric_heap::import_pid() +{ + // Get all pids + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + g_boot_handle.allgather(&my_pid, pid_list.data(), 1 * sizeof(int), &g_boot_handle); + + // Get all sdids + g_boot_handle.allgather(&my_sdid, sdid_list.data(), 1 * sizeof(int64_t), &g_boot_handle); + + // Set Sdid and pid into Shared Memory + int local_offset = mype * npes; + for (int i = 0; i < npes; i++) { + if (i == mype || !(g_host_state.transport_map[local_offset + i] & 0x1)) { + continue; + } + SHMEM_CHECK_RET(rtSetIpcMemorySuperPodPid(memory_name.c_str(), sdid_list[i], &pid_list[i], 1)); + } + + return SHMEM_SUCCESS; +} + +int shmem_symmetric_heap::import_memory() +{ + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + g_boot_handle.allgather(memory_name.c_str(), names, IPC_NAME_SIZE, &g_boot_handle); + + static std::mutex mut; + std::lock_guard lock(mut); + + int local_offset = mype * npes; + for (int i = 0; i < npes; i++) { + if (i == mype || !(g_host_state.transport_map[local_offset + i] & 0x1)) { + continue; + } + SHMEM_CHECK_RET(rtIpcOpenMemory(reinterpret_cast(&peer_heap_base_p2p_[i]), names[i])); + } + + return SHMEM_SUCCESS; +} + +int shmem_symmetric_heap::setup_heap() +{ + SHMEM_CHECK_RET(export_memory()); + SHMEM_CHECK_RET(export_pid()); + SHMEM_CHECK_RET(import_pid()); + SHMEM_CHECK_RET(import_memory()); + + return SHMEM_SUCCESS; +} + +int shmem_symmetric_heap::remove_heap() +{ + for (int i = 0; i < npes; i++) { + if (i == mype || peer_heap_base_p2p_[i] == NULL) { + continue; + } + SHMEM_CHECK_RET(rtIpcCloseMemory(static_cast(peer_heap_base_p2p_[i]))); + peer_heap_base_p2p_[i] = NULL; + } + + // This barrier is necessary, otherwise Unmap will fail. + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + g_boot_handle.barrier(&g_boot_handle); + + SHMEM_CHECK_RET(rtIpcDestroyMemoryName(memory_name.c_str())); + + SHMEM_CHECK_RET(aclrtUnmapMem(peer_heap_base_p2p_[mype])); + return SHMEM_SUCCESS; +} + +int shmem_symmetric_heap::unreserve_heap() +{ + for (int i = 0; i < npes; i++) { + if (peer_heap_base_p2p_[i] != NULL) { + SHMEM_CHECK_RET(aclrtReleaseMemAddress(peer_heap_base_p2p_[i])); + } + } + SHMEM_CHECK_RET(aclrtFreePhysical(local_handle)); + return SHMEM_SUCCESS; +} + +void *shmem_symmetric_heap::get_heap_base() +{ + return heap_base_; +} + +void *shmem_symmetric_heap::get_peer_heap_base_p2p(int pe_id) +{ + return peer_heap_base_p2p_[pe_id]; +} \ No newline at end of file diff --git a/src/host/mem/shmemi_heap.h b/src/host/mem/shmemi_heap.h new file mode 100644 index 0000000000000000000000000000000000000000..4dfbc4b76969aa2ae8a1d9fed928c9336cbca493 --- /dev/null +++ b/src/host/mem/shmemi_heap.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEMI_HEAP_H +#define SHMEMI_HEAP_H + +#include +#include +#include +#include + +#include + +#include "internal/host_device/shmemi_types.h" +#include "common/shmemi_host_types.h" +#include "bootstrap/shmemi_bootstrap.h" + +#include "runtime/kernel.h" +#include "runtime/mem.h" +#include "runtime/dev.h" +#include "runtime/rt_ffts.h" + +const int IPC_NAME_SIZE = 65; + +class shmem_symmetric_heap { +public: + shmem_symmetric_heap() {} + shmem_symmetric_heap(int pe_id, int pe_size, int dev_id); + ~shmem_symmetric_heap() {}; + + int reserve_heap(size_t size); // aclrtReserveMemAddress && aclrtMallocPhysical + int unreserve_heap(); // halMemAddressFree && aclrtFreePhysical + + int setup_heap(); // export && import p2p memories && aclrtMapMem + int remove_heap(); // aclrtUnmapMem + + void *get_heap_base(); // return heap_base_ + void *get_peer_heap_base_p2p(int pe_id); // peer_heap_base_p2p_ + +private: + int export_memory(); + int import_memory(); + + int export_pid(); + int import_pid(); + + int32_t mype; + int32_t npes; + int32_t device_id; + + uint64_t alloc_size; + + void *heap_base_; + void **peer_heap_base_p2p_; + + // handle used to map local virtual ptr + aclrtPhysicalMemProp memprop; + aclrtDrvMemHandle local_handle; + + // names used to share memory + std::string memory_name; + char names[SHMEM_MAX_RANKS][IPC_NAME_SIZE]; + + // pid set to memory_name + int32_t my_pid = 0UL; + std::vector pid_list = {}; + + // sdid set to memory_name in 910_93 + int64_t my_sdid = 0UL; + std::vector sdid_list = {}; +}; + + +#endif // SHMEMI_HEAP_H \ No newline at end of file diff --git a/src/host/mem/shmemi_mm.h b/src/host/mem/shmemi_mm.h index 0282dde79e5cd5340c011e2c1d958f2a093ea60d..e89f95804c46f20215c18476f4200540e1652291 100644 --- a/src/host/mem/shmemi_mm.h +++ b/src/host/mem/shmemi_mm.h @@ -10,11 +10,54 @@ #ifndef SHMEMI_MM_H #define SHMEMI_MM_H +#include +#include +#include +#include + #include "host/shmem_host_def.h" -namespace shm { +struct memory_range { + const uint64_t offset; + const uint64_t size; + + memory_range(uint64_t o, uint64_t s) noexcept : offset{o}, size{s} + {} +}; + +struct range_size_first_comparator { + bool operator()(const memory_range &mr1, const memory_range &mr2) const noexcept; +}; + +class memory_heap { +public: + memory_heap(void *base, uint64_t size) noexcept; + ~memory_heap() noexcept; + +public: + void *allocate(uint64_t size) noexcept; + void *aligned_allocate(uint64_t alignment, uint64_t size) noexcept; + bool change_size(void *address, uint64_t size) noexcept; + int32_t release(void *address) noexcept; + bool allocated_size(void *address, uint64_t &size) const noexcept; + +private: + static uint64_t allocated_size_align_up(uint64_t input_size) noexcept; + static bool alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, + uint64_t &head_skip) noexcept; + void reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; + bool expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; + +private: + uint8_t *const base_; + const uint64_t size_; + mutable pthread_spinlock_t spinlock_{}; + std::map address_idle_tree_; + std::map address_used_tree_; + std::set size_idle_tree_; +}; + int32_t memory_manager_initialize(void *base, uint64_t size); void memory_manager_destroy(); -} // namespace shm #endif // SHMEMI_MM_H diff --git a/src/host/mem/shmemi_mm_heap.cpp b/src/host/mem/shmemi_mm_heap.cpp deleted file mode 100644 index 0c0ad243332b44e68832c4e30395a5c7a2e50da1..0000000000000000000000000000000000000000 --- a/src/host/mem/shmemi_mm_heap.cpp +++ /dev/null @@ -1,305 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * This program is free software, you can redistribute it and/or modify it under the terms and conditions of - * CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#include "shmemi_host_common.h" -#include "shmemi_mm_heap.h" - -namespace shm { -bool range_size_first_comparator::operator()(const memory_range &mr1, const memory_range &mr2) const noexcept -{ - if (mr1.size != mr2.size) { - return mr1.size < mr2.size; - } - - return mr1.offset < mr2.offset; -} - -memory_heap::memory_heap(void *base, uint64_t size) noexcept : base_{reinterpret_cast(base)}, size_{size} -{ - pthread_spin_init(&spinlock_, 0); - address_idle_tree_[0] = size; - size_idle_tree_.insert({0, size}); -} - -memory_heap::~memory_heap() noexcept -{ - pthread_spin_destroy(&spinlock_); -} - -void *memory_heap::allocate(uint64_t size) noexcept -{ - if (size == 0 || size > shm::g_state.heap_size) { - SHM_LOG_ERROR("cannot allocate with size " << size); - return nullptr; - } - - auto aligned_size = allocated_size_align_up(size); - memory_range anchor{0, aligned_size}; - - pthread_spin_lock(&spinlock_); - auto size_pos = size_idle_tree_.lower_bound(anchor); - if (size_pos == size_idle_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("cannot allocate with size: " << size); - return nullptr; - } - - auto target_offset = size_pos->offset; - auto target_size = size_pos->size; - auto addr_pos = address_idle_tree_.find(target_offset); - if (addr_pos == address_idle_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("offset(" << target_offset << ") size(" << target_size << ") in size tree, not in address tree."); - return nullptr; - } - - size_idle_tree_.erase(size_pos); - address_idle_tree_.erase(addr_pos); - address_used_tree_.emplace(target_offset, aligned_size); - if (target_size > aligned_size) { - memory_range left{target_offset + aligned_size, target_size - aligned_size}; - address_idle_tree_.emplace(left.offset, left.size); - size_idle_tree_.emplace(left); - } - pthread_spin_unlock(&spinlock_); - - return base_ + target_offset; -} - -void *memory_heap::aligned_allocate(uint64_t alignment, uint64_t size) noexcept -{ - if (size == 0 || alignment == 0 || size > shm::g_state.heap_size) { - SHM_LOG_ERROR("invalid input, align=" << alignment << ", size=" << size); - return nullptr; - } - - if ((alignment & (alignment - 1UL)) != 0) { - SHM_LOG_ERROR("alignment should be power of 2, but real " << alignment); - return nullptr; - } - - uint64_t head_skip = 0; - auto aligned_size = allocated_size_align_up(size); - memory_range anchor{0, aligned_size}; - - pthread_spin_lock(&spinlock_); - auto size_pos = size_idle_tree_.lower_bound(anchor); - while (size_pos != size_idle_tree_.end() && !alignment_matches(*size_pos, alignment, aligned_size, head_skip)) { - ++size_pos; - } - - if (size_pos == size_idle_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("cannot allocate with size: " << size << ", alignment: " << alignment); - return nullptr; - } - - auto target_offset = size_pos->offset; - auto target_size = size_pos->size; - memory_range result_range{size_pos->offset + head_skip, aligned_size}; - size_idle_tree_.erase(size_pos); - - if (head_skip > 0) { - size_idle_tree_.emplace(memory_range{target_offset, head_skip}); - address_idle_tree_.emplace(target_offset, head_skip); - } - - if (head_skip + aligned_size < target_size) { - memory_range leftMR{target_offset + head_skip + aligned_size, target_size - head_skip - aligned_size}; - size_idle_tree_.emplace(leftMR); - address_idle_tree_.emplace(leftMR.offset, leftMR.size); - } - - address_used_tree_.emplace(result_range.offset, result_range.size); - pthread_spin_unlock(&spinlock_); - - return base_ + result_range.offset; -} - -bool memory_heap::change_size(void *address, uint64_t size) noexcept -{ - auto u8a = reinterpret_cast(address); - if (u8a < base_ || u8a >= base_ + size_) { - SHM_LOG_ERROR("release invalid address " << address); - return false; - } - - if (size == 0) { - release(address); - return true; - } - - auto offset = u8a - base_; - pthread_spin_lock(&spinlock_); - auto pos = address_used_tree_.find(offset); - if (pos == address_used_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("change size for address " << address << " not allocated."); - return false; - } - - // size不变 - if (pos->second == size) { - pthread_spin_unlock(&spinlock_); - return true; - } - - // 缩小size - if (pos->second > size) { - reduce_size_in_lock(pos, size); - pthread_spin_unlock(&spinlock_); - return true; - } - - // 扩大size - auto success = expend_size_in_lock(pos, size); - pthread_spin_unlock(&spinlock_); - - return success; -} - -int32_t memory_heap::release(void *address) noexcept -{ - auto u8a = reinterpret_cast(address); - if (u8a < base_ || u8a >= base_ + size_) { - SHM_LOG_ERROR("release invalid address " << address); - return -1; - } - - auto offset = u8a - base_; - pthread_spin_lock(&spinlock_); - auto pos = address_used_tree_.find(offset); - if (pos == address_used_tree_.end()) { - pthread_spin_unlock(&spinlock_); - SHM_LOG_ERROR("release address " << address << " not allocated."); - return -1; - } - - auto size = pos->second; - uint64_t final_offset = static_cast(offset); - uint64_t final_size = size; - address_used_tree_.erase(pos); - - auto prev_addr_pos = address_idle_tree_.lower_bound(offset); - if (prev_addr_pos != address_idle_tree_.begin()) { - --prev_addr_pos; - if (prev_addr_pos != address_idle_tree_.end() && - prev_addr_pos->first + prev_addr_pos->second == static_cast(offset)) { - // 合并前一个range - final_offset = prev_addr_pos->first; - final_size += prev_addr_pos->second; - - auto prev_addr_range = *prev_addr_pos; - address_idle_tree_.erase(prev_addr_pos); - size_idle_tree_.erase(memory_range{prev_addr_range.first, prev_addr_range.second}); - } - } - - auto next_addr_pos = address_idle_tree_.find(offset + size); - if (next_addr_pos != address_idle_tree_.end()) { // 合并后一个range - uint64_t next_addr = next_addr_pos->first; - uint64_t next_size = next_addr_pos->second; - final_size += next_size; - address_idle_tree_.erase(next_addr_pos); - size_idle_tree_.erase(memory_range{next_addr, next_size}); - } - address_idle_tree_.emplace(final_offset, final_size); - size_idle_tree_.emplace(memory_range{final_offset, final_size}); - pthread_spin_unlock(&spinlock_); - - return 0; -} - -bool memory_heap::allocated_size(void *address, uint64_t &size) const noexcept -{ - auto u8a = reinterpret_cast(address); - if (u8a < base_ || u8a >= base_ + size_) { - SHM_LOG_ERROR("release invalid address " << address); - return false; - } - - auto offset = u8a - base_; - bool exist = false; - pthread_spin_lock(&spinlock_); - auto pos = address_used_tree_.find(offset); - if (pos != address_used_tree_.end()) { - exist = true; - size = pos->second; - } - pthread_spin_unlock(&spinlock_); - - return exist; -} - -uint64_t memory_heap::allocated_size_align_up(uint64_t input_size) noexcept -{ - constexpr uint64_t align_size = 16UL; - constexpr uint64_t align_size_mask = ~(align_size - 1UL); - return (input_size + align_size - 1UL) & align_size_mask; -} - -bool memory_heap::alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, - uint64_t &head_skip) noexcept -{ - if (mr.size < size) { - return false; - } - - if ((mr.offset & (alignment - 1UL)) == 0UL) { - head_skip = 0; - return true; - } - - auto aligned_offset = ((mr.offset + alignment - 1UL) & (~(alignment - 1UL))); - head_skip = aligned_offset - mr.offset; - return mr.size >= size + head_skip; -} - -void memory_heap::reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept -{ - auto offset = pos->first; - auto old_size = pos->second; - pos->second = new_size; - auto next_addr_pos = address_idle_tree_.find(offset + old_size); - if (next_addr_pos == address_idle_tree_.end()) { - address_idle_tree_.emplace(offset + new_size, old_size - new_size); - size_idle_tree_.emplace(memory_range{offset + new_size, old_size - new_size}); - } else { - auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); - size_idle_tree_.erase(next_size_pos); - next_addr_pos->second += (old_size - new_size); - size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); - } -} - -bool memory_heap::expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept -{ - auto offset = pos->first; - auto old_size = pos->second; - auto delta = new_size - old_size; - - auto next_addr_pos = address_idle_tree_.find(offset + old_size); - if (next_addr_pos == address_idle_tree_.end() || next_addr_pos->second < delta) { - return false; - } - - pos->second = new_size; - auto next_size_pos = size_idle_tree_.find(memory_range{next_addr_pos->first, next_addr_pos->second}); - if (next_addr_pos->second == delta) { - size_idle_tree_.erase(next_size_pos); - address_idle_tree_.erase(next_addr_pos); - } else { - size_idle_tree_.erase(next_size_pos); - next_addr_pos->second -= delta; - size_idle_tree_.emplace(memory_range{next_addr_pos->first, next_addr_pos->second}); - } - - return true; -} -} // namespace shm \ No newline at end of file diff --git a/src/host/mem/shmemi_mm_heap.h b/src/host/mem/shmemi_mm_heap.h deleted file mode 100644 index 7818a1e0c8da6051db39ec22082c803b21c7b010..0000000000000000000000000000000000000000 --- a/src/host/mem/shmemi_mm_heap.h +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - * This program is free software, you can redistribute it and/or modify it under the terms and conditions of - * CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ -#ifndef SHMEMI_MM_HEAP_H -#define SHMEMI_MM_HEAP_H - -#include -#include -#include -#include - -namespace shm { -struct memory_range { - const uint64_t offset; - const uint64_t size; - - memory_range(uint64_t o, uint64_t s) noexcept : offset{o}, size{s} - { - } -}; - -struct range_size_first_comparator { - bool operator()(const memory_range &mr1, const memory_range &mr2) const noexcept; -}; - -class memory_heap { -public: - memory_heap(void *base, uint64_t size) noexcept; - ~memory_heap() noexcept; - -public: - void *allocate(uint64_t size) noexcept; - void *aligned_allocate(uint64_t alignment, uint64_t size) noexcept; - bool change_size(void *address, uint64_t size) noexcept; - int32_t release(void *address) noexcept; - bool allocated_size(void *address, uint64_t &size) const noexcept; - -private: - static uint64_t allocated_size_align_up(uint64_t input_size) noexcept; - static bool alignment_matches(const memory_range &mr, uint64_t alignment, uint64_t size, - uint64_t &head_skip) noexcept; - void reduce_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; - bool expend_size_in_lock(const std::map::iterator &pos, uint64_t new_size) noexcept; - -private: - uint8_t *const base_; - const uint64_t size_; - mutable pthread_spinlock_t spinlock_{}; - std::map address_idle_tree_; - std::map address_used_tree_; - std::set size_idle_tree_; -}; -} // namespace shm - -#endif // SHMEMI_MM_HEAP_H diff --git a/src/host/python_wrapper/pyshmem.cpp b/src/host/python_wrapper/pyshmem.cpp index 7dde98d884fa02f01cd890273bae47780f4c028f..91aa37150697be5cbc38c9c2f8e1a38134cad642 100644 --- a/src/host/python_wrapper/pyshmem.cpp +++ b/src/host/python_wrapper/pyshmem.cpp @@ -56,7 +56,7 @@ inline std::string get_connect_url() int shmem_initialize(shmem_init_attr_t &attributes) { - auto ret = shmem_init_attr(&attributes); + auto ret = shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, &attributes); if (ret != 0) { std::cerr << "initialize shmem failed, ret: " << ret; return ret; @@ -67,7 +67,7 @@ int shmem_initialize(shmem_init_attr_t &attributes) py::bytes shmem_get_unique_id() { - shmem_uniqueid_t uid; + shmemx_uniqueid_t uid; auto ret = shmem_get_uniqueid(&uid); if (ret != 0) { std::cerr << "get unique id failed " << ret << std::endl; @@ -77,22 +77,22 @@ py::bytes shmem_get_unique_id() int shmem_initialize_unique_id(int rank, int world_size, int64_t mem_size, const std::string &bytes) { - if (bytes.size() < sizeof(shmem_uniqueid_t)) { + if (bytes.size() < sizeof(shmemx_uniqueid_t)) { std::cerr << "Error: Input bytes size (" << bytes.size() - << ") is smaller than required size (" << sizeof(shmem_uniqueid_t) + << ") is smaller than required size (" << sizeof(shmemx_uniqueid_t) << ")." << std::endl; return -1; } - shmem_uniqueid_t uid; + shmemx_uniqueid_t uid; std::copy_n(bytes.data(), sizeof(uid), reinterpret_cast(&uid)); shmem_init_attr_t *attr; - auto ret = shmem_set_attr(rank, world_size, mem_size, nullptr, &attr); + auto ret = shmemx_set_attr_uniqueid_args(rank, world_size, mem_size, &uid, &attr); if (ret != 0) { std::cerr << "set attr failed " << ret << std::endl; return ret; } - return shmem_set_attr_uniqueid_args(rank, world_size, &uid, attr); + return shmem_init_attr(SHMEMX_INIT_WITH_UNIQUEID, attr); } int32_t shmem_set_op_engine_type(shmem_init_attr_t &attributes, data_op_engine_type_t value) @@ -202,7 +202,8 @@ void DefineShmemAttr(py::module_ &m) .def_readwrite("data_op_engine_type", &shmem_init_optional_attr_t::data_op_engine_type) .def_readwrite("shm_init_timeout", &shmem_init_optional_attr_t::shm_init_timeout) .def_readwrite("shm_create_timeout", &shmem_init_optional_attr_t::shm_create_timeout) - .def_readwrite("control_operation_timeout", &shmem_init_optional_attr_t::control_operation_timeout); + .def_readwrite("control_operation_timeout", &shmem_init_optional_attr_t::control_operation_timeout) + .def_readwrite("sockFd", &shmem_init_optional_attr_t::sockFd); py::class_(m, "InitAttr") .def(py::init([]() { diff --git a/src/host/shmemi_host_common.h b/src/host/shmemi_host_common.h index 977a55757a4410397a5489409e20f44c02bf17f0..5696519e99d9147d5ca06089f9bb4dd77401ba27 100644 --- a/src/host/shmemi_host_common.h +++ b/src/host/shmemi_host_common.h @@ -11,16 +11,14 @@ #define SHMEM_SHMEMI_HOST_COMMON_H #include "shmem_api.h" - #include "common/shmemi_logger.h" +#include "common/shmemi_host_types.h" #include "init/shmemi_init.h" #include "team/shmemi_team.h" #include "mem/shmemi_mm.h" #include "sync/shmemi_sync.h" - -// smem api -#include -#include -#include +#include "bootstrap/shmemi_bootstrap.h" +#include "transport/shmemi_transport.h" +#include "utils.h" #endif // SHMEM_SHMEMI_HOST_COMMON_H diff --git a/src/host/sync/shmemi_sync.cpp b/src/host/sync/shmemi_sync.cpp index d8adb2e01d6f6badf1c48f09d4c610e087ac7a8d..77a9cc242062bbef6055936717510456738167ef 100644 --- a/src/host/sync/shmemi_sync.cpp +++ b/src/host/sync/shmemi_sync.cpp @@ -20,7 +20,6 @@ extern "C" int rtGetC2cCtrlAddr(uint64_t *config, uint32_t *len); -namespace shm { static uint64_t ffts_config; int32_t shmemi_sync_init() @@ -29,11 +28,9 @@ int32_t shmemi_sync_init() return rtGetC2cCtrlAddr(&ffts_config, &len); } -} // namespace - uint64_t shmemx_get_ffts_config() { - return shm::ffts_config; + return ffts_config; } void shmem_barrier(shmem_team_t tid) diff --git a/src/host/sync/shmemi_sync.h b/src/host/sync/shmemi_sync.h index fd07bb01e2c8a33cfb7391601822046e1a6de885..4401920794412b6655578e30d7c4f60beaed30f6 100644 --- a/src/host/sync/shmemi_sync.h +++ b/src/host/sync/shmemi_sync.h @@ -10,10 +10,6 @@ #ifndef SHMEMI_SYNC_H #define SHMEMI_SYNC_H -namespace shm { - int32_t shmemi_sync_init(); -} - #endif // SHMEMI_TEAM_H diff --git a/src/host/team/shmem_team.cpp b/src/host/team/shmem_team.cpp index 70366107048a4d68e2b3bc59490117ecfb7c2dcf..d50138e91b543b93625f6ea5ff255b9010dc8735 100644 --- a/src/host/team/shmem_team.cpp +++ b/src/host/team/shmem_team.cpp @@ -20,7 +20,6 @@ #include "shmemi_device_intf.h" using namespace std; -namespace shm { uint64_t g_team_mask = 0; shmemi_team_t *g_shmem_team_pool = nullptr; @@ -58,12 +57,12 @@ inline int32_t device_team_update(int team_idx, shmemi_team_t *host_team_ptr) { // device_ptr Malloc void *team_ptr = nullptr; - SHMEM_CHECK_RET(aclrtMalloc(&team_ptr, sizeof(shmemi_team_t), ACL_MEM_MALLOC_NORMAL_ONLY), aclrtMalloc); + SHMEM_CHECK_RET(aclrtMalloc(&team_ptr, sizeof(shmemi_team_t), ACL_MEM_MALLOC_NORMAL_ONLY)); auto ret = aclrtMemcpy((shmemi_team_t *)team_ptr, sizeof(shmemi_team_t), host_team_ptr, sizeof(shmemi_team_t), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != 0) { SHM_LOG_ERROR("memcpy device team info failed, ret: " << ret); - SHMEM_CHECK_RET(aclrtFree(team_ptr), aclrtFree); + SHMEM_CHECK_RET(aclrtFree(team_ptr)); return SHMEM_INNER_ERROR; } g_state.team_pools[team_idx] = (shmemi_team_t *)team_ptr; @@ -240,11 +239,11 @@ int32_t shmemi_team_finalize() g_state.partial_barrier_pool = 0; } if (g_state.core_sync_counter != 0) { - SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_counter)), aclrtFree); + SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_counter))); g_state.core_sync_counter = 0; } if (g_state.core_sync_pool != 0) { - SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_pool)), aclrtFree); + SHMEM_CHECK_RET(aclrtFree(reinterpret_cast(g_state.core_sync_pool))); g_state.core_sync_pool = 0; } if (g_shmem_team_pool != nullptr) { @@ -254,8 +253,6 @@ int32_t shmemi_team_finalize() return 0; } -} // namespace shm - int32_t shmem_team_split_strided_precheck(shmem_team_t parent_team, int32_t pe_start, int32_t pe_stride, int32_t pe_size, shmem_team_t *&new_team) { @@ -265,16 +262,18 @@ int32_t shmem_team_split_strided_precheck(shmem_team_t parent_team, int32_t pe_s } *new_team = SHMEM_TEAM_INVALID; - if (!shm::is_valid_team(parent_team)) { + if (!is_valid_team(parent_team)) { SHM_LOG_ERROR("input parent team is invalid!, team: " << parent_team); return SHMEM_INVALID_PARAM; } - shmemi_team_t *src_team = &shm::g_shmem_team_pool[parent_team]; + shmemi_team_t my_team; + shmemi_team_t *src_team = &g_shmem_team_pool[parent_team]; + if (pe_start >= SHMEM_MAX_RANKS || pe_stride >= SHMEM_MAX_RANKS || pe_size > SHMEM_MAX_RANKS) { SHM_LOG_ERROR("create team failed, input invalid, pe_start:" << pe_start << " pe_size:" << pe_size << " pe_stride:" << pe_stride << " parent:" - << shm::team_config2string(src_team)); + << team_config2string(src_team)); return SHMEM_INVALID_PARAM; } return SHMEM_SUCCESS; @@ -288,21 +287,23 @@ int32_t shmem_team_split_strided(shmem_team_t parent_team, int32_t pe_start, int return ret; } - shmemi_team_t *src_team = &shm::g_shmem_team_pool[parent_team]; + shmemi_team_t *src_team = &g_shmem_team_pool[parent_team]; int32_t global_pe = src_team->mype; int32_t global_pe_start = src_team->start + pe_start * src_team->stride; int32_t global_pe_stride = src_team->stride * pe_stride; int32_t global_pe_end = global_pe_start + global_pe_stride * (pe_size - 1); if (pe_start < 0 || pe_start >= src_team->size || pe_size <= 0 || pe_size > src_team->size || pe_stride < 1) { - SHM_LOG_ERROR("create team failed, input invalid:" << pe_start << ":" << pe_size << ":" << pe_stride << ":" - << shm::team_config2string(src_team)); + SHM_LOG_ERROR("create team failed, input invalid, pe_start:" << pe_start << " pe_size:" << pe_size + << " pe_stride:" << pe_stride << " parent:" + << team_config2string(src_team)); return SHMEM_INVALID_PARAM; } if (global_pe_start >= shmem_n_pes() || global_pe_end >= shmem_n_pes()) { - SHM_LOG_ERROR("create team failed, large than world size:" << pe_start << ":" << pe_size << ":" << pe_stride - << ":" << shmem_n_pes() << ":" << shm::team_config2string(src_team)); + SHM_LOG_ERROR("create team failed, large than world size, pe_start:" + << pe_start << " pe_size:" << pe_size << " pe_stride:" << pe_stride + << " world_size:" << shmem_n_pes() << " parent:" << team_config2string(src_team)); return SHMEM_INVALID_PARAM; } @@ -318,24 +319,24 @@ int32_t shmem_team_split_strided(shmem_team_t parent_team, int32_t pe_start, int my_team.stride = global_pe_stride; my_team.size = pe_size; - my_team.team_idx = shm::first_free_idx_fetch(); + my_team.team_idx = first_free_idx_fetch(); if (my_team.team_idx == -1) { SHM_LOG_ERROR("create team failed, team num is full!"); return SHMEM_INNER_ERROR; } - shm::g_shmem_team_pool[my_team.team_idx] = my_team; - if (shm::device_team_update(my_team.team_idx, &shm::g_shmem_team_pool[my_team.team_idx]) != 0) { + g_shmem_team_pool[my_team.team_idx] = my_team; + if (device_team_update(my_team.team_idx, &g_shmem_team_pool[my_team.team_idx]) != 0) { shmem_team_destroy(my_team.team_idx); SHM_LOG_ERROR("create team failed, malloc device state failed!"); return SHMEM_INNER_ERROR; } - if (shm::update_device_state() != 0) { + if (update_device_state() != 0) { shmem_team_destroy(my_team.team_idx); SHM_LOG_ERROR("create team failed, update state failed!"); return SHMEM_INNER_ERROR; } - SHM_LOG_INFO("create team success:" << shm::team_config2string(&my_team)); + SHM_LOG_INFO("create team success:" << team_config2string(&my_team)); *new_team = my_team.team_idx; return 0; } @@ -354,7 +355,7 @@ int shmemi_team_split_2d_precheck(shmem_team_t p_team, int x_range, shmem_team_t *x_team = SHMEM_TEAM_INVALID; *y_team = SHMEM_TEAM_INVALID; - if (!shm::is_valid_team(p_team)) { + if (!is_valid_team(p_team)) { SHM_LOG_ERROR("input parent team is invalid!, team: " << p_team); return SHMEM_INVALID_PARAM; } @@ -367,7 +368,7 @@ int shmemi_team_split_2d_x(shmem_team_t &parent_team, int32_t &x_team_counts, in { int start = 0; int errorCode = 0; - shmemi_team_t *src_team = &shm::g_shmem_team_pool[parent_team]; + shmemi_team_t *src_team = &g_shmem_team_pool[parent_team]; for (int i = 0; i < x_team_counts; ++i) { shmem_team_t my_xteam; @@ -396,7 +397,7 @@ int shmemi_team_split_2d_y(shmem_team_t &parent_team, int32_t &y_team_counts, in { int start = 0; int errorCode = 0; - shmemi_team_t *src_team = &shm::g_shmem_team_pool[parent_team]; + shmemi_team_t *src_team = &g_shmem_team_pool[parent_team]; for (int i = 0; i < y_team_counts; ++i) { shmem_team_t my_yteam; @@ -429,7 +430,7 @@ int shmem_team_split_2d(shmem_team_t parent_team, int x_range, shmem_team_t *x_t return ret; } - shmemi_team_t *src_team = &shm::g_shmem_team_pool[parent_team]; + shmemi_team_t *src_team = &g_shmem_team_pool[parent_team]; int32_t src_start = src_team->start; int32_t src_stride = src_team->stride; int32_t src_size = src_team->size; @@ -451,12 +452,12 @@ int shmem_team_split_2d(shmem_team_t parent_team, int x_range, shmem_team_t *x_t int32_t shmem_team_translate_pe(shmem_team_t src_team, int32_t src_pe, shmem_team_t dest_team) { - if (!shm::is_valid_team(src_team) || !shm::is_valid_team(dest_team)) { + if (!is_valid_team(src_team) || !is_valid_team(dest_team)) { return -1; } - shmemi_team_t *src_team_ptr = &shm::g_shmem_team_pool[src_team]; - shmemi_team_t *dest_team_ptr = &shm::g_shmem_team_pool[dest_team]; + shmemi_team_t *src_team_ptr = &g_shmem_team_pool[src_team]; + shmemi_team_t *dest_team_ptr = &g_shmem_team_pool[dest_team]; if (src_pe > src_team_ptr->size) { return -1; @@ -477,32 +478,32 @@ int32_t shmem_team_translate_pe(shmem_team_t src_team, int32_t src_pe, shmem_tea void shmem_team_destroy(shmem_team_t team) { - if (!shm::is_valid_team(team)) { + if (!is_valid_team(team)) { SHM_LOG_WARN("input team is invalid!, team: " << team); return; } - shm::device_team_destroy(team); - shm::g_team_mask ^= 1ULL << team; - if (shm::update_device_state() != SHMEM_SUCCESS) { + device_team_destroy(team); + g_team_mask ^= 1ULL << team; + if (update_device_state() != SHMEM_SUCCESS) { SHM_LOG_WARN("update state failed when destroy team!"); } } int32_t shmem_my_pe(void) { - return shm::g_state.mype; + return g_state.mype; } int32_t shmem_n_pes(void) { - return shm::g_state.npes; + return g_state.npes; } int32_t shmem_team_my_pe(shmem_team_t team) { - if (shm::is_valid_team(team)) { - return shm::g_shmem_team_pool[team].mype; + if (is_valid_team(team)) { + return g_shmem_team_pool[team].mype; } else { return -1; } @@ -510,8 +511,8 @@ int32_t shmem_team_my_pe(shmem_team_t team) int32_t shmem_team_n_pes(shmem_team_t team) { - if (shm::is_valid_team(team)) { - return shm::g_shmem_team_pool[team].size; + if (is_valid_team(team)) { + return g_shmem_team_pool[team].size; } else { return -1; } @@ -520,7 +521,7 @@ int32_t shmem_team_n_pes(shmem_team_t team) int shmem_team_get_config(shmem_team_t team, shmem_team_config_t *config) { SHMEM_CHECK_RET(config == nullptr); - if (shm::is_valid_team(team)) { + if (is_valid_team(team)) { config->num_contexts = 0; return 0; } else { diff --git a/src/host/team/shmemi_team.h b/src/host/team/shmemi_team.h index 6903f8f2aa2cb430f919f4b992f185b794cf543d..3f6990ab91204e414995b6292fafd3398bed0e75 100644 --- a/src/host/team/shmemi_team.h +++ b/src/host/team/shmemi_team.h @@ -12,12 +12,8 @@ #include "stdint.h" -namespace shm { - int32_t shmemi_team_init(int32_t rank, int32_t size); int32_t shmemi_team_finalize(); -} // namespace shm - #endif // SHMEMI_TEAM_H diff --git a/src/host/transport/shmemi_transport.cpp b/src/host/transport/shmemi_transport.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0271e9c0e3cdd578675c6ad6978be164e8e1470d --- /dev/null +++ b/src/host/transport/shmemi_transport.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "mem/shmemi_heap.h" +#include "shmemi_host_common.h" +#include "dlfcn.h" + +#include "transport/shmemi_transport.h" + +static void *transport_mte_lib = NULL; +static void *transport_rdma_lib = NULL; + +uint64_t *host_hash_list; + +shmemi_host_trans_state_t g_host_state; + +int32_t shmemi_transport_init(shmemi_device_host_state_t *g_state, shmem_init_optional_attr_t& option_attr) { + // Initialize MTE by default + g_host_state.num_choosen_transport = 1; + g_host_state.transport_map = (int *)calloc(g_state->npes * g_state->npes, sizeof(int)); + g_host_state.pe_info = (shmemi_transport_pe_info *)calloc(g_state->npes, sizeof(shmemi_transport_pe_info)); + + transport_mte_lib = dlopen("shmem_transport_mte.so", RTLD_NOW); + if (!transport_mte_lib) { + SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_mte.so" << ", err is: " << stderr); + return SHMEM_INVALID_VALUE; + } + + transport_init_func init_mte_fn; + init_mte_fn = (transport_init_func)dlsym(transport_mte_lib, "shmemi_mte_init"); + if (!init_mte_fn) { + dlclose(transport_mte_lib); + transport_mte_lib = NULL; + SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_mte.so" << "."); + return SHMEM_INVALID_VALUE; + } + + // Package my_info + int32_t device_id; + int64_t server_id; + int64_t superpod_id; + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); + const int infoTypeServerId = 27; + SHMEM_CHECK_RET(rtGetDeviceInfo(device_id, 0, infoTypeServerId, &server_id)); + const int infoTypeSuperpodId = 29; + SHMEM_CHECK_RET(rtGetDeviceInfo(device_id, 0, infoTypeSuperpodId, &superpod_id)); + + shmemi_transport_pe_info_t my_info; + my_info.pe = g_state->mype; + my_info.dev_id = device_id; + my_info.server_id = server_id; + my_info.superpod_id = superpod_id; + + // server_id invalid + if (server_id == 0x3FFU) { + static uint32_t bootIdHead; + static std::string sysBootId; + + std::string bootIdPath("/proc/sys/kernel/random/boot_id"); + std::ifstream input(bootIdPath); + input >> sysBootId; + + std::stringstream ss(sysBootId); + ss >> std::hex >> bootIdHead; + + my_info.server_id = bootIdHead; + } + + // AllGather All pe's host info + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + g_boot_handle.allgather((void *)&my_info, g_host_state.pe_info, sizeof(shmemi_transport_pe_info_t), &g_boot_handle); + SHMEM_CHECK_RET(init_mte_fn(&g_host_state.choosen_transports[0], g_state)); + + // If enable RDMA + if (option_attr.data_op_engine_type & SHMEM_DATA_OP_ROCE) { + g_host_state.num_choosen_transport++; + + int32_t logicDeviceId = -1; + rtLibLoader& loader = rtLibLoader::getInstance(); + if (loader.isLoaded()) { + loader.getLogicDevId(device_id, &logicDeviceId); + } + g_host_state.choosen_transports[1].logical_dev_id = logicDeviceId; + g_host_state.choosen_transports[1].dev_id = device_id; + + transport_rdma_lib = dlopen("shmem_transport_rdma.so", RTLD_NOW); + if (!transport_rdma_lib) { + SHM_LOG_ERROR("Transport unable to load " << "shmem_transport_rdma.so" << ", err is: " << stderr); + return SHMEM_INVALID_VALUE; + } + + transport_init_func init_rdma_fn; + init_rdma_fn = (transport_init_func)dlsym(transport_rdma_lib, "shmemi_rdma_init"); + if (!init_rdma_fn) { + dlclose(transport_rdma_lib); + transport_rdma_lib = NULL; + SHM_LOG_ERROR("Unable to get info from " << "shmem_transport_rdma.so" << "."); + return SHMEM_INVALID_VALUE; + } + SHMEM_CHECK_RET(init_rdma_fn(&g_host_state.choosen_transports[1], g_state)); + } + + return SHMEM_SUCCESS; +} + +int32_t shmemi_build_transport_map(shmemi_device_host_state_t *g_state) { + int *local_map = NULL; + local_map = (int *)calloc(g_state->npes, sizeof(int)); + + shmemi_transport_t t; + + // Loop can_access_peer, j = 0 means MTE, j = 1 means RDMA ... + for (int j = 0; j < g_host_state.num_choosen_transport; j++) { + t = g_host_state.choosen_transports[j]; + + for (int i = 0; i < g_state->npes; i++) { + int reach = 0; + + SHMEM_CHECK_RET(t.can_access_peer(&reach, g_host_state.pe_info + i, g_host_state.pe_info + g_state->mype, &t)); + + if (reach) { + int m = 1 << j; + local_map[i] |= m; + } + } + } + + for (int i = 0; i < g_state->npes; i++) { + g_state->topo_list[i] = static_cast(local_map[i]); + } + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + g_boot_handle.allgather(local_map, g_host_state.transport_map, g_state->npes * sizeof(int), &g_boot_handle); + + if (local_map) free(local_map); + return SHMEM_SUCCESS; +} + +int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t *g_state) { + shmemi_transport_t t; + // MTE + t = g_host_state.choosen_transports[0]; + + int *mte_peer_list; + int mte_peer_num = 0; + mte_peer_list = (int *)calloc(g_state->npes, sizeof(int)); + + int local_offset = g_state->mype * g_state->npes; + for (int i = 0; i < g_state->npes; i++) { + if (i == g_state->mype) + continue; + /* Check if MTE connected. */ + if (g_host_state.transport_map[local_offset + i] & 0x1) { + shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); + shmemi_transport_pe_info_t *my_info = (g_host_state.pe_info + g_state->mype); + // Only PEs in the same Node need to build up MTE connection. + if (my_info->server_id == peer_info->server_id) { + mte_peer_list[mte_peer_num] = peer_info->dev_id; + ++mte_peer_num; + } + } + } + + t.connect_peers(&t, mte_peer_list, mte_peer_num, g_state); + + if (g_host_state.num_choosen_transport > 1) { + int *rdma_peer_list; + int rdma_peer_num = 0; + rdma_peer_list = (int *)calloc(g_state->npes, sizeof(int)); + + int local_offset = g_state->mype * g_state->npes; + for (int i = 0; i < g_state->npes; i++) { + if (i == g_state->mype) + continue; + if (g_host_state.transport_map[local_offset + i] & 2) { + shmemi_transport_pe_info_t *peer_info = (g_host_state.pe_info + i); + rdma_peer_list[rdma_peer_num] = peer_info->dev_id; + ++rdma_peer_num; + } + } + t = g_host_state.choosen_transports[1]; + t.connect_peers(&t, rdma_peer_list, rdma_peer_num, g_state); + } + + return 0; +} + +int32_t shmemi_transport_finalize(shmemi_device_host_state_t *g_state) { + shmemi_transport_t t; + // MTE + t = g_host_state.choosen_transports[0]; + t.finalize(&t, g_state); + + if (transport_mte_lib != NULL) { + dlclose(transport_mte_lib); + transport_mte_lib = NULL; + } + + if (g_host_state.num_choosen_transport > 1) { + t = g_host_state.choosen_transports[1]; + t.finalize(&t, g_state); + + if (transport_rdma_lib != NULL) { + dlclose(transport_rdma_lib); + transport_rdma_lib = NULL; + } + } + return 0; +} diff --git a/src/host/transport/shmemi_transport.h b/src/host/transport/shmemi_transport.h new file mode 100644 index 0000000000000000000000000000000000000000..e89bfaf12d692388fa167fd240a261d357d65cc8 --- /dev/null +++ b/src/host/transport/shmemi_transport.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEMI_TRANSPORT_H +#define SHMEMI_TRANSPORT_H + +typedef int(*transport_init_func)(shmemi_transport_t *transport, shmemi_device_host_state_t *g_state); + +int32_t shmemi_transport_init(shmemi_device_host_state_t *g_state, shmem_init_optional_attr_t &option_attr); + +int32_t shmemi_build_transport_map(shmemi_device_host_state_t *g_state); + +int32_t shmemi_transport_setup_connections(shmemi_device_host_state_t *g_state); + +int32_t shmemi_transport_finalize(shmemi_device_host_state_t *g_state); + +#endif // SHMEMI_TRANSPORT_H \ No newline at end of file diff --git a/src/host/utils.h b/src/host/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0d398de9b9e7d2c7d0b6df77e5e6ae7901a646f2 --- /dev/null +++ b/src/host/utils.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEM_UTILS_H +#define SHMEM_UTILS_H + +#include +#include +#include "common/shmemi_logger.h" + +class rtLibLoader { +private: + void* rt_handle_; + int (*rtGetLogicDevIdByUserDevIdFunc_)(const int32_t, int32_t *const); + + rtLibLoader() : rt_handle_(nullptr) { + if (!loadLibrary()) { + SHM_LOG_ERROR("Failed to initialize rtLibLoader: could not load liba.so or foo function"); + } + } + + rtLibLoader(const rtLibLoader&) = delete; + rtLibLoader& operator=(const rtLibLoader&) = delete; + + bool loadLibrary() { + rt_handle_ = dlopen("libascendcl.so", RTLD_NOW); + if (!rt_handle_) { + SHM_LOG_ERROR("dlopen failed: " << dlerror()); + return false; + } + + *((void**)&rtGetLogicDevIdByUserDevIdFunc_) = dlsym(rt_handle_, "rtGetLogicDevIdByUserDevId"); + if (!rtGetLogicDevIdByUserDevIdFunc_) { + dlclose(rt_handle_); + rt_handle_ = nullptr; + SHM_LOG_ERROR("Unable to get info from " << "libascendcl.so" << "."); + return SHMEM_INVALID_VALUE; + } + return true; + } + +public: + static rtLibLoader& getInstance() { + static rtLibLoader instance; + return instance; + } + + void getLogicDevId(const int32_t userDeviceId, int32_t *const logicDeviceId) { + if (rtGetLogicDevIdByUserDevIdFunc_) { + rtGetLogicDevIdByUserDevIdFunc_(userDeviceId, logicDeviceId); + } else { + SHM_LOG_ERROR("rtGetLogicDevIdByUserDevIdFunc function is not available"); + } + } + + bool isLoaded() const { + return rt_handle_ != nullptr && rtGetLogicDevIdByUserDevIdFunc_ != nullptr; + } + + ~rtLibLoader() { + if (rt_handle_) { + dlclose(rt_handle_); + } + } +}; + +#endif // SHMEM_UTILS_H \ No newline at end of file diff --git a/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d0ebba3d1aba7200e39ad5596fddc8d28ae34657 --- /dev/null +++ b/src/modules/bootstrap/shmemi_bootstrap_mpi.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include +#include "host/shmem_host_def.h" +#include "common/shmemi_logger.h" +#include "common/shmemi_host_types.h" +#include "bootstrap/shmemi_bootstrap.h" + +typedef struct { + MPI_Comm comm; + int mpi_initialized; +} shmemi_bootstrap_mpi_state_t; +static shmemi_bootstrap_mpi_state_t shmemi_bootstrap_mpi_state = {MPI_COMM_NULL, 0}; + +static int shmemi_bootstrap_mpi_barrier(shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS; + + status = MPI_Barrier(shmemi_bootstrap_mpi_state.comm); + SHMEM_CHECK_RET(status); + + return status; +} + +static int shmemi_bootstrap_mpi_allgather(const void *sendbuf, void *recvbuf, int length, + shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS; + + status = MPI_Allgather(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_mpi_state.comm); + SHMEM_CHECK_RET(status); + + return status; +} + +static int shmemi_bootstrap_mpi_alltoall(const void *sendbuf, void *recvbuf, int length, + shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS; + + status = MPI_Alltoall(sendbuf, length, MPI_BYTE, recvbuf, length, MPI_BYTE, shmemi_bootstrap_mpi_state.comm); + SHMEM_CHECK_RET(status); + + return status; +} + +static void shmemi_bootstrap_mpi_global_exit(int status) { + int rc = MPI_SUCCESS; + + rc = MPI_Abort(shmemi_bootstrap_mpi_state.comm, status); + if (rc != MPI_SUCCESS) { + exit(1); + } +} + +static int shmemi_bootstrap_mpi_finalize(shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS, finalized; + + status = MPI_Finalized(&finalized); + SHMEM_CHECK_RET(status); + + if (finalized) { + if (shmemi_bootstrap_mpi_state.mpi_initialized) { + status = SHMEM_INNER_ERROR; + } else { + status = 0; + } + + return status; + } + + if (!finalized && shmemi_bootstrap_mpi_state.mpi_initialized) { + status = MPI_Comm_free(&shmemi_bootstrap_mpi_state.comm); + SHMEM_CHECK_RET(status); + } + + if (shmemi_bootstrap_mpi_state.mpi_initialized) MPI_Finalize(); + + return status; +} + +int shmemi_bootstrap_plugin_init(void *mpi_comm, shmemi_bootstrap_handle_t *handle) { + int status = MPI_SUCCESS, initialized = 0, finalized = 0; + MPI_Comm src_comm; + if (NULL == mpi_comm) + src_comm = MPI_COMM_WORLD; + else + src_comm = *((MPI_Comm *)mpi_comm); + status = MPI_Initialized(&initialized); + SHMEM_CHECK_RET(status); + status = MPI_Finalized(&finalized); + SHMEM_CHECK_RET(status); + if (!initialized && !finalized) { + MPI_Init(NULL, NULL); + shmemi_bootstrap_mpi_state.mpi_initialized = 1; + + if (src_comm != MPI_COMM_WORLD && src_comm != MPI_COMM_SELF) { + status = SHMEM_INNER_ERROR; + if (shmemi_bootstrap_mpi_state.mpi_initialized) { + MPI_Finalize(); + shmemi_bootstrap_mpi_state.mpi_initialized = 0; + } + } + } else if (finalized) { + status = SHMEM_INNER_ERROR; + if (shmemi_bootstrap_mpi_state.mpi_initialized) { + MPI_Finalize(); + shmemi_bootstrap_mpi_state.mpi_initialized = 0; + } + } + status = MPI_Comm_dup(src_comm, &shmemi_bootstrap_mpi_state.comm); + SHMEM_CHECK_RET(status); + status = MPI_Comm_rank(shmemi_bootstrap_mpi_state.comm, &handle->mype); + SHMEM_CHECK_RET(status); + status = MPI_Comm_size(shmemi_bootstrap_mpi_state.comm, &handle->npes); + SHMEM_CHECK_RET(status); + handle->allgather = shmemi_bootstrap_mpi_allgather; + handle->alltoall = shmemi_bootstrap_mpi_alltoall; + handle->barrier = shmemi_bootstrap_mpi_barrier; + handle->global_exit = shmemi_bootstrap_mpi_global_exit; + handle->finalize = shmemi_bootstrap_mpi_finalize; + handle->pre_init_ops = NULL; + handle->bootstrap_state = &shmemi_bootstrap_mpi_state.comm; + return status; +} + diff --git a/src/modules/bootstrap/shmemi_bootstrap_uid.cpp b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf496a123f7cfbce12fc37857574739247593bcb --- /dev/null +++ b/src/modules/bootstrap/shmemi_bootstrap_uid.cpp @@ -0,0 +1,1379 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "socket/uid_socket.h" +#include "socket/uid_utils.h" + + + +#define SHMEM_UNIQUEID_INITIALIZER \ + { \ + SHMEM_UNIQUEID_VERSION, \ + { \ + 0 \ + } \ + } \ + + +#define MAX_ATTEMPTS 500 +#define MAX_IFCONFIG_LENGTH 23 +#define MAX_IP 48 +#define DEFAULT_IFNAME_LNEGTH 4 +#define BOOTSTRAP_IN_PLACE (void*)0x1 +#define SOCKET_MAGIC 0x243ab9f2fc4b9d6cULL + +static const char* env_ip_port = nullptr; +static const char* env_ifname = nullptr; +static shmemx_bootstrap_uid_state_t shmemi_bootstrap_uid_state; +static struct bootstrap_netstate priv_info; +static int static_magic_count = 1; + +bool is_ipv6_loopback(const struct in6_addr *addr6) { + static const struct in6_addr loopback6 = IN6ADDR_LOOPBACK_INIT; + return memcmp(addr6, &loopback6, sizeof(struct in6_addr)) == 0; +} + +// fe80 check +bool is_ipv6_link_local(const struct in6_addr *addr6) { + if (addr6 == nullptr) { + return false; + } + + const uint8_t* bytes = addr6->s6_addr; + + if (bytes[0] != 0xfe) { + return false; + } + + if ((bytes[1] & 0xc0) != 0x80) { + return false; + } + + SHM_LOG_DEBUG("It is IPv6 link-local address (fe80::/10)."); + return true; +} + +bool is_ipv4_loopback(const struct in_addr *addr4) { + return ((ntohl(addr4->s_addr) >> 24) & 0xFF) == IN_LOOPBACKNET; +} + +bool is_ipv4_link_local(const struct in_addr *addr4) { + if (addr4 == nullptr) { + return false; + } + + uint32_t ip_addr = ntohl(addr4->s_addr); + + uint8_t byte1 = (ip_addr >> 24) & 0xff; + uint8_t byte2 = (ip_addr >> 16) & 0xff; + if (byte1 != 169 || byte2 != 254) { + return false; + } + SHM_LOG_DEBUG("It is IPv4 link-local address (169.254.x.x)."); + return true; +} + +static bool is_loopback_addr(const sockaddr_t* addr) { + if (addr == nullptr) { + return false; + } + if (addr->type == ADDR_IPv4) { + return is_ipv4_loopback(&addr->addr.addr4.sin_addr); + } else if (addr->type == ADDR_IPv6) { + return is_ipv6_loopback(&addr->addr.addr6.sin6_addr); + } else { + return false; + } +} + +static bool is_link_local_addr(const sockaddr_t* addr) { + if (addr == nullptr) { + return false; + } + if (addr->type == ADDR_IPv4) { + return is_ipv4_link_local(&addr->addr.addr4.sin_addr); + } else if (addr->type == ADDR_IPv6) { + return is_ipv6_link_local(&addr->addr.addr6.sin6_addr); + } else { + return false; + } +} +static int32_t shmemi_get_uid_magic(shmemx_bootstrap_uid_state_t *innerUId) +{ + std::ifstream urandom("/dev/urandom", std::ios::binary); + if (!urandom) { + SHM_LOG_ERROR("open random failed"); + return SHMEM_INNER_ERROR; + } + + urandom.read(reinterpret_cast(&innerUId->magic), sizeof(innerUId->magic)); + if (urandom.fail()) { + SHM_LOG_ERROR("read random failed."); + return SHMEM_INNER_ERROR; + } + SHM_LOG_DEBUG("init magic id to " << innerUId->magic); + return SHMEM_SUCCESS; +} + + +static int32_t shmemi_uid_parse_interface_with_type(const char *ipInfo, char *IP, sa_family_t &sockType, bool &flag) +{ + const char *delim = ":"; + const char *sep = strchr(ipInfo, delim[0]); + if (sep != nullptr) { + size_t leftLen = sep - ipInfo; + if (leftLen >= MAX_IFCONFIG_LENGTH - 1 || leftLen == 0) { + SHM_LOG_ERROR("Invalid interface prefix length: " << leftLen); + return SHMEM_INVALID_VALUE; + } + strncpy(IP, ipInfo, leftLen); + IP[leftLen] = '\0'; + sockType = (strcmp(sep + 1, "inet6") != 0) ? AF_INET : AF_INET6; + flag = true; + SHM_LOG_INFO("Parse ipInfo success: ifaPrefix=" << IP << ", sockType=" << (sockType == AF_INET ? "IPv4" : "IPv6")); + } + return SHMEM_SUCCESS; +} + +int32_t shmemi_traverse_ifa( + struct ifaddrs *ifaddr, + sa_family_t &sockType, + bool flag, + const char **prefixes, + bool exclude, + shmemx_bootstrap_uid_state_t *uid_args, + bool skipStateCheck = false, + bool allow_local = false +) { + for (struct ifaddrs *ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { + if (ifa->ifa_addr == nullptr) continue; + const char* ifname = ifa->ifa_name; + if (!allow_local && (strstr(ifname, "lo") != nullptr || strstr(ifname, "docker") != nullptr)) { + SHM_LOG_DEBUG("Skip interface: " << ifname << " (lo/docker, allow_local=false)"); + continue; + } + + bool match = false; + const char **p = prefixes; + + while (*p != nullptr) { + if (**p == '\0') { + p++; + continue; + } + size_t prefix_len = strlen(*p); + size_t ifname_len = strlen(ifa->ifa_name); + if (ifname_len < prefix_len) { + p++; + continue; + } + if (strncmp(ifa->ifa_name, *p, prefix_len) == 0) { + match = true; + break; + } + p++; + } + if (exclude && match) continue; + if (!exclude && !match) continue; + + if (!skipStateCheck && (!(ifa->ifa_flags & IFF_UP) || !(ifa->ifa_flags & IFF_RUNNING))) continue; + + if (flag) { + if (ifa->ifa_addr->sa_family != sockType) { + SHM_LOG_DEBUG("Protocol family not match (flag=true), interface: " << ifa->ifa_name << ", get: " << ifa->ifa_addr->sa_family << ", expect: "<< sockType); + continue; + } + } + bool is_invalid_addr = false; + if (!allow_local) { + if (ifa->ifa_addr->sa_family == AF_INET) { + struct sockaddr_in *addr4 = (struct sockaddr_in *)ifa->ifa_addr; + if (is_ipv4_link_local(&addr4->sin_addr)) { + SHM_LOG_INFO("Blocked ipv4 link local address."); + continue; + } + if (is_ipv4_loopback(&addr4->sin_addr)) { + is_invalid_addr = true; + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)ifa->ifa_addr; + if (is_ipv6_link_local(&addr6->sin6_addr)) { + SHM_LOG_INFO("Blocked ipv6 link local address."); + continue; + } + if (is_ipv6_loopback(&addr6->sin6_addr)) { + is_invalid_addr = true; + } + } + } + if (is_invalid_addr) { + SHM_LOG_DEBUG("Skip invalid address (lo/fe80, allow_local=false) on interface: " << ifname); + continue; + } + + if (ifa->ifa_addr->sa_family == AF_INET && (sockType == AF_UNSPEC || sockType == AF_INET)) { + memset(&uid_args->addr.addr.addr4, 0, sizeof(struct sockaddr_in)); + uid_args->addr.type = ADDR_IPv4; + uid_args->addr.addr.addr4 = *(struct sockaddr_in *)ifa->ifa_addr; + uid_args->addr.addr.addr4.sin_port = 0; + sockType = AF_INET; + SHM_LOG_INFO("Assign IPv4 from interface: " << ifa->ifa_name); + return SHMEM_SUCCESS; + } + + if (ifa->ifa_addr->sa_family == AF_INET6 && (sockType == AF_UNSPEC || sockType == AF_INET6)) { + memset(&uid_args->addr.addr.addr6, 0, sizeof(struct sockaddr_in6)); + uid_args->addr.type = ADDR_IPv6; + uid_args->addr.addr.addr6 = *(struct sockaddr_in6 *)ifa->ifa_addr; + uid_args->addr.addr.addr6.sin6_port = 0; + uid_args->addr.addr.addr6.sin6_flowinfo = 0; + + sockType = AF_INET6; + SHM_LOG_INFO("Assign IPv6 from interface: " << ifa->ifa_name <<" scope_id: " << uid_args->addr.addr.addr6.sin6_scope_id); + return SHMEM_SUCCESS; + } + } + return SHMEM_INVALID_PARAM; +} + +int32_t shmemi_get_ip_from_env(shmemx_bootstrap_uid_state_t *uid_args, const char *ipPort) { + if (uid_args == nullptr || ipPort == nullptr || strlen(ipPort) == 0) { + SHM_LOG_ERROR("Invalid param: uid_args is null or ipPort is empty"); + return SHMEM_INVALID_PARAM; + } + + shmemi_get_uid_magic(uid_args); + SHM_LOG_DEBUG("get env SHMEM_UID_SESSION_ID value: " << ipPort); + std::string ipPortStr = ipPort; + + if (ipPort[0] == '[') { + size_t bracket_end = ipPortStr.find_last_of(']'); + if (bracket_end == std::string::npos || ipPortStr.length() - bracket_end <= 1) { + SHM_LOG_ERROR("Invalid IPv6 format: no closing ']'"); + return SHMEM_INVALID_PARAM; + } + + std::string ip_with_scope = ipPortStr.substr(1, bracket_end - 1); + size_t scope_sep = ip_with_scope.find('%'); + std::string ipStr; + std::string if_name; + + memset(&uid_args->addr.addr.addr6, 0, sizeof(struct sockaddr_in6)); + uid_args->addr.type = ADDR_IPv6; + uid_args->addr.addr.addr6.sin6_family = AF_INET6; + + if (scope_sep != std::string::npos) { + ipStr = ip_with_scope.substr(0, scope_sep); + if_name = ip_with_scope.substr(scope_sep + 1); + uid_args->addr.addr.addr6.sin6_scope_id = if_nametoindex(if_name.c_str()); + if (uid_args->addr.addr.addr6.sin6_scope_id == 0) { + SHM_LOG_WARN("Interface " << if_name.c_str() << "not found, scope_id set to 0"); + } + } else { + ipStr = ip_with_scope; + uid_args->addr.addr.addr6.sin6_scope_id = 0; + } + + std::string portStr = ipPortStr.substr(bracket_end + 2); + if (portStr.empty()) { + SHM_LOG_ERROR("IPv6 port is empty"); + return SHMEM_INVALID_PARAM; + } + uint16_t port = static_cast(std::stoi(portStr)); + uid_args->addr.addr.addr6.sin6_port = htons(port); + uid_args->addr.addr.addr6.sin6_flowinfo = 0; + + if (inet_pton(AF_INET6, ipStr.c_str(), &uid_args->addr.addr.addr6.sin6_addr) <= 0) { + SHM_LOG_ERROR("inet_pton IPv6 failed: " << strerror(errno)); + return SHMEM_NOT_INITED; + } + } else { + size_t colon_pos = ipPortStr.find_last_of(':'); + if (colon_pos == std::string::npos || ipPortStr.length() - colon_pos <= 1) { + SHM_LOG_ERROR("Invalid IPv4 format: no colon separator"); + return SHMEM_INVALID_PARAM; + } + + std::string ipStr = ipPortStr.substr(0, colon_pos); + std::string portStr = ipPortStr.substr(colon_pos + 1); + + memset(&uid_args->addr.addr.addr4, 0, sizeof(struct sockaddr_in)); + uid_args->addr.type = ADDR_IPv4; + uid_args->addr.addr.addr4.sin_family = AF_INET; + + uint16_t port = static_cast(std::stoi(portStr)); + uid_args->addr.addr.addr4.sin_port = htons(port); + + if (inet_pton(AF_INET, ipStr.c_str(), &uid_args->addr.addr.addr4.sin_addr) <= 0) { + SHM_LOG_ERROR("inet_pton IPv4 failed: " << strerror(errno)); + return SHMEM_NOT_INITED; + } + } + + SHM_LOG_INFO("Assign IP/Port from env success"); + return SHMEM_SUCCESS; +} + +int32_t shmemi_get_ip_from_ifa(shmemx_bootstrap_uid_state_t *uid_args, const char *ipInfo) { + if (uid_args == nullptr) { + SHM_LOG_ERROR("uid_args is nullptr"); + return SHMEM_INVALID_PARAM; + } + + struct ifaddrs *ifaddr = nullptr; + char ifaPrefix[MAX_IFCONFIG_LENGTH] = {0}; + bool flag = false; + sa_family_t sockType = AF_INET; + bool foundValidIp = false; + + shmemi_get_uid_magic(uid_args); + + bool isIpInfoConfigured = (ipInfo != nullptr && strlen(ipInfo) > 0); + if (isIpInfoConfigured) { + int32_t ret = shmemi_uid_parse_interface_with_type(ipInfo, ifaPrefix, sockType, flag); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("Parse ipInfo failed, ret: " << ret); + return ret; + } + } + bool allow_local = isIpInfoConfigured; + + if (getifaddrs(&ifaddr) == -1) { + SHM_LOG_ERROR("getifaddrs failed: " << strerror(errno)); + return SHMEM_INVALID_PARAM; + } + + if (isIpInfoConfigured) { + const char *specifiedPrefixes[] = {ifaPrefix, nullptr}; + SHM_LOG_INFO("Search interface with specified prefix: " << ifaPrefix); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, specifiedPrefixes, false, uid_args, false, allow_local) == SHMEM_SUCCESS); + } else { + const char *ethPrefixes[] = {"eth", nullptr}; + const char *excludePrefixes[] = {"docker", "lo", nullptr}; + + SHM_LOG_INFO("Step 1: Search interfaces match 'eth'"); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, ethPrefixes, false, uid_args, false, allow_local) == SHMEM_SUCCESS); + + if (!foundValidIp) { + SHM_LOG_WARN("Step 3: Search interfaces exclude 'docker' and 'lo/fe80'"); + foundValidIp = (shmemi_traverse_ifa(ifaddr, sockType, flag, excludePrefixes, true, uid_args, false, allow_local) == SHMEM_SUCCESS); + } + } + + if (ifaddr != nullptr) { + freeifaddrs(ifaddr); + ifaddr = nullptr; + } + if (!foundValidIp) { + SHM_LOG_ERROR("No valid IP address found from any interface!"); + return SHMEM_INVALID_PARAM; + } + + SHM_LOG_INFO("Assign IP/Port from interface success"); + return SHMEM_SUCCESS; +} + +int32_t shmemi_set_ip_info(void *uid, sa_family_t &sockType, char *pta_env_ip, uint16_t pta_env_port, + bool is_from_ifa) +{ + // init default uid + shmemx_bootstrap_uid_state_t *innerUID = (shmemx_bootstrap_uid_state_t *)(uid); + SHM_LOG_INFO(" ENV IP: " << pta_env_ip << " ENV port: " << pta_env_port << " sockType: " << sockType); + SHMEM_CHECK_RET(shmemi_get_uid_magic(innerUID)); + + // fill ip port as part of uid + uint16_t port = 0; + if (is_from_ifa) { + SHM_LOG_DEBUG("Automatically obtain the value of port. port: " << port); + } else { + port = pta_env_port; + SHM_LOG_DEBUG("Get the port from the environment variable. port: " << port); + } + + if (sockType == AF_INET) { + SHM_LOG_INFO("SockType is AF_INET."); + innerUID->addr.addr.addr4.sin_family = AF_INET; + if (inet_pton(AF_INET, pta_env_ip, &(innerUID->addr.addr.addr4.sin_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv4 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.addr.addr4.sin_port = htons(port); + innerUID->addr.type = ADDR_IPv4; + } else if (sockType == AF_INET6) { + SHM_LOG_INFO("SockType is AF_INET6."); + innerUID->addr.addr.addr6.sin6_family = AF_INET6; + if (inet_pton(AF_INET6, pta_env_ip, &(innerUID->addr.addr.addr6.sin6_addr)) <= 0) { + SHM_LOG_ERROR("inet_pton IPv6 failed"); + return SHMEM_NOT_INITED; + } + innerUID->addr.addr.addr6.sin6_port = htons(port); + innerUID->addr.type = ADDR_IPv6; + } else { + SHM_LOG_ERROR("IP Type is not IPv4 or IPv6"); + return SHMEM_INVALID_PARAM; + } + SHM_LOG_INFO("gen unique id success."); + return SHMEM_SUCCESS; +} + +static int shmemi_bootstrap_uid_finalize(shmemi_bootstrap_handle_t *handle) { + if (!handle) { + return SHMEM_SUCCESS; + } + + if (handle->bootstrap_state) { + uid_bootstrap_state* state = (uid_bootstrap_state*) handle->bootstrap_state; + unexpected_conn_t* elem = state->unexpected_conns; + while (elem != NULL) { + unexpected_conn_t* next = elem->next; + socket_close(&elem->sock); // 关闭socket句柄 + SHMEM_BOOTSTRAP_PTR_FREE(elem); + elem = next; + } + state->unexpected_conns = NULL; + socket_close(&state->listen_sock); + socket_close(&state->ring_send_sock); + socket_close(&state->ring_recv_sock); + + SHMEM_BOOTSTRAP_PTR_FREE(state->peer_addrs); + state->peer_addrs = nullptr; + SHMEM_BOOTSTRAP_PTR_FREE(state); + handle->bootstrap_state = nullptr; + } + + if (handle->pre_init_ops) { + SHMEM_BOOTSTRAP_PTR_FREE(handle->pre_init_ops); + handle->pre_init_ops = nullptr; + } + + return SHMEM_SUCCESS; +} + + +static int shmemi_bootstrap_uid_allgather(const void *in, void *out, int len, shmemi_bootstrap_handle_t *handle) { + if (!in || !out || !handle || !handle->bootstrap_state) { + SHM_LOG_ERROR("bootstrap allgather: invalid arguments."); + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*) handle->bootstrap_state; + int rank = state->rank; + int nranks = state->nranks; + char* send_buf = (char*)in; + + if (state->ring_send_sock.state != SOCKET_STATE_READY || + state->ring_recv_sock.state != SOCKET_STATE_READY) { + SHM_LOG_ERROR("bootstrap allgather: rank " << rank << ": sockets not ready for allgather"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (in != BOOTSTRAP_IN_PLACE) { + memcpy((char*)out + (rank % nranks) * len, send_buf, len); + } + + for (int i = 0; i < nranks - 1; i++) { + size_t rslice = (rank - i - 1 + nranks) % nranks; + size_t sslice = (rank - i + nranks) % nranks; + + SHMEM_CHECK_RET(socket_send(&state->ring_send_sock, ((char*)out + sslice * len), len), "rank " << rank << ": barrier send failed"); + SHMEM_CHECK_RET(socket_recv(&state->ring_recv_sock, ((char*)out + rslice * len), len), "rank " << rank << ": barrier recv failed"); + } + return SHMEM_SUCCESS; +} + + +static int shmemi_bootstrap_uid_barrier(shmemi_bootstrap_handle_t *handle) { + SHM_LOG_INFO("shmemi_bootstrap_uid_barrier"); + if (!handle || !handle->bootstrap_state) { + SHM_LOG_ERROR("bootstrap barrier: invalid arguments"); + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*) handle->bootstrap_state; + int rank = state->rank; + int nranks = state->nranks; + + if (nranks == 1) { + return SHMEM_SUCCESS; + } + + if (state->ring_send_sock.state != SOCKET_STATE_READY || + state->ring_recv_sock.state != SOCKET_STATE_READY) { + SHM_LOG_ERROR("bootstrap barrier: rank " << rank << ": sockets not ready for barrier"); + return SHMEM_BOOTSTRAP_ERROR; + } + + char token = 0; + if (rank == 0) { + SHMEM_CHECK_RET(socket_send(&state->ring_send_sock, &token, 1), "rank 0: barrier send failed"); + SHMEM_CHECK_RET(socket_recv(&state->ring_recv_sock, &token, 1), "rank 0: barrier recv failed"); + } else { + SHMEM_CHECK_RET(socket_recv(&state->ring_recv_sock, &token, 1), "rank " << rank << ": barrier recv failed"); + SHMEM_CHECK_RET(socket_send(&state->ring_send_sock, &token, 1), "rank " << rank << ": barrier send failed"); + } + return SHMEM_SUCCESS; +} +static int unexpected_dequeue(uid_bootstrap_state* state, int peer, int tag, socket_t* sock, int* found) { + SHM_LOG_INFO("unexpected_dequeue start."); + if (state == NULL || sock == NULL || found == NULL) { + return SHMEM_BOOTSTRAP_ERROR; + } + + unexpected_conn_t* elem = state->unexpected_conns; + unexpected_conn_t* prev = NULL; + *found = 0; + while (elem != NULL) { + if (elem->peer == peer && elem->tag == tag) { + if (prev == NULL) { + state->unexpected_conns = elem->next; + } else { + prev->next = elem->next; + } + + memcpy(sock, &elem->sock, sizeof(socket_t)); + SHMEM_BOOTSTRAP_PTR_FREE(elem); + *found = 1; + return SHMEM_SUCCESS; + } + + prev = elem; + elem = elem->next; + } + return SHMEM_SUCCESS; +} + +static int unexpected_enqueue(uid_bootstrap_state* state, int peer, int tag, socket_t* sock) { + SHM_LOG_INFO("unexpected_enqueue start."); + if (state == NULL || sock == NULL) { + return SHMEM_BOOTSTRAP_ERROR; + } + + unexpected_conn_t* new_conn = NULL; + SHMEM_BOOTSTRAP_CALLOC(&new_conn, 1); + if (new_conn == NULL) { + return SHMEM_BOOTSTRAP_ERROR; + } + + new_conn->peer = peer; + new_conn->tag = tag; + memcpy(&new_conn->sock, sock, sizeof(socket_t)); + new_conn->next = NULL; + if (state->unexpected_conns == NULL) { + state->unexpected_conns = new_conn; + } else { + new_conn->next = state->unexpected_conns; + state->unexpected_conns = new_conn; + } + + return SHMEM_SUCCESS; +} +static int bootstrap_send(void* comm_state, int peer, int tag, void* data, int size) { + if (comm_state == nullptr || data == nullptr || size < 0 || peer < 0) { + SHM_LOG_ERROR("bootstrap_send: invalid arguments"); + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*)comm_state; + socket_t sock; + SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, state->magic, &state->peer_addrs[peer]), "bootstrap_send: socket_init failed for peer " << peer); + + SHMEM_CHECK_RET_CLOSE_SOCK(socket_connect(&sock), "bootstrap_send: socket_connect failed for peer " << peer, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &state->rank, sizeof(int)), "bootstrap_send: send rank failed to peer " << peer, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &tag, sizeof(int)), "bootstrap_send: send tag " << tag << " failed to peer " << peer, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, data, size), "bootstrap_send: send data (size=" << size << ") failed to peer " << peer, sock); + if (sock.fd >= 0) { + socket_close(&sock); + } + return SHMEM_SUCCESS; +} + + +static int bootstrap_recv(void* comm_state, int peer, int tag, void* data, int size) { + if (comm_state == NULL || data == NULL || size < 0 || peer < 0) { + return SHMEM_BOOTSTRAP_ERROR; + } + + uid_bootstrap_state* state = (uid_bootstrap_state*)comm_state; + socket_t sock; + int found = 0; + int retry_count = 0; + int ret = SHMEM_SUCCESS; + SHMEM_CHECK_RET(unexpected_dequeue(state, peer, tag, &sock, &found)); + + if (found == 1) { + ret = socket_recv(&sock, data, size); + socket_close(&sock); + return (ret == SHMEM_SUCCESS) ? SHMEM_SUCCESS : SHMEM_BOOTSTRAP_ERROR; + } + while (1) { + socket_t new_sock; + int new_peer = -1; + int new_tag = -1; + SHMEM_CHECK_RET(socket_init(&new_sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, NULL), "socket_init new_sock failed"); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_accept(&new_sock, &state->listen_sock), "socket_accept new_sock failed", new_sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&new_sock, &new_peer, sizeof(int)), "socket_recv new_peer failed", new_sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&new_sock, &new_tag, sizeof(int)), "socket_recv new_tag failed", new_sock); + if (new_peer == peer && new_tag == tag) { + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&new_sock, data, size), "socket_recv failed", new_sock); + return SHMEM_SUCCESS; + } else { + SHMEM_CHECK_RET_CLOSE_SOCK(unexpected_enqueue(state, new_peer, new_tag, &new_sock), "unexpected_enqueue failed", new_sock); + } + } +} + +static int shmemi_bootstrap_uid_barrier_v2(shmemi_bootstrap_handle_t *handle) { + SHM_LOG_INFO("shmemi_bootstrap_uid_barrier_v2"); + uid_bootstrap_state* state = (uid_bootstrap_state*)(handle->bootstrap_state); + int rank = state->rank; + int tag = 0; + int nranks = state->nranks; + + if (nranks == 1) { + SHM_LOG_DEBUG("Single rank, skip barrier"); + return SHMEM_SUCCESS; + } + + SHM_LOG_DEBUG("Barrier start. rank: " << rank << " nranks: " << nranks <<" tag: "<< tag); + + int data[1]; + for (int mask = 1; mask < nranks; mask <<= 1) { + int src = (rank - mask + nranks) % nranks; + int dst = (rank + mask) % nranks; + tag++; + + SHMEM_CHECK_RET(bootstrap_send(state, dst, tag, data, sizeof(data)), "rank " << rank << ": barrier send failed, dst: " << dst << "tag: " << tag); + SHMEM_CHECK_RET(bootstrap_recv(state, src, tag, data, sizeof(data)), "rank " << rank << ": barrier recv failed, src: " << src << "tag: " << tag); + } + + SHM_LOG_DEBUG("Barrier end. rank: " << rank << " nranks: " << nranks <<" tag: "<< tag); + return SHMEM_SUCCESS; +} + +static int shmemi_bootstrap_uid_alltoall(const void *sendbuf, void *recvbuf, int length, + shmemi_bootstrap_handle_t *handle) { + +} + +static void shmemi_bootstrap_uid_global_exit(int status) { + +} + +static bool matchSubnet(struct ifaddrs local_if, sockaddr_t* remote) { + int family; + bool is_lo_interface = (strncmp(local_if.ifa_name, "lo", 2) == 0); + if (remote->type == ADDR_IPv4) { + family = AF_INET; + } else if (remote->type == ADDR_IPv6) { + family = AF_INET6; + } else { + return false; + } + + SHM_LOG_DEBUG("local_if family: " << local_if.ifa_addr->sa_family << " remote family: " << family); + if (family != local_if.ifa_addr->sa_family) { + SHM_LOG_DEBUG(" matchSubnet family unmatch."); + return false; + } + + if (family == AF_INET) { + struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr); + struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask); + struct sockaddr_in* remote_addr = &remote->addr.addr4; + + uint32_t local_subnet = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr; + uint32_t remote_subnet = remote_addr->sin_addr.s_addr & mask->sin_addr.s_addr; + SHM_LOG_DEBUG("ipv4 matchSubnet result:" << (local_subnet == remote_subnet)); + return local_subnet == remote_subnet; + } else if (family == AF_INET6) { + struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr); + struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask); + struct sockaddr_in6* remote_addr = &remote->addr.addr6; + + bool same = true; + for (int c = 0; c < 16; c++) { + uint8_t l = local_addr->sin6_addr.s6_addr[c] & mask->sin6_addr.s6_addr[c]; + uint8_t r = remote_addr->sin6_addr.s6_addr[c] & mask->sin6_addr.s6_addr[c]; + if (l != r) { + same = false; + break; + } + } + if (is_lo_interface) { + SHM_LOG_DEBUG("IPv6 on lo interface, skipping sin6_scope_id validation"); + SHM_LOG_DEBUG("ipv6 matchSubnet result:" << same); + return same; + } + same &= (local_addr->sin6_scope_id == remote_addr->sin6_scope_id); + SHM_LOG_DEBUG("ipv6 matchSubnet result:" << same << " local_addr->sin6_scope_id: " <sin6_scope_id << " remote_addr->sin6_scope_id: "<< remote_addr->sin6_scope_id); + return same; + } + return false; +} + +static int find_interface_match_subnet(char* ifNames, sockaddr_t* localAddrs, sockaddr_t* remoteAddr) { + int found = 0; + struct ifaddrs *interfaces, *interface; + if (getifaddrs(&interfaces) != 0) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (remoteAddr) { + if (remoteAddr->type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN]; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &remoteAddr->addr.addr4.sin_addr, ip_str, INET_ADDRSTRLEN) == nullptr, "convert remote ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(remoteAddr->addr.addr4.sin_port); + SHM_LOG_INFO(" Type: IPv4" << " IP: " << ip_str <<" Port: " << (port ? port : 0) << " (0 means not set)"); + } else if (remoteAddr->type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN]; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &remoteAddr->addr.addr6.sin6_addr, ip_str, INET6_ADDRSTRLEN) == nullptr, "convert remote ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(remoteAddr->addr.addr6.sin6_port); + SHM_LOG_INFO(" Type: IPv6" << " IP: " << ip_str <<" Port: " << (port ? port : 0) << " (0 means not set)"); + } else { + SHM_LOG_ERROR(" remoteAddr: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + } else { + SHM_LOG_ERROR(" remoteAddr is NULL."); + return SHMEM_BOOTSTRAP_ERROR; + } + SHMEM_CHECK_RET(is_link_local_addr(remoteAddr), "Remote address is link_local", SHMEM_BOOTSTRAP_ERROR); + bool remote_is_loopback = is_loopback_addr(remoteAddr); + SHM_LOG_INFO("Remote address is loopback:" << remote_is_loopback); + + if (remote_is_loopback) { + SHM_LOG_DEBUG("Remote address is loopback, check lo interface first"); + for (interface = interfaces; interface && !found; interface = interface->ifa_next) { + if (interface->ifa_addr == NULL) continue; + int family = interface->ifa_addr->sa_family; + if (family != AF_INET && family != AF_INET6) continue; + if (strcmp(interface->ifa_name, "lo") != 0) continue; + + if (matchSubnet(*interface, remoteAddr)) { + if (family == AF_INET) { + localAddrs->type = ADDR_IPv4; + memcpy(&localAddrs->addr.addr4, interface->ifa_addr, sizeof(struct sockaddr_in)); + } else { + localAddrs->type = ADDR_IPv6; + memcpy(&localAddrs->addr.addr6, interface->ifa_addr, sizeof(struct sockaddr_in6)); + } + strncpy(ifNames, interface->ifa_name, MAX_IF_NAME_SIZE); + ifNames[MAX_IF_NAME_SIZE] = '\0'; + found = 1; + break; + } + } + } + if (!found) { + for (interface = interfaces; interface && !found; interface = interface->ifa_next) { + if (interface->ifa_addr == NULL) continue; + int family = interface->ifa_addr->sa_family; + if (family != AF_INET && family != AF_INET6) continue; + + if (!remote_is_loopback && strcmp(interface->ifa_name, "lo") == 0) { + continue; + } + + if (matchSubnet(*interface, remoteAddr)) { + if (family == AF_INET) { + localAddrs->type = ADDR_IPv4; + memcpy(&localAddrs->addr.addr4, interface->ifa_addr, sizeof(struct sockaddr_in)); + } else { + localAddrs->type = ADDR_IPv6; + memcpy(&localAddrs->addr.addr6, interface->ifa_addr, sizeof(struct sockaddr_in6)); + } + strncpy(ifNames, interface->ifa_name, MAX_IF_NAME_SIZE); + ifNames[MAX_IF_NAME_SIZE] = '\0'; + found = 1; + break; + } + } + } + + freeifaddrs(interfaces); + return (found == 0) ? SHMEM_BOOTSTRAP_ERROR : SHMEM_SUCCESS; +} + +static int bootstrap_get_sock_addr(socket_t* sock, sockaddr_t* addr) { + if (sock == NULL) return SHMEM_BOOTSTRAP_ERROR; + struct sockaddr_storage temp_storage; + memset(&temp_storage, 0, sizeof(temp_storage)); + struct sockaddr* temp_addr = reinterpret_cast(&temp_storage); + socklen_t addr_len = 0; + int ret = socket_get_sainfo(sock, temp_addr, &addr_len); + if (ret != 0) { + return SHMEM_BOOTSTRAP_ERROR; + } + + if (temp_storage.ss_family == AF_INET) { + addr->type = ADDR_IPv4; + const struct sockaddr_in* ipv4_src = reinterpret_cast(&temp_storage); + memcpy(&addr->addr.addr4, ipv4_src, sizeof(struct sockaddr_in)); + } else if (temp_storage.ss_family == AF_INET6) { + addr->type = ADDR_IPv6; + const struct sockaddr_in6* ipv6_src = reinterpret_cast(&temp_storage); + memcpy(&addr->addr.addr6, ipv6_src, sizeof(struct sockaddr_in6)); + } else { + SHM_LOG_ERROR("Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + return SHMEM_SUCCESS; +} + +// Network Initialization (Locating Local Interface Matching Subnet / initialize root node UID information when is_arg_init == false) +static int shmemi_bootstrap_net_init(shmemx_bootstrap_uid_state_t* uid_args, bool is_arg_init = true) { + SHM_LOG_INFO(" Network Initialization, Finding Interfaces Matching Subnets"); + pthread_mutex_lock(&priv_info.bootstrap_netlock); + + if (!is_arg_init) { + SHM_LOG_INFO("net_init uid_args is NULL, get uid arg"); + bool is_from_ifa = false; + + if (env_ip_port != nullptr) { + SHM_LOG_INFO("Environment variable SHMEM_UID_SESSION_ID has been set."); + SHMEM_CHECK_RET(shmemi_get_ip_from_env(uid_args, env_ip_port), + "No available addresses were found with env_ip_port."); + } else { + SHM_LOG_INFO("Environment variable SHMEM_UID_SESSION_ID is not set, automatically obtaining ipPort."); + is_from_ifa = true; + SHMEM_CHECK_RET(shmemi_get_ip_from_ifa(uid_args, env_ifname), + "No available addresses were found with auto."); + } + + SHM_LOG_INFO("Get uid arg success."); + is_arg_init = true; + } + + if (priv_info.bootstrap_netinitdone) { + // Initialized, printing currently saved information + SHM_LOG_INFO(" priv_info already inited: " << " bootstrap_netifname: " << (priv_info.bootstrap_netifname ? priv_info.bootstrap_netifname : "nullptr")); + if (priv_info.bootstrap_netifaddr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &priv_info.bootstrap_netifaddr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv4): " << ip_str << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr4.sin_port)); + } else if (priv_info.bootstrap_netifaddr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &priv_info.bootstrap_netifaddr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv6): " << ip_str << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr6.sin6_port)); + } else { + SHM_LOG_ERROR(" bootstrap_netifaddr: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + pthread_mutex_unlock(&priv_info.bootstrap_netlock); + return SHMEM_SUCCESS; + } + + // Print the root node address to be matched (uid_args->addr) + if (uid_args->addr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &uid_args->addr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert uid_args addr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" Root address (IPv4): " << ip_str << ":" << ntohs(uid_args->addr.addr.addr4.sin_port)); + } else if (uid_args->addr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &uid_args->addr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert uid_args addr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" Root address (IPv6): " << ip_str << ":" << ntohs(uid_args->addr.addr.addr6.sin6_port)); + } else { + SHM_LOG_ERROR(" Root address: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + int find_result = find_interface_match_subnet(priv_info.bootstrap_netifname, + &priv_info.bootstrap_netifaddr, + &uid_args->addr); + if (find_result != 0) { + SHM_LOG_ERROR(" Failed to find matching interface."); + pthread_mutex_unlock(&priv_info.bootstrap_netlock); + return SHMEM_BOOTSTRAP_ERROR; + } + + // Print the information of priv_info. + if (priv_info.bootstrap_netifaddr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &priv_info.bootstrap_netifaddr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv4): " << ip_str + << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr4.sin_port)); + } else if (priv_info.bootstrap_netifaddr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &priv_info.bootstrap_netifaddr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert bootstrap_netifaddr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_INFO(" bootstrap_netifaddr (IPv6): " << ip_str + << ":" << ntohs(priv_info.bootstrap_netifaddr.addr.addr6.sin6_port)); + } else { + SHM_LOG_ERROR(" Root bootstrap_netifaddr: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + priv_info.bootstrap_netinitdone = 1; + pthread_mutex_unlock(&priv_info.bootstrap_netlock); + SHM_LOG_INFO(" Net init success, priv_info.bootstrap_netinitdone = 1"); + return SHMEM_SUCCESS; +} + +static int set_files_limit() { + struct rlimit files_limit, old_limit; + + SHMEM_CHECK_RET(getrlimit(RLIMIT_NOFILE, &old_limit), "getrlimit failed", SHMEM_BOOTSTRAP_ERROR); + SHM_LOG_DEBUG("Original file descriptor limit - soft limit: " << old_limit.rlim_cur << ", hard limit: " << old_limit.rlim_max); + + files_limit = old_limit; + files_limit.rlim_cur = files_limit.rlim_max; + SHMEM_CHECK_RET(setrlimit(RLIMIT_NOFILE, &files_limit), "setrlimit failed", SHMEM_BOOTSTRAP_ERROR); + + struct rlimit new_limit; + getrlimit(RLIMIT_NOFILE, &new_limit); + SHM_LOG_DEBUG("Updated file descriptor limit - soft limit: " << new_limit.rlim_cur << ", hard limit: " << new_limit.rlim_max); + + return SHMEM_SUCCESS; +} + +static void* bootstrap_root(void* rargs) { + struct bootstrap_root_args* args = (struct bootstrap_root_args*)rargs; + if (args == NULL || args->listen_sock == NULL) { + SHM_LOG_ERROR("bootstrap_root: invalid args"); + return NULL; + } + + socket_t* listen_sock = args->listen_sock; + uint64_t magic = args->magic; + int root_version = args->version; + int nranks = 0; + int c = 0; // Number of received nodes. + bootstrap_ext_info info; + sockaddr_t* zero_addr = nullptr; + SHMEM_BOOTSTRAP_CALLOC(&zero_addr, 1); + sockaddr_t* rank_addrs = NULL; // Store the common listening addresses of all nodes. + sockaddr_t* rank_addrs_root = NULL; // Store the dedicated root addresses for all nodes. + + if (zero_addr == NULL) { + SHM_LOG_ERROR("bootstrap_root: calloc zero_addr failed"); + SHMEM_BOOTSTRAP_PTR_FREE(args); + return NULL; + } + + // Adjusting file descriptor limits (the root node needs to handle multiple connections) + if (set_files_limit() != 0) { + SHM_LOG_ERROR("bootstrap_root: set_files_limit failed"); + SHMEM_BOOTSTRAP_PTR_FREE(zero_addr); + SHMEM_BOOTSTRAP_PTR_FREE(args); + return NULL; + } + + // Continuously receive connections and information from all slave nodes + while (1) { + socket_t client_sock; + // Initialize client socket (for receiving connections from a single slave node) + if (socket_init(&client_sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, NULL) != 0) { + SHM_LOG_ERROR("bootstrap_root: socket_init failed"); + break; + } + + // Accept connections from the node (blocking wait) + if (socket_accept(&client_sock, listen_sock) != 0) { + SHM_LOG_ERROR("bootstrap_root: socket_accept failed"); + socket_close(&client_sock); + break; + } + + // Version verification + int peer_version; + if (socket_recv(&client_sock, &peer_version, sizeof(peer_version)) != 0) { + SHM_LOG_ERROR("bootstrap_root: recv peer_version failed"); + socket_close(&client_sock); + break; + } + if (socket_send(&client_sock, &root_version, sizeof(root_version)) != 0) { + SHM_LOG_ERROR("bootstrap_root: send root_version failed"); + socket_close(&client_sock); + break; + } + if (peer_version != root_version) { + SHM_LOG_ERROR("bootstrap_root: version mismatch"); + socket_close(&client_sock); + break; + } + + // Receive address information from the node + if (socket_recv(&client_sock, &info, sizeof(info)) != 0) { + SHM_LOG_ERROR("bootstrap_root: recv info failed"); + socket_close(&client_sock); + break; + } + socket_close(&client_sock); + + // Initialize the address array upon first reception + if (c == 0) { + nranks = info.nranks; + if (nranks <= 0) { + SHM_LOG_ERROR("bootstrap_root: invalid nranks"); + break; + } + SHMEM_BOOTSTRAP_CALLOC(&rank_addrs, nranks); + SHMEM_BOOTSTRAP_CALLOC(&rank_addrs_root, nranks); + if (rank_addrs == NULL || rank_addrs_root == NULL) { + SHM_LOG_ERROR("bootstrap_root: calloc addr arrays failed"); + break; + } + } + + if (info.nranks != nranks || info.rank < 0 || info.rank >= nranks) { + SHM_LOG_ERROR("bootstrap_root: invalid info from rank " << info.rank); + break; + } + // Check if the rank is duplicated + if (memcmp(zero_addr, &rank_addrs_root[info.rank], sizeof(sockaddr_t)) != 0) { + SHM_LOG_ERROR("bootstrap_root: duplicate rank " << info.rank); + break; + } + + memcpy(&rank_addrs_root[info.rank], &info.ext_address_listen_root, sizeof(sockaddr_t)); + memcpy(&rank_addrs[info.rank], &info.ext_addr_listen, sizeof(sockaddr_t)); + c++; + + if (c >= nranks) { + SHM_LOG_INFO("bootstrap_root: Address receiving completed"); + break; + } + } + + if (c == nranks && rank_addrs != NULL && rank_addrs_root != NULL) { + SHM_LOG_INFO("bootstrap_root: Start distributing addresses."); + for (int r = 0; r < nranks; r++) { + int next_rank = (r + 1) % nranks; + socket_t send_sock; + + if (socket_init(&send_sock, SOCKET_TYPE_BOOTSTRAP, magic, &rank_addrs_root[r]) != 0) { + SHM_LOG_ERROR("bootstrap_root: init send_sock for rank " << r << " failed"); + break; + } + + if (socket_connect(&send_sock) != 0) { + SHM_LOG_ERROR("bootstrap_root: connect to rank " << r << " failed"); + socket_close(&send_sock); + break; + } + + if (socket_send(&send_sock, &rank_addrs[next_rank], sizeof(sockaddr_t)) != 0) { + SHM_LOG_ERROR("bootstrap_root: send next_addr to rank " << r << " failed"); + socket_close(&send_sock); + break; + } + + socket_close(&send_sock); + } + } + + SHMEM_BOOTSTRAP_PTR_FREE(zero_addr); + SHMEM_BOOTSTRAP_PTR_FREE(rank_addrs); + SHMEM_BOOTSTRAP_PTR_FREE(rank_addrs_root); + if (listen_sock != NULL) { + socket_close(listen_sock); + SHMEM_BOOTSTRAP_PTR_FREE(listen_sock); + } + SHMEM_BOOTSTRAP_PTR_FREE(args); + return NULL; +} + +static int bootstrap_create_root(shmemx_bootstrap_uid_state_t* uid_args) { + if (uid_args == NULL) { + SHM_LOG_ERROR("bootstrap_create_root: invalid uid_args"); + return SHMEM_BOOTSTRAP_ERROR; + } + + // 1. Create a dedicated listening socket for the root node. + socket_t* listen_sock_root = nullptr; + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&listen_sock_root, 1), "bootstrap_create_root: malloc listen_sock_root failed"); + + // 2. Initialize the listening socket (using the global network interface address) + SHMEM_CHECK_RET(socket_init(listen_sock_root, SOCKET_TYPE_BOOTSTRAP, uid_args->magic, &uid_args->addr), "bootstrap_create_root: socket_init failed"); + + SHMEM_CHECK_RET_CLOSE_SOCK(socket_listen(listen_sock_root), "Listen_sock_root failed while executing listen. fd=" << listen_sock_root->fd, *listen_sock_root); + + // 3. Write the root node's listening address into uid_args (for slave nodes to connect to). + memcpy(&uid_args->addr, &listen_sock_root->addr, sizeof(sockaddr_t)); + + // 4. Prepare thread parameters + struct bootstrap_root_args* args = nullptr; + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&args, 1), "bootstrap_create_root: malloc args failed"); + + args->listen_sock = listen_sock_root; + args->magic = uid_args->magic; + args->version = uid_args->version; + + // 5. Create detached thread + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); + int ret = pthread_create(&priv_info.bootstrap_root, &attr, bootstrap_root, args); + if (ret != 0) { + SHM_LOG_ERROR("bootstrap_create_root: pthread_create failed"); + SHMEM_BOOTSTRAP_PTR_FREE(args); + socket_close(listen_sock_root); + SHMEM_BOOTSTRAP_PTR_FREE(listen_sock_root); + return SHMEM_BOOTSTRAP_ERROR; + } + pthread_attr_destroy(&attr); + return SHMEM_SUCCESS; +} + + + +int shmemi_bootstrap_get_unique_id(void* uid) { + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)uid; + + if (env_ip_port == nullptr) { + const char* envip = std::getenv("SHMEM_UID_SESSION_ID"); + if (envip != nullptr) { + env_ip_port = envip; + SHM_LOG_DEBUG("SHMEM_UID_SESSION_ID is: " << env_ip_port); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SESSION_ID is not set."); + } + } + + if (env_ifname == nullptr) { + const char* envinfo = std::getenv("SHMEM_UID_SOCK_IFNAME"); + if (envinfo != nullptr) { + env_ifname = envinfo; + SHM_LOG_DEBUG("SHMEM_UID_SOCK_IFNAME is: " << env_ifname); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SOCK_IFNAME is not set."); + } + } + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args, false), "rank 0: failed to init bootstrap net."); + SHMEM_CHECK_RET(bootstrap_create_root(uid_args), "rank 0: failed to create root thread"); + return SHMEM_SUCCESS; +} + +int shmemi_bootstrap_get_unique_id_static_magic(void* uid, bool is_root) { + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)uid; + + if (env_ip_port == nullptr) { + const char* envip = std::getenv("SHMEM_UID_SESSION_ID"); + if (envip != nullptr) { + env_ip_port = envip; + SHM_LOG_DEBUG("SHMEM_UID_SESSION_ID is: " << env_ip_port); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SESSION_ID is not set."); + } + } + if (env_ip_port == nullptr) { + SHM_LOG_ERROR("Using method get_unique_id_static_magic requires setting SHMEM_UID_SESSION_ID."); + return SHMEM_BOOTSTRAP_ERROR; + } + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args, false), "rank 0: failed to init bootstrap net."); + uid_args->magic = SOCKET_MAGIC + static_magic_count; + static_magic_count++; + if (is_root) { + SHMEM_CHECK_RET(bootstrap_create_root(uid_args), "rank 0: failed to create root thread"); + } + return SHMEM_SUCCESS; +} + +int shmemi_bootstrap_get_unique_id_by_ipport(void* uid, const char *ipport) { + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)uid; + + if (ipport != nullptr) { + env_ip_port = ipport; + SHM_LOG_DEBUG("The ipport param is: " << env_ip_port); + } else { + + SHM_LOG_DEBUG("The ipport param is not set. Try to use SHMEM_UID_SESSION_ID."); + const char* envip = std::getenv("SHMEM_UID_SESSION_ID"); + if (envip != nullptr) { + env_ip_port = envip; + SHM_LOG_DEBUG("SHMEM_UID_SESSION_ID is: " << env_ip_port); + } else { + SHM_LOG_DEBUG("The environment variable SHMEM_UID_SESSION_ID is not set."); + } + } + if (env_ip_port == nullptr) { + SHM_LOG_ERROR("Using method get_unique_id_by_ipport requires setting ipport or SHMEM_UID_SESSION_ID."); + return SHMEM_BOOTSTRAP_ERROR; + } + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args, false), "rank 0: failed to init bootstrap net."); + uid_args->magic = SOCKET_MAGIC + static_magic_count; + static_magic_count++; + if (uid_args->rank == 0) { + SHMEM_CHECK_RET(bootstrap_create_root(uid_args), "rank 0: failed to create root thread"); + } + return SHMEM_SUCCESS; +} + +// Plugin pre-initialization entry function. +int shmemi_bootstrap_plugin_pre_init(shmemi_bootstrap_handle_t* handle) { + if (handle->pre_init_ops == nullptr) { + SHM_LOG_DEBUG(" bootstrap plugin pre init start."); + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1)); + handle->pre_init_ops->get_unique_id = shmemi_bootstrap_get_unique_id; + handle->pre_init_ops->get_unique_id_static_magic = shmemi_bootstrap_get_unique_id_static_magic; + handle->pre_init_ops->cookie = nullptr; + SHM_LOG_DEBUG(" bootstrap plugin pre init end."); + } else { + SHM_LOG_DEBUG(" pre_init_ops had already prepared."); + } + return SHMEM_SUCCESS; +} + + +int shmemi_bootstrap_plugin_init(void* comm, shmemi_bootstrap_handle_t* handle) { + if (comm == nullptr || handle == nullptr) { + SHM_LOG_ERROR(" shmemi_bootstrap_plugin_init: invalid arguments (nullptr)"); + return SHMEM_BOOTSTRAP_ERROR; + } + socket_t sock, listen_sock_root; + uid_bootstrap_state* state = nullptr; + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&state, 1)); + shmemx_bootstrap_uid_state_t* uid_args = (shmemx_bootstrap_uid_state_t*)comm; + sockaddr_t next_addr; + bootstrap_ext_info info = {}; + int rank = uid_args->rank; + int nranks = uid_args->nranks; + + if (handle->use_attr_ipport && handle->ipport != nullptr) { + SHM_LOG_DEBUG("shmemi_bootstrap_get_unique_id_by_ipport start. ipport: " << handle->ipport); + shmemi_bootstrap_get_unique_id_by_ipport(comm, handle->ipport); + } + uint64_t magic = uid_args->magic; + + SHMEM_CHECK_RET(shmemi_bootstrap_net_init(uid_args), " rank: " << rank << ": network interface init failed."); + + if (state == nullptr) { + SHM_LOG_ERROR(" rank: " << rank << ": failed to allocate uid_bootstrap_state"); + return SHMEM_BOOTSTRAP_ERROR; + } + + state->rank = rank; + state->nranks = nranks; + state->magic = magic; + + SHMEM_CHECK_RET(SHMEM_BOOTSTRAP_CALLOC(&state->peer_addrs, nranks)); + + if (state->peer_addrs == nullptr) { + SHM_LOG_ERROR(" rank: " << rank << ": failed to allocate peer_addrs"); + SHMEM_BOOTSTRAP_PTR_FREE(state); + return SHMEM_BOOTSTRAP_ERROR; + } + + handle->bootstrap_state = state; + handle->mype = rank; + handle->npes = nranks; + + SHMEM_CHECK_RET(socket_init(&state->listen_sock, SOCKET_TYPE_BOOTSTRAP, state->magic, &priv_info.bootstrap_netifaddr), "State's listen_sock failed while executing init. fd=" << state->listen_sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_listen(&state->listen_sock), "State's listen_sock failed while executing listen. fd=" << state->listen_sock.fd, state->listen_sock); + SHMEM_CHECK_RET(bootstrap_get_sock_addr(&state->listen_sock, &info.ext_addr_listen), "Get addr failed, the listen_sock in state maybe null. fd=" << state->listen_sock.fd); + + SHMEM_CHECK_RET(socket_init(&listen_sock_root, SOCKET_TYPE_BOOTSTRAP, state->magic, &priv_info.bootstrap_netifaddr), "Listen_sock_root failed while executing init. fd=" << listen_sock_root.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_listen(&listen_sock_root), "listen_sock_root failed while executing listen. fd=" << listen_sock_root.fd, listen_sock_root); + SHMEM_CHECK_RET(bootstrap_get_sock_addr(&listen_sock_root, &info.ext_address_listen_root), "Get addr failed, the listen_sock_root maybe null. fd=" << listen_sock_root.fd); + + + SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, magic, &uid_args->addr), "Sock failed while executing init. fd=" << sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_connect(&sock), "Sock failed while executing connect. fd=" << sock.fd, sock); + int peer_version = uid_args->version; + int root_version; + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &peer_version, sizeof(peer_version)), "Sock failed while executing send peer_version. fd=" << sock.fd, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&sock, &root_version, sizeof(root_version)), "Sock failed while executing recv root_version. fd=" << sock.fd, sock); + SHMEM_CHECK_RET(peer_version != root_version, " rank: " << rank << " . version mismatch with root", SHMEM_SMEM_ERROR); + + info.rank = rank; + info.nranks = nranks; + + if (info.ext_addr_listen.type == ADDR_IPv4) { + struct sockaddr_in* ipv4 = &info.ext_addr_listen.addr.addr4; + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_addr_listen ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv4->sin_port); + SHM_LOG_INFO(" Ext_addr_listen socket: Type: IPv4, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); + + } else if (info.ext_addr_listen.type == ADDR_IPv6) { + struct sockaddr_in6* ipv6 = &info.ext_addr_listen.addr.addr6; + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_addr_listen ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv6->sin6_port); + SHM_LOG_INFO(" Ext_addr_listen socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv6->sin6_family); + } else { + SHM_LOG_ERROR(" Ext_address_listen_root socket: Type: Unknown address type is not within IPv4 or IPv6. (type=" << info.ext_addr_listen.type << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (info.ext_address_listen_root.type == ADDR_IPv4) { + struct sockaddr_in* ipv4 = &info.ext_address_listen_root.addr.addr4; + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_address_listen_root ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv4->sin_port); + SHM_LOG_INFO(" Ext_address_listen_root socket: Type: IPv4, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv4->sin_family); + + } else if (info.ext_address_listen_root.type == ADDR_IPv6) { + struct sockaddr_in6* ipv6 = &info.ext_address_listen_root.addr.addr6; + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ext_address_listen_root ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv6->sin6_port); + SHM_LOG_INFO(" Ext_address_listen_root socket: Type: IPv6, IP: " << ip_str << ", Port: " << port << ", sa_family: " << ipv6->sin6_family); + } else { + SHM_LOG_ERROR(" Ext_address_listen_root socket: Type: Unknown address type is not within IPv4 or IPv6. (type=" << info.ext_address_listen_root.type << ")"); + return SHMEM_BOOTSTRAP_ERROR; + + } + + + SHMEM_CHECK_RET_CLOSE_SOCK(socket_send(&sock, &info, sizeof(info)), "Sock failed while executing send info. fd=" << sock.fd, sock); + SHMEM_CHECK_RET(socket_close(&sock), "Sock failed while executing close. fd=" << sock.fd); + + + SHMEM_CHECK_RET(socket_init(&sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, nullptr), "Sock failed while executing init. fd=" << sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_accept(&sock, &listen_sock_root), "Sock failed while executing accept listen_sock_root. fd=" << sock.fd, sock); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_recv(&sock, &next_addr, sizeof(next_addr)), "Sock failed while executing recv next_addr. fd=" << sock.fd, sock); + SHMEM_CHECK_RET(socket_close(&sock), "Sock failed while executing close. fd=" << sock.fd); + SHMEM_CHECK_RET(socket_close(&listen_sock_root), "Listen_sock_root failed while executing close. fd=" << listen_sock_root.fd); + + + if (next_addr.type == ADDR_IPv4) { + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &next_addr.addr.addr4.sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert next_addr ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(next_addr.addr.addr4.sin_port); + SHM_LOG_INFO(" Received next socket: Type: IPv4, IP: " << ip_str << ", Port: " << port); + } else if (next_addr.type == ADDR_IPv6) { + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &next_addr.addr.addr6.sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert next_addr ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(next_addr.addr.addr6.sin6_port); + SHM_LOG_INFO(" Received next socket: Type: IPv6, IP: " << ip_str << ", Port: " << port); + } else { + SHM_LOG_ERROR(" Received next socket: Type: Unknown address type is not within IPv4 or IPv6."); + return SHMEM_BOOTSTRAP_ERROR; + } + + // Initialize ring send socket + SHMEM_CHECK_RET(socket_init(&state->ring_send_sock, SOCKET_TYPE_BOOTSTRAP, magic, &next_addr), "State's ring_send_sock failed while executing init. fd=" << state->ring_send_sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_connect(&state->ring_send_sock), "State's ring_send_sock failed while executing connect. fd=" << state->ring_send_sock.fd, state->ring_send_sock); + SHMEM_CHECK_RET(socket_init(&state->ring_recv_sock, SOCKET_TYPE_BOOTSTRAP, SOCKET_MAGIC, nullptr), "State's ring_recv_sock failed while executing init. fd=" << state->ring_recv_sock.fd); + SHMEM_CHECK_RET_CLOSE_SOCK(socket_accept(&state->ring_recv_sock, &state->listen_sock),"State's ring_recv_sock failed while executing accept State's listen_sock. fd=" << state->ring_recv_sock.fd, state->ring_recv_sock); + SHMEM_CHECK_RET(bootstrap_get_sock_addr(&state->listen_sock, state->peer_addrs + handle->mype), "Get addr failed, the listen_sock in state maybe null. fd=" << state->listen_sock.fd); + + SHMEM_CHECK_RET(shmemi_bootstrap_uid_allgather(BOOTSTRAP_IN_PLACE, state->peer_addrs, sizeof(sockaddr_t), handle), "Bootstrap_uid_allgather failed"); + + handle->allgather = shmemi_bootstrap_uid_allgather; + handle->barrier = shmemi_bootstrap_uid_barrier_v2; + handle->finalize = shmemi_bootstrap_uid_finalize; + handle->alltoall = nullptr; + handle->global_exit = nullptr; + + SHM_LOG_INFO("rank " << rank << ": bootstrap plugin initialized successfully"); + return SHMEM_SUCCESS; +} \ No newline at end of file diff --git a/src/modules/bootstrap/socket/uid_socket.cpp b/src/modules/bootstrap/socket/uid_socket.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d87bc10bc5647ddeb58551afd6edf498e31652f --- /dev/null +++ b/src/modules/bootstrap/socket/uid_socket.cpp @@ -0,0 +1,591 @@ + +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include +#include "uid_socket.h" + +static int socket_poll_fd(int fd, int events, int timeout_ms) { + struct pollfd pfd = {0}; + pfd.fd = fd; + pfd.events = events; + + int ret = poll(&pfd, 1, timeout_ms); + if (ret == -1) { + SHM_LOG_ERROR("poll failed: " << strerror(errno) << " (fd: " << fd << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } else if (ret == 0) { + SHM_LOG_ERROR("poll timeout (" << timeout_ms << "ms) - fd: " << fd); + return SHMEM_TIMEOUT_ERROR; + } + + // 检查fd错误 + if (pfd.revents & (POLLERR | POLLHUP | POLLNVAL)) { + SHM_LOG_ERROR("fd error (revents: " << pfd.revents << ") - fd: " << fd); + return SHMEM_BOOTSTRAP_ERROR; + } + + return SHMEM_SUCCESS; +} + +static int socket_progress(int op, socket_t* sock, void* ptr, int size, int* offset, bool block = false, bool state_check = true) { + if (sock == nullptr || ptr == nullptr || offset == nullptr || size < 0 || *offset < 0 || *offset > size) { + SHM_LOG_ERROR("Invalid arguments: sock=" << sock << ", ptr=" << ptr + << ", size=" << size << ", offset=" << *offset); + return SHMEM_BOOTSTRAP_ERROR; + } + if (state_check && sock->state != SOCKET_STATE_READY) { + SHM_LOG_ERROR("socket_progress: invalid state " << sock->state << " (expected READY)"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + int poll_events = (op == SOCKET_TYPE_RECV) ? POLLIN : POLLOUT; + SHMEM_CHECK_RET(socket_poll_fd(sock->fd, poll_events, SOCKET_RECV_TIMEOUT_MS), "socket_poll_fd failed."); + + int bytes = 0; + int closed = 0; + char* data = (char*)(ptr); + SHM_LOG_DEBUG("socket_progress: start"); + do { + if (op == SOCKET_TYPE_RECV) { + int flags = block ? 0 : MSG_DONTWAIT; + SHM_LOG_DEBUG("Executing RECV operation - fd: " << sock->fd << ", buffer offset: " << *offset << ", remaining size: " << (size - *offset) << ", flags: " << flags); + bytes = recv(sock->fd, data + *offset, size - *offset, flags); + SHM_LOG_DEBUG("RECV result - bytes received: " << bytes); + } else if (op == SOCKET_TYPE_SEND) { + int flags = block ? MSG_NOSIGNAL : (MSG_DONTWAIT | MSG_NOSIGNAL); + SHM_LOG_DEBUG("Executing SEND operation - fd: " << sock->fd << ", buffer offset: " << *offset << ", remaining size: " << (size - *offset) << ", flags: " << flags); + bytes = send(sock->fd, data + *offset, size - *offset, flags); + SHM_LOG_DEBUG("SEND result - bytes sent: " << bytes); + } else { + SHM_LOG_ERROR("Invalid operation type: " << op); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (op == SOCKET_TYPE_RECV && bytes == 0) { + SHM_LOG_DEBUG("RECV operation got 0 bytes - remote peer closed the connection (fd: " << sock->fd << ")"); + closed = 1; + break; + } + + if (bytes == -1) { + int err = errno; + if (err != EINTR && err != EWOULDBLOCK && err != EAGAIN) { + SHM_LOG_ERROR("Socket operation failed (fd: " << sock->fd << ", op: " << op << ") - error: " << strerror(err) << " (errno: " << err << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } else { + SHM_LOG_DEBUG("Socket operation would block (fd: " << sock->fd << ", op: " << op << ") - errno: " << err << ", setting bytes to 0"); + bytes = 0; + } + } + + *offset += bytes; + SHM_LOG_DEBUG("Updated buffer offset - current offset: " << *offset << ", total size: " << size); + } while (bytes > 0 && *offset < size); + + if (closed) { + SHM_LOG_ERROR("Loop exited - remote connection closed (fd: " << sock->fd << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_progress: success"); + + return SHMEM_SUCCESS; +} + +static int socket_wait(int op, socket_t* sock, void* ptr, int size, int* offset, bool block = false, bool state_check = true) { + while (*offset < size) + if (socket_progress(op, sock, ptr, size, offset, block, state_check) != SHMEM_SUCCESS) { + SHM_LOG_ERROR("socket_wait fail!"); + return SHMEM_BOOTSTRAP_ERROR; + } + return SHMEM_SUCCESS; +} + +int socket_send(socket_t* sock, void* ptr, int size) { + SHM_LOG_DEBUG("socket_send: start"); + int offset = 0; + if (sock == NULL || ptr == NULL || size <= 0 ) { + SHM_LOG_ERROR("send sock == NULL"); + return SHMEM_BOOTSTRAP_ERROR; + } + + return socket_wait(SOCKET_TYPE_SEND, sock, ptr, size, &offset); +} + +int socket_recv(socket_t* sock, void* ptr, int size) { + SHM_LOG_DEBUG("socket_recv: start"); + int offset = 0; + if (sock == NULL) { + SHM_LOG_ERROR("recv sock == NULL"); + return SHMEM_BOOTSTRAP_ERROR; + } + return socket_wait(SOCKET_TYPE_RECV, sock, ptr, size, &offset); +} + + +int socket_close(socket_t* sock) { + if (sock) { + if (sock->fd >= 0) { + shutdown(sock->fd, SHUT_RDWR); + close(sock->fd); + } + sock->fd = -1; + sock->accept_fd = -1; + sock->state = SOCKET_STATE_CLOSED; + } else { + SHM_LOG_DEBUG("socket_close: sock is null"); + } + SHM_LOG_DEBUG("socket_close: success"); + return SHMEM_SUCCESS; +} + +int socket_get_sainfo(socket_t* sock, sockaddr* sa, socklen_t* addr_len) { + if (sock == nullptr || sa == nullptr || addr_len == nullptr) { + SHM_LOG_ERROR("Some of sock, sa and addr_len are null."); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->addr.type == ADDR_IPv4) { + SHM_LOG_DEBUG("socket_get_sainfo memcpy addr4"); + memcpy(sa, &sock->addr.addr.addr4, sizeof(struct sockaddr_in)); + *addr_len = sizeof(struct sockaddr_in); + } else { + SHM_LOG_DEBUG("socket_get_sainfo memcpy addr6"); + memcpy(sa, &sock->addr.addr.addr6, sizeof(struct sockaddr_in6)); + *addr_len = sizeof(struct sockaddr_in6); + } + + return SHMEM_SUCCESS; +} + + +int socket_listen(socket_t* sock) { + if (!sock || sock->fd < 0 || sock->state == SOCKET_STATE_ERROR) { + SHM_LOG_ERROR("socket_listen Precondition failed! " + << "sock is null: " << (sock == nullptr) + << ", invalid fd: " << (sock ? (sock->fd < 0) : true) + << ", state is error: " << (sock ? (sock->state == SOCKET_STATE_ERROR) : false)); + if (sock) sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_INFO("socket_listen Entering. sock fd: " << (sock ? sock->fd : -1) + << ", current state: " << (sock ? sock->state : -1)); + + if (sock->state == SOCKET_STATE_CREATED) { + SHM_LOG_DEBUG("socket_listen State is CREATED, starting bind process"); + struct sockaddr_storage sa_storage; + memset(&sa_storage, 0, sizeof(sa_storage)); + struct sockaddr* sa = reinterpret_cast(&sa_storage); + socklen_t addr_len; + + SHMEM_CHECK_RET(socket_get_sainfo(sock, sa, &addr_len),"socket_listen socket_get_sainfo failed"); + + + std::string target_ip = "unknown"; + uint16_t target_port = 0; + if (sa->sa_family == AF_INET) { + struct sockaddr_in* ipv4 = reinterpret_cast(sa); + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + target_ip = ip_str; + target_port = ntohs(ipv4->sin_port); + } else if (sa->sa_family == AF_INET6) { + struct sockaddr_in6* ipv6 = reinterpret_cast(sa); + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + target_ip = ip_str; + target_port = ntohs(ipv6->sin6_port); + } + SHM_LOG_DEBUG("socket_listen socket_get_sainfo succeeded, addr_len: " << addr_len + << ", target IP: " << target_ip << ", target port: " << target_port); + + int opt = 1; + if (setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + SHM_LOG_ERROR("socket_listen setsockopt(SO_REUSEADDR) failed! " + << "errno: " << errno << ", reason: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_listen setsockopt(SO_REUSEADDR) succeeded"); + + if (bind(sock->fd, sa, addr_len) < 0) { + SHM_LOG_ERROR("socket_listen bind failed! " + << "errno: " << errno << ", reason: " << strerror(errno) + << ", fd: " << sock->fd << ", addr_len: " << addr_len + << ", target IP: " << target_ip << ", target port: " << target_port); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("[socket_listen] bind succeeded"); + + if (getsockname(sock->fd, &sock->addr.addr.sa, &addr_len) < 0) { + SHM_LOG_ERROR("socket_listen getsockname failed! " + << "errno: " << errno << ", reason: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->addr.type == ADDR_IPv4) { + struct sockaddr_in* ipv4 = &sock->addr.addr.addr4; + char ip_str[INET_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET, &ipv4->sin_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv4 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv4->sin_port); + SHM_LOG_DEBUG(" Stored IPv4 address: " << ip_str << ":" << port << " sa_family: " << ipv4->sin_family << " (expected AF_INET=" << AF_INET << ")"); + } else if (sock->addr.type == ADDR_IPv6) { + struct sockaddr_in6* ipv6 = &sock->addr.addr.addr6; + char ip_str[INET6_ADDRSTRLEN] = {0}; + SHMEM_CHECK_RET(inet_ntop(AF_INET6, &ipv6->sin6_addr, ip_str, sizeof(ip_str)) == nullptr, "convert ipv6 to string failed. ", SHMEM_BOOTSTRAP_ERROR); + uint16_t port = ntohs(ipv6->sin6_port); + SHM_LOG_DEBUG(" Stored IPv6 address: " << ip_str << ":" << port << " sa_family: " << ipv6->sin6_family << " (expected AF_INET6=" << AF_INET6 << ")"); + } else { + SHM_LOG_ERROR(" Stored address type: unknown (type=" << sock->addr.type << ")"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + sock->state = SOCKET_STATE_BOUND; + SHM_LOG_DEBUG("socket_listen State updated to BOUND"); + } + + if (sock->state == SOCKET_STATE_BOUND) { + SHM_LOG_DEBUG("socket_listen State is BOUND, starting listen"); + if (listen(sock->fd, SOCKET_BACKLOG) < 0) { + SHM_LOG_ERROR("socket_listen] listen failed! " + << "errno: " << errno << ", reason: " << strerror(errno) + << ", fd: " << sock->fd << ", backlog: " << SOCKET_BACKLOG); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + sock->accept_fd = sock->fd; + sock->state = SOCKET_STATE_LISTENING; + SHM_LOG_DEBUG("socket_listen listen succeeded. New state: LISTENING, accept_fd: " << sock->accept_fd); + } else { + SHM_LOG_ERROR("socket_listen Skip listen: current state is " << sock->state << " (expected BOUND)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + SHM_LOG_DEBUG("socket_listen Exiting with success"); + return SHMEM_SUCCESS; +} + +static int socket_try_accept(socket_t* sock) { + if (sock->state != SOCKET_STATE_ACCEPTING) { + SHM_LOG_ERROR("socket_try_accept: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + SHMEM_CHECK_RET(socket_poll_fd(sock->accept_fd, POLLIN, SOCKET_ACCEPT_TIMEOUT_MS), "socket_poll_fd failed."); + struct sockaddr sa; + socklen_t socklen = sizeof(sa); + + sock->fd = accept(sock->accept_fd, &sa, &socklen); + if (sock->fd != -1) { + if (sa.sa_family == AF_INET) { + sock->addr.type = ADDR_IPv4; + memcpy(&sock->addr.addr.addr4, &sa, sizeof(struct sockaddr_in)); + } else { + sock->addr.type = ADDR_IPv6; + memcpy(&sock->addr.addr.addr6, &sa, sizeof(struct sockaddr_in6)); + } + sock->state = SOCKET_STATE_ACCEPTED; + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + SHM_LOG_ERROR("socket_try_accept failed: " << strerror(errno)); + return SHMEM_BOOTSTRAP_ERROR; + } + + return SHMEM_SUCCESS; +} + +static int socket_finalize_accept(socket_t* sock) { + if (sock->state != SOCKET_STATE_ACCEPTED) { + SHM_LOG_ERROR("socket_finalize_accept: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + + uint64_t magic; + socket_type_t type; + int received = 0; + const int one = 1; + + if (setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)) < 0) { + SHM_LOG_ERROR("setsockopt TCP_NODELAY failed: " << strerror(errno)); + close(sock->fd); + sock->fd = -1; + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + if (socket_progress(SOCKET_TYPE_RECV, sock, &magic, sizeof(magic), &received, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (received == 0) return SHMEM_SUCCESS; + if (socket_wait(SOCKET_TYPE_RECV, sock, &magic, sizeof(magic), &received, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + + if (magic != sock->magic) { + SHM_LOG_DEBUG("socket_finalize_accept: wrong magic " << magic << " != " << sock->magic); + close(sock->fd); + sock->fd = -1; + sock->state = SOCKET_STATE_ACCEPTING; + return SHMEM_SUCCESS; + } + + received = 0; + if (socket_wait(SOCKET_TYPE_RECV, sock, &type, sizeof(type), &received, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (type != sock->type) { + SHM_LOG_ERROR("socket_finalize_accept: wrong type " << type << " != " << sock->type); + close(sock->fd); + sock->fd = -1; + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + sock->state = SOCKET_STATE_READY; + return SHMEM_SUCCESS; +} + +static int socket_start_connect(socket_t* sock) { + if (sock->state != SOCKET_STATE_CONNECTING) { + SHM_LOG_ERROR("socket_start_connect: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + + struct sockaddr_storage sa_storage; + memset(&sa_storage, 0, sizeof(sa_storage)); + struct sockaddr* sa = reinterpret_cast(&sa_storage); + socklen_t addr_len; + if (socket_get_sainfo(sock, sa, &addr_len) != 0) { + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + int ret = connect(sock->fd, sa, addr_len); + if (ret == 0) { + sock->state = SOCKET_STATE_CONNECTED; + SHM_LOG_DEBUG("socket_start_connect: success!"); + } else if (errno == ECONNREFUSED) { + SHM_LOG_DEBUG("socket_start_connect: refused retry time:" << sock->refused_retries); + if (++sock->refused_retries >= RETRY_REFUSED_TIMES) { + SHM_LOG_ERROR("exceeded refused retries"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + usleep(SLEEP_INT); + } else if (errno == ETIMEDOUT) { + SHM_LOG_DEBUG("socket_start_connect: timeout retry time:" << sock->timeout_retries); + if (++sock->timeout_retries >= RETRY_TIMEDOUT_TIMES) { + SHM_LOG_ERROR("exceeded timeout retries"); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + usleep(SLEEP_INT); + } else { + SHM_LOG_ERROR("connect failed: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_start_connect: end!"); + + return SHMEM_SUCCESS; +} + + +static int socket_finalize_connect(socket_t* sock) { + SHM_LOG_DEBUG("socket_finalize_connect socket_finalize_connect: start!"); + if (sock->state != SOCKET_STATE_CONNECTED) { + SHM_LOG_ERROR("socket_finalize_connect: invalid state " << sock->state); + return SHMEM_BOOTSTRAP_ERROR; + } + + int sent = 0; + if (socket_progress(SOCKET_TYPE_SEND, sock, &sock->magic, sizeof(sock->magic), &sent, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + if (sent == 0) return SHMEM_SUCCESS; + if (socket_wait(SOCKET_TYPE_SEND, sock, &sock->magic, sizeof(sock->magic), &sent, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + + sent = 0; + if (socket_wait(SOCKET_TYPE_SEND, sock, &sock->type, sizeof(sock->type), &sent, false, false) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_finalize_connect socket_finalize_connect: end!"); + + sock->state = SOCKET_STATE_READY; + return SHMEM_SUCCESS; +} + +static int socket_progress_state(socket_t* sock) { + if (sock == nullptr) { + SHM_LOG_ERROR("socket_progress_state: null socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->state == SOCKET_STATE_ACCEPTING) { + SHMEM_CHECK_RET(socket_try_accept(sock), "socket_try_accept failed"); + } + if (sock->state == SOCKET_STATE_ACCEPTED) { + SHMEM_CHECK_RET(socket_finalize_accept(sock), "socket_finalize_accept failed"); + } + if (sock->state == SOCKET_STATE_CONNECTING) { + SHMEM_CHECK_RET(socket_start_connect(sock), "socket_start_connect failed"); + } + + if (sock->state == SOCKET_STATE_CONNECTED) { + SHMEM_CHECK_RET(socket_finalize_connect(sock), "socket_finalize_connect failed"); + } + + return SHMEM_SUCCESS; +} + +int socket_connect(socket_t* sock) { + if (sock == nullptr) { + SHM_LOG_ERROR("socket_connect: NULL socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + if (sock->fd == -1) { + SHM_LOG_ERROR("socket_connect: invalid fd (-1)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (sock->state != SOCKET_STATE_CREATED) { + SHM_LOG_ERROR("socket_connect: invalid state " << sock->state << " (expected CREATED)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + const int one = 1; + // Disabling the Nagle algorithm + if (setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)) < 0) { + SHM_LOG_ERROR("setsockopt TCP_NODELAY failed: " << strerror(errno)); + sock->state = SOCKET_STATE_ERROR; + return SHMEM_BOOTSTRAP_ERROR; + } + + sock->state = SOCKET_STATE_CONNECTING; + SHM_LOG_DEBUG("socket_connect: start!"); + do { + if (socket_progress_state(sock) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + } while (sock->state == SOCKET_STATE_CONNECTING || + sock->state == SOCKET_STATE_CONNECTED); + + switch (sock->state) { + case SOCKET_STATE_READY: + return SHMEM_SUCCESS; + case SOCKET_STATE_ERROR: + return SHMEM_BOOTSTRAP_ERROR; + default: + return SHMEM_BOOTSTRAP_ERROR; + } +} + +int socket_accept(socket_t* client_sock, socket_t* listen_sock) { + if (listen_sock == nullptr || client_sock == nullptr) { + SHM_LOG_ERROR("socket_accept: NULL socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (listen_sock->state != SOCKET_STATE_LISTENING) { + SHM_LOG_ERROR("socket_accept: listen socket state " << listen_sock->state << " (expected LISTENING)"); + return SHMEM_BOOTSTRAP_ERROR; + } + + if (client_sock->accept_fd == -1) { + client_sock->addr = listen_sock->addr; + client_sock->magic = listen_sock->magic; + client_sock->type = listen_sock->type; + client_sock->refused_retries = 0; + client_sock->timeout_retries = 0; + client_sock->accept_fd = listen_sock->fd; + client_sock->fd = -1; + client_sock->state = SOCKET_STATE_ACCEPTING; + } + SHM_LOG_DEBUG("socket_accept: start!"); + do { + if (socket_progress_state(client_sock) != SHMEM_SUCCESS) { + return SHMEM_BOOTSTRAP_ERROR; + } + } while (client_sock->state == SOCKET_STATE_ACCEPTING || + client_sock->state == SOCKET_STATE_ACCEPTED); + + switch (client_sock->state) { + case SOCKET_STATE_READY: + return SHMEM_SUCCESS; + case SOCKET_STATE_ERROR: + return SHMEM_BOOTSTRAP_ERROR; + default: + return SHMEM_BOOTSTRAP_ERROR; + } +} + +int socket_init(socket_t* sock, socket_type_t type, uint64_t magic, const sockaddr_t* init_addr) { + if (sock == nullptr) { + SHM_LOG_ERROR("socket_init: NULL socket"); + return SHMEM_BOOTSTRAP_ERROR; + } + SHM_LOG_DEBUG("socket_init: start"); + memset(sock, 0, sizeof(socket_t)); + sock->fd = -1; + sock->accept_fd = -1; + sock->state = SOCKET_STATE_CREATED; + sock->type = type; + sock->magic = magic; + sock->refused_retries = 0; + sock->timeout_retries = 0; + + if (init_addr != nullptr) { + int family; + if (init_addr->type == ADDR_IPv4) { + family = AF_INET; + memcpy(&sock->addr.addr.addr4, &init_addr->addr.addr4, sizeof(struct sockaddr_in)); + } else if (init_addr->type == ADDR_IPv6) { + family = AF_INET6; + memcpy(&sock->addr.addr.addr6, &init_addr->addr.addr6, sizeof(struct sockaddr_in6)); + } else { + SHM_LOG_ERROR("socket_init: unsupported address type " << init_addr->type); + return SHMEM_BOOTSTRAP_ERROR; + } + sock->addr.type = init_addr->type; + + sock->fd = socket(family, SOCK_STREAM, 0); + if (sock->fd == -1) { + SHM_LOG_ERROR("socket_init: create socket failed: " << strerror(errno)); + return SHMEM_BOOTSTRAP_ERROR; + } + } else { + SHM_LOG_DEBUG("socket_init: init_addr is null"); + memset(&sock->addr, 0, sizeof(sock->addr)); + sock->addr.type = ADDR_IPv4; + } + + // set blocking + if (sock->fd >= 0) { + int32_t value = 1; + if ((value = fcntl(sock->fd, F_GETFL)) == -1) { + SHM_LOG_ERROR("sock: " << sock->fd <<" failed to get control value"); + return SHMEM_BOOTSTRAP_ERROR; + } + int new_flags = value & ~O_NONBLOCK; + if (fcntl(sock->fd, F_SETFL, new_flags) == -1) { + SHM_LOG_ERROR("sock: " << sock->fd << "Failed to set control value of link"); + return SHMEM_BOOTSTRAP_ERROR; + } + } + + SHM_LOG_DEBUG("socket_init: success"); + return SHMEM_SUCCESS; +} \ No newline at end of file diff --git a/src/modules/bootstrap/socket/uid_socket.h b/src/modules/bootstrap/socket/uid_socket.h new file mode 100644 index 0000000000000000000000000000000000000000..0402a888658e45bd4fa9f9b3e22d65d0874fc46a --- /dev/null +++ b/src/modules/bootstrap/socket/uid_socket.h @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SHMEM_SOCKET_H +#define SHMEM_SOCKET_H + +#include "host/shmem_host_def.h" +#include "common/shmemi_logger.h" +#include "common/shmemi_host_types.h" +#include "internal/host/shmemi_host_def.h" +#include "bootstrap/shmemi_bootstrap.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define MAX_IF_NAME_SIZE 16 +#define SOCKET_TYPE_SEND 0 +#define SOCKET_TYPE_RECV 1 + +#define RETRY_REFUSED_TIMES 1e5 // 100s超时 +#define RETRY_TIMEDOUT_TIMES 50 +#define SLEEP_INT 1000 // 重试间隔(微秒) + +#define SOCKET_ACCEPT_TIMEOUT_MS 50000 // accept超时50秒 +#define SOCKET_RECV_TIMEOUT_MS 30000 // recv超时30秒 + +#define SOCKET_BACKLOG 16384 + +typedef enum { + // 初始状态:套接字刚创建,未执行任何操作 + SOCKET_STATE_CREATED, + + // 服务器端状态 + SOCKET_STATE_BOUND, // 已绑定地址(bind 成功) + SOCKET_STATE_LISTENING, // 正在监听连接(listen 成功) + SOCKET_STATE_ACCEPTING, // 正在等待接受连接(准备调用 accept) + SOCKET_STATE_ACCEPTED, // 已接受连接(accept 成功,未完成验证) + + // 客户端状态 + SOCKET_STATE_CONNECTING, // 正在发起连接(connect 调用中) + SOCKET_STATE_CONNECTED, // 已建立连接(未完成验证) + + // 公共状态 + SOCKET_STATE_READY, // 连接已验证,可进行数据传输(最终就绪状态) + SOCKET_STATE_ERROR, // 发生错误 + SOCKET_STATE_CLOSED // 已关闭 +} socket_state_t; + +typedef enum { + SOCKET_TYPE_BOOTSTRAP, // 用于初始化信息交换 + SOCKET_TYPE_DATA // 用于实际数据传输 +} socket_type_t; + +typedef struct { + int fd; // 套接字fd + int accept_fd; // 监听用fd,初始化-1 + sockaddr_t addr; // 存储地址信息(socket_init阶段初始化) + socket_state_t state; + uint64_t magic; + socket_type_t type; + int refused_retries; // 连接被拒绝重试计数 + int timeout_retries; // 超时重试计数 +} socket_t; + +struct bootstrap_root_args { + socket_t* listen_sock; + uint64_t magic; + int version; +}; + +// 其他内部结构体 +typedef struct { + int rank; + int nranks; + sockaddr_t ext_addr_listen; + sockaddr_t ext_address_listen_root; +} bootstrap_ext_info; + +struct bootstrap_netstate { + char bootstrap_netifname[MAX_IF_NAME_SIZE + 1]; /* Socket Interface Name */ + sockaddr_t bootstrap_netifaddr; /* Socket Interface Address */ + int bootstrap_netinitdone = 0; /* Socket Interface Init Status */ + pthread_mutex_t bootstrap_netlock = PTHREAD_MUTEX_INITIALIZER; /* Socket Interface Lock */ + pthread_t bootstrap_root; /* Socket Root Thread for phoning root to non-root peers */ +}; + +typedef struct unexpected_conn { + int peer; // 发送方rank + int tag; // 消息tag + socket_t sock; // 对应的socket连接 + struct unexpected_conn* next; // 链表下一个节点 +} unexpected_conn_t; + +typedef struct { + int rank; + int nranks; + uint64_t magic; + socket_t listen_sock; + socket_t ring_send_sock; + socket_t ring_recv_sock; + sockaddr_t* peer_addrs; + unexpected_conn_t* unexpected_conns; // 意外连接队列 +} uid_bootstrap_state; + +int socket_init(socket_t* sock, socket_type_t type, uint64_t magic, const sockaddr_t* init_addr); +int socket_listen(socket_t* sock); +int socket_connect(socket_t* sock); +int socket_accept(socket_t* client_sock, socket_t* listen_sock); +int socket_send(socket_t* sock, void* ptr, int size); +int socket_recv(socket_t* sock, void* ptr, int size); +int socket_close(socket_t* sock); +int socket_get_sainfo(socket_t* sock, sockaddr* sa, socklen_t* addr_len); + +#ifdef __cplusplus +} +#endif +#endif // SHMEM_SOCKET_H \ No newline at end of file diff --git a/src/modules/bootstrap/socket/uid_utils.h b/src/modules/bootstrap/socket/uid_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..65ad8ca7cc0f9c83e0326faddaa46c9862e7f26b --- /dev/null +++ b/src/modules/bootstrap/socket/uid_utils.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef SHMEM_UID_UTILS_H +#define SHMEM_UID_UTILS_H + + +#include "common/shmemi_logger.h" +#include // malloc, free +#include // memset + +template +inline int bootstrap_calloc(T** ptr, size_t nelem, const char* file, int line) { + if (ptr == nullptr || nelem == 0) { // 校验输入:指针为空或元素数为 0 均为无效 + SHM_LOG_ERROR("Invalid arguments: ptr=" << ptr << ", nelem=" << nelem + << " (" << file << ":" << line << ")"); + return SHMEM_BOOTSTRAP_ERROR; + } + size_t total_size = nelem * sizeof(T); // 计算总内存大小 + void* p = malloc(total_size); + if (p == nullptr) { + SHM_LOG_ERROR("Allocation failed: " << total_size << " bytes (nelem=" << nelem + << ") at " << file << ":" << line); + return SHMEM_BOOTSTRAP_ERROR; + } + + memset(p, 0, total_size); // 内存清零 + *ptr = static_cast(p); // 类型转换,赋值给输出指针 + + // 调试日志:输出分配信息 + SHM_LOG_DEBUG("Allocated " << total_size << " bytes (" << nelem + << " elements of " << sizeof(T) << " bytes) at " + << static_cast(p) << " (" << file << ":" << line << ")"); + return SHMEM_SUCCESS; +} +#define SHMEM_BOOTSTRAP_CALLOC(ptr, nelem) \ + bootstrap_calloc((ptr), (nelem), __FILE__, __LINE__) + + +#define SHMEM_BOOTSTRAP_PTR_FREE(ptr) \ + do { \ + if ((ptr) != NULL) { \ + free(ptr); \ + } \ + } while (0) + +#define SHMEM_CHECK_RET_CLOSE_SOCK(x, LOG_STR, SOCK) \ + do { \ + int32_t check_ret = (x); \ + if (check_ret != 0) { \ + SHM_LOG_ERROR(" " << LOG_STR << " close sock " << #SOCK << " and return shmem error: " << check_ret); \ + if ((&(SOCK)) != nullptr) { \ + socket_close(&(SOCK)); \ + } \ + return check_ret; \ + } \ + } while (0) + +#endif //SHMEM_UID_UTILS_H \ No newline at end of file diff --git a/src/modules/transport/rdma/device_qp_manager.cpp b/src/modules/transport/rdma/device_qp_manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc37f32a4f50c1b331b58381065774fb17a4d155 --- /dev/null +++ b/src/modules/transport/rdma/device_qp_manager.cpp @@ -0,0 +1,681 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "dl_hccp_api.h" +#include "device_qp_manager.h" +using namespace shm; + +DeviceQpManager::DeviceQpManager(uint32_t deviceId, uint32_t rankId, uint32_t rankCount, sockaddr_in devNet, + hybm_role_type role) noexcept + : deviceId_{deviceId}, + rankId_{rankId}, + rankCount_{rankCount}, + deviceAddress_{devNet}, + rankRole_{role} +{ +} + +void *DeviceQpManager::CreateLocalSocket() noexcept +{ + void *socketHandle = nullptr; + HccpRdev rdev; + rdev.phyId = deviceId_; + rdev.family = AF_INET; + rdev.localIp.addr = deviceAddress_.sin_addr; + auto ret = DlHccpApi::RaSocketInit(HccpNetworkMode::NETWORK_OFFLINE, rdev, socketHandle); + if (ret != 0) { + SHM_LOG_ERROR("initialize socket handle failed: " << ret); + return nullptr; + } + + return socketHandle; +} + +int DeviceQpManager::CreateServerSocket() noexcept +{ + if (serverSocketHandle_ != nullptr) { + return SHMEM_SUCCESS; + } + + auto socketHandle = CreateLocalSocket(); + if (socketHandle == nullptr) { + SHM_LOG_ERROR("create local socket handle failed."); + return SHMEM_INNER_ERROR; + } + + HccpSocketListenInfo listenInfo{}; + listenInfo.handle = socketHandle; + listenInfo.port = deviceAddress_.sin_port; + bool successListen = false; + while (listenInfo.port <= std::numeric_limits::max()) { + auto ret = DlHccpApi::RaSocketListenStart(&listenInfo, 1); + if (ret == 0) { + deviceAddress_.sin_port = listenInfo.port; + successListen = true; + break; + } + listenInfo.port++; + } + if (!successListen) { + SHM_LOG_ERROR("start to listen server socket failed."); + DlHccpApi::RaSocketDeinit(socketHandle); + return SHMEM_INNER_ERROR; + } + + SHM_LOG_INFO("start to listen on port: " << listenInfo.port << " success."); + serverSocketHandle_ = socketHandle; + return SHMEM_SUCCESS; +} + +void DeviceQpManager::DestroyServerSocket() noexcept +{ + if (serverSocketHandle_ == nullptr) { + return; + } + + HccpSocketListenInfo listenInfo{}; + listenInfo.handle = serverSocketHandle_; + listenInfo.port = deviceAddress_.sin_port; + auto ret = DlHccpApi::RaSocketListenStop(&listenInfo, 1); + if (ret != 0) { + SHM_LOG_INFO("stop to listen on port: " << listenInfo.port << " return: " << ret); + } + + ret = DlHccpApi::RaSocketDeinit(serverSocketHandle_); + if (ret != 0) { + SHM_LOG_INFO("deinit server socket return: " << ret); + } + serverSocketHandle_ = nullptr; +} + +static constexpr uint32_t SEND_CQ_DEPTH = 8192; +static constexpr uint32_t RECV_CQ_DEPTH = 128; +static constexpr uint32_t MAX_SEND_WR = 8192; +static constexpr uint32_t MAX_RECV_WR = 128; +static constexpr uint32_t MAX_SEND_SGE = 1; +static constexpr uint32_t MAX_RECV_SGE = 1; +static constexpr uint32_t QP_MODE = 2; +static constexpr uint32_t CALLER_POLL_CQ_CSTM = 1; + +DeviceQpManager::~DeviceQpManager() noexcept +{ + CloseServices(); +} + +int DeviceQpManager::SetRemoteRankInfo(const std::unordered_map &ranks) noexcept +{ + if (started_) { + SHM_LOG_ERROR("fixed ranks not support update ranks info after startup"); + return SHMEM_INNER_ERROR; + } + + currentRanksInfo_ = ranks; + return SHMEM_SUCCESS; +} + +int DeviceQpManager::SetLocalMemories(const MemoryRegionMap &mrs) noexcept +{ + if (started_) { + SHM_LOG_INFO("fixed ranks not support update register MRs after startup"); + return SHMEM_SUCCESS; + } + + currentLocalMrs_ = mrs; + return SHMEM_SUCCESS; +} + +int DeviceQpManager::Startup(void *rdma) noexcept +{ + if (rdma == nullptr) { + SHM_LOG_ERROR("input rdma is null"); + return SHMEM_INVALID_PARAM; + } + + if (started_) { + SHM_LOG_ERROR("already started."); + return SHMEM_INNER_ERROR; + } + + rdmaHandle_ = rdma; + if (!ReserveQpInfoSpace()) { + SHM_LOG_ERROR("reserve qp info space failed."); + return SHMEM_INNER_ERROR; + } + + if (currentRanksInfo_.size() != rankCount_) { + SHM_LOG_ERROR("set rank count = " << currentRanksInfo_.size() << ", but rank_size = " << rankCount_); + return SHMEM_INVALID_PARAM; + } + + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + if (it->first >= rankCount_) { + SHM_LOG_ERROR("input options of nics contains rankId:" << it->first << ", rank count: " << rankCount_); + return SHMEM_INVALID_PARAM; + } + } + + auto ret = StartServerSide(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("start server side failed: " << ret); + return ret; + } + + ret = StartClientSide(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("start client side failed: " << ret); + return ret; + } + + started_ = true; + return SHMEM_SUCCESS; +} + +void DeviceQpManager::Shutdown() noexcept +{ + CloseServices(); +} + +int DeviceQpManager::WaitingConnectionReady() noexcept +{ + if (serverConnectResult == SHMEM_SUCCESS && clientConnectResult == SHMEM_SUCCESS) { + SHM_LOG_INFO("client & server connections ready."); + return SHMEM_SUCCESS; + } + + SHM_LOG_ERROR("background connection thread not started."); + return SHMEM_INNER_ERROR; +} + +void *DeviceQpManager::GetQpInfoAddress() const noexcept +{ + return qpInfo_; +} + +void *DeviceQpManager::GetQpHandleWithRankId(uint32_t rankId) const noexcept +{ + auto connections = rankId < rankId_ ? &clientConnections_ : &serverConnections_; + auto pos = connections->find(rankId); + if (pos == connections->end()) { + return nullptr; + } + + return pos->second.qpHandles[CONN_QP_STARS]; +} + +bool DeviceQpManager::ReserveQpInfoSpace() noexcept +{ + if (qpInfo_ != nullptr) { + return true; + } + + void *ptr = nullptr; + auto oneQpSize = 2U * (sizeof(AiQpRMAWQ) + sizeof(AiQpRMACQ)) + sizeof(RdmaMemRegionInfo); + qpInfoSize_ = sizeof(AiQpRMAQueueInfo) + oneQpSize * rankCount_; + auto ret = aclrtMalloc(&ptr, qpInfoSize_, ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != 0) { + SHM_LOG_ERROR("allocate device size: " << qpInfoSize_ << ", failed: " << ret); + return false; + } + + qpInfo_ = (AiQpRMAQueueInfo *)ptr; + return true; +} + +int DeviceQpManager::StartServerSide() noexcept +{ + if (rankId_ + 1U == rankCount_) { + serverConnectResult = 0; + return SHMEM_SUCCESS; + } + + auto ret = CreateServerSocket(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("create server socket failed: " << ret); + return ret; + } + + ret = GenerateWhiteList(); + if (ret != 0) { + SHM_LOG_ERROR("generate white list failed: " << ret); + return SHMEM_INNER_ERROR; + } + + aclrtSetDevice(deviceId_); + ret = WaitConnectionsReady(serverConnections_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait connection ready failed: " << ret); + serverConnectResult = ret; + return SHMEM_INNER_ERROR; + } + ret = CreateQpWaitingReady(serverConnections_, CONN_QP_AI_CORE); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait connection AI qp ready failed: " << ret); + serverConnectResult = ret; + } + + ret = CreateQpWaitingReady(serverConnections_, CONN_QP_STARS); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait connection STARS qp ready failed: " << ret); + serverConnectResult = ret; + } + + serverConnectResult = SHMEM_SUCCESS; + + return SHMEM_SUCCESS; +} + +int DeviceQpManager::StartClientSide() noexcept +{ + if (rankId_ == 0U) { + SHM_LOG_INFO("rankId: " << rankId_ << " need not connect to others."); + clientConnectResult = SHMEM_SUCCESS; + return SHMEM_SUCCESS; + } + + std::vector connectInfos; + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + if (it->first >= rankId_) { + continue; // client connect to small ranks. + } + + auto socketHandle = CreateLocalSocket(); + if (socketHandle == nullptr) { + SHM_LOG_ERROR("create local socket handle failed"); + CloseClientConnections(); + return SHMEM_INNER_ERROR; + } + + clientConnections_.emplace(it->first, ConnectionChannel{it->second.network.sin_addr, socketHandle}); + HccpSocketConnectInfo connectInfo; + connectInfo.handle = socketHandle; + connectInfo.remoteIp.addr = it->second.network.sin_addr; + connectInfo.port = it->second.network.sin_port; + bzero(connectInfo.tag, sizeof(connectInfo.tag)); + SHM_LOG_DEBUG("add connecting server " << connectInfo); + connectInfos.emplace_back(connectInfo); + } + + auto ret = DlHccpApi::RaSocketBatchConnect(connectInfos.data(), connectInfos.size()); + if (ret != 0) { + SHM_LOG_ERROR("connect to all servers failed: " << ret << ", servers count = " << connectInfos.size()); + CloseClientConnections(); + return SHMEM_INNER_ERROR; + } + + aclrtSetDevice(deviceId_); + ret = WaitConnectionsReady(clientConnections_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("client wait connections failed: " << ret); + CloseClientConnections(); + return ret; + } + + ret = CreateQpWaitingReady(clientConnections_, CONN_QP_AI_CORE); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("client create qp for AI CORE failed: " << ret); + CloseClientConnections(); + return ret; + } + + ret = CreateQpWaitingReady(clientConnections_, CONN_QP_STARS); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("client create qp for STARS failed: " << ret); + CloseClientConnections(); + return ret; + } + clientConnectResult = SHMEM_SUCCESS; + return SHMEM_SUCCESS; +} + +int DeviceQpManager::GenerateWhiteList() noexcept +{ + std::vector whitelist; + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + if (it->first <= rankId_) { + continue; // small id as server, large id as client + } + HccpSocketWhiteListInfo info{}; + info.remoteIp.addr = it->second.network.sin_addr; + info.connLimit = rankCount_; + bzero(info.tag, sizeof(info.tag)); + whitelist.emplace_back(info); + serverConnections_.emplace(it->first, ConnectionChannel{info.remoteIp.addr, serverSocketHandle_}); + } + + if (whitelist.empty()) { + return SHMEM_SUCCESS; + } + + auto ret = DlHccpApi::RaSocketWhiteListAdd(serverSocketHandle_, whitelist.data(), whitelist.size()); + if (ret != 0) { + SHM_LOG_ERROR("socket handle add white list failed: " << ret); + return SHMEM_INNER_ERROR; + } + + return SHMEM_SUCCESS; +} + +int DeviceQpManager::WaitConnectionsReady(std::unordered_map &connections) noexcept +{ + uint32_t totalSuccessCount = 0; + auto start = std::chrono::steady_clock::now(); + auto timeout = start + std::chrono::minutes(2); + while (totalSuccessCount < connections.size()) { + if (std::chrono::steady_clock::now() >= timeout) { + SHM_LOG_ERROR("waiting connection ready timeout."); + return SHMEM_INNER_ERROR; + } + + uint32_t successCount = 0; + std::vector socketInfos; + std::unordered_map addr2index; + for (auto it = connections.begin(); it != connections.end(); ++it) { + if (it->second.socketFd != nullptr) { + continue; + } + + HccpSocketInfo info{}; + info.handle = it->second.socketHandle; + info.fd = nullptr; + info.remoteIp.addr = it->second.remoteIp; + info.status = 0; + bzero(info.tag, sizeof(info.tag)); + socketInfos.push_back(info); + addr2index.emplace(it->second.remoteIp.s_addr, it->first); + } + + auto role = (&connections == &clientConnections_) ? 1 : 0; + auto ret = DlHccpApi::RaGetSockets(role, socketInfos.data(), socketInfos.size(), successCount); + if (ret != 0) { + SHM_LOG_ERROR("role(" << role << ") side get sockets failed: " << ret); + return SHMEM_INNER_ERROR; + } + + for (auto i = 0U; i < successCount; i++) { + auto socketInfoPos = addr2index.find(socketInfos[i].remoteIp.addr.s_addr); + if (socketInfoPos == addr2index.end()) { + SHM_LOG_ERROR("socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) << ") should not exist."); + return SHMEM_INNER_ERROR; + } + + auto rankId = socketInfoPos->second; + auto pos = connections.find(rankId); + if (pos == connections.end()) { + SHM_LOG_ERROR("socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) << ") should not exist."); + return SHMEM_INNER_ERROR; + } + + if (pos->second.socketFd != nullptr) { + SHM_LOG_ERROR("get socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) << ") already get socket fd."); + return SHMEM_INNER_ERROR; + } + + if (pos->second.socketHandle != socketInfos[i].handle) { + SHM_LOG_ERROR("get socket ip(" << inet_ntoa(socketInfos[i].remoteIp.addr) + << ") socket handle not match."); + return SHMEM_INNER_ERROR; + } + + pos->second.socketFd = socketInfos[i].fd; + SHM_LOG_INFO("connect to (" << rankId << ") ready."); + } + + totalSuccessCount += successCount; + } + + return SHMEM_SUCCESS; +} + +int DeviceQpManager::CreateQpWaitingReady(std::unordered_map &connections, + ConnQpType qpType) noexcept +{ + for (auto it = connections.begin(); it != connections.end(); ++it) { + auto ret = CreateOneQp(qpType, it->second); + if (ret != 0) { + SHM_LOG_ERROR("create QP type:" << qpType << " to " << it->first << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + + for (auto pos = currentLocalMrs_.begin(); pos != currentLocalMrs_.end(); ++pos) { + HccpMrInfo info{}; + info.addr = (void *)(ptrdiff_t)pos->second.address; + info.size = pos->second.size; + info.access = 7; + ret = DlHccpApi::RaMrReg(it->second.qpHandles[qpType], info); + if (ret != 0) { + SHM_LOG_ERROR("register MR failed: " << ret); + return SHMEM_INNER_ERROR; + } + } + + ret = DlHccpApi::RaQpConnectAsync(it->second.qpHandles[qpType], it->second.socketFd); + if (ret != 0) { + SHM_LOG_ERROR("connect AI QP to " << it->first << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + } + + auto start = std::chrono::steady_clock::now(); + auto timeout = start + std::chrono::minutes(1); + while (std::chrono::steady_clock::now() < timeout) { + int connectingCount = 0; + for (auto it = connections.begin(); it != connections.end(); ++it) { + int status = 0; + auto ret = DlHccpApi::RaGetQpStatus(it->second.qpHandles[qpType], status); + if (ret != 0) { + SHM_LOG_ERROR("get AI QP status to " << it->first << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + if (status != 1) { + connectingCount++; + } + } + if (connectingCount == 0) { + return FillQpInfo(qpType); + } + } + return SHMEM_INNER_ERROR; +} + +int DeviceQpManager::CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) noexcept +{ + int ret; + if (qpType == CONN_QP_AI_CORE) { + HccpQpExtAttrs attr{}; + attr.qpMode = NETWORK_OFFLINE; + attr.version = 1; + attr.cqAttr.sendCqDepth = SEND_CQ_DEPTH; + attr.cqAttr.recvDqDepth = RECV_CQ_DEPTH; + attr.qp_attr.cap.max_send_wr = MAX_SEND_WR; + attr.qp_attr.cap.max_send_sge = MAX_SEND_SGE; + attr.qp_attr.cap.max_recv_wr = MAX_RECV_WR; + attr.qp_attr.cap.max_recv_sge = MAX_RECV_SGE; + attr.qp_attr.qp_type = IBV_QPT_RC; + attr.data_plane_flag.bs.cq_cstm = CALLER_POLL_CQ_CSTM; + ret = DlHccpApi::RaQpAiCreate(rdmaHandle_, attr, channel.aiQpInfo, channel.qpHandles[qpType]); + } else { + ret = DlHccpApi::RaQpCreate(rdmaHandle_, 0, QP_MODE, channel.qpHandles[qpType]); + } + return ret; +} + +int DeviceQpManager::FillQpInfo(ConnQpType qpType) noexcept +{ + if (qpType != CONN_QP_AI_CORE) { + return SHMEM_SUCCESS; + } + + const uint32_t slevel = 4; + std::vector qpInfoBuffer(qpInfoSize_); + auto copyInfo = (AiQpRMAQueueInfo *)(void *)qpInfoBuffer.data(); + copyInfo->count = 1; + copyInfo->sq = (AiQpRMAWQ *)(void *)(copyInfo + 1); + copyInfo->rq = (AiQpRMAWQ *)(void *)(copyInfo->sq + rankCount_); + copyInfo->scq = (AiQpRMACQ *)(void *)(copyInfo->rq + rankCount_); + copyInfo->rcq = (AiQpRMACQ *)(void *)(copyInfo->scq + rankCount_); + copyInfo->mr = (RdmaMemRegionInfo *)(void *)(copyInfo->rcq + rankCount_); + for (auto it = currentRanksInfo_.begin(); it != currentRanksInfo_.end(); ++it) { + copyInfo->mr[it->first].size = it->second.mr.size; + copyInfo->mr[it->first].addr = it->second.mr.address; + copyInfo->mr[it->first].lkey = it->second.mr.lkey; + copyInfo->mr[it->first].rkey = it->second.mr.rkey; + if (it->first == rankId_) { + continue; + } + + std::unordered_map *connections; + if (it->first < rankId_) { + connections = &clientConnections_; + } else { + connections = &serverConnections_; + } + + auto pos = connections->find(it->first); + if (pos == connections->end()) { + SHM_LOG_ERROR("missing for remote: " << it->first); + return SHMEM_INNER_ERROR; + } + + CopyAiWQInfo(copyInfo->sq[it->first], pos->second.aiQpInfo.data_plane_info.sq, DBMode::HW_DB, slevel); + CopyAiWQInfo(copyInfo->rq[it->first], pos->second.aiQpInfo.data_plane_info.rq, DBMode::SW_DB, slevel); + CopyAiCQInfo(copyInfo->scq[it->first], pos->second.aiQpInfo.data_plane_info.scq, DBMode::HW_DB); + CopyAiCQInfo(copyInfo->rcq[it->first], pos->second.aiQpInfo.data_plane_info.rcq, DBMode::SW_DB); + } + + auto pointer = (ptrdiff_t)(void *)(qpInfo_); + pointer += sizeof(AiQpRMAQueueInfo); + copyInfo->sq = (AiQpRMAWQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMAWQ) * rankCount_; + copyInfo->rq = (AiQpRMAWQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMAWQ) * rankCount_; + copyInfo->scq = (AiQpRMACQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMACQ) * rankCount_; + copyInfo->rcq = (AiQpRMACQ *)(void *)(pointer); + + pointer += sizeof(AiQpRMACQ) * rankCount_; + copyInfo->mr = (RdmaMemRegionInfo *)(void *)pointer; + + auto ret = aclrtMemcpy(qpInfo_, qpInfoSize_, copyInfo, qpInfoSize_, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != 0) { + SHM_LOG_ERROR("copy qp info to device failed: " << ret); + return SHMEM_INNER_ERROR; + } + SHM_LOG_INFO("copy qp info success"); + + return SHMEM_SUCCESS; +} + +void DeviceQpManager::CopyAiWQInfo(struct AiQpRMAWQ &dest, const struct ai_data_plane_wq &src, DBMode dbMode, + uint32_t sl) noexcept +{ + dest.wqn = src.wqn; + dest.bufAddr = src.buf_addr; + dest.wqeSize = src.wqebb_size; + dest.depth = src.depth; + dest.headAddr = src.head_addr; + dest.tailAddr = src.tail_addr; + dest.dbMode = dbMode; + if (dbMode == DBMode::SW_DB) { + dest.dbAddr = src.swdb_addr; + } else if (dbMode == DBMode::HW_DB) { + dest.dbAddr = src.db_reg; + } + dest.sl = sl; + SHM_LOG_INFO("CopyAiWQInfo: wqn = " << dest.wqn << ", bufAddr = " << dest.bufAddr << ", wqeSize = " + << dest.wqeSize << ", depth = " << dest.depth << ", headAddr = " << dest.headAddr + << ", tailAddr = " << dest.tailAddr << ", dbAddr = " << dest.dbAddr + << ", sl = " << dest.sl); +} + +void DeviceQpManager::CopyAiCQInfo(struct AiQpRMACQ &dest, const ai_data_plane_cq &source, DBMode dbMode) noexcept +{ + dest.cqn = source.cqn; + dest.bufAddr = source.buf_addr; + dest.cqeSize = source.cqe_size; + dest.depth = source.depth; + dest.headAddr = source.head_addr; + dest.tailAddr = source.tail_addr; + dest.dbMode = dbMode; + if (dbMode == DBMode::SW_DB) { + dest.dbAddr = source.swdb_addr; + } else if (dbMode == DBMode::HW_DB) { + dest.dbAddr = source.db_reg; + } + SHM_LOG_INFO("CopyAiCQInfo: cqn = " << dest.cqn << ", bufAddr = " << dest.bufAddr << ", cqeSize = " + << dest.cqeSize << ", depth = " << dest.depth << ", headAddr = " << dest.headAddr + << ", tailAddr = " << dest.tailAddr << ", dbAddr = " << dest.dbAddr); +} + +void DeviceQpManager::CloseServices() noexcept +{ + CloseServerConnections(); + CloseClientConnections(); +} + +void DeviceQpManager::CloseClientConnections() noexcept +{ + CloseConnections(clientConnections_); +} + +void DeviceQpManager::CloseServerConnections() noexcept +{ + DestroyServerSocket(); + CloseConnections(serverConnections_); +} + +void DeviceQpManager::CloseConnections(std::unordered_map &connections) noexcept +{ + std::vector socketCloseInfos; + for (auto it = connections.begin(); it != connections.end(); ++it) { + if (it->second.qpHandles[CONN_QP_AI_CORE] != nullptr) { + auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_AI_CORE]); + if (ret != 0) { + SHM_LOG_WARN("destroy AI QP to server: " << it->first << " failed: " << ret); + } + it->second.qpHandles[CONN_QP_AI_CORE] = nullptr; + } + + if (it->second.qpHandles[CONN_QP_STARS] != nullptr) { + auto ret = DlHccpApi::RaQpDestroy(it->second.qpHandles[CONN_QP_STARS]); + if (ret != 0) { + SHM_LOG_WARN("destroy stars QP to server: " << it->first << " failed: " << ret); + } + it->second.qpHandles[CONN_QP_STARS] = nullptr; + } + + if (it->second.socketFd != nullptr) { + HccpSocketCloseInfo info; + info.handle = it->second.socketHandle; + info.fd = it->second.socketFd; + info.linger = 0; + socketCloseInfos.push_back(info); + it->second.socketFd = nullptr; + } + } + + if (!socketCloseInfos.empty()) { + auto ret = DlHccpApi::RaSocketBatchClose(socketCloseInfos.data(), socketCloseInfos.size()); + if (ret != 0) { + SHM_LOG_INFO("close sockets return: " << ret); + } + } + + for (auto it = connections.begin(); it != connections.end(); ++it) { + auto ret = DlHccpApi::RaSocketDeinit(it->second.socketHandle); + if (ret != 0) { + SHM_LOG_INFO("deinit socket to server: " << it->first << " return: " << ret); + } + } + + connections.clear(); +} \ No newline at end of file diff --git a/src/modules/transport/rdma/device_qp_manager.h b/src/modules/transport/rdma/device_qp_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..313a9316bcfbb9fc5aeccdce3d343cbf72ec3149 --- /dev/null +++ b/src/modules/transport/rdma/device_qp_manager.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DEVICE_QP_MANAGER_H +#define DEVICE_QP_MANAGER_H + +#include +#include +#include +#include "dl_hccp_api.h" + +class DeviceQpManager { +public: + DeviceQpManager(uint32_t deviceId, uint32_t rankId, uint32_t rankCount, sockaddr_in devNet, + hybm_role_type role) noexcept; + ~DeviceQpManager() noexcept; + + int SetRemoteRankInfo(const std::unordered_map &ranks) noexcept; + int SetLocalMemories(const MemoryRegionMap &mrs) noexcept; + int Startup(void *rdma) noexcept; + void Shutdown() noexcept; + int WaitingConnectionReady() noexcept; + void *GetQpInfoAddress() const noexcept; + void *GetQpHandleWithRankId(uint32_t rankId) const noexcept; + +protected: + void *CreateLocalSocket() noexcept; + int CreateServerSocket() noexcept; + void DestroyServerSocket() noexcept; + +protected: + const uint32_t deviceId_; + const uint32_t rankId_; + const uint32_t rankCount_; + const hybm_role_type rankRole_; + sockaddr_in deviceAddress_; + void *serverSocketHandle_{nullptr}; + +private: + enum ConnQpType : uint32_t { + CONN_QP_AI_CORE, // AI core使用的QP + CONN_QP_STARS, // Host侧使用STARS驱动的QP + CONN_QP_COUNT + }; + + struct ConnectionChannel { + in_addr remoteIp; + void *socketHandle; + void *socketFd{nullptr}; + void *qpHandles[CONN_QP_COUNT]{}; + HccpAiQpInfo aiQpInfo{}; + int qpStatus{-1}; + + explicit ConnectionChannel(const in_addr ip) : ConnectionChannel{ip, nullptr} {} + ConnectionChannel(in_addr ip, void *sock) : remoteIp{ip}, socketHandle{sock} {} + }; + + bool ReserveQpInfoSpace() noexcept; + int StartServerSide() noexcept; + int StartClientSide() noexcept; + int GenerateWhiteList() noexcept; + int WaitConnectionsReady(std::unordered_map &connections) noexcept; + int CreateQpWaitingReady(std::unordered_map &connections, ConnQpType qpType) noexcept; + int CreateOneQp(ConnQpType qpType, ConnectionChannel &channel) noexcept; + int FillQpInfo(ConnQpType qpType) noexcept; + void CopyAiWQInfo(struct AiQpRMAWQ &dest, const struct ai_data_plane_wq &src, DBMode dbMode, uint32_t sl) noexcept; + void CopyAiCQInfo(struct AiQpRMACQ &dest, const ai_data_plane_cq &source, DBMode dbMode) noexcept; + void CloseServices() noexcept; + void CloseClientConnections() noexcept; + void CloseServerConnections() noexcept; + void CloseConnections(std::unordered_map &connections) noexcept; + + bool started_{false}; + int serverConnectResult{-1}; + int clientConnectResult{-1}; + uint32_t qpInfoSize_{0}; + void *rdmaHandle_{nullptr}; + std::unordered_map currentRanksInfo_; + MemoryRegionMap currentLocalMrs_; + AiQpRMAQueueInfo *qpInfo_{nullptr}; + std::unordered_map clientConnections_; + std::unordered_map serverConnections_; +}; + +#endif // DEVICE_QP_MANAGER_H \ No newline at end of file diff --git a/src/modules/transport/rdma/dl_hccp_api.cpp b/src/modules/transport/rdma/dl_hccp_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbcc92bec8329d5a43e346934ed61770b052c154 --- /dev/null +++ b/src/modules/transport/rdma/dl_hccp_api.cpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include "dl_hccp_api.h" + +bool DlHccpApi::gLoaded = false; +std::mutex DlHccpApi::gMutex; +void *DlHccpApi::raHandle; +void *DlHccpApi::tsdHandle; + +const char *DlHccpApi::gRaLibName = "libra.so"; +const char *DlHccpApi::gTsdLibName = "libtsdclient.so"; + +raRdevGetHandleFunc DlHccpApi::gRaRdevGetHandle; + +raInitFunc DlHccpApi::gRaInit; +raGetInterfaceVersionFunc DlHccpApi::gRaGetInterfaceVersion; +raSocketInitFunc DlHccpApi::gRaSocketInit; +raSocketDeinitFunc DlHccpApi::gRaSocketDeinit; +raRdevInitV2Func DlHccpApi::gRaRdevInitV2; +raSocketBatchConnectFunc DlHccpApi::gRaSocketBatchConnect; +raSocketBatchCloseFunc DlHccpApi::gRaSocketBatchClose; +raSocketBatchAbortFunc DlHccpApi::gRaSocketBatchAbort; +raSocketListenStartFunc DlHccpApi::gRaSocketListenStart; +raSocketListenStopFunc DlHccpApi::gRaSocketListenStop; +raGetSocketsFunc DlHccpApi::gRaGetSockets; +raSocketSendFunc DlHccpApi::gRaSocketSend; +raSocketRecvFunc DlHccpApi::gRaSocketRecv; +raGetIfNumFunc DlHccpApi::gRaGetIfNum; +raGetIfAddrsFunc DlHccpApi::gRaGetIfAddrs; +raSocketWhiteListAddFunc DlHccpApi::gRaSocketWhiteListAdd; +raSocketWhiteListDelFunc DlHccpApi::gRaSocketWhiteListDel; +raQpCreateFunc DlHccpApi::gRaQpCreate; +raQpAiCreateFunc DlHccpApi::gRaQpAiCreate; +raQpDestroyFunc DlHccpApi::gRaQpDestroy; +raGetQpStatusFunc DlHccpApi::gRaGetQpStatus; +raQpConnectAsyncFunc DlHccpApi::gRaQpConnectAsync; +raRegisterMrFunc DlHccpApi::gRaRegisterMR; +raDeregisterMrFunc DlHccpApi::gRaDeregisterMR; +raMrRegFunc DlHccpApi::gRaMrReg; +raMrDeregFunc DlHccpApi::gRaMrDereg; +raSendWrFunc DlHccpApi::gRaSendWr; +raPollCqFunc DlHccpApi::gRaPollCq; + +tsdOpenFunc DlHccpApi::gTsdOpen; + +Result DlHccpApi::LoadLibrary() +{ + std::lock_guard guard(gMutex); + if (gLoaded) { + return 0; + } + + raHandle = dlopen(gRaLibName, RTLD_NOW); + if (raHandle == nullptr) { + std::cout << "Failed to open library [" + << gRaLibName + << "], please source ascend-toolkit set_env.sh, or add ascend driver lib path into LD_LIBRARY_PATH," + << " error: " << dlerror() << std::endl; + return -1; + } + + tsdHandle = dlopen(gTsdLibName, RTLD_NOW); + if (tsdHandle == nullptr) { + std::cout << "Failed to open library [" + << gTsdLibName + << "], please source ascend-toolkit set_env.sh, or add ascend driver lib path into LD_LIBRARY_PATH," + << " error: " << dlerror() << std::endl; + dlclose(raHandle); + raHandle = nullptr; + return -1; + } + + /* load sym */ + DL_LOAD_SYM(gRaGetInterfaceVersion, raGetInterfaceVersionFunc, raHandle, "ra_get_interface_version"); + DL_LOAD_SYM(gRaSocketInit, raSocketInitFunc, raHandle, "ra_socket_init"); + DL_LOAD_SYM(gRaInit, raInitFunc, raHandle, "ra_init"); + DL_LOAD_SYM(gRaSocketDeinit, raSocketDeinitFunc, raHandle, "ra_socket_deinit"); + DL_LOAD_SYM(gRaRdevInitV2, raRdevInitV2Func, raHandle, "ra_rdev_init_v2"); + DL_LOAD_SYM(gRaRdevGetHandle, raRdevGetHandleFunc, raHandle, "ra_rdev_get_handle"); + DL_LOAD_SYM(gRaSocketBatchConnect, raSocketBatchConnectFunc, raHandle, "ra_socket_batch_connect"); + DL_LOAD_SYM(gRaSocketBatchClose, raSocketBatchCloseFunc, raHandle, "ra_socket_batch_close"); + DL_LOAD_SYM(gRaSocketBatchAbort, raSocketBatchAbortFunc, raHandle, "ra_socket_batch_abort"); + DL_LOAD_SYM(gRaSocketListenStart, raSocketListenStartFunc, raHandle, "ra_socket_listen_start"); + DL_LOAD_SYM(gRaSocketListenStop, raSocketListenStopFunc, raHandle, "ra_socket_listen_stop"); + DL_LOAD_SYM(gRaGetSockets, raGetSocketsFunc, raHandle, "ra_get_sockets"); + DL_LOAD_SYM(gRaSocketSend, raSocketSendFunc, raHandle, "ra_socket_send"); + DL_LOAD_SYM(gRaSocketRecv, raSocketRecvFunc, raHandle, "ra_socket_recv"); + DL_LOAD_SYM(gRaGetIfNum, raGetIfNumFunc, raHandle, "ra_get_ifnum"); + DL_LOAD_SYM(gRaGetIfAddrs, raGetIfAddrsFunc, raHandle, "ra_get_ifaddrs"); + DL_LOAD_SYM(gRaSocketWhiteListAdd, raSocketWhiteListAddFunc, raHandle, "ra_socket_white_list_add"); + DL_LOAD_SYM(gRaSocketWhiteListDel, raSocketWhiteListDelFunc, raHandle, "ra_socket_white_list_del"); + DL_LOAD_SYM(gRaQpCreate, raQpCreateFunc, raHandle, "ra_qp_create"); + DL_LOAD_SYM(gRaQpAiCreate, raQpAiCreateFunc, raHandle, "ra_ai_qp_create"); + DL_LOAD_SYM(gRaQpDestroy, raQpDestroyFunc, raHandle, "ra_qp_destroy"); + DL_LOAD_SYM(gRaGetQpStatus, raGetQpStatusFunc, raHandle, "ra_get_qp_status"); + DL_LOAD_SYM(gRaQpConnectAsync, raQpConnectAsyncFunc, raHandle, "ra_qp_connect_async"); + DL_LOAD_SYM(gRaRegisterMR, raRegisterMrFunc, raHandle, "ra_register_mr"); + DL_LOAD_SYM(gRaDeregisterMR, raDeregisterMrFunc, raHandle, "ra_deregister_mr"); + DL_LOAD_SYM(gRaMrReg, raMrRegFunc, raHandle, "ra_mr_reg"); + DL_LOAD_SYM(gRaMrDereg, raMrDeregFunc, raHandle, "ra_mr_dereg"); + DL_LOAD_SYM(gRaSendWr, raSendWrFunc, raHandle, "ra_send_wr"); + DL_LOAD_SYM(gRaPollCq, raPollCqFunc, raHandle, "ra_poll_cq"); + + DL_LOAD_SYM(gTsdOpen, tsdOpenFunc, tsdHandle, "TsdOpen"); + SHM_LOG_INFO("LoadLibrary for DlHccpApi success"); + gLoaded = true; + return 0; +} + +void DlHccpApi::CleanupLibrary() +{ + std::lock_guard guard(gMutex); + if (!gLoaded) { + return; + } + + gRaRdevGetHandle = nullptr; + gRaInit = nullptr; + gRaGetInterfaceVersion = nullptr; + gRaSocketInit = nullptr; + gRaSocketDeinit = nullptr; + gRaRdevInitV2 = nullptr; + gRaSocketBatchConnect = nullptr; + gRaSocketBatchClose = nullptr; + gRaSocketBatchAbort = nullptr; + gRaSocketListenStart = nullptr; + gRaSocketListenStop = nullptr; + gRaGetSockets = nullptr; + gRaSocketSend = nullptr; + gRaSocketRecv = nullptr; + gRaGetIfNum = nullptr; + gRaGetIfAddrs = nullptr; + gRaSocketWhiteListAdd = nullptr; + gRaSocketWhiteListDel = nullptr; + gRaQpCreate = nullptr; + gRaQpAiCreate = nullptr; + gRaQpDestroy = nullptr; + gRaGetQpStatus = nullptr; + gRaQpConnectAsync = nullptr; + gRaRegisterMR = nullptr; + gRaDeregisterMR = nullptr; + gRaMrReg = nullptr; + gRaMrDereg = nullptr; + gTsdOpen = nullptr; + gRaSendWr = nullptr; + gRaPollCq = nullptr; + + if (raHandle != nullptr) { + dlclose(raHandle); + raHandle = nullptr; + } + + if (tsdHandle != nullptr) { + dlclose(tsdHandle); + tsdHandle = nullptr; + } + gLoaded = false; +} \ No newline at end of file diff --git a/src/modules/transport/rdma/dl_hccp_api.h b/src/modules/transport/rdma/dl_hccp_api.h new file mode 100644 index 0000000000000000000000000000000000000000..b8722bca25eb1a85a4da86c74602ad45a7dc9fcf --- /dev/null +++ b/src/modules/transport/rdma/dl_hccp_api.h @@ -0,0 +1,248 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DL_HCCP_API_H +#define DL_HCCP_API_H + +#include +#include "dl_hccp_def.h" + +using namespace shm; + +using raRdevGetHandleFunc = int (*)(uint32_t, void **); + +using raGetInterfaceVersionFunc = int (*)(uint32_t, uint32_t, uint32_t *); +using raInitFunc = int (*)(const HccpRaInitConfig *); +using raSocketInitFunc = int (*)(HccpNetworkMode, HccpRdev, void **); +using raSocketDeinitFunc = int (*)(void *); +using raRdevInitV2Func = int (*)(HccpRdevInitInfo, HccpRdev, void **); +using raSocketBatchConnectFunc = int (*)(HccpSocketConnectInfo[], uint32_t); +using raSocketBatchCloseFunc = int (*)(HccpSocketCloseInfo[], uint32_t); +using raSocketBatchAbortFunc = int (*)(HccpSocketConnectInfo[], uint32_t); +using raSocketListenStartFunc = int (*)(HccpSocketListenInfo[], uint32_t); +using raSocketListenStopFunc = int (*)(HccpSocketListenInfo[], uint32_t); +using raGetSocketsFunc = int (*)(uint32_t, HccpSocketInfo[], uint32_t, uint32_t *); +using raSocketSendFunc = int (*)(const void *, const void *, uint64_t, uint64_t *); +using raSocketRecvFunc = int (*)(const void *, void *, uint64_t, uint64_t *); +using raGetIfNumFunc = int (*)(const HccpRaGetIfAttr *, uint32_t *); +using raGetIfAddrsFunc = int (*)(const HccpRaGetIfAttr *, HccpInterfaceInfo[], uint32_t *); +using raSocketWhiteListAddFunc = int (*)(void *, const HccpSocketWhiteListInfo[], uint32_t num); +using raSocketWhiteListDelFunc = int (*)(void *, const HccpSocketWhiteListInfo[], uint32_t num); +using raQpCreateFunc = int (*)(void *, int, int, void **); +using raQpAiCreateFunc = int (*)(void *, const HccpQpExtAttrs *, HccpAiQpInfo *, void **); +using raQpDestroyFunc = int (*)(void *); +using raGetQpStatusFunc = int (*)(void *, int *); +using raQpConnectAsyncFunc = int (*)(void *, const void *); +using raRegisterMrFunc = int (*)(const void *, HccpMrInfo *, void **); +using raDeregisterMrFunc = int (*)(const void *, void *); +using raMrRegFunc = int (*)(void *, HccpMrInfo *); +using raMrDeregFunc = int (*)(void *, HccpMrInfo *); +using raSendWrFunc = int (*)(void *, send_wr *, send_wr_rsp *); +using tsdOpenFunc = uint32_t (*)(uint32_t, uint32_t); +using raPollCqFunc = int (*)(void *, bool, uint32_t, void *); + +class DlHccpApi { +public: + static Result LoadLibrary(); + static void CleanupLibrary(); + + static inline int RaGetInterfaceVersion(uint32_t deviceId, uint32_t opcode, uint32_t &version) + { + return gRaGetInterfaceVersion(deviceId, opcode, &version); + } + + static inline int RaSocketInit(HccpNetworkMode mode, const HccpRdev &rdev, void *&socketHandle) + { + return gRaSocketInit(mode, rdev, &socketHandle); + } + + static inline int RaInit(const HccpRaInitConfig &config) + { + return gRaInit(&config); + } + + static inline int RaSocketDeinit(void *socketHandle) + { + return gRaSocketDeinit(socketHandle); + } + + static inline int RaRdevInitV2(const HccpRdevInitInfo &info, const HccpRdev &rdev, void *&rdmaHandle) + { + return gRaRdevInitV2(info, rdev, &rdmaHandle); + } + + static inline int RaRdevGetHandle(uint32_t deviceId, void *&rdmaHandle) + { + return gRaRdevGetHandle(deviceId, &rdmaHandle); + } + + static inline int RaSocketBatchConnect(HccpSocketConnectInfo conn[], uint32_t num) + { + return gRaSocketBatchConnect(conn, num); + } + + static inline int RaSocketBatchClose(HccpSocketCloseInfo conn[], uint32_t num) + { + return gRaSocketBatchClose(conn, num); + } + + static inline int RaSocketBatchAbort(HccpSocketConnectInfo conn[], uint32_t num) + { + return gRaSocketBatchAbort(conn, num); + } + + static inline int RaSocketListenStart(HccpSocketListenInfo conn[], uint32_t num) + { + return gRaSocketListenStart(conn, num); + } + + static inline int RaSocketListenStop(HccpSocketListenInfo conn[], uint32_t num) + { + return gRaSocketListenStop(conn, num); + } + + static inline int RaGetSockets(uint32_t role, HccpSocketInfo conn[], uint32_t num, uint32_t &connectedNum) + { + return gRaGetSockets(role, conn, num, &connectedNum); + } + + static inline int RaSocketSend(const void *fd, const void *data, uint64_t size, uint64_t &sent) + { + return gRaSocketSend(fd, data, size, &sent); + } + + static inline int RaSocketRecv(const void *fd, void *data, uint64_t size, uint64_t &received) + { + return gRaSocketRecv(fd, data, size, &received); + } + + static inline int RaGetIfNum(const HccpRaGetIfAttr &config, uint32_t &num) + { + return gRaGetIfNum(&config, &num); + } + + static inline int RaGetIfAddrs(const HccpRaGetIfAttr &config, HccpInterfaceInfo infos[], uint32_t &num) + { + return gRaGetIfAddrs(&config, infos, &num); + } + + static inline int RaSocketWhiteListAdd(void *socket, const HccpSocketWhiteListInfo list[], uint32_t num) + { + return gRaSocketWhiteListAdd(socket, list, num); + } + + static inline int RaSocketWhiteListDel(void *socket, const HccpSocketWhiteListInfo list[], uint32_t num) + { + return gRaSocketWhiteListAdd(socket, list, num); + } + + static inline int RaQpCreate(void *rdmaHandle, int flag, int qpMode, void *&qpHandle) + { + return gRaQpCreate(rdmaHandle, flag, qpMode, &qpHandle); + } + + static inline int RaQpAiCreate(void *rdmaHandle, const HccpQpExtAttrs &attrs, HccpAiQpInfo &info, void *&qpHandle) + { + return gRaQpAiCreate(rdmaHandle, &attrs, &info, &qpHandle); + } + + static inline int RaQpDestroy(void *qpHandle) + { + return gRaQpDestroy(qpHandle); + } + + static inline int RaGetQpStatus(void *qpHandle, int &status) + { + return gRaGetQpStatus(qpHandle, &status); + } + + static inline int RaQpConnectAsync(void *qp, const void *socketFd) + { + return gRaQpConnectAsync(qp, socketFd); + } + + static inline int RaRegisterMR(const void *rdmaHandle, HccpMrInfo *info, void *&mrHandle) + { + return gRaRegisterMR(rdmaHandle, info, &mrHandle); + } + + static inline int RaDeregisterMR(const void *rdmaHandle, void *mrHandle) + { + return gRaDeregisterMR(rdmaHandle, mrHandle); + } + + static inline int RaMrReg(void *qpHandle, HccpMrInfo &info) + { + return gRaMrReg(qpHandle, &info); + } + + static inline int RaMrDereg(void *qpHandle, HccpMrInfo &info) + { + return gRaMrDereg(qpHandle, &info); + } + + static inline int RaSendWr(void *qp_handle, struct send_wr *wr, struct send_wr_rsp *op_rsp) + { + return gRaSendWr(qp_handle, wr, op_rsp); + } + + static inline int RaPollCq(void *qp_handle, bool is_send_cq, unsigned int num_entries, void *wc) + { + return gRaPollCq(qp_handle, is_send_cq, num_entries, wc); + } + + static inline uint32_t TsdOpen(uint32_t deviceId, uint32_t rankSize) + { + return gTsdOpen(deviceId, rankSize); + } + +private: + static std::mutex gMutex; + static bool gLoaded; + static void *raHandle; + static void *tsdHandle; + static const char *gRaLibName; + static const char *gTsdLibName; + + static raRdevGetHandleFunc gRaRdevGetHandle; + + static raGetInterfaceVersionFunc gRaGetInterfaceVersion; + static raInitFunc gRaInit; + static raSocketInitFunc gRaSocketInit; + static raSocketDeinitFunc gRaSocketDeinit; + static raRdevInitV2Func gRaRdevInitV2; + static raSocketBatchConnectFunc gRaSocketBatchConnect; + static raSocketBatchCloseFunc gRaSocketBatchClose; + static raSocketBatchAbortFunc gRaSocketBatchAbort; + static raSocketListenStartFunc gRaSocketListenStart; + static raSocketListenStopFunc gRaSocketListenStop; + static raGetSocketsFunc gRaGetSockets; + static raSocketSendFunc gRaSocketSend; + static raSocketRecvFunc gRaSocketRecv; + static raGetIfNumFunc gRaGetIfNum; + static raGetIfAddrsFunc gRaGetIfAddrs; + static raSocketWhiteListAddFunc gRaSocketWhiteListAdd; + static raSocketWhiteListDelFunc gRaSocketWhiteListDel; + static raQpCreateFunc gRaQpCreate; + static raQpAiCreateFunc gRaQpAiCreate; + static raQpDestroyFunc gRaQpDestroy; + static raGetQpStatusFunc gRaGetQpStatus; + static raQpConnectAsyncFunc gRaQpConnectAsync; + static raRegisterMrFunc gRaRegisterMR; + static raDeregisterMrFunc gRaDeregisterMR; + static raMrRegFunc gRaMrReg; + static raMrDeregFunc gRaMrDereg; + static raSendWrFunc gRaSendWr; + static raPollCqFunc gRaPollCq; + + static tsdOpenFunc gTsdOpen; +}; + +#endif // DL_HCCP_API_H \ No newline at end of file diff --git a/src/modules/transport/rdma/dl_hccp_def.h b/src/modules/transport/rdma/dl_hccp_def.h new file mode 100644 index 0000000000000000000000000000000000000000..4d905f8388f84e449235ca5b54fa2fbabeb61be1 --- /dev/null +++ b/src/modules/transport/rdma/dl_hccp_def.h @@ -0,0 +1,649 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DL_HCCP_DEF_H +#define DL_HCCP_DEF_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "shmem_api.h" +#include "common/shmemi_functions.h" +#include "common/shmemi_host_types.h" + +using Result = int32_t; + +constexpr uint32_t HCCL_ROOT_INFO_BYTES = 256; // 4108: root info length +constexpr uint32_t HCCP_SOCK_CONN_TAG_SIZE = 192; +constexpr uint32_t HCCP_MAX_INTERFACE_NAME_LEN = 256; + +constexpr uint64_t EXPORT_INFO_MAGIC = 0xAABB1234FFFFEEEEUL; +constexpr uint64_t EXPORT_SLICE_MAGIC = 0xAABB1234FFFFBBBBUL; +constexpr uint64_t EXPORT_INFO_VERSION = 0x1UL; + +struct HybmDeviceGlobalMeta { + uint64_t entityCount; + uint64_t reserved[15]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE +}; + +struct HybmDeviceMeta { + uint32_t entityId; + uint32_t rankId; + uint32_t rankSize; + uint32_t extraContextSize; + uint64_t symmetricSize; + uint64_t qpInfoAddress; + uint64_t reserved[12]; // total 128B, equal HYBM_DEVICE_PRE_META_SIZE +}; + + +/** + * @brief HCCL root info + */ +struct HcclRootInfo { + char internal[HCCL_ROOT_INFO_BYTES]; +}; + +struct HccpRaInitConfig { + uint32_t phyId; /**< physical device id */ + uint32_t nicPosition; /**< reference to HccpNetworkMode */ + int hdcType; /**< reference to drvHdcServiceType */ +}; + +/** + * @ingroup libinit + * ip address + */ +union HccpIpAddr { + struct in_addr addr; + struct in6_addr addr6; +}; + +struct HccpRdevInitInfo { + int mode; + uint32_t notifyType; + bool enabled910aLite; /**< true will enable 910A lite, invalid if enabled_2mb_lite is false; default is false */ + bool disabledLiteThread; /**< true will not start lite thread, flag invalid if enabled_910a/2mb_lite is false */ + bool enabled2mbLite; /**< true will enable 2MB lite(include 910A & 910B), default is false */ +}; + +/** + * @ingroup libinit + * hccp operating environment + */ +enum HccpNetworkMode { + NETWORK_PEER_ONLINE = 0, /**< Third-party online mode */ + NETWORK_OFFLINE, /**< offline mode */ + NETWORK_ONLINE, /**< online mode */ +}; + +/** + * @ingroup librdma + * Flag of mr access + */ +enum HccpMrAccessFlags { + RA_ACCESS_LOCAL_WRITE = 1, /**< mr local write access */ + RA_ACCESS_REMOTE_WRITE = (1 << 1), /**< mr remote write access */ + RA_ACCESS_REMOTE_READ = (1 << 2), /**< mr remote read access */ + RA_ACCESS_REDUCE = (1 << 8), +}; + +enum HccpNotifyType { + NO_USE = 0, + NOTIFY = 1, + EVENTID = 2, +}; + +/** + * @ingroup libsocket + * struct of the client socket + */ +struct HccpSocketConnectInfo { + void *handle; /**< socket handle */ + HccpIpAddr remoteIp; /**< IP address of remote socket, [0-7] is reserved for vnic */ + uint16_t port; /**< Socket listening port number */ + char tag[HCCP_SOCK_CONN_TAG_SIZE]; /**< tag must ended by '\0' */ +}; + +inline std::ostream &operator<<(std::ostream &output, const HccpSocketConnectInfo &info) +{ + output << "HccpSocketConnectInfo(socketHandle=" << info.handle << ", remoteIp=" << inet_ntoa(info.remoteIp.addr) + << ", port=" << info.port << ")"; + return output; +} + +/** + * @ingroup libsocket + * Details about socket after socket is linked + */ +struct HccpSocketCloseInfo { + void *handle; /**< socket handle */ + void *fd; /**< fd handle */ + int linger; /**< 0:use(default l_linger is RS_CLOSE_TIMEOUT), others:disuse */ +}; + +/** + * @ingroup libsocket + * struct of the listen info + */ +struct HccpSocketListenInfo { + void *handle; /**< socket handle */ + unsigned int port; /**< Socket listening port number */ + unsigned int phase; /**< refer to enum listen_phase */ + unsigned int err; /**< errno */ +}; + +/** + * @ingroup libsocket + * Details about socket after socket is linked + */ +struct HccpSocketInfo { + void *handle; /**< socket handle */ + void *fd; /**< fd handle */ + HccpIpAddr remoteIp; /**< IP address of remote socket */ + int status; /**< socket status:0 not connected 1:connected 2:connect timeout 3:connecting */ + char tag[HCCP_SOCK_CONN_TAG_SIZE]; /**< tag must ended by '\0' */ +}; + +/** + * @ingroup libinit + * hccp init info + */ +struct HccpRdev { + uint32_t phyId; /**< physical device id */ + int family; /**< AF_INET(ipv4) or AF_INET6(ipv6) */ + HccpIpAddr localIp; +}; + +struct HccpRaGetIfAttr { + uint32_t phyId; /**< physical device id */ + uint32_t nicPosition; /**< reference to network_mode */ + bool isAll; /**< valid when nic_position is NETWORK_OFFLINE. false: get specific rnic ip, true: get all rnic ip */ +}; + +struct HccpIfaddrInfo { + HccpIpAddr ip; /* Address of interface */ + struct in_addr mask; /* Netmask of interface */ +}; + +struct HccpInterfaceInfo { + int family; + int scopeId; + HccpIfaddrInfo ifaddr; /* Address and netmask of interface */ + char ifname[HCCP_MAX_INTERFACE_NAME_LEN]; /* Name of interface */ +}; + +struct HccpSocketWhiteListInfo { + HccpIpAddr remoteIp; /**< IP address of remote */ + uint32_t connLimit; /**< limit of whilte list */ + char tag[HCCP_SOCK_CONN_TAG_SIZE]; /**< tag used for whitelist must ended by '\0' */ +}; + +struct HccpMrInfo { + void *addr; /**< starting address of mr */ + unsigned long long size; /**< size of mr */ + int access; /**< access of mr, reference to HccpMrAccessFlags */ + unsigned int lkey; /**< local addr access key */ + unsigned int rkey; /**< remote addr access key */ +}; + +struct HccpCqExtAttr { + int sendCqDepth; + int recvDqDepth; + int sendCqCompVector; + int recvCqCompVector; +}; + +enum ibv_qp_type { + IBV_QPT_RC = 2, + IBV_QPT_UC, + IBV_QPT_UD, + IBV_QPT_RAW_PACKET = 8, + IBV_QPT_XRC_SEND = 9, + IBV_QPT_XRC_RECV, + IBV_QPT_DRIVER = 0xff, +}; + +enum ibv_wc_status { + IBV_WC_SUCCESS, + IBV_WC_LOC_LEN_ERR, + IBV_WC_LOC_QP_OP_ERR, + IBV_WC_LOC_EEC_OP_ERR, + IBV_WC_LOC_PROT_ERR, + IBV_WC_WR_FLUSH_ERR, + IBV_WC_MW_BIND_ERR, + IBV_WC_BAD_RESP_ERR, + IBV_WC_LOC_ACCESS_ERR, + IBV_WC_REM_INV_REQ_ERR, + IBV_WC_REM_ACCESS_ERR, + IBV_WC_REM_OP_ERR, + IBV_WC_RETRY_EXC_ERR, + IBV_WC_RNR_RETRY_EXC_ERR, + IBV_WC_LOC_RDD_VIOL_ERR, + IBV_WC_REM_INV_RD_REQ_ERR, + IBV_WC_REM_ABORT_ERR, + IBV_WC_INV_EECN_ERR, + IBV_WC_INV_EEC_STATE_ERR, + IBV_WC_FATAL_ERR, + IBV_WC_RESP_TIMEOUT_ERR, + IBV_WC_GENERAL_ERR +}; + +enum ibv_wc_opcode { + IBV_WC_SEND, + IBV_WC_RDMA_WRITE, + IBV_WC_RDMA_READ, + IBV_WC_COMP_SWAP, + IBV_WC_FETCH_ADD, + IBV_WC_BIND_MW, + /* + * Set value of IBV_WC_RECV so consumers can test if a completion is a + * receive by testing (opcode & IBV_WC_RECV). + */ + IBV_WC_RECV = 1 << 7, + IBV_WC_RECV_RDMA_WITH_IMM +}; + +struct ibv_wc { + uint64_t wr_id; + enum ibv_wc_status status; + enum ibv_wc_opcode opcode; + uint32_t vendor_err; + uint32_t byte_len; + uint32_t imm_data; /* in network byte order */ + uint32_t qp_num; + uint32_t src_qp; + int wc_flags; + uint16_t pkey_index; + uint16_t slid; + uint8_t sl; + uint8_t dlid_path_bits; +}; + +struct ibv_qp_cap { + uint32_t max_send_wr; + uint32_t max_recv_wr; + uint32_t max_send_sge; + uint32_t max_recv_sge; + uint32_t max_inline_data; +}; + +struct ibv_qp_init_attr { + void *qp_context; + struct ibv_cq *send_cq; + struct ibv_cq *recv_cq; + struct ibv_srq *srq; + struct ibv_qp_cap cap; + enum ibv_qp_type qp_type; + int sq_sig_all; +}; + +union ai_data_plane_cstm_flag { + struct { + uint32_t cq_cstm : 1; // 0: hccp poll cq; 1: caller poll cq + uint32_t reserved : 31; + } bs; + uint32_t value; +}; + +struct HccpQpExtAttrs { + int qpMode; + // cq attr + HccpCqExtAttr cqAttr; + // qp attr + struct ibv_qp_init_attr qp_attr; + // version control and reserved + int version; + int mem_align; // 0,1:4KB, 2:2MB + uint32_t udp_sport; + union ai_data_plane_cstm_flag data_plane_flag; // only valid in ra_ai_qp_create + uint32_t reserved[29]; +}; + +struct ai_data_plane_wq { + unsigned wqn; + unsigned long long buf_addr; + unsigned int wqebb_size; + unsigned int depth; + unsigned long long head_addr; + unsigned long long tail_addr; + unsigned long long swdb_addr; + unsigned long long db_reg; + unsigned int reserved[8U]; +}; + +struct ai_data_plane_cq { + unsigned int cqn; + unsigned long long buf_addr; + unsigned int cqe_size; + unsigned int depth; + unsigned long long head_addr; + unsigned long long tail_addr; + unsigned long long swdb_addr; + unsigned long long db_reg; + unsigned int reserved[2U]; +}; + +struct ai_data_plane_info { + struct ai_data_plane_wq sq; + struct ai_data_plane_wq rq; + struct ai_data_plane_cq scq; + struct ai_data_plane_cq rcq; + unsigned int reserved[8U]; +}; + +struct HccpAiQpInfo { + unsigned long long aiQpAddr; // refer to struct ibv_qp * + unsigned int sqIndex; // index of sq + unsigned int dbIndex; // index of db + + // below cq related info valid when data_plane_flag.bs.cq_cstm was 1 + unsigned long long ai_scq_addr; // refer to struct ibv_cq *scq + unsigned long long ai_rcq_addr; // refer to struct ibv_cq *rcq + struct ai_data_plane_info data_plane_info; +}; + +enum class DBMode : int32_t { INVALID_DB = -1, HW_DB = 0, SW_DB }; + +struct AiQpRMAWQ { + uint32_t wqn{0}; + uint64_t bufAddr{0}; + uint32_t wqeSize{0}; + uint32_t depth{0}; + uint64_t headAddr{0}; + uint64_t tailAddr{0}; + DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw + uint64_t dbAddr{0}; + uint32_t sl{0}; +}; + +struct AiQpRMACQ { + uint32_t cqn{0}; + uint64_t bufAddr{0}; + uint32_t cqeSize{0}; + uint32_t depth{0}; + uint64_t headAddr{0}; + uint64_t tailAddr{0}; + DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw + uint64_t dbAddr{0}; +}; + +struct RdmaMemRegionInfo { + uint64_t size{0}; // size of the memory region + uint64_t addr{0}; // start address of the memory region + uint32_t lkey{0}; + uint32_t rkey{0}; // key of the memory region +}; + +struct AiQpRMAQueueInfo { + uint32_t count; + struct AiQpRMAWQ *sq; + struct AiQpRMAWQ *rq; + struct AiQpRMACQ *scq; + struct AiQpRMACQ *rcq; + RdmaMemRegionInfo *mr; +}; + +/** + * @ingroup librdma + * Scatter and gather element + */ +struct sg_list { + uint64_t addr; /**< address of buf */ + uint32_t len; /**< len of buf */ + uint32_t lkey; /**< local addr access key */ +}; + +/** + * @ingroup librdma + * RDMA work request + */ +struct send_wr { + struct sg_list *buf_list; /**< list of sg */ + uint16_t buf_num; /**< num of buf_list */ + uint64_t dst_addr; /**< destination address */ + uint32_t rkey; /**< remote address access key */ + uint32_t op; /**< operations of RDMA supported:RDMA_WRITE:0 */ + int send_flag; /**< reference to ra_send_flags */ +}; + +/** + * @ingroup librdma + * wqe template info + */ +struct wqe_info { + unsigned int sq_index; /**< index of sq */ + unsigned int wqe_index; /**< index of wqe */ +}; + +enum ra_send_flags { + RA_SEND_FENCE = 1 << 0, /**< RDMA operation with fence */ + RA_SEND_SIGNALED = 1 << 1, /**< RDMA operation with signaled */ + RA_SEND_SOLICITED = 1 << 2, /**< RDMA operation with solicited */ + RA_SEND_INLINE = 1 << 3, /**< RDMA operation with inline */ +}; +/** + * @ingroup librdma + * doorbell info + */ +struct db_info { + unsigned int db_index; /**< index of db */ + unsigned long db_info; /**< db content */ +}; + +/** + * @ingroup librdma + * respond of sending work request + */ +struct send_wr_rsp { + union { + struct wqe_info wqe_tmp; /**< wqe template info */ + struct db_info db; /**< doorbell info */ + }; +}; +/** + * @brief handle to HCCL communicator + */ +typedef void *HcclComm; + +// macro for gcc optimization for prediction of if/else +#ifndef LIKELY +#define LIKELY(x) (__builtin_expect(!!(x), 1) != 0) +#endif + +#ifndef UNLIKELY +#define UNLIKELY(x) (__builtin_expect(!!(x), 0) != 0) +#endif + +#define HYBM_API __attribute__((visibility("default"))) + +#define DL_LOAD_SYM(TARGET_FUNC_VAR, TARGET_FUNC_TYPE, FILE_HANDLE, SYMBOL_NAME) \ + do { \ + TARGET_FUNC_VAR = (TARGET_FUNC_TYPE)dlsym(FILE_HANDLE, SYMBOL_NAME); \ + if ((TARGET_FUNC_VAR) == nullptr) { \ + std::cout << "Failed to call dlsym to load symbol" << SYMBOL_NAME << std::endl; \ + dlclose(FILE_HANDLE); \ + return -1; \ + } \ + } while (0) + + +enum HybmGvaVersion : uint32_t { + HYBM_GVA_V1 = 0, + HYBM_GVA_V2 = 1, + HYBM_GVA_V3 = 2, + HYBM_GVA_UNKNOWN +}; + +inline std::ostream &operator<<(std::ostream &output, const HccpRaInitConfig &config) +{ + output << "HccpRaInitConfig(phyId=" << config.phyId << ", nicPosition=" << config.nicPosition + << ", hdcType=" << config.hdcType << ")"; + return output; +} + +inline std::ostream &operator<<(std::ostream &output, const HccpRdevInitInfo &info) +{ + output << "HccpRdevInitInfo(mode=" << info.mode << ", notify=" << info.notifyType + << ", enabled910aLite=" << info.enabled910aLite << ", disabledLiteThread=" << info.disabledLiteThread + << ", enabled2mbLite=" << info.enabled2mbLite << ")"; + return output; +} + +inline std::ostream &operator<<(std::ostream &output, const HccpRdev &rdev) +{ + output << "HccpRdev(phyId=" << rdev.phyId << ", family=" << rdev.family + << ", rdev.ip=" << inet_ntoa(rdev.localIp.addr) << ")"; + return output; +} + +struct RegMemResult { + uint32_t reserved{0}; + uint64_t address{0}; + uint64_t size{0}; + void *mrHandle{nullptr}; + uint32_t lkey{0}; + uint32_t rkey{0}; + + RegMemResult() = default; + + RegMemResult(uint64_t addr, uint64_t sz, void *hd, uint32_t lk, uint32_t rk) + : address(addr), + size(sz), + mrHandle(hd), + lkey(lk), + rkey(rk) + { + } +}; + +inline std::ostream &operator<<(std::ostream &output, const RegMemResult &mr) +{ + output << "RegMemResult(address = " << mr.address << ", size = " << mr.size + << ", lkey = " << mr.lkey << ", rkey = " << mr.rkey << ")"; + return output; +} + +constexpr int32_t REG_MR_ACCESS_FLAG_LOCAL_WRITE = 0x1; +constexpr int32_t REG_MR_ACCESS_FLAG_REMOTE_WRITE = 0x2; +constexpr int32_t REG_MR_ACCESS_FLAG_REMOTE_READ = 0x4; +constexpr int32_t REG_MR_ACCESS_FLAG_BOTH_READ_WRITE = 0x7; + +typedef enum { + HYBM_ROLE_PEER = 0, + HYBM_ROLE_SENDER, + HYBM_ROLE_RECEIVER, + HYBM_ROLE_BUTT +} hybm_role_type; + +struct TransportOptions { + uint32_t rankId; + uint32_t rankCount; + uint32_t protocol; + hybm_role_type role; + int nic; + int32_t dev_id; + int32_t logic_dev_id; +}; + +struct TransportMemoryRegion { + uint64_t addr = 0; /* virtual address of memory could be hbm or host dram */ + uint64_t size = 0; /* size of memory to be registered */ + int32_t access = REG_MR_ACCESS_FLAG_BOTH_READ_WRITE; /* access right by local and remote */ + uint32_t flags = 0; /* optional flags: 加一个flag标识是DRAM还是HBM */ + + friend std::ostream &operator<<(std::ostream &output, const TransportMemoryRegion &mr) + { + output << "MemoryRegion address size=" << mr.size << ", access=" << mr.access + << ", flags=" << mr.flags << ")"; + return output; + } +}; + +using MemoryRegionMap = std::map>; + +struct TransportMemoryKey { + uint32_t keys[16]; + + friend std::ostream &operator<<(std::ostream &output, const TransportMemoryKey &key) + { + output << "MemoryKey" << std::hex; + for (auto i = 0U; i < sizeof(key.keys) / sizeof(key.keys[0]); i++) { + output << "-" << key.keys[i]; + } + output << std::dec; + return output; + } +}; + +#define container_of(ptr, type, member) \ + ({ \ + const typeof(((const type *)0)->member) *__mptr = (ptr); \ + (const type *)(const void *)((const char *)__mptr - offsetof(type, member)); \ + }) + +union RegMemKeyUnion { + TransportMemoryKey commonKey; + RegMemResult deviceKey; +}; + +struct ConnectRankInfo { + hybm_role_type role; + sockaddr_in network; + RegMemResult mr; + + ConnectRankInfo(hybm_role_type r, sockaddr_in nw, RegMemResult memory_region) : role{r}, + network{std::move(nw)}, mr{memory_region} {} +}; + +struct TransportRankPrepareInfo { + std::string nic; + hybm_role_type role{HYBM_ROLE_PEER}; + RegMemResult mr; + + TransportRankPrepareInfo() {} + + TransportRankPrepareInfo(std::string n, RegMemResult k) + : nic{std::move(n)}, role{HYBM_ROLE_PEER}, mr{k} {} + + TransportRankPrepareInfo(std::string n, hybm_role_type r, RegMemResult k) + : nic{std::move(n)}, role{r}, mr{k} {} + + friend std::ostream &operator<<(std::ostream &output, const TransportRankPrepareInfo &info) + { + output << "PrepareInfo(nic=" << info.nic << ", role=" << info.role << ", mr=" << info.mr; + return output; + } +}; + +struct HybmTransPrepareOptions { + std::unordered_map options; + + friend std::ostream &operator<<(std::ostream &output, const HybmTransPrepareOptions &info) + { + output << "PrepareOptions("; + for (auto &op : info.options) { + output << op.first << " => " << op.second << ", "; + } + output << ")"; + return output; + } +}; +#endif // DL_HCCP_DEF_H \ No newline at end of file diff --git a/src/modules/transport/rdma/rdma_manager.h b/src/modules/transport/rdma/rdma_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..9f2fbbe32dcfb69bdfcd15ffe3da5e66d9166fc4 --- /dev/null +++ b/src/modules/transport/rdma/rdma_manager.h @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef RDMA_MANAGER_H +#define RDMA_MANAGER_H + +#include +#include +#include +#include "dl_hccp_api.h" +#include "acl/acl.h" +#include "device_qp_manager.h" + +const char *g_rt_lib_name = "libascendcl.so"; +int (*rtGetLogicDevIdByUserDevIdFunc)(const int32_t, int32_t *const); + +class rdma_manager { +public: + rdma_manager() {} + + ~rdma_manager() { + delete qpManager_; + ClearAllRegisterMRs(); + tsdOpened_ = false; + raInitialized_ = false; + deviceIpRetired_ = false; + storedRdmaHandle_ = nullptr; + } + + void ClearAllRegisterMRs() + { + for (auto it = registerMRS_.begin(); it != registerMRS_.end(); ++it) { + auto ret = DlHccpApi::RaDeregisterMR(rdmaHandle_, it->second.mrHandle); + if (ret != 0) { + SHM_LOG_ERROR("Unregister:" << (void *)(ptrdiff_t)it->first << " : " << it->second << " failed: " << ret); + } + } + registerMRS_.clear(); + } + + int OpenDevice(const TransportOptions &options) + { + int32_t deviceId = options.dev_id; + int32_t logicDeviceId = options.logic_dev_id; + deviceId_ = static_cast(logicDeviceId); + rankId_ = options.rankId; + rankCount_ = options.rankCount; + role_ = options.role; + auto port = options.nic; + if (port < 0 || port > 65536) { + SHM_LOG_ERROR("Failed to parse nic info, nic = " << options.nic); + } + devicePort_ = static_cast(port); + DlHccpApi::LoadLibrary(); + + if (!PrepareOpenDevice(deviceId, rankCount_, deviceIp_, rdmaHandle_, deviceId_)) { + SHM_LOG_ERROR("PrepareOpenDevice failed."); + return -1; + } + + SHM_LOG_INFO("ip = " << inet_ntoa(deviceIp_) << ", port = " << devicePort_); + + sockaddr_in deviceAddr; + deviceAddr.sin_family = AF_INET; + deviceAddr.sin_addr = deviceIp_; + deviceAddr.sin_port = devicePort_; + qpManager_ = new DeviceQpManager(deviceId_, rankId_, rankCount_, deviceAddr, HYBM_ROLE_PEER); + + return 0; + } + + void* GetQPInfoAddr() { + return qpManager_->GetQpInfoAddress(); + } + + in_addr GetDeviceIP() { + return deviceIp_; + } + + Result RegisterMemoryRegion(const TransportMemoryRegion &mr) + { + void *mrHandle = nullptr; + HccpMrInfo info{}; + info.addr = (void *)(ptrdiff_t)mr.addr; + info.size = mr.size; + info.access = mr.access; + auto ret = DlHccpApi::RaRegisterMR(rdmaHandle_, &info, mrHandle); + if (ret != 0) { + SHM_LOG_ERROR("register MR=" << mr << " failed: " << ret); + return SHMEM_INNER_ERROR; + } + + RegMemResult result{mr.addr, mr.size, mrHandle, info.lkey, info.rkey}; + localMR_ = result; + SHM_LOG_DEBUG("register MR result=" << result); + + registerMRS_.emplace(mr.addr, result); + ret = qpManager_->SetLocalMemories(registerMRS_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("qp manager set mr failed: " << ret); + return ret; + } + return 0; + } + + Result UnregisterMemoryRegion(uint64_t addr) + { + auto pos = registerMRS_.find(addr); + if (pos == registerMRS_.end()) { + SHM_LOG_ERROR("input address not register!"); + return SHMEM_INVALID_PARAM; + } + + auto ret = DlHccpApi::RaDeregisterMR(rdmaHandle_, pos->second.mrHandle); + if (ret != 0) { + SHM_LOG_ERROR("Unregister MR addr failed: " << ret); + return SHMEM_INNER_ERROR; + } + + registerMRS_.erase(pos); + return 0; + } + + RegMemResult GetLocalMR() { + return localMR_; + } + + Result Prepare(const HybmTransPrepareOptions &options) + { + SHM_LOG_DEBUG("RdmaTransportManager Prepare with : " << options); + int ret; + if ((ret = CheckPrepareOptions(options)) != 0) { + return ret; + } + + sockaddr_in deviceNetwork; + std::unordered_map rankInfo; + for (auto it = options.options.begin(); it != options.options.end(); ++it) { + ret = ipPortStringToSockaddr(it->second.nic, deviceNetwork); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("parse networks[" << it->first << "]=" << it->second.nic << " failed: " << ret); + return SHMEM_INVALID_PARAM; + } + + rankInfo.emplace(it->first, ConnectRankInfo{it->second.role, deviceNetwork, it->second.mr}); + } + + ret = qpManager_->SetRemoteRankInfo(rankInfo); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("qp manager set remote rank info failed: " << ret); + return ret; + } + + ret = qpManager_->Startup(rdmaHandle_); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("qp manager startup failed: " << ret); + return ret; + } + + return SHMEM_SUCCESS; + } + + Result Connect() + { + auto ret = AsyncConnect(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("AsyncConnect() failed: " << ret); + return ret; + } + + ret = WaitForConnected(-1L); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("WaitForConnected(-1) failed: " << ret); + return ret; + } + + return SHMEM_SUCCESS; + } +private: + bool OpenTsd(uint32_t deviceId, uint32_t rankCount) + { + if (tsdOpened_) { + SHM_LOG_INFO("tsd already opened."); + return true; + } + + auto res = DlHccpApi::TsdOpen(deviceId, rankCount); + if (res != 0) { + SHM_LOG_ERROR("TsdOpen for (deviceId=" << deviceId << ", rankCount=" << rankCount << ") failed: " << res); + return false; + } + + SHM_LOG_DEBUG("open tsd for device id: " << deviceId << ", rank count: " << rankCount << " success."); + tsdOpened_ = true; + return true; + } + + bool RaInit(uint32_t deviceId) + { + if (raInitialized_) { + SHM_LOG_INFO("ra already initialized."); + return true; + } + + HccpRaInitConfig initConfig{}; + initConfig.phyId = deviceId; + initConfig.nicPosition = NETWORK_OFFLINE; + initConfig.hdcType = 6; // HDC_SERVICE_TYPE_RDMA = 6 + SHM_LOG_DEBUG("RaInit=" << initConfig); + auto ret = DlHccpApi::RaInit(initConfig); + if (ret != 0) { + SHM_LOG_ERROR("Hccp Init RA failed: " << ret); + return false; + } + + SHM_LOG_DEBUG("ra init for device id: " << deviceId << " success."); + raInitialized_ = true; + return true; + } + + bool RetireDeviceIp(uint32_t deviceId, in_addr &deviceIp) + { + static in_addr retiredIp{}; + + if (deviceIpRetired_) { + SHM_LOG_INFO("device ip already retired : " << inet_ntoa(retiredIp)); + deviceIp = retiredIp; + return true; + } + + uint32_t count = 0; + std::vector infos; + + HccpRaGetIfAttr config; + config.phyId = deviceId; + config.nicPosition = NETWORK_OFFLINE; + config.isAll = true; + + auto ret = DlHccpApi::RaGetIfNum(config, count); + if (ret != 0 || count == 0) { + SHM_LOG_ERROR("get interface count failed: " << ret << ", count: " << count); + return false; + } + + infos.resize(count); + ret = DlHccpApi::RaGetIfAddrs(config, infos.data(), count); + if (ret != 0) { + SHM_LOG_ERROR("get interface information failed: " << ret); + return false; + } + + for (auto &info : infos) { + if (info.family == AF_INET) { + deviceIp = retiredIp = info.ifaddr.ip.addr; + deviceIpRetired_ = true; + SHM_LOG_DEBUG("retire device ip success : " << inet_ntoa(deviceIp)); + return true; + } + } + + SHM_LOG_ERROR("not found network device of AF_INET on NPU."); + return false; + } + + bool RaRdevInit(uint32_t deviceId, in_addr deviceIp, void *&rdmaHandle) + { + if (storedRdmaHandle_ != nullptr) { + SHM_LOG_INFO("ra rdev already initialized."); + rdmaHandle = storedRdmaHandle_; + return true; + } + + HccpRdevInitInfo info{}; + HccpRdev rdev{}; + + info.mode = NETWORK_OFFLINE; + info.notifyType = NOTIFY; + info.enabled2mbLite = true; + rdev.phyId = deviceId; + rdev.family = AF_INET; + rdev.localIp.addr = deviceIp; + SHM_LOG_DEBUG("RaRdevInitV2, info=" << info << "rdev=" << rdev); + auto ret = DlHccpApi::RaRdevInitV2(info, rdev, rdmaHandle); + if (ret != 0) { + SHM_LOG_ERROR("Hccp Init RDev failed: " << ret); + return false; + } + + storedRdmaHandle_ = rdmaHandle; + SHM_LOG_INFO("initialize RDev success."); + return true; + } + + bool PrepareOpenDevice(uint32_t device, uint32_t rankCount, in_addr &deviceIp, void *&rdmaHandle, uint32_t logicDeviceId) + { + // If can get rdmaHanle, maybe the device has beed opened, can try get rdmaHanle directly. + if (DlHccpApi::RaRdevGetHandle(device, rdmaHandle) == 0) { + if (rdmaHandle != nullptr) { + if (!RetireDeviceIp(device, deviceIp)) { + SHM_LOG_ERROR("RetireDeviceIp failed."); + return false; + } + SHM_LOG_DEBUG("Had prepared device and get rdmaHandle success."); + return true; + } + SHM_LOG_INFO("Had prepared device, but RdmaHadle is null, need init again."); + } + if (!OpenTsd(device, rankCount)) { + SHM_LOG_ERROR("open tsd failed."); + return false; + } + + if (!RaInit(logicDeviceId)) { + SHM_LOG_ERROR("RaInit failed."); + return false; + } + + if (!RetireDeviceIp(logicDeviceId, deviceIp)) { + SHM_LOG_ERROR("RetireDeviceIp failed."); + return false; + } + + if (!RaRdevInit(logicDeviceId, deviceIp, rdmaHandle)) { + SHM_LOG_ERROR("RaRdevInit failed."); + return false; + } + return true; + } + + Result AsyncConnect() + { + return SHMEM_SUCCESS; + } + + Result WaitForConnected(int64_t timeoutNs) + { + if (qpManager_ == nullptr) { + SHM_LOG_ERROR("server side not listen!"); + return SHMEM_INNER_ERROR; + } + + auto ret = qpManager_->WaitingConnectionReady(); + if (ret != SHMEM_SUCCESS) { + SHM_LOG_ERROR("wait for server side connected on device failed: " << ret); + return ret; + } + + return SHMEM_SUCCESS; + } + + int CheckPrepareOptions(const HybmTransPrepareOptions &options) + { + if (role_ != HYBM_ROLE_PEER) { + SHM_LOG_INFO("transport role: " << role_ << " check options passed."); + return SHMEM_SUCCESS; + } + + if (options.options.size() > rankCount_) { + SHM_LOG_ERROR("options size():" << options.options.size() << " larger than rank count: " << rankCount_); + return SHMEM_INVALID_PARAM; + } + + if (options.options.find(rankId_) == options.options.end()) { + SHM_LOG_ERROR("options not contains self rankId: " << rankId_); + return SHMEM_INVALID_PARAM; + } + + for (auto it = options.options.begin(); it != options.options.end(); ++it) { + if (it->first >= rankCount_) { + SHM_LOG_ERROR("input options of nics contains rankId:" << it->first << ", rank count: " << rankCount_); + return SHMEM_INVALID_PARAM; + } + } + + return SHMEM_SUCCESS; + } + + Result ipPortStringToSockaddr(const std::string& ip_port_str, sockaddr_in& addr) { + std::memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + + size_t colon_pos = ip_port_str.find(':'); + if (colon_pos == std::string::npos || + colon_pos == 0 || + colon_pos == ip_port_str.length() - 1) { + SHM_LOG_ERROR("format mismatch"); + return SHMEM_INNER_ERROR; + } + + std::string ip_str = ip_port_str.substr(0, colon_pos); + std::string port_str = ip_port_str.substr(colon_pos + 1); + + if (port_str.empty()) { + SHM_LOG_ERROR("Port not available!"); + return SHMEM_INNER_ERROR; + } + + for (char c : port_str) { + if (!std::isdigit(static_cast(c))) { + SHM_LOG_ERROR("Port contains non-digit characters!"); + return SHMEM_INNER_ERROR; + } + } + + char* endptr; + unsigned long port = std::strtoul(port_str.c_str(), &endptr, 10); + + if (endptr == port_str.c_str() || *endptr != '\0' || + port == 0 || port > 65535) { + SHM_LOG_ERROR("Port out of range!"); + return SHMEM_INNER_ERROR; + } + + addr.sin_port = htons(static_cast(port)); + + // Transform IP address + if (inet_pton(AF_INET, ip_str.c_str(), &addr.sin_addr) != 1) { + SHM_LOG_ERROR("IP address invalid!"); + return SHMEM_INNER_ERROR; + } + + return SHMEM_SUCCESS; + } + + uint32_t rankId_{0}; + uint32_t rankCount_{1}; + uint32_t deviceId_{0}; + hybm_role_type role_{HYBM_ROLE_PEER}; + in_addr deviceIp_{0}; + uint16_t devicePort_{0}; + void *rdmaHandle_{nullptr}; + void *storedRdmaHandle_{nullptr}; + bool tsdOpened_{0}; + bool raInitialized_{0}; + bool deviceIpRetired_{0}; + DeviceQpManager* qpManager_; + RegMemResult localMR_; + MemoryRegionMap registerMRS_; +}; + +#endif // RDMA_MANAGER_H \ No newline at end of file diff --git a/src/modules/transport/shmemi_mte.cpp b/src/modules/transport/shmemi_mte.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9e952b6c84189cad2203dde9d22bc3de92d86fc7 --- /dev/null +++ b/src/modules/transport/shmemi_mte.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include +#include "mem/shmemi_heap.h" +#include "shmemi_host_common.h" +#include "internal/host_device/shmemi_types.h" +#include "transport/shmemi_transport.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int shmemi_mte_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport_pe_info_t *my_info, shmemi_transport *t) { + // origin access set to 0. + *access = 0; + + auto sName = aclrtGetSocName(); + std::string socName{sName}; + if (socName.find("Ascend910B") != std::string::npos) { // Ascend910B Topo + // Check Server ID. + if (my_info->server_id != peer_info->server_id) { + *access = 0; + return 0; + } + + // In same node, Check HCCS Connectivity. + int64_t hccs_connected = -1; + SHMEM_CHECK_RET(rtGetPairDevicesInfo(my_info->pe, peer_info->dev_id, 0, &hccs_connected)); + + // In 910B, Flag 0 -> HCCS. + const static int SELF_FLAG = 0; + if (hccs_connected == SELF_FLAG) { + *access = 1; + } + } else if (socName.find("Ascend910_93") != std::string::npos) { // Ascend910_93 Topo + // In same node, Check HCCS Connectivity. + int64_t hccs_connected = -1; + /* TODO: This func now doesn't support 910_93 multiNode HCCS Check. Only Check in the same Node. */ + SHMEM_CHECK_RET(rtGetPairDevicesInfo(my_info->pe, peer_info->dev_id, 0, &hccs_connected)); + + // In 910_93, Flag 0 -> SELF, 5 -> SIO, 6 -> HCCS. + const static int SELF_FLAG = 0; + const static int SIO_FLAG = 5; + const static int HCCS_FLAG = 6; + if (hccs_connected == SELF_FLAG || hccs_connected == SIO_FLAG || hccs_connected == HCCS_FLAG) { + *access = 1; + } + } + + return 0; +} + +int shmemi_mte_connect_peers(shmemi_transport *t, int *selected_dev_ids, int num_selected_devs, shmemi_device_host_state_t *g_state) { + // EnablePeerAccess + for (int i = 0; i < num_selected_devs; i++) { + SHMEM_CHECK_RET(aclrtDeviceEnablePeerAccess(selected_dev_ids[i], 0)); + } + return 0; +} + +int shmemi_mte_finalize(shmemi_transport *t, shmemi_device_host_state_t *g_state) { + return 0; +} + +// control plane +int shmemi_mte_init(shmemi_transport_t *t, shmemi_device_host_state_t *g_state) { + t->can_access_peer = shmemi_mte_can_access_peer; + t->connect_peers = shmemi_mte_connect_peers; + t->finalize = shmemi_mte_finalize; + + return 0; +} + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/src/modules/transport/shmemi_rdma.cpp b/src/modules/transport/shmemi_rdma.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2792506e8da1a0343c1b3365b7c03bf4bbc54782 --- /dev/null +++ b/src/modules/transport/shmemi_rdma.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include "host/shmem_host_def.h" +#include "rdma/rdma_manager.h" +#include "common/shmemi_host_types.h" +#include "common/shmemi_logger.h" +#include "internal/host_device/shmemi_types.h" + +#ifdef __cplusplus +extern "C" { +#endif +static rdma_manager* manager; + +int shmemi_rdma_can_access_peer(int *access, shmemi_transport_pe_info_t *peer_info, shmemi_transport_pe_info_t *my_info, shmemi_transport *t) { + if (peer_info->pe == my_info->pe) { + *access = 0; + } else { + *access = 1; + } + return 0; +} + +int shmemi_rdma_connect_peers(shmemi_transport *t, int *selected_dev_ids, int num_selected_devs, shmemi_device_host_state_t *state) { + auto local_device_ip = manager->GetDeviceIP(); + SHM_LOG_INFO("local ip = " << inet_ntoa(local_device_ip)); + std::vector device_ips(state->npes); + SHMEM_CHECK_RET((g_boot_handle.is_bootstraped != true), "boot_handle not bootstraped, Please check if the method call occurs before initialization or after finalization.", SHMEM_BOOTSTRAP_ERROR); + g_boot_handle.allgather(&local_device_ip, device_ips.data(), sizeof(in_addr), &g_boot_handle); + g_boot_handle.barrier(&g_boot_handle); + for (int i = 0; i < state->npes; i++) { + SHM_LOG_INFO("get rank " << i << ", device ip = " << inet_ntoa(device_ips[i])); + } + + auto local_mr = manager->GetLocalMR(); + SHM_LOG_INFO("local mr = " << local_mr); + std::vector mrs(state->npes); + g_boot_handle.allgather(&local_mr, mrs.data(), sizeof(RegMemResult), &g_boot_handle); + for (int i = 0; i < state->npes; i++) { + state->host_rdma_heap_base[i] = reinterpret_cast(mrs[i].address); + SHM_LOG_INFO("get rank " << i << ", mr info = " << mrs[i]); + } + + HybmTransPrepareOptions TransPrepareOp; + for (int i = 0; i < state->npes; i++) { + TransPrepareOp.options[i].nic = std::string(inet_ntoa(device_ips[i])) + ":4647"; + TransPrepareOp.options[i].mr = mrs[i]; + } + manager->Prepare(TransPrepareOp); + manager->Connect(); + state->qp_info = reinterpret_cast(manager->GetQPInfoAddr()); + return 0; +} + +int shmemi_rdma_finalize(shmemi_transport *t, shmemi_device_host_state_t *state) { + delete manager; + return 0; +} + +int shmemi_rdma_init(shmemi_transport *t, shmemi_device_host_state_t *state) { + manager = new rdma_manager; + TransportOptions options; + options.rankId = state->mype; + options.rankCount = state->npes; + options.protocol = 7; + options.nic = 10002; + options.dev_id = t->dev_id; + options.logic_dev_id = t->logical_dev_id; + manager->OpenDevice(options); + + TransportMemoryRegion mr; + mr.addr = reinterpret_cast(state->heap_base); + mr.size = reinterpret_cast(state->heap_size); + manager->RegisterMemoryRegion(mr); + t->can_access_peer = shmemi_rdma_can_access_peer; + t->connect_peers = shmemi_rdma_connect_peers; + t->finalize = shmemi_rdma_finalize; + return 0; +} + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt deleted file mode 100644 index 500174066a9163b96b33b71d7b69d63c91d7f699..0000000000000000000000000000000000000000 --- a/src/transport/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/adaptor/CMakeLists.txt b/src/transport/adaptor/CMakeLists.txt deleted file mode 100644 index 500174066a9163b96b33b71d7b69d63c91d7f699..0000000000000000000000000000000000000000 --- a/src/transport/adaptor/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/adaptor/hccs/CMakeLists.txt b/src/transport/adaptor/hccs/CMakeLists.txt deleted file mode 100644 index 500174066a9163b96b33b71d7b69d63c91d7f699..0000000000000000000000000000000000000000 --- a/src/transport/adaptor/hccs/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/adaptor/mte/CMakeLists.txt b/src/transport/adaptor/mte/CMakeLists.txt deleted file mode 100644 index 500174066a9163b96b33b71d7b69d63c91d7f699..0000000000000000000000000000000000000000 --- a/src/transport/adaptor/mte/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. \ No newline at end of file diff --git a/src/transport/include/.gitkeep b/src/transport/include/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tests/fuzz/device/mem/shmem_ptr_kernel.cpp b/tests/fuzz/device/mem/shmem_ptr_kernel.cpp index 498d07c3fefe3fca01898803145595a9b7575307..7ca9f7ec3d05b5e67a02528ac60b9faa36a73834 100644 --- a/tests/fuzz/device/mem/shmem_ptr_kernel.cpp +++ b/tests/fuzz/device/mem/shmem_ptr_kernel.cpp @@ -13,8 +13,8 @@ public: __aicore__ inline void Init(GM_ADDR gva) { gva_gm = (__gm__ int *)gva; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/fuzz/device/sync/barrier_kernel.cpp b/tests/fuzz/device/sync/barrier_kernel.cpp index be13d7c5c6b3c6086c2243a8486ce42e9a234473..9915458ce4b20cea7a2f17fdc0452b644debeeb3 100644 --- a/tests/fuzz/device/sync/barrier_kernel.cpp +++ b/tests/fuzz/device/sync/barrier_kernel.cpp @@ -25,7 +25,7 @@ extern "C" SHMEM_GLOBAL void increase(uint64_t config, GM_ADDR addr, int rank_id uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmem_barrier_all(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmem_barrier_all(); #endif @@ -39,7 +39,7 @@ extern "C" SHMEM_GLOBAL void increase_vec(uint64_t config, GM_ADDR addr, int ran uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmemx_barrier_all_vec(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmemx_barrier_all_vec(); #endif @@ -61,7 +61,7 @@ extern "C" SHMEM_GLOBAL void increase_odd_team(uint64_t config, GM_ADDR addr, in shmem_barrier(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmem_barrier(team_id); @@ -78,7 +78,7 @@ extern "C" SHMEM_GLOBAL void increase_vec_odd_team(uint64_t config, GM_ADDR addr shmemx_barrier_vec(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmemx_barrier_vec(team_id); diff --git a/tests/fuzz/device/sync/order_kernel.cpp b/tests/fuzz/device/sync/order_kernel.cpp index 3bcd8fa15a25de7fc32c97cde63027f676ded75a..73e28ad495fbc28670e92ef9c7ffd0c848fbd0ad 100644 --- a/tests/fuzz/device/sync/order_kernel.cpp +++ b/tests/fuzz/device/sync/order_kernel.cpp @@ -24,7 +24,7 @@ extern "C" SHMEM_GLOBAL void quiet_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 32)); seen_b = shmemi_load(remote + 32); @@ -51,7 +51,7 @@ extern "C" SHMEM_GLOBAL void fence_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 16)); seen_b = shmemi_load(remote + 16); diff --git a/tests/fuzz/device/team/team_kernel.cpp b/tests/fuzz/device/team/team_kernel.cpp index 85002089145006f1e2aea348f50481dd6c0571d2..64e9de464f835411f6eb2fdeb082e555aadc8390 100644 --- a/tests/fuzz/device/team/team_kernel.cpp +++ b/tests/fuzz/device/team/team_kernel.cpp @@ -18,8 +18,8 @@ public: gva_gm = (__gm__ int *)gva; team_idx= team_id; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/fuzz/host/sync/barrier_host_fuzz.cpp b/tests/fuzz/host/sync/barrier_host_fuzz.cpp index 6f74daa24e308eb86fd630a1e0a24215ad351edb..700bb1a0543adabcb8c7ef94f4564a8b0164e443 100644 --- a/tests/fuzz/host/sync/barrier_host_fuzz.cpp +++ b/tests/fuzz/host/sync/barrier_host_fuzz.cpp @@ -73,7 +73,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_barrier_black_box_success) ASSERT_EQ(aclrtSynchronizeStream(scope.stream), ACL_SUCCESS); ASSERT_EQ(aclrtMemcpy(addr_host, size, addr_dev, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } ASSERT_EQ(aclrtFreeHost(addr_host), ACL_SUCCESS); @@ -111,7 +111,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_vec_barrier_black_box_success) ASSERT_EQ(aclrtMemcpy(addr_host_vec, size, addr_dev_vec, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } ASSERT_EQ(aclrtFreeHost(addr_host_vec), ACL_SUCCESS); @@ -156,7 +156,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_barrier_black_box_odd_team_success) ASSERT_EQ(aclrtSynchronizeStream(scope.stream), ACL_SUCCESS); ASSERT_EQ(aclrtMemcpy(addr_host, size, addr_dev, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } } @@ -205,7 +205,7 @@ TEST_F(ShmemSyncBarrierFuzz, shmem_vec_barrier_black_box_odd_team_success) ASSERT_EQ(aclrtMemcpy(addr_host_vec, size, addr_dev_vec, size, ACL_MEMCPY_DEVICE_TO_HOST), ACL_SUCCESS); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } } diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index cd370614ee6c79f80c08fa6ccd04fe9b02a467c5..e5b00e9e6179ec406e12090e6b3b070d3f5d975f 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -19,4 +19,11 @@ target_link_directories(shmem_unittest PRIVATE ${PROJECT_SOURCE_DIR}/install/memfabric_hybrid/lib ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ) +target_link_libraries( + shmem_unittest PRIVATE + -Wl,--no-as-needed + shmem + -Wl,--as-needed +) + target_link_libraries(shmem_unittest PRIVATE shmem_unittest_device gtest gcov mf_smem shmem_unittest_include) diff --git a/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp b/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp index 58a8934863ea38af2c5539f50db10d908b68e0b4..21354d0f1626811624a9f9bb946e111b133a867a 100644 --- a/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp +++ b/tests/unittest/device/mem/atomic_add/atomic_add_kernel.cpp @@ -17,8 +17,8 @@ constexpr uint64_t MESSAGE_SIZE = 64; extern "C" __global__ __aicore__ void test_atomic_add_##NAME##_kernel(GM_ADDR gva, uint64_t config) \ { \ shmemx_set_ffts_config(config); \ - int64_t rank = smem_shm_get_global_rank(); \ - int64_t rank_size = smem_shm_get_global_rank_size(); \ + int64_t rank = shmem_my_pe(); \ + int64_t rank_size = shmem_n_pes(); \ GM_ADDR dst_addr; \ \ for (int64_t peer = 0; peer < rank_size; peer++) { \ diff --git a/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp index 3fcf60c5649fa6c8ae5b28d96884a962827d0b3b..1b72a449363aa08c04e8836e0b880373998db28a 100644 --- a/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp +++ b/tests/unittest/device/mem/rdma_mem/rdma_mem_kernel.cpp @@ -20,8 +20,8 @@ extern "C" __global__ __aicore__ void RDMAGetTestLowLevel(GM_ADDR gva, uint64_t pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); AscendC::LocalTensor ubLocal = buf.GetWithOffset(UB_ALIGN_SIZE * 2, 0); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR dest_addr; for (int64_t peer = 0; peer < rank_size; peer++) { @@ -46,8 +46,8 @@ extern "C" __global__ __aicore__ void RDMAPutTestLowLevel(GM_ADDR gva, uint64_t pipe.InitBuffer(buf, UB_ALIGN_SIZE * 2); AscendC::LocalTensor ubLocal = buf.GetWithOffset(UB_ALIGN_SIZE * 2, 0); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR src_addr; for (int64_t peer = 0; peer < rank_size; peer++) { @@ -67,8 +67,8 @@ void test_rdma_put_low_level(uint32_t block_dim, void* stream, uint8_t* gva, uin extern "C" __global__ __aicore__ void RDMAGetTestHighLevel(GM_ADDR gva, uint64_t config) { shmemx_set_ffts_config(config); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR dest_addr; for (int64_t peer = 0; peer < rank_size; peer++) { @@ -89,8 +89,8 @@ void test_rdma_get_high_level(uint32_t block_dim, void* stream, uint8_t* gva, ui extern "C" __global__ __aicore__ void RDMAPutTestHighLevel(GM_ADDR gva, uint64_t config) { shmemx_set_ffts_config(config); - int64_t rank = smem_shm_get_global_rank(); - int64_t rank_size = smem_shm_get_global_rank_size(); + int64_t rank = shmem_my_pe(); + int64_t rank_size = shmem_n_pes(); GM_ADDR src_addr; for (int64_t peer = 0; peer < rank_size; peer++) { diff --git a/tests/unittest/device/mem/shmem_ptr_kernel.cpp b/tests/unittest/device/mem/shmem_ptr_kernel.cpp index 3cba7f57affbbb952ec27ef6e4f5fae664200a5a..1a76666a0f72c26f927ae9c5603e4b04e812a751 100644 --- a/tests/unittest/device/mem/shmem_ptr_kernel.cpp +++ b/tests/unittest/device/mem/shmem_ptr_kernel.cpp @@ -16,8 +16,8 @@ public: __aicore__ inline void Init(GM_ADDR gva) { gva_gm = (__gm__ int *)gva; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + rank = shmem_my_pe(); + rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp b/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp index 376616c7904b01036c8847aca3c428fb5bb852cc..af8940ff3d6f7991351da16985ac9edcaf801460 100644 --- a/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp +++ b/tests/unittest/device/mem/ub_mem/ub_mem_kernel.cpp @@ -8,8 +8,6 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "kernel_operator.h" -#include "smem_shm_aicore_base_api.h" - #include "shmem_api.h" #include "unittest/utils/func_type.h" diff --git a/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp b/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp index 954c19245f571596f7e1cf1aa847d577434af528..ab620b9ca276921f9d60795ce5e91e885b815331 100644 --- a/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp +++ b/tests/unittest/device/mem/ub_non_contiguous/ub_non_contiguous_kernel.cpp @@ -8,8 +8,6 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "kernel_operator.h" -#include "smem_shm_aicore_base_api.h" - #include "shmem_api.h" #include "unittest/utils/func_type.h" diff --git a/tests/unittest/device/sync/barrier/barrier_kernel.cpp b/tests/unittest/device/sync/barrier/barrier_kernel.cpp index 2fda972954b764b0390c018035ba0ee4c49434b4..49b09b594d226efbb8935e364e7518b57bab65fd 100644 --- a/tests/unittest/device/sync/barrier/barrier_kernel.cpp +++ b/tests/unittest/device/sync/barrier/barrier_kernel.cpp @@ -23,7 +23,7 @@ extern "C" SHMEM_GLOBAL void increase(uint64_t config, GM_ADDR addr, int rank_id uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmem_barrier_all(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmem_barrier_all(); #endif @@ -36,7 +36,7 @@ extern "C" SHMEM_GLOBAL void increase_vec(uint64_t config, GM_ADDR addr, int ran uint64_t val = shmemi_load((__gm__ uint64_t *)addr); shmemx_barrier_all_vec(); - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 1) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 1) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); shmemx_barrier_all_vec(); #endif @@ -57,7 +57,7 @@ extern "C" SHMEM_GLOBAL void increase_odd_team(uint64_t config, GM_ADDR addr, in shmem_barrier(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmem_barrier(team_id); @@ -73,7 +73,7 @@ extern "C" SHMEM_GLOBAL void increase_vec_odd_team(uint64_t config, GM_ADDR addr shmemx_barrier_vec(team_id); if (rank_id & 1) { - GM_ADDR remote = shmemi_ptr(addr, (rank_id + 2) % rank_size); + GM_ADDR remote = (GM_ADDR)shmem_ptr(addr, (rank_id + 2) % rank_size); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmemx_barrier_vec(team_id); @@ -104,7 +104,7 @@ extern "C" SHMEM_GLOBAL void partial_increase(uint64_t config, int team_pe = shmem_team_my_pe(team_id); int peer = (team_pe + 1) % count; if (team_pe < count) { - GM_ADDR remote = shmemi_ptr(addr, peer * stride + start); + GM_ADDR remote = shmem_ptr(addr, peer * stride + start); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmemx_partial_barrier(team_id, pes, count); @@ -127,7 +127,7 @@ extern "C" SHMEM_GLOBAL void partial_increase_vec(uint64_t config, int team_pe = shmem_team_my_pe(team_id); int peer = (team_pe + 1) % count; if (team_pe < count) { - GM_ADDR remote = shmemi_ptr(addr, peer * stride + start); + GM_ADDR remote = shmem_ptr(addr, peer * stride + start); shmemi_store((__gm__ uint64_t *)remote, val + 1); } shmemx_partial_barrier_vec(team_id, pes, count); diff --git a/tests/unittest/device/sync/order/order_kernel.cpp b/tests/unittest/device/sync/order/order_kernel.cpp index 391f77c6d24820af91c64dc0437d75c6aaee53aa..a6138d8f6912e97501c68f0a853c1fd99d57ca6f 100644 --- a/tests/unittest/device/sync/order/order_kernel.cpp +++ b/tests/unittest/device/sync/order/order_kernel.cpp @@ -22,7 +22,7 @@ extern "C" SHMEM_GLOBAL void quiet_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 32)); seen_b = shmemi_load(remote + 32); @@ -48,7 +48,7 @@ extern "C" SHMEM_GLOBAL void fence_order(uint64_t config, GM_ADDR addr, int rank if (rank_id == 1) { uint64_t seen_b; - __gm__ uint64_t *remote = shmemi_ptr(base, 0); + __gm__ uint64_t *remote = (__gm__ uint64_t *)shmem_ptr(base, 0); do { dcci_cacheline((__gm__ uint8_t *)(remote + 16)); seen_b = shmemi_load(remote + 16); diff --git a/tests/unittest/device/team/team/team_kernel.cpp b/tests/unittest/device/team/team/team_kernel.cpp index 47f41d5a42dac23ef0760cd197cb8cf6ad67a2a7..462e3cb1a54117286bf3e47de6c6f54d8ecbc680 100644 --- a/tests/unittest/device/team/team/team_kernel.cpp +++ b/tests/unittest/device/team/team/team_kernel.cpp @@ -19,8 +19,8 @@ public: gva_gm = (__gm__ int *)gva; team_idx= team_id; - rank = smem_shm_get_global_rank(); - rank_size = smem_shm_get_global_rank_size(); + rank = shmem_my_pe(); + rank_size = shmem_n_pes(); } __aicore__ inline void Process() { diff --git a/tests/unittest/host/init/init_host_test.cpp b/tests/unittest/host/init/init_host_test.cpp index c4b59298c8726bce33f21910bfae698f8320d5fb..78dc9c21b4458e9320d23eaee0014537298b311a 100644 --- a/tests/unittest/host/init/init_host_test.cpp +++ b/tests/unittest/host/init/init_host_test.cpp @@ -29,55 +29,21 @@ void test_shmem_init(int rank_id, int n_ranks, uint64_t local_mem_size) int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); - status = shmem_finalize(); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } -} - -void test_shmem_init_attr(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - shmem_init_attr_t *attributes = new shmem_init_attr_t{ - rank_id, n_ranks, {}, local_mem_size, {0, SHMEM_DATA_OP_MTE, 120, 120, 120}}; - std::copy_n(test_global_ipport, SHMEM_MAX_IP_PORT_LEN, attributes->ip_port); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); + EXPECT_EQ(g_state.mype, rank_id); + EXPECT_EQ(g_state.npes, n_ranks); + EXPECT_NE(g_state.heap_base, nullptr); + EXPECT_NE(g_state.host_p2p_heap_base[rank_id], nullptr); + EXPECT_EQ(g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); + EXPECT_NE(g_state.team_pools[0], nullptr); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); status = shmem_finalize(); - delete attributes; EXPECT_EQ(status, SHMEM_SUCCESS); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); @@ -93,10 +59,13 @@ void test_shmem_init_invalid_rank_id(int rank_id, int n_ranks, uint64_t local_me int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); + shmem_init_attr_t *attributes; shmem_set_attr(erank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -114,14 +83,14 @@ void test_shmem_init_invalid_n_ranks(int rank_id, int n_ranks, uint64_t local_me int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); + shmemx_uniqueid_t uid; shmem_init_attr_t *attributes; - status = shmem_set_attr(rank_id, en_ranks, local_mem_size, test_global_ipport, &attributes); - EXPECT_EQ(status, SHMEM_INVALID_VALUE); - status = shmem_init_attr(attributes); - EXPECT_TRUE(status != 0); - attributes->n_ranks = en_ranks; - status = shmem_init_attr(attributes); + shmem_set_attr(rank_id, en_ranks, local_mem_size, test_global_ipport, &attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -138,10 +107,12 @@ void test_shmem_init_rank_id_over_size(int rank_id, int n_ranks, uint64_t local_ int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_set_conf_store_tls(false, nullptr, 0); + shmem_init_attr_t *attributes; shmem_set_attr(rank_id + n_ranks, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_INVALID_PARAM); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); @@ -159,113 +130,14 @@ void test_shmem_init_zero_mem(int rank_id, int n_ranks, uint64_t local_mem_size) int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_INVALID_VALUE); - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } -} - -void test_shmem_init_invalid_mem(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - // local_mem_size = invalid - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_SMEM_ERROR); - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } -} - -void test_shmem_init_set_config(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - - shmem_set_data_op_engine_type(attributes, SHMEM_DATA_OP_MTE); - shmem_set_timeout(attributes, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); - - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); - - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); - - status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); - status = shmem_finalize(); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(aclrtResetDevice(device_id), 0); - EXPECT_EQ(aclFinalize(), 0); - if (::testing::Test::HasFailure()) { - exit(1); - } -} - -void test_shmem_global_exit(int rank_id, int n_ranks, uint64_t local_mem_size) -{ - uint32_t device_id = rank_id % test_gnpu_num + test_first_npu; - int status = SHMEM_SUCCESS; - EXPECT_EQ(aclInit(nullptr), 0); - EXPECT_EQ(status = aclrtSetDevice(device_id), 0); - status = shmem_set_conf_store_tls(false, nullptr, 0); - EXPECT_EQ(status, 0); - shmem_init_attr_t *attributes; - shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - - shmem_set_data_op_engine_type(attributes, SHMEM_DATA_OP_MTE); - shmem_set_timeout(attributes, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); - - shmem_set_conf_store_tls(false, nullptr, 0); - status = shmem_init_attr(attributes); - EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); - - EXPECT_EQ(shm::g_attr.option_attr.control_operation_timeout, shm::timeout); - EXPECT_EQ(shm::g_attr.option_attr.data_op_engine_type, SHMEM_DATA_OP_MTE); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + EXPECT_EQ(status, SHMEM_INVALID_VALUE); status = shmem_init_status(); - EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); - shmem_global_exit(0); + EXPECT_EQ(status, SHMEM_STATUS_NOT_INITIALIZED); EXPECT_EQ(aclrtResetDevice(device_id), 0); EXPECT_EQ(aclFinalize(), 0); if (::testing::Test::HasFailure()) { @@ -280,13 +152,6 @@ TEST(TestInitAPI, TestShmemInit) test_mutil_task(test_shmem_init, local_mem_size, process_count); } -TEST(TestInitAPI, TestShmemInitAttr) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - test_mutil_task(test_shmem_init_attr, local_mem_size, process_count); -} - TEST(TestInitAPI, TestShmemInitErrorInvalidRankId) { const int process_count = test_gnpu_num; @@ -315,20 +180,6 @@ TEST(TestInitAPI, TestShmemInitErrorZeroMem) test_mutil_task(test_shmem_init_zero_mem, local_mem_size, process_count); } -TEST(TestInitAPI, TestShmemInitErrorInvalidMem) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL; - test_mutil_task(test_shmem_init_invalid_mem, local_mem_size, process_count); -} - -TEST(TestInitAPI, TestSetConfig) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - test_mutil_task(test_shmem_init_set_config, local_mem_size, process_count); -} - TEST(TestInitAPI, TestInfoGetVersion) { int major = 0; @@ -365,18 +216,17 @@ TEST(TestInitAPI, TestInfoGetNameNull) EXPECT_EQ(input, nullptr); } -TEST(TestInitAPI, TestShmemGlobalExit) -{ - const int process_count = test_gnpu_num; - uint64_t local_mem_size = 1024UL * 1024UL * 1024; - test_mutil_task(test_shmem_global_exit, local_mem_size, process_count); -} - TEST(TestInitAPI, TestShmemSetLogLevel) { auto ret = shmem_set_log_level(shm::DEBUG_LEVEL); EXPECT_EQ(ret, 0); + char* original_log_level = NULL; + const char* env_val = getenv("SHMEM_LOG_LEVEL"); + if (env_val != NULL) { + original_log_level = strdup(env_val); + } + setenv("SHMEM_LOG_LEVEL", "DEBUG", 1); EXPECT_EQ(shmem_set_log_level(-1), 0); @@ -393,52 +243,9 @@ TEST(TestInitAPI, TestShmemSetLogLevel) EXPECT_EQ(shmem_set_log_level(-1), 0); unsetenv("SHMEM_LOG_LEVEL"); -} - -TEST(TestInitAPI, TestShmemSetExternLogger) -{ - auto ret = shmem_set_extern_logger(shm::logger_test_example); - EXPECT_EQ(ret, 0); -} - -TEST(TestInitAPI, TestShmemGetUniqueId) -{ - const char *ipInfo = std::getenv("SHMEM_UID_SOCK_IFNAM"); - if (ipInfo == nullptr) { - return; - } - - for (int i = 0; i < 10; i++) { - shmem_uniqueid_t uid; - int ret = shmem_get_uniqueid(&uid); - EXPECT_EQ(ret, SHMEM_SUCCESS); - - shmem_uniqueid_inner_t *innerUID = reinterpret_cast(&uid); - - // test bind ip:port again - int sockfd = ::socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) { - std::cout << "create socket failed" << std::endl; - return; - } - - int reuse = 1; - ::setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)); - - struct sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = innerUID->addr.addr.addr4.sin_port; - addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); // 绑定 127.0.0.1 - - bool inUse = (::bind(sockfd, reinterpret_cast(&addr), sizeof(addr)) != 0); - if (inUse) { - auto errorNum = errno; - std::cout << "the address is in use" << errorNum << std::endl; - EXPECT_TRUE(false); - break; - } - - ::close(sockfd); - EXPECT_TRUE(true); + if (original_log_level != NULL) { + setenv("SHMEM_LOG_LEVEL", original_log_level, 1); + free(original_log_level); + original_log_level = NULL; } } \ No newline at end of file diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index c31aeefcd96008373cf0f0104bdd99b599f3de34..a3eb98d9b86b8dc260f0107c5d40d9ea99af95bf 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -12,6 +12,7 @@ #include #include "acl/acl.h" #include "shmem_api.h" +#include "shmemi_host_common.h" #include "unittest_main_test.h" int test_global_ranks; @@ -34,12 +35,13 @@ void test_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtStream *s EXPECT_EQ(status = aclrtSetDevice(device_id), 0); aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - shmem_init_attr_t* attributes; + shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - status = shmem_init_attr(attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + EXPECT_EQ(status, 0); *st = stream; } @@ -58,13 +60,15 @@ int32_t test_rdma_init(int rank_id, int n_ranks, uint64_t local_mem_size, aclrtS EXPECT_EQ(status = aclrtSetDevice(device_id), 0); aclrtStream stream = nullptr; EXPECT_EQ(status = aclrtCreateStream(&stream), 0); - EXPECT_EQ(status = shmem_set_conf_store_tls(false, nullptr, 0), 0); - shmem_init_attr_t* attributes; + shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); + attributes->option_attr.data_op_engine_type = SHMEM_DATA_OP_ROCE; - status = shmem_init_attr(attributes); + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); + + EXPECT_EQ(status, 0); *st = stream; return status; } @@ -95,6 +99,8 @@ void test_mutil_task(std::function func, uint64_t loca waitpid(pids[i], &status[i], 0); if (WIFEXITED(status[i]) && WEXITSTATUS(status[i]) != 0) { FAIL(); + } else if (WIFSIGNALED(status[i])) { + FAIL(); } } } diff --git a/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp b/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp index 1686c1135ccf5097e559f4ad7a61ef279430eb87..db9c05adcf651aa07a0c89fe82c93b039bbe85f8 100644 --- a/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp +++ b/tests/unittest/host/mem/atomic_add/atomic_add_host_test.cpp @@ -77,8 +77,13 @@ SHMEM_ATOMIC_ADD_FUNC_TYPE_HOST(TEST_SHMEM_ATOMIC_ADD_HOST); aclrtStream stream; \ test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ - test_atomic_add_##NAME##_host(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ - TEST_CLEANUP_AND_EXIT(rank_id, stream, device_id); \ + test_atomic_add_##NAME##_host(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ + std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ + test_finalize(stream, device_id); \ + if (::testing::Test::HasFailure()) \ + { \ + exit(1); \ + } \ } SHMEM_ATOMIC_ADD_FUNC_TYPE_HOST(TEST_SHMEM_ATOMIC_ADD); diff --git a/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp b/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp index 5fbf4473129a4ec8ef0039a3c4c95e06328513a3..b22516a011bd75bf552888ec5ca705d29231dff3 100644 --- a/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp +++ b/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp @@ -76,7 +76,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - test_##NAME##_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + test_##NAME##_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp b/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp index 98ed8bbaaa06f2159ef8ffc99164606cd514d947..b56186f6f1bb61b5093868f1d12ea9a80972af9b 100644 --- a/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp +++ b/tests/unittest/host/mem/gm_mem_disable_L2/gm_mem_disable_L2_host_test.cpp @@ -74,7 +74,7 @@ void test_shmemx_mte_mem(int rank_id, int n_ranks, uint64_t local_mem_size) test_init(rank_id, n_ranks, local_mem_size, &stream); ASSERT_NE(stream, nullptr); - test_shmemx_mte_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); + test_shmemx_mte_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; test_finalize(stream, device_id); if (::testing::Test::HasFailure()) { diff --git a/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp b/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp index 572f885a2990c07754bc8c2edf5efaa1aebe7508..dfa22eaddf840cb60ce0b7a4408532bc69bcf863 100644 --- a/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp +++ b/tests/unittest/host/mem/gm_non_contiguous/gm_non_contiguous_host_test.cpp @@ -85,7 +85,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_NON_CONTIGUOUS_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - test_##NAME##_non_contiguous_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + test_##NAME##_non_contiguous_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp index 4b45d532fb2445232d5cef3d55a8847a209de37f..68ab11494589aa8041072a936c3b97776f61f563 100644 --- a/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp +++ b/tests/unittest/host/mem/rdma_mem/rdma_mem_host_test.cpp @@ -44,40 +44,40 @@ static void test_rdma_put_get(aclrtStream stream, uint8_t *gva, uint32_t rank_id } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_put_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_get_low_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_put_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); } ASSERT_EQ(aclrtMemcpy(gva, totalSize, inHost, totalSize, ACL_MEMCPY_HOST_TO_DEVICE), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); test_rdma_get_high_level(block_dim, stream, (uint8_t *)gva, shmemx_get_ffts_config()); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); ASSERT_EQ(aclrtMemcpy(outHost, totalSize, gva, totalSize, ACL_MEMCPY_DEVICE_TO_HOST), 0); for (uint32_t i = 0; i < rank_size; i++) { ASSERT_EQ(outHost[i * messageSize / sizeof(uint32_t)], i + rankOffset); diff --git a/tests/unittest/host/mem/shmem_host_get_stream_test.cpp b/tests/unittest/host/mem/shmem_host_get_stream_test.cpp index b36db8fa600a850745cae4486e39f08485c3d01f..d37340fc524f665de556bf118febcee8c6a3bd33 100644 --- a/tests/unittest/host/mem/shmem_host_get_stream_test.cpp +++ b/tests/unittest/host/mem/shmem_host_get_stream_test.cpp @@ -118,7 +118,7 @@ static void host_test_put_get_mem_stream(int rank_id, int rank_size, uint64_t lo void *ptr = shmem_malloc(1024); host_putmem(ptr, dev_ptr, rank_id, input_size); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); sleep(sleep_time); ASSERT_EQ(aclrtMemcpy(out.data(), total_size, ptr, total_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); @@ -136,7 +136,7 @@ static void host_test_put_get_mem_stream(int rank_id, int rank_size, uint64_t lo std::cout << std::endl; size_t ele_size = 16; host_test_getmem_stream((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_size, ele_size, stream); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); sleep(sleep_time); ASSERT_EQ(aclrtMemcpy(input.data(), input_size, dev_ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); diff --git a/tests/unittest/host/mem/shmem_host_heap_test.cpp b/tests/unittest/host/mem/shmem_host_heap_test.cpp index 4d6f4e20fe5cecdf3344183b79830c3d5f288205..7ac741e0472e6787e4a2b4ad41adc01aa59fcec6 100644 --- a/tests/unittest/host/mem/shmem_host_heap_test.cpp +++ b/tests/unittest/host/mem/shmem_host_heap_test.cpp @@ -27,17 +27,18 @@ protected: int status = SHMEM_SUCCESS; EXPECT_EQ(aclInit(nullptr), 0); EXPECT_EQ(status = aclrtSetDevice(device_id), 0); + shmem_init_attr_t *attributes; shmem_set_attr(rank_id, n_ranks, local_mem_size, test_global_ipport, &attributes); - status = shmem_init_attr(attributes); + + status = shmem_init_attr(SHMEMX_INIT_WITH_DEFAULT, attributes); EXPECT_EQ(status, SHMEM_SUCCESS); - EXPECT_EQ(shm::g_state.mype, rank_id); - EXPECT_EQ(shm::g_state.npes, n_ranks); - EXPECT_NE(shm::g_state.heap_base, nullptr); - EXPECT_NE(shm::g_state.p2p_heap_host_base[rank_id], nullptr); - EXPECT_NE(shm::g_state.p2p_heap_device_base[rank_id], nullptr); - EXPECT_EQ(shm::g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); - EXPECT_NE(shm::g_state.team_pools[0], nullptr); + EXPECT_EQ(g_state.mype, rank_id); + EXPECT_EQ(g_state.npes, n_ranks); + EXPECT_NE(g_state.heap_base, nullptr); + EXPECT_NE(g_state.host_p2p_heap_base[rank_id], nullptr); + EXPECT_EQ(g_state.heap_size, local_mem_size + SHMEM_EXTRA_SIZE); + EXPECT_NE(g_state.team_pools[0], nullptr); status = shmem_init_status(); EXPECT_EQ(status, SHMEM_STATUS_IS_INITIALIZED); testingRank = true; @@ -103,7 +104,7 @@ TEST_F(ShareMemoryManagerTest, allocate_large_memory_failed) int32_t device_id = rank_id % test_gnpu_num + test_first_npu; aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); - auto ptr = shmem_malloc(heap_memory_size + 1UL); + auto ptr = shmem_malloc(heap_memory_size + SHMEM_EXTRA_SIZE + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, @@ -186,13 +187,39 @@ TEST_F(ShareMemoryManagerTest, calloc_large_memory_failed) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); const size_t nmemb = 16; - auto ptr = shmem_calloc(nmemb, heap_memory_size / nmemb + 1UL); + auto ptr = shmem_calloc(nmemb, (heap_memory_size + SHMEM_EXTRA_SIZE) / nmemb + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, calloc_multiply_overflow_size_t_max) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + const size_t nmemb = static_cast(~0ULL); + const size_t each = 2; + + void *p = shmem_calloc(nmemb, each); + EXPECT_EQ(nullptr, p); + + void *ok = shmem_malloc(4096UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, align_zero) { const int process_count = test_gnpu_num; @@ -205,7 +232,6 @@ TEST_F(ShareMemoryManagerTest, align_zero) const size_t alignment = 16; auto ptr = shmem_align(alignment, 0UL); EXPECT_EQ(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -224,7 +250,7 @@ TEST_F(ShareMemoryManagerTest, align_one_piece_success) const size_t size = 128UL; auto ptr = shmem_align(alignment, size); EXPECT_NE(nullptr, ptr); - EXPECT_EQ(reinterpret_cast(ptr) & alignment, 0u); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -242,6 +268,7 @@ TEST_F(ShareMemoryManagerTest, align_full_space_success) const size_t alignment = 16; auto ptr = shmem_align(alignment, heap_memory_size); EXPECT_NE(nullptr, ptr); + EXPECT_EQ(reinterpret_cast(ptr) & (alignment - 1), 0u); test_finalize(stream, device_id); }, local_mem_size, process_count); @@ -257,7 +284,7 @@ TEST_F(ShareMemoryManagerTest, align_large_memory_failed) aclrtStream stream; test_init(rank_id, n_ranks, local_mem_size, &stream); const size_t alignment = 16; - auto ptr = shmem_align(alignment, heap_memory_size + 1UL); + auto ptr = shmem_align(alignment, heap_memory_size + SHMEM_EXTRA_SIZE + 1UL); EXPECT_EQ(nullptr, ptr); test_finalize(stream, device_id); }, @@ -282,6 +309,165 @@ TEST_F(ShareMemoryManagerTest, align_not_two_power_failed) local_mem_size, process_count); } +TEST_F(ShareMemoryManagerTest, stress_malloc_calloc_align_no_leak) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + constexpr int rounds = 100; + std::vector ptrs; + ptrs.reserve(rounds * 3); + + for (int i = 0; i < rounds; ++i) { + void *p1 = shmem_malloc(1024UL + (i % 7) * 128UL); + EXPECT_NE(nullptr, p1); + ptrs.push_back(p1); + + void *p2 = shmem_calloc(32, 16 + (i % 5)); + EXPECT_NE(nullptr, p2); + ptrs.push_back(p2); + + void *p3 = shmem_align(64, 1536UL + (i % 3) * 64UL); + EXPECT_NE(nullptr, p3); + ptrs.push_back(p3); + + if ((i % 4) == 0) { + shmem_free(p1); + ptrs[ptrs.size()-3] = nullptr; + } + if ((i % 6) == 0) { + shmem_free(p2); + ptrs[ptrs.size()-2] = nullptr; + } + } + + for (void *p : ptrs) { + if (p) shmem_free(p); + } + + void *big = shmem_malloc(heap_memory_size / 2); + EXPECT_NE(nullptr, big); + shmem_free(big); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, calls_before_init_and_after_finalize) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(4, 256)); + EXPECT_EQ(nullptr, shmem_align(64, 4096UL)); + shmem_free(nullptr); + + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + void *ok = shmem_malloc(2048UL); + EXPECT_NE(nullptr, ok); + shmem_free(ok); + test_finalize(stream, device_id); + + EXPECT_EQ(nullptr, shmem_malloc(1024UL)); + EXPECT_EQ(nullptr, shmem_calloc(2, 512)); + EXPECT_EQ(nullptr, shmem_align(32, 1024UL)); + shmem_free(nullptr); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_nullptr_is_noop) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + shmem_free(nullptr); + + void *p = shmem_malloc(8192UL); + EXPECT_NE(nullptr, p); + shmem_free(p); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, double_free_should_not_corrupt_heap) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 64UL * 1024UL; + void *p = shmem_malloc(sz); + ASSERT_NE(nullptr, p); + + shmem_free(p); + shmem_free(p); + + void *q = shmem_malloc(sz); + EXPECT_NE(nullptr, q); + shmem_free(q); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + +TEST_F(ShareMemoryManagerTest, free_middle_pointer_should_not_work_and_not_corrupt) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = heap_memory_size; + + test_mutil_task( + [this](int rank_id, int n_ranks, uint64_t local_mem_size) { + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + size_t sz = 128UL * 1024UL; + uint8_t *base = static_cast(shmem_malloc(sz)); + ASSERT_NE(nullptr, base); + + void *middle = base + 64; + + shmem_free(middle); + + shmem_free(base); + + void *again = shmem_malloc(sz); + EXPECT_NE(nullptr, again); + shmem_free(again); + + test_finalize(stream, device_id); + }, + local_mem_size, process_count); +} + TEST_F(ShareMemoryManagerTest, free_merge) { const int process_count = test_gnpu_num; diff --git a/tests/unittest/host/mem/shmem_ptr_host_test.cpp b/tests/unittest/host/mem/shmem_ptr_host_test.cpp index ae01a65f131660acce6fe0f392e1963a3fefdf4e..575f08666729b87bdbb48a478f24e57757efc1a5 100644 --- a/tests/unittest/host/mem/shmem_ptr_host_test.cpp +++ b/tests/unittest/host/mem/shmem_ptr_host_test.cpp @@ -24,7 +24,7 @@ static int32_t test_get_device_ptr(aclrtStream stream, uint8_t *ptr, int rank_id uint32_t block_dim = 1; int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id), aclrtGetDevice); + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); get_device_ptr(block_dim, stream, ptr); EXPECT_EQ(aclrtSynchronizeStream(stream), 0); @@ -92,9 +92,9 @@ TEST(TestMemApi, TestShmemMteSetUbParams) uint32_t event_id = 0; ASSERT_EQ(shmem_mte_set_ub_params(offset, ub_size, event_id), 0); - ASSERT_EQ(shm::g_state.mte_config.shmem_ub, offset); - ASSERT_EQ(shm::g_state.mte_config.ub_size, ub_size); - ASSERT_EQ(shm::g_state.mte_config.event_id, event_id); + ASSERT_EQ(g_state.mte_config.shmem_ub, offset); + ASSERT_EQ(g_state.mte_config.ub_size, ub_size); + ASSERT_EQ(g_state.mte_config.event_id, event_id); test_finalize(stream, device_id); }, local_mem_size, process_count); diff --git a/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp b/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp index f3463adad84f81ab118ab2af903c3aa8a0296bad..0c46a7af9f910219c81207217a4c125d1206d99d 100644 --- a/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp +++ b/tests/unittest/host/mem/ub_mem/ub_mem_host_test.cpp @@ -75,7 +75,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_UB_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - test_ub_##NAME##_put_get(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + test_ub_##NAME##_put_get(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp b/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp index 4d478b6aa47a69bdaa319b2b3e3816db9522e9b1..be49150409ab595351dc6857982530fc96260b80 100644 --- a/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp +++ b/tests/unittest/host/mem/ub_non_contiguous/ub_non_contiguous_host_test.cpp @@ -83,7 +83,7 @@ SHMEM_FUNC_TYPE_HOST(TEST_UB_NON_CONTIGUOUS_PUT_GET); test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ \ - TestUB##NAME##NonContiguousPutGet(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks); \ + TestUB##NAME##NonContiguousPutGet(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks); \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ if (::testing::Test::HasFailure()) { \ diff --git a/tests/unittest/host/sync/barrier/barrier_host_test.cpp b/tests/unittest/host/sync/barrier/barrier_host_test.cpp index dd6491b01b5e1db65831c073a8d93e5ac7adb5e8..d1d57d57ace3d9975929b21fa15a121087b1772a 100644 --- a/tests/unittest/host/sync/barrier/barrier_host_test.cpp +++ b/tests/unittest/host/sync/barrier/barrier_host_test.cpp @@ -37,7 +37,7 @@ static void test_barrier_black_box(int32_t rank_id, int32_t n_ranks, uint64_t lo ASSERT_EQ(aclrtSynchronizeStream(stream), 0); ASSERT_EQ(aclrtMemcpy(addr_host, sizeof(uint64_t), addr_dev, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), 0); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } uint64_t *addr_dev_vec = static_cast(shmem_malloc(sizeof(uint64_t))); @@ -52,7 +52,7 @@ static void test_barrier_black_box(int32_t rank_id, int32_t n_ranks, uint64_t lo ASSERT_EQ( aclrtMemcpy(addr_host_vec, sizeof(uint64_t), addr_dev_vec, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), 0); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); + shmemi_control_barrier_all(); } ASSERT_EQ(aclrtFreeHost(addr_host), 0); @@ -89,29 +89,33 @@ static void test_barrier_black_box_odd_team(int32_t rank_id, int32_t n_ranks, ui uint64_t *addr_host_vec; ASSERT_EQ(aclrtMallocHost(reinterpret_cast(&addr_host_vec), sizeof(uint64_t)), 0); - if (rank_id & 1) { - for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + + for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + if (rank_id & 1) { std::cout << "[TEST] barriers test blackbox rank_id: " << rank_id << " time: " << i << std::endl; increase_do_odd_team(stream, shmemx_get_ffts_config(), (uint8_t *)addr_dev, rank_id, n_ranks, team_odd); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); ASSERT_EQ(aclrtMemcpy(addr_host, sizeof(uint64_t), addr_dev, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), - 0); + 0); ASSERT_EQ((*addr_host), i); - shm::shmemi_control_barrier_all(); } + shmemi_control_barrier_all(); + } - for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + for (int32_t i = 1; i <= SHMEM_BARRIER_TEST_NUM; i++) { + if (rank_id & 1) { std::cout << "[TEST] vec barriers test blackbox rank_id: " << rank_id << " time: " << i << std::endl; increase_vec_do_odd_team(stream, shmemx_get_ffts_config(), (uint8_t *)addr_dev_vec, rank_id, n_ranks, - team_odd); + team_odd); ASSERT_EQ(aclrtSynchronizeStream(stream), 0); ASSERT_EQ( aclrtMemcpy(addr_host_vec, sizeof(uint64_t), addr_dev_vec, sizeof(uint64_t), ACL_MEMCPY_DEVICE_TO_HOST), 0); ASSERT_EQ((*addr_host_vec), i); - shm::shmemi_control_barrier_all(); } + shmemi_control_barrier_all(); } + ASSERT_EQ(aclrtFreeHost(addr_host), 0); shmem_free(addr_dev); diff --git a/tests/unittest/host/team/team/team_host_test.cpp b/tests/unittest/host/team/team/team_host_test.cpp index 1dc7bef47e842670aebbba141fbb61971cb83de4..04a83adf039d928159f9249b6f8c667a17978557 100644 --- a/tests/unittest/host/team/team/team_host_test.cpp +++ b/tests/unittest/host/team/team/team_host_test.cpp @@ -32,7 +32,7 @@ static int32_t test_get_device_state(aclrtStream stream, uint8_t *gva, uint32_t uint32_t block_dim = 1; void *ptr = shmem_malloc(1024); int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id), aclrtGetDevice); + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); get_device_state(block_dim, stream, (uint8_t *)ptr, team_id); EXPECT_EQ(aclrtSynchronizeStream(stream), 0); sleep(1); @@ -132,7 +132,7 @@ void test_shmem_team(int rank_id, int n_ranks, uint64_t local_mem_size) // #################### device代码测试 ############################## - auto status = test_get_device_state(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks, team_odd, stride); + auto status = test_get_device_state(stream, (uint8_t *)g_state.heap_base, rank_id, n_ranks, team_odd, stride); EXPECT_EQ(status, SHMEM_SUCCESS); // #################### 相关资源释放 ################################ diff --git a/tests/unittest/include/unittest/mem_host_direct_test.cpp b/tests/unittest/include/unittest/mem_host_direct_test.cpp index e6f228743559cbb11053ffd79c05804675d91cae..a8ef5dd4d1bd1fda2915fd6df0542e9954e18102 100644 --- a/tests/unittest/include/unittest/mem_host_direct_test.cpp +++ b/tests/unittest/include/unittest/mem_host_direct_test.cpp @@ -124,7 +124,7 @@ SHMEM_MEM_PUT_GET_FUNC(GET_MEM_TEST) void *ptr = shmem_malloc(1024); \ host_test_##NAME##_put((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_id, rank_size, is_nbi); \ \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(input.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \ @@ -137,7 +137,7 @@ SHMEM_MEM_PUT_GET_FUNC(GET_MEM_TEST) std::cout << std::endl; \ host_test_##NAME##_get((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_id, rank_size, is_nbi); \ \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(input.data(), input_size, dev_ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \ @@ -166,7 +166,7 @@ SHMEM_MEM_PUT_GET_FUNC(PUT_GET_TEST) aclrtStream stream; \ test_init(rank_id, n_ranks, local_mem_size, &stream); \ ASSERT_NE(stream, nullptr); \ - host_test_##NAME##_put_get((uint8_t *)shm::g_state.heap_base, rank_id, n_ranks, is_nbi); \ + host_test_##NAME##_put_get((uint8_t *)g_state.heap_base, rank_id, n_ranks, is_nbi); \ \ std::cout << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; \ test_finalize(stream, device_id); \ diff --git a/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp b/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp index 8700db3be0a4e00524423f01e180da348dd0beb1..dcb26bddd55c18b918a971bc5e1664bb0f53407c 100644 --- a/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp +++ b/tests/unittest/include/unittest/mem_host_get_and_put_test.cpp @@ -111,7 +111,7 @@ static void host_test_put_get_mem(int rank_id, int rank_size, uint64_t local_mem void *ptr = shmem_malloc(1024); host_test_putmem(ptr, dev_ptr, rank_id, input_size); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); sleep(sleep_time); ASSERT_EQ(aclrtMemcpy(input.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); @@ -124,7 +124,7 @@ static void host_test_put_get_mem(int rank_id, int rank_size, uint64_t local_mem std::cout << std::endl; size_t ele_size = 16; host_test_getmem((uint8_t *)ptr, (uint8_t *)dev_ptr, rank_size, ele_size); - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); ASSERT_EQ(aclrtMemcpy(input.data(), input_size, dev_ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); @@ -167,7 +167,7 @@ void test_host_shmem_putmem_and_getmem(int rank_id, int n_ranks, uint64_t local_ \ void *ptr = shmem_malloc(input_size); \ shmem_##NAME##_p(static_cast(ptr), static_cast(rank_id + 10), rank_id); \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ TYPE msg; \ diff --git a/tests/unittest/include/unittest/mem_putmem_signal_test.cpp b/tests/unittest/include/unittest/mem_putmem_signal_test.cpp index 9f9264c42e31b7741d34ddf112b5f8d016f754c0..883356dc074a6aef11acdc50c93da1426448acdd 100644 --- a/tests/unittest/include/unittest/mem_putmem_signal_test.cpp +++ b/tests/unittest/include/unittest/mem_putmem_signal_test.cpp @@ -81,7 +81,7 @@ SHMEM_MEM_PUT_GET_FUNC(PUT_MEM_SIGNAL) void *ptr = shmem_malloc(1024); \ int32_t signal = 6; \ putmem_##NAME##_signal_test((TYPE *)ptr, (TYPE *)dev_ptr, (uint8_t *)signal_addr, signal, rank_id, sig_op); \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(output.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \ @@ -180,7 +180,7 @@ SHMEM_MEM_PUT_GET_FUNC(PUT_MEM_SIGNAL_NBI) int32_t signal = 6; \ putmem_signal_##NAME##_test_nbi((TYPE *)ptr, (TYPE *)dev_ptr, (uint8_t *)signal_addr, signal, rank_id, \ sig_op); \ - ASSERT_EQ(aclrtSynchronizeStream(shm::g_state_host.default_stream), 0); \ + ASSERT_EQ(aclrtSynchronizeStream(g_state_host.default_stream), 0); \ sleep(2); \ \ ASSERT_EQ(aclrtMemcpy(output.data(), input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST), 0); \