Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 333 additions & 23 deletions cpp/include/cugraph/utilities/device_comm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,10 @@ device_alltoall_impl(raft::comms::comms_t const& comm,
using value_type = typename std::iterator_traits<InputIterator>::value_type;
static_assert(
std::is_same_v<typename std::iterator_traits<OutputIterator>::value_type, value_type>);
#if 1 // FIXME: we should add comm.device_alltoall to raft (which calls ncclAlltoAll)
std::vector<size_t> sizes(comm.get_size(), count_per_rank);
std::vector<size_t> displs(comm.get_size());
for (size_t i = 0; i < displs.size(); ++i) {
displs[i] = i * count_per_rank;
}
std::vector<int> ranks(comm.get_size());
std::iota(ranks.begin(), ranks.end(), int{0});
comm.device_multicast_sendrecv(iter_to_raw_ptr(input_first),
sizes,
displs,
ranks,
iter_to_raw_ptr(output_first),
sizes,
displs,
ranks,
stream_view.value());
#else
comm.device_alltoall(iter_to_raw_ptr(input_first),
iter_to_raw_ptr(output_first),
count_per_rank,
stream_view.value());
#endif
comm.alltoall(iter_to_raw_ptr(input_first),
iter_to_raw_ptr(output_first),
count_per_rank,
stream_view.value());
}

template <typename InputIterator, typename OutputIterator, size_t I, size_t N>
Expand Down Expand Up @@ -716,6 +697,210 @@ struct device_allgatherv_tuple_iterator_element_impl<InputIterator, OutputIterat
}
};

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<is_discard_iterator<OutputIterator>::value, void> device_scatter_impl(
raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
// no-op
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
std::is_arithmetic<typename std::iterator_traits<OutputIterator>::value_type>::value,
void>
device_scatter_impl(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
static_assert(std::is_same_v<typename std::iterator_traits<InputIterator>::value_type,
typename std::iterator_traits<OutputIterator>::value_type>);
comm.scatter(iter_to_raw_ptr(input_first),
iter_to_raw_ptr(output_first),
recvcount,
root,
stream_view.value());
}

template <typename InputIterator, typename OutputIterator, size_t I, size_t N>
struct device_scatter_tuple_iterator_element_impl {
void run(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view) const
{
device_scatter_impl(comm,
cuda::std::get<I>(input_first.get_iterator_tuple()),
cuda::std::get<I>(output_first.get_iterator_tuple()),
recvcount,
root,
stream_view);
device_scatter_tuple_iterator_element_impl<InputIterator, OutputIterator, I + 1, N>().run(
comm, input_first, output_first, recvcount, root, stream_view);
}
};

template <typename InputIterator, typename OutputIterator, size_t I>
struct device_scatter_tuple_iterator_element_impl<InputIterator, OutputIterator, I, I> {
void run(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view) const
{
}
};

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<is_discard_iterator<OutputIterator>::value, void> device_scatterv_impl(
raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
raft::host_span<size_t const> sendcounts,
raft::host_span<size_t const> displacements,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
// no-op
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
std::is_arithmetic<typename std::iterator_traits<OutputIterator>::value_type>::value,
void>
device_scatterv_impl(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
raft::host_span<size_t const> sendcounts,
raft::host_span<size_t const> displacements,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
static_assert(std::is_same_v<typename std::iterator_traits<InputIterator>::value_type,
typename std::iterator_traits<OutputIterator>::value_type>);
comm.scatterv(iter_to_raw_ptr(input_first),
iter_to_raw_ptr(output_first),
sendcounts.data(),
displacements.data(),
recvcount,
root,
stream_view.value());
}

template <typename InputIterator, typename OutputIterator, size_t I, size_t N>
struct device_scatterv_tuple_iterator_element_impl {
void run(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
raft::host_span<size_t const> sendcounts,
raft::host_span<size_t const> displacements,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view) const
{
device_scatterv_impl(comm,
cuda::std::get<I>(input_first.get_iterator_tuple()),
cuda::std::get<I>(output_first.get_iterator_tuple()),
sendcounts,
displacements,
recvcount,
root,
stream_view);
device_scatterv_tuple_iterator_element_impl<InputIterator, OutputIterator, I + 1, N>().run(
comm, input_first, output_first, sendcounts, displacements, recvcount, root, stream_view);
}
};

template <typename InputIterator, typename OutputIterator, size_t I>
struct device_scatterv_tuple_iterator_element_impl<InputIterator, OutputIterator, I, I> {
void run(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
raft::host_span<size_t const> sendcounts,
raft::host_span<size_t const> displacements,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view) const
{
}
};

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<is_discard_iterator<OutputIterator>::value, void> device_gather_impl(
raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t sendcount,
int root,
rmm::cuda_stream_view stream_view)
{
// no-op
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
std::is_arithmetic<typename std::iterator_traits<OutputIterator>::value_type>::value,
void>
device_gather_impl(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t sendcount,
int root,
rmm::cuda_stream_view stream_view)
{
static_assert(std::is_same_v<typename std::iterator_traits<InputIterator>::value_type,
typename std::iterator_traits<OutputIterator>::value_type>);
comm.gather(iter_to_raw_ptr(input_first),
iter_to_raw_ptr(output_first),
sendcount,
root,
stream_view.value());
}

template <typename InputIterator, typename OutputIterator, size_t I, size_t N>
struct device_gather_tuple_iterator_element_impl {
void run(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t sendcount,
int root,
rmm::cuda_stream_view stream_view) const
{
device_gather_impl(comm,
cuda::std::get<I>(input_first.get_iterator_tuple()),
cuda::std::get<I>(output_first.get_iterator_tuple()),
sendcount,
root,
stream_view);
device_gather_tuple_iterator_element_impl<InputIterator, OutputIterator, I + 1, N>().run(
comm, input_first, output_first, sendcount, root, stream_view);
}
};

template <typename InputIterator, typename OutputIterator, size_t I>
struct device_gather_tuple_iterator_element_impl<InputIterator, OutputIterator, I, I> {
void run(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t sendcount,
int root,
rmm::cuda_stream_view stream_view) const
{
}
};

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<is_discard_iterator<OutputIterator>::value, void> device_gatherv_impl(
raft::comms::comms_t const& comm,
Expand Down Expand Up @@ -1222,6 +1407,131 @@ device_allgatherv(raft::comms::comms_t const& comm,
.run(comm, input_first, output_first, recvcounts, displacements, stream_view);
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
std::is_arithmetic<typename std::iterator_traits<InputIterator>::value_type>::value,
void>
device_scatter(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
detail::device_scatter_impl(comm, input_first, output_first, recvcount, root, stream_view);
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
is_thrust_tuple_of_arithmetic<typename std::iterator_traits<InputIterator>::value_type>::value &&
is_thrust_tuple<typename std::iterator_traits<OutputIterator>::value_type>::value,
void>
device_scatter(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
static_assert(
cuda::std::tuple_size<typename thrust::iterator_traits<InputIterator>::value_type>::value ==
cuda::std::tuple_size<typename thrust::iterator_traits<OutputIterator>::value_type>::value);

size_t constexpr tuple_size =
cuda::std::tuple_size<typename thrust::iterator_traits<InputIterator>::value_type>::value;

detail::device_scatter_tuple_iterator_element_impl<InputIterator,
OutputIterator,
size_t{0},
tuple_size>()
.run(comm, input_first, output_first, recvcount, root, stream_view);
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
std::is_arithmetic<typename std::iterator_traits<InputIterator>::value_type>::value,
void>
device_scatterv(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
raft::host_span<size_t const> sendcounts,
raft::host_span<size_t const> displacements,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
detail::device_scatterv_impl(
comm, input_first, output_first, sendcounts, displacements, recvcount, root, stream_view);
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
is_thrust_tuple_of_arithmetic<typename std::iterator_traits<InputIterator>::value_type>::value &&
is_thrust_tuple<typename std::iterator_traits<OutputIterator>::value_type>::value,
void>
device_scatterv(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
raft::host_span<size_t const> sendcounts,
raft::host_span<size_t const> displacements,
size_t recvcount,
int root,
rmm::cuda_stream_view stream_view)
{
static_assert(
cuda::std::tuple_size<typename thrust::iterator_traits<InputIterator>::value_type>::value ==
cuda::std::tuple_size<typename thrust::iterator_traits<OutputIterator>::value_type>::value);

size_t constexpr tuple_size =
cuda::std::tuple_size<typename thrust::iterator_traits<InputIterator>::value_type>::value;

detail::device_scatterv_tuple_iterator_element_impl<InputIterator,
OutputIterator,
size_t{0},
tuple_size>()
.run(comm, input_first, output_first, sendcounts, displacements, recvcount, root, stream_view);
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
std::is_arithmetic<typename std::iterator_traits<InputIterator>::value_type>::value,
void>
device_gather(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t sendcount,
int root,
rmm::cuda_stream_view stream_view)
{
detail::device_gather_impl(comm, input_first, output_first, sendcount, root, stream_view);
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
is_thrust_tuple_of_arithmetic<typename std::iterator_traits<InputIterator>::value_type>::value &&
is_thrust_tuple<typename std::iterator_traits<OutputIterator>::value_type>::value,
void>
device_gather(raft::comms::comms_t const& comm,
InputIterator input_first,
OutputIterator output_first,
size_t sendcount,
int root,
rmm::cuda_stream_view stream_view)
{
static_assert(
cuda::std::tuple_size<typename thrust::iterator_traits<InputIterator>::value_type>::value ==
cuda::std::tuple_size<typename thrust::iterator_traits<OutputIterator>::value_type>::value);

size_t constexpr tuple_size =
cuda::std::tuple_size<typename thrust::iterator_traits<InputIterator>::value_type>::value;

detail::device_gather_tuple_iterator_element_impl<InputIterator,
OutputIterator,
size_t{0},
tuple_size>()
.run(comm, input_first, output_first, sendcount, root, stream_view);
}

template <typename InputIterator, typename OutputIterator>
std::enable_if_t<
std::is_arithmetic<typename std::iterator_traits<InputIterator>::value_type>::value,
Expand Down
Loading