Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions cpp/include/raft/comms/comms_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,35 @@ namespace raft {
namespace comms {

/**
* @brief A simple sanity check that NCCL is able to perform a collective operation
* @brief A simple sanity check that NCCL is able to perform a collective all-to-all
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
* @param[in] root the root rank id
*/
bool test_collective_alltoall(raft::resources const& handle, int root)
{
return detail::test_collective_alltoall(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective allreduce
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_allreduce(raft::resources const& handle, int root)
{
return detail::test_collective_allreduce(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective operation
* @brief A simple sanity check that NCCL is able to perform a collective broadcast
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
* @param[in] root the root rank id
*/
bool test_collective_broadcast(raft::resources const& handle, int root)
{
Expand All @@ -41,7 +53,7 @@ bool test_collective_broadcast(raft::resources const& handle, int root)
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
* @param[in] root the root rank id
*/
bool test_collective_reduce(raft::resources const& handle, int root)
{
Expand All @@ -53,19 +65,43 @@ bool test_collective_reduce(raft::resources const& handle, int root)
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
* @param[in] root the root rank id
*/
bool test_collective_allgather(raft::resources const& handle, int root)
{
return detail::test_collective_allgather(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective scatter
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_scatter(raft::resources const& handle, int root)
{
return detail::test_collective_scatter(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective scatterv
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_scatterv(raft::resources const& handle, int root)
{
return detail::test_collective_scatterv(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective gather
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
* @param[in] root the root rank id
*/
bool test_collective_gather(raft::resources const& handle, int root)
{
Expand All @@ -77,7 +113,7 @@ bool test_collective_gather(raft::resources const& handle, int root)
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
* @param[in] root the root rank id
*/
bool test_collective_gatherv(raft::resources const& handle, int root)
{
Expand All @@ -89,7 +125,7 @@ bool test_collective_gatherv(raft::resources const& handle, int root)
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
* @param[in] root the root rank id
*/
bool test_collective_reducescatter(raft::resources const& handle, int root)
{
Expand Down
54 changes: 44 additions & 10 deletions cpp/include/raft/comms/detail/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,16 @@ class mpi_comms : public comms_iface {
RAFT_MPI_TRY(MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE));
}

void alltoall(const void* sendbuff,
void* recvbuff,
size_t count,
datatype_t datatype,
cudaStream_t stream) const
{
RAFT_NCCL_TRY(
ncclAlltoAll(sendbuff, recvbuff, count, get_nccl_datatype(datatype), nccl_comm_, stream));
}

void allreduce(const void* sendbuff,
void* recvbuff,
size_t count,
Expand Down Expand Up @@ -308,30 +318,54 @@ class mpi_comms : public comms_iface {
RAFT_NCCL_TRY(ncclGroupEnd());
}

void gather(const void* sendbuff,
void* recvbuff,
size_t sendcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
void scatter(const void* sendbuff,
void* recvbuff,
size_t recvcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
{
RAFT_NCCL_TRY(ncclScatter(
sendbuff, recvbuff, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
}

void scatterv(const void* sendbuf,
void* recvbuf,
const size_t* sendcounts,
const size_t* displs,
size_t recvcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
{
size_t dtype_size = get_datatype_size(datatype);
RAFT_NCCL_TRY(ncclGroupStart());
RAFT_NCCL_TRY(
ncclRecv(recvbuf, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
RAFT_NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + sendcount * r * dtype_size,
sendcount,
RAFT_NCCL_TRY(ncclSend(static_cast<const char*>(sendbuf) + displs[r] * dtype_size,
sendcounts[r],
get_nccl_datatype(datatype),
r,
nccl_comm_,
stream));
}
}
RAFT_NCCL_TRY(
ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclGroupEnd());
}

void gather(const void* sendbuff,
void* recvbuff,
size_t sendcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
{
RAFT_NCCL_TRY(ncclGather(
sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
}

void gatherv(const void* sendbuff,
void* recvbuff,
size_t sendcount,
Expand Down
54 changes: 44 additions & 10 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ class std_comms : public comms_iface {
}
}

void alltoall(const void* sendbuff,
void* recvbuff,
size_t count,
datatype_t datatype,
cudaStream_t stream) const
{
RAFT_NCCL_TRY(
ncclAlltoAll(sendbuff, recvbuff, count, get_nccl_datatype(datatype), nccl_comm_, stream));
}

void allreduce(const void* sendbuff,
void* recvbuff,
size_t count,
Expand Down Expand Up @@ -443,30 +453,54 @@ class std_comms : public comms_iface {
RAFT_NCCL_TRY(ncclGroupEnd());
}

void gather(const void* sendbuff,
void* recvbuff,
size_t sendcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
void scatter(const void* sendbuff,
void* recvbuff,
size_t recvcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
{
RAFT_NCCL_TRY(ncclScatter(
sendbuff, recvbuff, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
}

void scatterv(const void* sendbuf,
void* recvbuf,
const size_t* sendcounts,
const size_t* displs,
size_t recvcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
{
size_t dtype_size = get_datatype_size(datatype);
RAFT_NCCL_TRY(ncclGroupStart());
RAFT_NCCL_TRY(
ncclRecv(recvbuf, recvcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
RAFT_NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + sendcount * r * dtype_size,
sendcount,
RAFT_NCCL_TRY(ncclSend(static_cast<const char*>(sendbuf) + displs[r] * dtype_size,
sendcounts[r],
get_nccl_datatype(datatype),
r,
nccl_comm_,
stream));
}
}
RAFT_NCCL_TRY(
ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
RAFT_NCCL_TRY(ncclGroupEnd());
}

void gather(const void* sendbuff,
void* recvbuff,
size_t sendcount,
datatype_t datatype,
int root,
cudaStream_t stream) const
{
RAFT_NCCL_TRY(ncclGather(
sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream));
}

void gatherv(const void* sendbuff,
void* recvbuff,
size_t sendcount,
Expand Down
Loading
Loading