diff --git a/tests/unittest/device/team/team/team_kernel.cpp b/tests/unittest/device/team/team/team_kernel.cpp index 26ff4ce8f89b17214536148ebaf01d7cc823a45f..6b8472e2355aafcf36ac0c192fc6703c106848e3 100644 --- a/tests/unittest/device/team/team/team_kernel.cpp +++ b/tests/unittest/device/team/team/team_kernel.cpp @@ -50,4 +50,48 @@ extern "C" __global__ __aicore__ void device_state_test(GM_ADDR gva, int team_id void get_device_state(uint32_t block_dim, void* stream, uint8_t* gva, shmem_team_t team_id) { device_state_test<<>>(gva, (int)team_id); -} \ No newline at end of file +} + +class kernel_invalid_test { +public: + __aicore__ inline kernel_invalid_test() {} + __aicore__ inline void Init(GM_ADDR gva, shmem_team_t team_id) + { + gva_gm = (__gm__ int *)gva; + destroyed= team_id; + + rank = smem_shm_get_global_rank(); + rank_size = smem_shm_get_global_rank_size(); + } + __aicore__ inline void Process() + { + AscendC::PipeBarrier(); + shmem_int32_p(gva_gm, shmem_team_my_pe(SHMEM_TEAM_INVALID), rank); + shmem_int32_p(gva_gm + 1U, shmem_team_n_pes(SHMEM_TEAM_INVALID), rank); + shmem_int32_p(gva_gm + 2U, shmem_team_translate_pe(SHMEM_TEAM_INVALID, 0, SHMEM_TEAM_WORLD), rank); + shmem_int32_p(gva_gm + 3U, shmem_team_translate_pe(SHMEM_TEAM_WORLD, 0, SHMEM_TEAM_INVALID), rank); + shmem_int32_p(gva_gm + 4U, shmem_team_translate_pe(SHMEM_TEAM_WORLD, -1, SHMEM_TEAM_WORLD), rank); + shmem_int32_p(gva_gm + 5U, shmem_team_translate_pe(SHMEM_TEAM_WORLD, rank_size, SHMEM_TEAM_WORLD), rank); + shmem_int32_p(gva_gm + 6U, shmem_team_my_pe(destroyed), rank); + shmem_int32_p(gva_gm + 7U, shmem_team_n_pes(destroyed), rank); + } +private: + __gm__ int *gva_gm; + shmem_team_t destroyed; + + int64_t rank; + int64_t rank_size; +}; + +extern "C" __global__ __aicore__ void device_team_invalid_test(GM_ADDR gva, int team_id) +{ + kernel_invalid_test op; + op.Init(gva, (shmem_team_t)team_id); + op.Process(); +} + +void get_device_team_invalid_state(uint32_t block_dim, void *stream, + uint8_t *gva, shmem_team_t destroyed_team) +{ + device_team_invalid_test<<>>(gva, (int)destroyed_team); +} diff --git a/tests/unittest/host/team/team/team_host_test.cpp b/tests/unittest/host/team/team/team_host_test.cpp index 3c17926f2d7411a5787a952e76ff5ff1356310e9..a91afc4cddc87ff7e3db240028d72078d231ca6f 100644 --- a/tests/unittest/host/team/team/team_host_test.cpp +++ b/tests/unittest/host/team/team/team_host_test.cpp @@ -21,7 +21,7 @@ using namespace std; static int32_t test_get_device_state(aclrtStream stream, uint8_t *gva, uint32_t rank_id, uint32_t rank_size, - shmem_team_t team_id, int stride) + shmem_team_t team_id, int stride, bool is_odd) { int *y_host; int num3 = 3; @@ -32,14 +32,13 @@ static int32_t test_get_device_state(aclrtStream stream, uint8_t *gva, uint32_t uint32_t block_dim = 1; void *ptr = shmem_malloc(1024); int32_t device_id; - SHMEM_CHECK_RET(aclrtGetDevice(&device_id), aclrtGetDevice); + SHMEM_CHECK_RET(aclrtGetDevice(&device_id)); get_device_state(block_dim, stream, (uint8_t *)ptr, team_id); EXPECT_EQ(aclrtSynchronizeStream(stream), 0); - sleep(1); EXPECT_EQ(aclrtMemcpy(y_host, num5 * sizeof(int), ptr, num5 * sizeof(int), ACL_MEMCPY_DEVICE_TO_HOST), 0); - if (rank_id & 1) { + if (is_odd == rank_id & 1) { int idx = 0; EXPECT_EQ(y_host[idx++], rank_size); EXPECT_EQ(y_host[idx++], rank_id); @@ -95,6 +94,9 @@ void test_shmem_team(int rank_id, int n_ranks, uint64_t local_mem_size) int invalid_dest = shmem_team_translate_pe(SHMEM_TEAM_WORLD, start - 1, team_odd); ASSERT_EQ(invalid_dest, -1); + int negative_src = shmem_team_translate_pe(SHMEM_TEAM_WORLD, -1, team_odd); + ASSERT_EQ(negative_src, -1); + int invalid_odd_even = shmem_team_translate_pe(team_odd, local_idx, team_even); ASSERT_EQ(invalid_odd_even, -1); } @@ -132,11 +134,17 @@ void test_shmem_team(int rank_id, int n_ranks, uint64_t local_mem_size) // #################### device代码测试 ############################## - auto status = test_get_device_state(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks, team_odd, stride); + auto status = test_get_device_state(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks, team_odd, stride, true); + EXPECT_EQ(status, SHMEM_SUCCESS); + + status = test_get_device_state(stream, (uint8_t *)shm::g_state.heap_base, rank_id, n_ranks, team_even, stride, false); EXPECT_EQ(status, SHMEM_SUCCESS); // #################### 相关资源释放 ################################ shmem_team_destroy(team_odd); + EXPECT_EQ(shmem_team_n_pes(team_odd), -1); + shmem_team_destroy(team_even); + EXPECT_EQ(shmem_team_n_pes(team_even), -1); std::cerr << "[TEST] begin to exit...... rank_id: " << rank_id << std::endl; test_finalize(stream, device_id); @@ -198,6 +206,12 @@ TEST(TestTeamApi, TestShmemTeamSplitStrided_failConditions) ret = shmem_team_split_strided(-1, 0, 1, 1, &team_odd); EXPECT_EQ(ret, SHMEM_INVALID_PARAM); + ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, 1, 0, &team_odd); + EXPECT_EQ(ret, SHMEM_INVALID_PARAM); + + ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, 1, -1, &team_odd); + EXPECT_EQ(ret, SHMEM_INVALID_PARAM); + const int32_t pe_size = 2; ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, -1, pe_size, &team_odd); EXPECT_EQ(ret, SHMEM_INVALID_PARAM); @@ -206,6 +220,13 @@ TEST(TestTeamApi, TestShmemTeamSplitStrided_failConditions) const int32_t pe_start = SHMEM_MAX_RANKS - 1; ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, pe_start, stride, pe_size, &team_odd); EXPECT_EQ(ret, SHMEM_INVALID_PARAM); + + shmem_team_t team; + ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, 1, SHMEM_MAX_RANKS + 1, &team); + EXPECT_EQ(ret, SHMEM_INVALID_PARAM); + + ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, -1, 1, 1, &team_odd); + EXPECT_EQ(ret, SHMEM_INVALID_PARAM); } TEST(TestTeamApi, ShmemTeamSplit2d_failConditions) @@ -221,4 +242,330 @@ TEST(TestTeamApi, ShmemTeamSplit2d_failConditions) errorCode = shmem_team_split_2d(-1, 0, &team_x, nullptr); EXPECT_EQ(errorCode, SHMEM_INVALID_PARAM); -} \ No newline at end of file +} + +void test_shmem_team_split_from_subteam(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + shmem_team_t team_even = SHMEM_TEAM_INVALID; + int stride = 2; + int team_size = n_ranks / stride; + + int ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, stride, team_size, &team_even); + ASSERT_EQ(ret, SHMEM_SUCCESS); + + int even_size = shmem_team_n_pes(team_even); + + if (!(rank_id & 1)) { + int sub_size = even_size / 2; + if (sub_size == 0) { + sub_size = 1; + } + + shmem_team_t team_sub = SHMEM_TEAM_INVALID; + int my_pe_in_even = shmem_team_my_pe(team_even); + + ret = shmem_team_split_strided(team_even, 0, 1, sub_size, &team_sub); + ASSERT_EQ(ret, SHMEM_SUCCESS); + + if (rank_id % 4 == 0) { + ASSERT_NE(team_sub, SHMEM_TEAM_INVALID); + EXPECT_EQ(shmem_team_n_pes(team_sub), sub_size); + } + + if (team_sub != SHMEM_TEAM_INVALID) { + shmem_team_destroy(team_sub); + EXPECT_EQ(shmem_team_n_pes(team_sub), -1); + } + } + + if (team_even != SHMEM_TEAM_INVALID) { + shmem_team_destroy(team_even); + EXPECT_EQ(shmem_team_n_pes(team_even), -1); + } + + test_finalize(stream, device_id); +} + + +TEST(TestTeamApi, TestShmemTeamSplitFromSubteam) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_team_split_from_subteam, local_mem_size, process_count); +} + +void test_shmem_team_split_all(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + shmem_team_t team_all; + int ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, 1, n_ranks, &team_all); + ASSERT_EQ(ret, SHMEM_SUCCESS); + + EXPECT_EQ(shmem_team_n_pes(team_all), n_ranks); + EXPECT_EQ(shmem_team_my_pe(team_all), shmem_my_pe()); + + shmem_team_destroy(team_all); + EXPECT_EQ(shmem_team_n_pes(team_all), -1); + test_finalize(stream, device_id); +} + +TEST(TestTeamApi, TestShmemTeamSplitAllRanks) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_team_split_all, local_mem_size, process_count); +} + +void test_shmem_team_split_last_pe_only(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + ASSERT_NE(stream, nullptr); + + shmem_team_t last_team = SHMEM_TEAM_INVALID; + int ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, n_ranks - 1, 1, 1, &last_team); + ASSERT_EQ(ret, SHMEM_SUCCESS); + + if (rank_id == n_ranks - 1) { + ASSERT_NE(last_team, SHMEM_TEAM_INVALID); + EXPECT_EQ(shmem_team_n_pes(last_team), 1); + EXPECT_EQ(shmem_team_my_pe(last_team), 0); + + int world_pe = shmem_team_translate_pe(last_team, 0, SHMEM_TEAM_WORLD); + EXPECT_EQ(world_pe, rank_id); + } else { + EXPECT_EQ(last_team, SHMEM_TEAM_INVALID); + } + + if (last_team != SHMEM_TEAM_INVALID) { + shmem_team_destroy(last_team); + EXPECT_EQ(shmem_team_n_pes(last_team), -1); + EXPECT_EQ(shmem_team_my_pe(last_team), -1); + EXPECT_EQ(shmem_team_translate_pe(last_team, 0, SHMEM_TEAM_WORLD), -1); + } + + test_finalize(stream, device_id); +} + +TEST(TestTeamApi, TestShmemTeamSplitLastPEOnly) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_team_split_last_pe_only, local_mem_size, process_count); +} + +void test_shmem_team_split_strided_lifecycle(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + int stride = 2; + int team_size = std::max(1, n_ranks / stride); + int start = rank_id % stride; + + for (int i = 0; i < 100; ++i) { + shmem_team_t team; + int ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, start, stride, team_size, &team); + ASSERT_EQ(ret, SHMEM_SUCCESS); + EXPECT_EQ(shmem_team_n_pes(team), team_size); + shmem_team_destroy(team); + EXPECT_EQ(shmem_team_n_pes(team), -1); + } + + test_finalize(stream, device_id); +} + +TEST(TestTeamApi, TestShmemTeamSplitStridedLifecycle) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_team_split_strided_lifecycle, local_mem_size, process_count); +} + +void test_shmem_team_destroy_edge_cases(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + shmem_team_t team; + int ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, 1, n_ranks, &team); + ASSERT_EQ(ret, SHMEM_SUCCESS); + ASSERT_NE(team, SHMEM_TEAM_INVALID); + + EXPECT_EQ(shmem_team_n_pes(team), n_ranks); + EXPECT_EQ(shmem_team_my_pe(team), shmem_my_pe()); + + shmem_team_destroy(team); + EXPECT_EQ(shmem_team_n_pes(team), -1); + EXPECT_EQ(shmem_team_my_pe(team), -1); + EXPECT_EQ(shmem_team_translate_pe(team, 0, SHMEM_TEAM_WORLD), -1); + + shmem_team_destroy(team); + EXPECT_EQ(shmem_team_n_pes(team), -1); + EXPECT_EQ(shmem_team_my_pe(team), -1); + EXPECT_EQ(shmem_team_translate_pe(team, 0, SHMEM_TEAM_WORLD), -1); + + shmem_team_t invalid = SHMEM_TEAM_INVALID; + shmem_team_destroy(invalid); + EXPECT_EQ(shmem_team_n_pes(invalid), -1); + EXPECT_EQ(shmem_team_my_pe(invalid), -1); + + test_finalize(stream, device_id); +} + +TEST(TestTeamApi, TestShmemTeamDestroyEdgeCases) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_team_destroy_edge_cases, local_mem_size, process_count); +} + +void test_device_team_invalid(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + int stride = 2; + int team_size = n_ranks / stride; + shmem_team_t team_odd; + ASSERT_EQ(shmem_team_split_strided(SHMEM_TEAM_WORLD, 1, stride, team_size, &team_odd), SHMEM_SUCCESS); + shmem_team_t destroyed_team = team_odd; + shmem_team_destroy(team_odd); + std::cout << "teamodd is " << team_odd << std::endl; + + void *ptr = shmem_malloc(1024); + ASSERT_NE(ptr, nullptr); + int *y_host = nullptr; + ASSERT_EQ(aclrtMallocHost(reinterpret_cast(&y_host), 8 * sizeof(int)), 0); + + uint32_t block_dim = 1; + get_device_team_invalid_state(block_dim, stream, + (uint8_t *)ptr, destroyed_team); + ASSERT_EQ(aclrtSynchronizeStream(stream), 0); + + ASSERT_EQ(aclrtMemcpy(y_host, 8 * sizeof(int), ptr, 8 * sizeof(int), ACL_MEMCPY_DEVICE_TO_HOST), 0); + + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(y_host[i], -1); + } + + EXPECT_EQ(aclrtFreeHost(y_host), 0); + shmem_free(ptr); + test_finalize(stream, device_id); +} + +TEST(TestTeamApi, TestDeviceTeamInvalidCases) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_device_team_invalid, local_mem_size, process_count); +} + +void test_shmem_team_destroy_world(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + + EXPECT_EQ(shmem_team_n_pes(SHMEM_TEAM_WORLD), shmem_n_pes()); + EXPECT_EQ(shmem_team_my_pe(SHMEM_TEAM_WORLD), shmem_my_pe()); + + shmem_team_destroy(SHMEM_TEAM_WORLD); + + EXPECT_EQ(shmem_team_n_pes(SHMEM_TEAM_WORLD), -1); + EXPECT_EQ(shmem_team_my_pe(SHMEM_TEAM_WORLD), -1); + EXPECT_EQ(shmem_team_translate_pe(SHMEM_TEAM_WORLD, 0, SHMEM_TEAM_WORLD), -1); + + test_finalize(stream, device_id); +} + +TEST(TestTeamApi, TestShmemTeamDestroyWorld) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_team_destroy_world, local_mem_size, process_count); +} + +void test_shmem_team_out_of_range_handle(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + ASSERT_NE(stream, nullptr); + + shmem_team_t out_of_range_team = static_cast(SHMEM_MAX_TEAMS + 1); + + EXPECT_EQ(shmem_team_n_pes(out_of_range_team), -1); + EXPECT_EQ(shmem_team_my_pe(out_of_range_team), -1); + + EXPECT_EQ(shmem_team_translate_pe(out_of_range_team, 0, SHMEM_TEAM_WORLD), -1); + EXPECT_EQ(shmem_team_translate_pe(SHMEM_TEAM_WORLD, 0, out_of_range_team), -1); + + EXPECT_EQ(shmem_team_translate_pe(out_of_range_team, -1, SHMEM_TEAM_WORLD), -1); + + test_finalize(stream, device_id); +} + +TEST(TestTeamApi, TestShmemTeamOutOfRangeHandle) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_shmem_team_out_of_range_handle, local_mem_size, process_count); +} + +static void team_apis_pack() +{ + EXPECT_EQ(shmem_team_n_pes(SHMEM_TEAM_WORLD), -1); + EXPECT_EQ(shmem_team_my_pe(SHMEM_TEAM_WORLD), -1); + EXPECT_EQ(shmem_team_translate_pe(SHMEM_TEAM_WORLD, 0, SHMEM_TEAM_WORLD), -1); + + shmem_team_t team = SHMEM_TEAM_INVALID; + int ret = shmem_team_split_strided(SHMEM_TEAM_WORLD, 0, 1, 1, &team); + EXPECT_NE(ret, SHMEM_SUCCESS); + EXPECT_EQ(team, SHMEM_TEAM_INVALID); + shmem_team_destroy(team); + + (void)shmem_my_pe(); + (void)shmem_n_pes(); + return; +} + +void test_team_host_apis_init_finalize_seq(int rank_id, int n_ranks, uint64_t local_mem_size) +{ + if (shmem_init_status() == SHMEM_STATUS_IS_INITIALIZED) { + EXPECT_EQ(shmem_finalize(), SHMEM_SUCCESS); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); + } + + ASSERT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); + + team_apis_pack(); + + int32_t device_id = rank_id % test_gnpu_num + test_first_npu; + aclrtStream stream; + test_init(rank_id, n_ranks, local_mem_size, &stream); + ASSERT_NE(stream, nullptr); + + test_finalize(stream, device_id); + EXPECT_EQ(shmem_init_status(), SHMEM_STATUS_NOT_INITIALIZED); + + team_apis_pack(); +} + +TEST(TestTeamApi, TestTeamHostApisInitFinalizeSeq) +{ + const int process_count = test_gnpu_num; + uint64_t local_mem_size = 1024UL * 1024UL * 16; + test_mutil_task(test_team_host_apis_init_finalize_seq, local_mem_size, process_count); +} diff --git a/tests/unittest/team/team/team_kernel.h b/tests/unittest/team/team/team_kernel.h index c5fcf37b168384fca3974f75d493b58f548b4c83..41f77d974bbc564dd42a4b2aa7c43c1c11d42dff 100644 --- a/tests/unittest/team/team/team_kernel.h +++ b/tests/unittest/team/team/team_kernel.h @@ -12,4 +12,6 @@ void get_device_state(uint32_t block_dim, void* stream, uint8_t* gva, shmem_team_t team_id); +void get_device_team_invalid_state(uint32_t block_dim, void *stream, uint8_t *gva, shmem_team_t destroyed_team); + #endif \ No newline at end of file