From c47648f0304902abad561c9c22718e7ae7801be2 Mon Sep 17 00:00:00 2001 From: caixilong Date: Thu, 4 Dec 2025 17:26:32 +0800 Subject: [PATCH] add rma test cases --- .../device/mem/gm_mem/gm_mem_kernel.cpp | 48 +++++ tests/unittest/host/main_test.cpp | 8 +- .../host/mem/gm_mem/gm_mem_host_test.cpp | 63 +++++++ .../unittest/host/mem/shmem_ptr_host_test.cpp | 164 +++++++++++++++++- 4 files changed, 276 insertions(+), 7 deletions(-) diff --git a/tests/unittest/device/mem/gm_mem/gm_mem_kernel.cpp b/tests/unittest/device/mem/gm_mem/gm_mem_kernel.cpp index bd0f464b..680810bf 100644 --- a/tests/unittest/device/mem/gm_mem/gm_mem_kernel.cpp +++ b/tests/unittest/device/mem/gm_mem/gm_mem_kernel.cpp @@ -173,3 +173,51 @@ SHMEM_FUNC_TYPE_KERNEL(GET_NUM_TEST); } SHMEM_FUNC_TYPE_KERNEL(TEST_GET); + +class kernel_float_gm_mem_zero_one_ring { +public: + __aicore__ inline kernel_float_gm_mem_zero_one_ring() {} + __aicore__ inline void Init(GM_ADDR gva, GM_ADDR dev) + { + gva_gm = (__gm__ float *)gva; + dev_gm = (__gm__ float *)dev; + } + __aicore__ inline void Process(uint64_t config) + { + shmemx_set_ffts_config(config); + int rank = shmem_my_pe(); + int n_ranks = shmem_n_pes(); + int dst_pe = (rank + 1) % n_ranks; + + for (int i = 0; i < 4; i++) { + shmem_put_float_mem_nbi(gva_gm + i, dev_gm + i, 1, rank); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + shmem_put_float_mem_nbi(gva_gm + i, dev_gm + i, 0, dst_pe); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + shmemx_barrier_all_vec(); + shmem_get_float_mem_nbi(dev_gm, gva_gm, 0, dst_pe); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + shmem_get_float_mem_nbi(dev_gm + 1, gva_gm + 1, 1, dst_pe); + shmemx_barrier_all_vec(); + } +private: + __gm__ float *gva_gm; + __gm__ float *dev_gm; +}; + +extern "C" __global__ __aicore__ void gm_mem_float_zero_one_ring_kernel( + GM_ADDR gva, GM_ADDR dev, uint64_t config) +{ + kernel_float_gm_mem_zero_one_ring op; + op.Init(gva, dev); + op.Process(config); +} + +void test_float_gm_mem_zero_one_ring(uint32_t block_dim, void *stream, uint64_t config, uint8_t *gva, uint8_t *dev) +{ + gm_mem_float_zero_one_ring_kernel<<>>(gva, dev, config); +} \ No newline at end of file diff --git a/tests/unittest/host/main_test.cpp b/tests/unittest/host/main_test.cpp index a91dbaba..8b14b152 100644 --- a/tests/unittest/host/main_test.cpp +++ b/tests/unittest/host/main_test.cpp @@ -88,10 +88,16 @@ void test_mutil_task(std::function func, uint64_t loca std::cout << "fork failed ! " << pids[i] << std::endl; } else if (pids[i] == 0) { func(i + test_first_rank, test_global_ranks, local_mem_size); - exit(0); + if (::testing::Test::HasFailure()) { + _exit(1); + } + _exit(0); } } for (int i = 0; i < process_count; ++i) { + if (pids[i] <= 0) { + continue; + } waitpid(pids[i], &status[i], 0); if (WIFEXITED(status[i]) && WEXITSTATUS(status[i]) != 0) { FAIL(); diff --git a/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp b/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp index fe75b96c..a4401a4a 100644 --- a/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp +++ b/tests/unittest/host/mem/gm_mem/gm_mem_host_test.cpp @@ -26,6 +26,9 @@ const int test_offset = 10; extern void test_##NAME##_put(uint32_t block_dim, void *stream, uint64_t config, uint8_t *gva, uint8_t *dev_ptr); \ extern void test_##NAME##_get(uint32_t block_dim, void *stream, uint64_t config, uint8_t *gva, uint8_t *dev_ptr) +extern void test_float_gm_mem_zero_one_ring(uint32_t block_dim, void *stream, uint64_t config, uint8_t *gva, + uint8_t *dev_ptr); + SHMEM_FUNC_TYPE_HOST(TEST_FUNC); #define TEST_PUT_GET(NAME, TYPE) \ @@ -95,3 +98,63 @@ SHMEM_FUNC_TYPE_HOST(TEST_SHMEM_MEM); } SHMEM_FUNC_TYPE_HOST(TESTAPI); + +static void test_float_gm_mem_zero_one_ring_impl(aclrtStream stream, int rank_id, int n_ranks) +{ + const int elem_count = 4; + const size_t bytes = elem_count * sizeof(float); + + std::vector src(elem_count); + src[0] = rank_id + 10.5f; + src[1] = rank_id + 20.5f; + src[2] = rank_id + 30.5f; + src[3] = rank_id + 40.5f; + + void *dev_ptr = nullptr; + ASSERT_EQ(aclrtMalloc(&dev_ptr, bytes, ACL_MEM_MALLOC_NORMAL_ONLY), 0); + ASSERT_EQ(aclrtMemcpy(dev_ptr, bytes, src.data(), bytes, ACL_MEMCPY_HOST_TO_DEVICE), 0); + + void *ptr = shmem_malloc(bytes); + + uint32_t block_dim = 1; + uint64_t config = shmemx_get_ffts_config(); + + test_float_gm_mem_zero_one_ring(block_dim, stream, config, static_cast(ptr), + static_cast(dev_ptr)); + ASSERT_EQ(aclrtSynchronizeStream(stream), 0); + + std::vector y_host(elem_count, 0.0f); + + ASSERT_EQ(aclrtMemcpy(y_host.data(), bytes, dev_ptr, bytes, ACL_MEMCPY_DEVICE_TO_HOST), 0); + + shmem_free(ptr); + aclrtFree(dev_ptr); + + int next = (rank_id + 1) % n_ranks; + + EXPECT_FLOAT_EQ(y_host[0], rank_id + 10.5f); + EXPECT_FLOAT_EQ(y_host[1], next + 20.5f); + EXPECT_FLOAT_EQ(y_host[2], rank_id + 30.5f); + EXPECT_FLOAT_EQ(y_host[3], rank_id + 40.5f); +} + +static void test_float_shmem_mem_zero_one_ring(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream = nullptr; + + test_init(rank_id, n_ranks, local_mem_size, &stream); + ASSERT_NE(stream, nullptr); + + test_float_gm_mem_zero_one_ring_impl(stream, rank_id, n_ranks); + + test_finalize(stream, device_id); +} + +TEST(TestMemApi, TestShmemGMFloatMemZeroOneRing) +{ + const int process_count = test_gnpu_num; + const uint64_t local_mem_size = 1024UL * 1024UL * 1024; + + test_mutil_task(test_float_shmem_mem_zero_one_ring, local_mem_size, process_count); +} \ No newline at end of file diff --git a/tests/unittest/host/mem/shmem_ptr_host_test.cpp b/tests/unittest/host/mem/shmem_ptr_host_test.cpp index c4714178..07c863c6 100644 --- a/tests/unittest/host/mem/shmem_ptr_host_test.cpp +++ b/tests/unittest/host/mem/shmem_ptr_host_test.cpp @@ -24,7 +24,7 @@ static int32_t test_get_device_ptr(aclrtStream stream, uint8_t *ptr, int rank_id uint32_t block_dim = 1; int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id), aclrtGetDevice); + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); get_device_ptr(block_dim, stream, ptr); EXPECT_EQ(aclrtSynchronizeStream(stream), 0); @@ -46,7 +46,8 @@ void test_shmem_ptr(int rank_id, int n_ranks, uint64_t local_mem_size) test_init(rank_id, n_ranks, local_mem_size, &stream); ASSERT_NE(stream, nullptr); - int *ptr = static_cast(shmem_malloc(2 * sizeof(int))); + int heap_size = 16; + int *ptr = static_cast(shmem_malloc(heap_size * sizeof(int))); ASSERT_NE(ptr, nullptr); void *host_self = shmem_ptr(ptr, rank_id); @@ -54,10 +55,43 @@ void test_shmem_ptr(int rank_id, int n_ranks, uint64_t local_mem_size) EXPECT_EQ(host_self, ptr); int peer = (rank_id + 1) % n_ranks; - void *host_remote = shmem_ptr(ptr, peer); - void *next_remote = shmem_ptr(ptr + 1, peer); - ASSERT_NE(host_remote, nullptr); - EXPECT_EQ(static_cast(next_remote) - static_cast(host_remote), 1); + void *host_remote_peer = shmem_ptr(ptr, peer); + void *next_remote_peer = shmem_ptr(ptr + 1, peer); + void *mid_remote_peer = shmem_ptr(ptr + heap_size / 2, peer); + void *tail_remote_peer = shmem_ptr(ptr + heap_size - 1, peer); + ASSERT_NE(host_remote_peer, nullptr); + ASSERT_NE(next_remote_peer, nullptr); + EXPECT_EQ(static_cast(next_remote_peer) - static_cast(host_remote_peer), 1); + ASSERT_NE(mid_remote_peer, nullptr); + EXPECT_EQ(static_cast(mid_remote_peer) - static_cast(host_remote_peer), heap_size / 2); + ASSERT_NE(tail_remote_peer, nullptr); + EXPECT_EQ(static_cast(tail_remote_peer) - static_cast(host_remote_peer), heap_size - 1); + + int start = 0; + void *host_remote_start = shmem_ptr(ptr, start); + void *next_remote_start = shmem_ptr(ptr + 1, start); + void *mid_remote_start = shmem_ptr(ptr + heap_size / 2, start); + void *tail_remote_start = shmem_ptr(ptr + heap_size - 1, start); + ASSERT_NE(host_remote_start, nullptr); + ASSERT_NE(next_remote_start, nullptr); + EXPECT_EQ(static_cast(next_remote_start) - static_cast(host_remote_start), 1); + ASSERT_NE(mid_remote_start, nullptr); + EXPECT_EQ(static_cast(mid_remote_start) - static_cast(host_remote_start), heap_size / 2); + ASSERT_NE(tail_remote_start, nullptr); + EXPECT_EQ(static_cast(tail_remote_start) - static_cast(host_remote_start), heap_size - 1); + + int end = n_ranks - 1; + void *host_remote_end = shmem_ptr(ptr, end); + void *next_remote_end = shmem_ptr(ptr + 1, end); + void *mid_remote_end = shmem_ptr(ptr + heap_size / 2, end); + void *tail_remote_end = shmem_ptr(ptr + heap_size - 1, end); + ASSERT_NE(host_remote_end, nullptr); + ASSERT_NE(next_remote_end, nullptr); + EXPECT_EQ(static_cast(next_remote_end) - static_cast(host_remote_end), 1); + ASSERT_NE(mid_remote_end, nullptr); + EXPECT_EQ(static_cast(mid_remote_end) - static_cast(host_remote_end), heap_size / 2); + ASSERT_NE(tail_remote_end, nullptr); + EXPECT_EQ(static_cast(tail_remote_end) - static_cast(host_remote_end), heap_size - 1); auto status = test_get_device_ptr(stream, (uint8_t *)ptr, rank_id, n_ranks); EXPECT_EQ(status, SHMEM_SUCCESS); @@ -78,6 +112,50 @@ TEST(TestMemApi, TestShmemPtr) test_mutil_task(test_shmem_ptr, local_mem_size, process_count); } +void test_shmem_ptr_invalid_addr_and_pe(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream{}; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + { + size_t size = 1024; + void *non_symmetric; + EXPECT_EQ(aclrtMallocHost(&non_symmetric, size), 0); + ASSERT_NE(non_symmetric, nullptr); + + void *p0 = shmem_ptr(non_symmetric, 0); + void *plast = shmem_ptr(non_symmetric, n_ranks - 1); + EXPECT_EQ(p0, nullptr); + EXPECT_EQ(plast, nullptr); + + EXPECT_EQ(aclrtFreeHost(non_symmetric), 0);; + } + + { + size_t size = 1024; + void *ptr = shmem_malloc(size); + ASSERT_NE(ptr, nullptr); + + void *p_negative = shmem_ptr(ptr, -1); + EXPECT_EQ(p_negative, nullptr); + + void *p_too_large = shmem_ptr(ptr, n_ranks); + EXPECT_EQ(p_too_large, nullptr); + + shmem_free(ptr); + } + + test_finalize(stream, device_id); +} + +TEST(TestMemApi, TestShmemPtrInvalidAddrAndPe) +{ + const int process_count = test_gnpu_num; + const uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_ptr_invalid_addr_and_pe, local_mem_size, process_count); +} + TEST(TestMemApi, TestShmemMteSetUbParams) { const int process_count = test_gnpu_num; @@ -98,4 +176,78 @@ TEST(TestMemApi, TestShmemMteSetUbParams) test_finalize(stream, device_id); }, local_mem_size, process_count); +} + +void test_shmem_mte_set_ub_params_basic_and_boundary(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + (void)rank_id; + (void)n_ranks; + + int32_t device_id = test_first_npu; + aclrtStream stream{}; + test_init(0, 1, local_mem_size, &stream); + ASSERT_NE(stream, nullptr); + + constexpr uint64_t ub_offset_min = 0; + constexpr uint64_t ub_size_min = 16; + constexpr int event_id_min = 0; + constexpr int event_id_max = 15; + + int ret = shmem_mte_set_ub_params(ub_offset_min, ub_size_min, event_id_min); + EXPECT_EQ(ret, SHMEM_SUCCESS); + + ret = shmem_mte_set_ub_params(ub_offset_min, ub_size_min * 4, event_id_min); + EXPECT_EQ(ret, SHMEM_SUCCESS); + + ret = shmem_mte_set_ub_params(ub_offset_min, ub_size_min, event_id_max); + EXPECT_EQ(ret, SHMEM_SUCCESS); + + test_finalize(stream, device_id); +} + +TEST(TestMemApi, TestShmemMteSetUbParamsBasicAndBoundary) +{ + const int process_count = 1; + const uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_mte_set_ub_params_basic_and_boundary, local_mem_size, process_count); +} + +static void shmem_mte_set_ub_params_before_after_init() +{ + constexpr uint64_t ub_offset = 0; + constexpr uint64_t ub_size = 16; + constexpr int event_id = 0; + + int ret = shmem_mte_set_ub_params(ub_offset, ub_size, event_id); + EXPECT_NE(ret, SHMEM_SUCCESS); +} + +void test_shmem_mte_set_ub_params_invalid_lifecycle(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + (void)rank_id; + (void)n_ranks; + + if (shmem_init_status() == SHMEM_STATUS_IS_INITIALIZED) { + EXPECT_EQ(shmem_finalize(), SHMEM_SUCCESS); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); + } + ASSERT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); + + shmem_mte_set_ub_params_before_after_init(); + + int32_t device_id = test_first_npu; + aclrtStream stream{}; + test_init(0, 1, local_mem_size, &stream); + ASSERT_NE(stream, nullptr); + + test_finalize(stream, device_id); + + shmem_mte_set_ub_params_before_after_init(); +} + +TEST(TestMemApi, TestShmemMteSetUbParamsInvalidLifecycle) +{ + const int process_count = 1; + const uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_mte_set_ub_params_invalid_lifecycle, local_mem_size, process_count); } \ No newline at end of file -- Gitee