diff --git a/cpp/include/raft/comms/comms_test.hpp b/cpp/include/raft/comms/comms_test.hpp index e71080cb74..a52c530d7e 100644 --- a/cpp/include/raft/comms/comms_test.hpp +++ b/cpp/include/raft/comms/comms_test.hpp @@ -13,11 +13,23 @@ namespace raft { namespace comms { /** - * @brief A simple sanity check that NCCL is able to perform a collective operation + * @brief A simple sanity check that NCCL is able to perform a collective all-to-all * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id + */ +bool test_collective_alltoall(raft::resources const& handle, int root) +{ + return detail::test_collective_alltoall(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective allreduce + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id */ bool test_collective_allreduce(raft::resources const& handle, int root) { @@ -25,11 +37,11 @@ bool test_collective_allreduce(raft::resources const& handle, int root) } /** - * @brief A simple sanity check that NCCL is able to perform a collective operation + * @brief A simple sanity check that NCCL is able to perform a collective broadcast * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_broadcast(raft::resources const& handle, int root) { @@ -41,7 +53,7 @@ bool test_collective_broadcast(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_reduce(raft::resources const& handle, int root) { @@ -53,19 +65,43 @@ bool test_collective_reduce(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_allgather(raft::resources const& handle, int root) { return detail::test_collective_allgather(handle, root); } +/** + * @brief A simple sanity check that NCCL is able to perform a collective scatter + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_scatter(raft::resources const& handle, int root) +{ + return detail::test_collective_scatter(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective scatterv + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_scatterv(raft::resources const& handle, int root) +{ + return detail::test_collective_scatterv(handle, root); +} + /** * @brief A simple sanity check that NCCL is able to perform a collective gather * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_gather(raft::resources const& handle, int root) { @@ -77,7 +113,7 @@ bool test_collective_gather(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_gatherv(raft::resources const& handle, int root) { @@ -89,7 +125,7 @@ bool test_collective_gatherv(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_reducescatter(raft::resources const& handle, int root) { diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index d0de62c461..3b1bc95660 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -227,6 +227,16 @@ class mpi_comms : public comms_iface { RAFT_MPI_TRY(MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE)); } + void alltoall(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + cudaStream_t stream) const + { + RAFT_NCCL_TRY( + ncclAlltoAll(sendbuff, recvbuff, count, get_nccl_datatype(datatype), nccl_comm_, stream)); + } + void allreduce(const void* sendbuff, void* recvbuff, size_t count, @@ -308,30 +318,54 @@ class mpi_comms : public comms_iface { RAFT_NCCL_TRY(ncclGroupEnd()); } - void gather(const void* sendbuff, - void* recvbuff, - size_t sendcount, - datatype_t datatype, - int root, - cudaStream_t stream) const + void scatter(const void* sendbuff, + void* recvbuff, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclScatter( + sendbuff, recvbuff, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + + void scatterv(const void* sendbuf, + void* recvbuf, + const size_t* sendcounts, + const size_t* displs, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const { size_t dtype_size = get_datatype_size(datatype); RAFT_NCCL_TRY(ncclGroupStart()); + RAFT_NCCL_TRY( + ncclRecv(recvbuf, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); if (get_rank() == root) { for (int r = 0; r < get_size(); ++r) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + sendcount * r * dtype_size, - sendcount, + RAFT_NCCL_TRY(ncclSend(static_cast(sendbuf) + displs[r] * dtype_size, + sendcounts[r], get_nccl_datatype(datatype), r, nccl_comm_, stream)); } } - RAFT_NCCL_TRY( - ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); RAFT_NCCL_TRY(ncclGroupEnd()); } + void gather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclGather( + sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + void gatherv(const void* sendbuff, void* recvbuff, size_t sendcount, diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 40ac0cbc51..a4b9ad48ac 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -362,6 +362,16 @@ class std_comms : public comms_iface { } } + void alltoall(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + cudaStream_t stream) const + { + RAFT_NCCL_TRY( + ncclAlltoAll(sendbuff, recvbuff, count, get_nccl_datatype(datatype), nccl_comm_, stream)); + } + void allreduce(const void* sendbuff, void* recvbuff, size_t count, @@ -443,30 +453,54 @@ class std_comms : public comms_iface { RAFT_NCCL_TRY(ncclGroupEnd()); } - void gather(const void* sendbuff, - void* recvbuff, - size_t sendcount, - datatype_t datatype, - int root, - cudaStream_t stream) const + void scatter(const void* sendbuff, + void* recvbuff, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclScatter( + sendbuff, recvbuff, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + + void scatterv(const void* sendbuf, + void* recvbuf, + const size_t* sendcounts, + const size_t* displs, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const { size_t dtype_size = get_datatype_size(datatype); RAFT_NCCL_TRY(ncclGroupStart()); + RAFT_NCCL_TRY( + ncclRecv(recvbuf, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); if (get_rank() == root) { for (int r = 0; r < get_size(); ++r) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + sendcount * r * dtype_size, - sendcount, + RAFT_NCCL_TRY(ncclSend(static_cast(sendbuf) + displs[r] * dtype_size, + sendcounts[r], get_nccl_datatype(datatype), r, nccl_comm_, stream)); } } - RAFT_NCCL_TRY( - ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); RAFT_NCCL_TRY(ncclGroupEnd()); } + void gather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclGather( + sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + void gatherv(const void* sendbuff, void* recvbuff, size_t sendcount, diff --git a/cpp/include/raft/comms/detail/test.hpp b/cpp/include/raft/comms/detail/test.hpp index 83306b6204..f2d2b26979 100644 --- a/cpp/include/raft/comms/detail/test.hpp +++ b/cpp/include/raft/comms/detail/test.hpp @@ -20,12 +20,52 @@ namespace raft { namespace comms { namespace detail { +/** + * @brief A simple sanity check that NCCL is able to perform a collective all-to-all + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_alltoall(raft::resources const& handle, int root) +{ + comms_t const& communicator = resource::get_comms(handle); + + std::vector sends(communicator.get_size(), communicator.get_rank()); + + cudaStream_t stream = resource::get_cuda_stream(handle); + + rmm::device_uvector temp_d(communicator.get_size(), stream); + rmm::device_uvector recv_d(communicator.get_size(), stream); + + RAFT_CUDA_TRY(cudaMemcpyAsync(temp_d.data(), + sends.data(), + sizeof(int) * communicator.get_size(), + cudaMemcpyHostToDevice, + stream)); + + communicator.alltoall(temp_d.data(), recv_d.data(), 1, stream); + communicator.sync_stream(stream); + std::vector temp_h(communicator.get_size()); + RAFT_CUDA_TRY(cudaMemcpyAsync(temp_h.data(), + recv_d.data(), + sizeof(int) * communicator.get_size(), + cudaMemcpyDeviceToHost, + stream)); + resource::sync_stream(handle, stream); + + for (int i = 0; i < communicator.get_size(); i++) { + if (temp_h[i] != i) return false; + } + return true; +} + /** * @brief A simple sanity check that NCCL is able to perform a collective operation * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_allreduce(raft::resources const& handle, int root) { @@ -56,7 +96,7 @@ bool test_collective_allreduce(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_broadcast(raft::resources const& handle, int root) { @@ -91,7 +131,7 @@ bool test_collective_broadcast(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_reduce(raft::resources const& handle, int root) { @@ -127,7 +167,7 @@ bool test_collective_reduce(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_allgather(raft::resources const& handle, int root) { @@ -159,12 +199,98 @@ bool test_collective_allgather(raft::resources const& handle, int root) return true; } +/** + * @brief A simple sanity check that NCCL is able to perform a collective scatter + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_scatter(raft::resources const& handle, int root) +{ + comms_t const& communicator = resource::get_comms(handle); + + cudaStream_t stream = resource::get_cuda_stream(handle); + + rmm::device_uvector temp_d(communicator.get_rank() == root ? communicator.get_size() : 0, + stream); + rmm::device_scalar recv_d(stream); + + if (communicator.get_rank() == root) { + std::vector sends(communicator.get_size(), communicator.get_rank()); + std::fill(sends.begin(), sends.end(), root); + RAFT_CUDA_TRY(cudaMemcpyAsync( + temp_d.data(), sends.data(), sizeof(int) * sends.size(), cudaMemcpyHostToDevice, stream)); + } + + communicator.scatter( + communicator.get_rank() == root ? temp_d.data() : nullptr, recv_d.data(), 1, root, stream); + communicator.sync_stream(stream); + + int temp_h = -1; // Verify more than one byte is being sent + RAFT_CUDA_TRY( + cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); + resource::sync_stream(handle, stream); + + return temp_h == root; +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective scatterv + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_scatterv(raft::resources const& handle, int root) +{ + comms_t const& communicator = resource::get_comms(handle); + + std::vector sendcounts(communicator.get_size()); + std::iota(sendcounts.begin(), sendcounts.end(), size_t{1}); + std::vector displacements(communicator.get_size() + 1, 0); + std::partial_sum(sendcounts.begin(), sendcounts.end(), displacements.begin() + 1); + + cudaStream_t stream = resource::get_cuda_stream(handle); + + rmm::device_uvector temp_d(communicator.get_rank() == root ? displacements.back() : 0, + stream); + rmm::device_uvector recv_d( + displacements[communicator.get_rank() + 1] - displacements[communicator.get_rank()], stream); + + if (communicator.get_rank() == root) { + std::vector sends(displacements.back(), root); + RAFT_CUDA_TRY(cudaMemcpyAsync( + temp_d.data(), sends.data(), sends.size() * sizeof(int), cudaMemcpyHostToDevice, stream)); + } + + communicator.scatterv( + communicator.get_rank() == root ? temp_d.data() : nullptr, + recv_d.data(), + communicator.get_rank() == root ? sendcounts.data() : static_cast(nullptr), + communicator.get_rank() == root ? displacements.data() : static_cast(nullptr), + recv_d.size(), + root, + stream); + communicator.sync_stream(stream); + + std::vector temp_h(recv_d.size(), 0); + RAFT_CUDA_TRY(cudaMemcpyAsync( + temp_h.data(), recv_d.data(), sizeof(int) * recv_d.size(), cudaMemcpyDeviceToHost, stream)); + resource::sync_stream(handle, stream); + + if (std::count_if(temp_h.begin(), temp_h.end(), [root](auto val) { return val != root; }) != 0) { + return false; + } + return true; +} + /** * @brief A simple sanity check that NCCL is able to perform a collective gather * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_gather(raft::resources const& handle, int root) { @@ -201,7 +327,7 @@ bool test_collective_gather(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_gatherv(raft::resources const& handle, int root) { @@ -260,7 +386,7 @@ bool test_collective_gatherv(raft::resources const& handle, int root) * * @param[in] handle the raft handle to use. This is expected to already have an * initialized comms instance. - * @param[in] root the root rank id + * @param[in] root the root rank id */ bool test_collective_reducescatter(raft::resources const& handle, int root) { diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp index 7ff91a6cb1..ff15147e6c 100644 --- a/cpp/include/raft/core/comms.hpp +++ b/cpp/include/raft/core/comms.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -131,6 +131,12 @@ class comms_iface { virtual void waitall(int count, request_t array_of_requests[]) const = 0; + virtual void alltoall(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + cudaStream_t stream) const = 0; + virtual void allreduce(const void* sendbuff, void* recvbuff, size_t count, @@ -169,6 +175,22 @@ class comms_iface { datatype_t datatype, cudaStream_t stream) const = 0; + virtual void scatter(const void* sendbuff, + void* recvbuff, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const = 0; + + virtual void scatterv(const void* sendbuf, + void* recvbuf, + const size_t* sendcounts, + const size_t* displs, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const = 0; + virtual void gather(const void* sendbuff, void* recvbuff, size_t sendcount, @@ -322,6 +344,24 @@ class comms_t { impl_->waitall(count, array_of_requests); } + /** + * Perform an alltoall collective + * @tparam value_t datatype of underlying buffers + * @param sendbuff buffer containing data to send (size = # ranks * count) + * @param recvbuff buffer containing data to receive (size = # ranks * count) + * @param count number of elements to send to/receive from each rank + * @param stream CUDA stream to synchronize operation + */ + template + void alltoall(const value_t* sendbuff, value_t* recvbuff, size_t count, cudaStream_t stream) const + { + impl_->alltoall(static_cast(sendbuff), + static_cast(recvbuff), + count, + get_type(), + stream); + } + /** * Perform an allreduce collective * @tparam value_t datatype of underlying buffers @@ -453,12 +493,69 @@ class comms_t { } /** - * Gathers data from each rank onto all ranks + * Scatters data from one rank to all ranks + * @tparam value_t datatype of underlying buffers + * @param sendbuff buffer containing data to scatter (only used in root) + * @param recvbuff buffer containing data received from the root rank + * @param recvcount number of elements in receive buffer + * @param root rank holding the data to scatter + * @param stream CUDA stream to synchronize operation + */ + template + void scatter(const value_t* sendbuff, + value_t* recvbuff, + size_t recvcount, + int root, + cudaStream_t stream) const + { + impl_->scatter(static_cast(sendbuff), + static_cast(recvbuff), + recvcount, + get_type(), + root, + stream); + } + + /** + * Scatters data from one rank to all ranks (different ranks can receive different amounts of + * data) + * @tparam value_t datatype of underlying buffers + * @param sendbuf buffer containing data to scatter (only used in root) + * @param recvbuf buffer containing data received from the root rank + * @param sendcounts pointer to an array (of length num_ranks size) containing the number of + * elements that are to be sent to each rank + * @param recvcount number of elements in receive buffer + * @param displs pointer to an array (of length num_ranks size) to specify the displacement + * (relative to sendbuf) at which to start the outgoing data to each rank + * @param root rank holding the data to scatter + * @param stream CUDA stream to synchronize operation + */ + template + void scatterv(const value_t* sendbuf, + value_t* recvbuf, + const size_t* sendcounts, + const size_t* displs, + size_t recvcount, + int root, + cudaStream_t stream) const + { + impl_->scatterv(static_cast(sendbuf), + static_cast(recvbuf), + sendcounts, + displs, + recvcount, + get_type(), + root, + stream); + } + + /** + * Gathers data from all ranks to one rank * @tparam value_t datatype of underlying buffers * @param sendbuff buffer containing data to gather - * @param recvbuff buffer containing gathered data from all ranks + * @param recvbuff buffer containing gathered data from all ranks (only used in root) * @param sendcount number of elements in send buffer - * @param root rank to store the results + * @param root rank to store the gathered data * @param stream CUDA stream to synchronize operation */ template @@ -477,16 +574,16 @@ class comms_t { } /** - * Gathers data from all ranks and delivers to combined data to all ranks + * Gathers data from all ranks to one rank (different ranks can send different amounts of data) * @tparam value_t datatype of underlying buffers * @param sendbuf buffer containing data to send - * @param recvbuf buffer containing data to receive + * @param recvbuf buffer containing gathered data from all ranks (only used in root) * @param sendcount number of elements in send buffer * @param recvcounts pointer to an array (of length num_ranks size) containing the number of * elements that are to be received from each rank * @param displs pointer to an array (of length num_ranks size) to specify the displacement * (relative to recvbuf) at which to place the incoming data from each rank - * @param root rank to store the results + * @param root rank to store the gathered data * @param stream CUDA stream to synchronize operation */ template diff --git a/cpp/tests/core/handle.cpp b/cpp/tests/core/handle.cpp index 322e56da91..c20bc01a2e 100644 --- a/cpp/tests/core/handle.cpp +++ b/cpp/tests/core/handle.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -52,6 +52,14 @@ class mock_comms : public comms_iface { void waitall(int count, request_t array_of_requests[]) const {} + void alltoall(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + cudaStream_t stream) const + { + } + void allreduce(const void* sendbuff, void* recvbuff, size_t count, @@ -99,6 +107,26 @@ class mock_comms : public comms_iface { { } + void scatter(const void* sendbuff, + void* recvbuff, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + } + + void scatterv(const void* sendbuf, + void* recvbuf, + const size_t* sendcounts, + const size_t* displs, + size_t recvcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + } + void gather(const void* sendbuff, void* recvbuff, size_t sendcount, diff --git a/python/raft-dask/raft_dask/common/__init__.py b/python/raft-dask/raft_dask/common/__init__.py index 4792076cb4..3c97bbdd11 100644 --- a/python/raft-dask/raft_dask/common/__init__.py +++ b/python/raft-dask/raft_dask/common/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2020-2022, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -9,6 +9,7 @@ perform_test_comm_split, perform_test_comms_allgather, perform_test_comms_allreduce, + perform_test_comms_alltoall, perform_test_comms_bcast, perform_test_comms_device_multicast_sendrecv, perform_test_comms_device_send_or_recv, @@ -17,6 +18,8 @@ perform_test_comms_gatherv, perform_test_comms_reduce, perform_test_comms_reducescatter, + perform_test_comms_scatter, + perform_test_comms_scatterv, perform_test_comms_send_recv, ) from .ucx import UCX diff --git a/python/raft-dask/raft_dask/common/comms_utils.pyx b/python/raft-dask/raft_dask/common/comms_utils.pyx index 5b2a119cc0..4429dbff7f 100644 --- a/python/raft-dask/raft_dask/common/comms_utils.pyx +++ b/python/raft-dask/raft_dask/common/comms_utils.pyx @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # cython: profile=False @@ -43,6 +43,7 @@ cdef extern from "raft/comms/std_comms.hpp" namespace "raft::comms": cdef extern from "raft/comms/comms_test.hpp" namespace "raft::comms": + bool test_collective_alltoall(const device_resources &h, int root) except + bool test_collective_allreduce(const device_resources &h, int root) \ except + bool test_collective_broadcast(const device_resources &h, int root) \ @@ -50,6 +51,8 @@ cdef extern from "raft/comms/comms_test.hpp" namespace "raft::comms": bool test_collective_reduce(const device_resources &h, int root) except + bool test_collective_allgather(const device_resources &h, int root) \ except + + bool test_collective_scatter(const device_resources &h, int root) except + + bool test_collective_scatterv(const device_resources &h, int root) except + bool test_collective_gather(const device_resources &h, int root) except + bool test_collective_gatherv(const device_resources &h, int root) except + bool test_collective_reducescatter(const device_resources &h, int root) \ @@ -65,6 +68,20 @@ cdef extern from "raft/comms/comms_test.hpp" namespace "raft::comms": bool test_commsplit(const device_resources &h, int n_colors) except + +def perform_test_comms_alltoall(handle, root): + """ + Performs an alltoall on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + """ + cdef const device_resources* h = \ + handle.getHandle() + return test_collective_alltoall(deref(h), root) + + def perform_test_comms_allreduce(handle, root): """ Performs an allreduce on the current worker @@ -135,6 +152,38 @@ def perform_test_comms_allgather(handle, root): return test_collective_allgather(deref(h), root) +def perform_test_comms_scatter(handle, root): + """ + Performs a scatter on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + root : int + Rank of the root worker + """ + cdef const device_resources* h = \ + handle.getHandle() + return test_collective_scatter(deref(h), root) + + +def perform_test_comms_scatterv(handle, root): + """ + Performs a scatterv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + root : int + Rank of the root worker + """ + cdef const device_resources* h = \ + handle.getHandle() + return test_collective_scatterv(deref(h), root) + + def perform_test_comms_gather(handle, root): """ Performs a gather on the current worker diff --git a/python/raft-dask/raft_dask/tests/test_comms.py b/python/raft-dask/raft_dask/tests/test_comms.py index 676d40f7de..8eeb8684b1 100644 --- a/python/raft-dask/raft_dask/tests/test_comms.py +++ b/python/raft-dask/raft_dask/tests/test_comms.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -16,6 +16,7 @@ perform_test_comm_split, perform_test_comms_allgather, perform_test_comms_allreduce, + perform_test_comms_alltoall, perform_test_comms_bcast, perform_test_comms_device_multicast_sendrecv, perform_test_comms_device_send_or_recv, @@ -24,6 +25,8 @@ perform_test_comms_gatherv, perform_test_comms_reduce, perform_test_comms_reducescatter, + perform_test_comms_scatter, + perform_test_comms_scatterv, perform_test_comms_send_recv, ) @@ -161,11 +164,14 @@ def _has_handle(sessionId): functions = [ perform_test_comms_allgather, perform_test_comms_allreduce, + perform_test_comms_alltoall, perform_test_comms_bcast, perform_test_comms_gather, perform_test_comms_gatherv, perform_test_comms_reduce, perform_test_comms_reducescatter, + perform_test_comms_scatter, + perform_test_comms_scatterv, ] else: functions = [None]