diff --git a/cpp/include/cugraph/utilities/device_comm.hpp b/cpp/include/cugraph/utilities/device_comm.hpp index c68f6d4859..730a83589e 100644 --- a/cpp/include/cugraph/utilities/device_comm.hpp +++ b/cpp/include/cugraph/utilities/device_comm.hpp @@ -349,29 +349,10 @@ device_alltoall_impl(raft::comms::comms_t const& comm, using value_type = typename std::iterator_traits::value_type; static_assert( std::is_same_v::value_type, value_type>); -#if 1 // FIXME: we should add comm.device_alltoall to raft (which calls ncclAlltoAll) - std::vector sizes(comm.get_size(), count_per_rank); - std::vector displs(comm.get_size()); - for (size_t i = 0; i < displs.size(); ++i) { - displs[i] = i * count_per_rank; - } - std::vector ranks(comm.get_size()); - std::iota(ranks.begin(), ranks.end(), int{0}); - comm.device_multicast_sendrecv(iter_to_raw_ptr(input_first), - sizes, - displs, - ranks, - iter_to_raw_ptr(output_first), - sizes, - displs, - ranks, - stream_view.value()); -#else - comm.device_alltoall(iter_to_raw_ptr(input_first), - iter_to_raw_ptr(output_first), - count_per_rank, - stream_view.value()); -#endif + comm.alltoall(iter_to_raw_ptr(input_first), + iter_to_raw_ptr(output_first), + count_per_rank, + stream_view.value()); } template @@ -716,6 +697,210 @@ struct device_allgatherv_tuple_iterator_element_impl +std::enable_if_t::value, void> device_scatter_impl( + raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + // no-op +} + +template +std::enable_if_t< + std::is_arithmetic::value_type>::value, + void> +device_scatter_impl(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + static_assert(std::is_same_v::value_type, + typename std::iterator_traits::value_type>); + comm.scatter(iter_to_raw_ptr(input_first), + iter_to_raw_ptr(output_first), + recvcount, + root, + stream_view.value()); +} + +template +struct device_scatter_tuple_iterator_element_impl { + void run(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) const + { + device_scatter_impl(comm, + cuda::std::get(input_first.get_iterator_tuple()), + cuda::std::get(output_first.get_iterator_tuple()), + recvcount, + root, + stream_view); + device_scatter_tuple_iterator_element_impl().run( + comm, input_first, output_first, recvcount, root, stream_view); + } +}; + +template +struct device_scatter_tuple_iterator_element_impl { + void run(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) const + { + } +}; + +template +std::enable_if_t::value, void> device_scatterv_impl( + raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + raft::host_span sendcounts, + raft::host_span displacements, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + // no-op +} + +template +std::enable_if_t< + std::is_arithmetic::value_type>::value, + void> +device_scatterv_impl(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + raft::host_span sendcounts, + raft::host_span displacements, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + static_assert(std::is_same_v::value_type, + typename std::iterator_traits::value_type>); + comm.scatterv(iter_to_raw_ptr(input_first), + iter_to_raw_ptr(output_first), + sendcounts.data(), + displacements.data(), + recvcount, + root, + stream_view.value()); +} + +template +struct device_scatterv_tuple_iterator_element_impl { + void run(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + raft::host_span sendcounts, + raft::host_span displacements, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) const + { + device_scatterv_impl(comm, + cuda::std::get(input_first.get_iterator_tuple()), + cuda::std::get(output_first.get_iterator_tuple()), + sendcounts, + displacements, + recvcount, + root, + stream_view); + device_scatterv_tuple_iterator_element_impl().run( + comm, input_first, output_first, sendcounts, displacements, recvcount, root, stream_view); + } +}; + +template +struct device_scatterv_tuple_iterator_element_impl { + void run(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + raft::host_span sendcounts, + raft::host_span displacements, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) const + { + } +}; + +template +std::enable_if_t::value, void> device_gather_impl( + raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t sendcount, + int root, + rmm::cuda_stream_view stream_view) +{ + // no-op +} + +template +std::enable_if_t< + std::is_arithmetic::value_type>::value, + void> +device_gather_impl(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t sendcount, + int root, + rmm::cuda_stream_view stream_view) +{ + static_assert(std::is_same_v::value_type, + typename std::iterator_traits::value_type>); + comm.gather(iter_to_raw_ptr(input_first), + iter_to_raw_ptr(output_first), + sendcount, + root, + stream_view.value()); +} + +template +struct device_gather_tuple_iterator_element_impl { + void run(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t sendcount, + int root, + rmm::cuda_stream_view stream_view) const + { + device_gather_impl(comm, + cuda::std::get(input_first.get_iterator_tuple()), + cuda::std::get(output_first.get_iterator_tuple()), + sendcount, + root, + stream_view); + device_gather_tuple_iterator_element_impl().run( + comm, input_first, output_first, sendcount, root, stream_view); + } +}; + +template +struct device_gather_tuple_iterator_element_impl { + void run(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t sendcount, + int root, + rmm::cuda_stream_view stream_view) const + { + } +}; + template std::enable_if_t::value, void> device_gatherv_impl( raft::comms::comms_t const& comm, @@ -1222,6 +1407,131 @@ device_allgatherv(raft::comms::comms_t const& comm, .run(comm, input_first, output_first, recvcounts, displacements, stream_view); } +template +std::enable_if_t< + std::is_arithmetic::value_type>::value, + void> +device_scatter(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + detail::device_scatter_impl(comm, input_first, output_first, recvcount, root, stream_view); +} + +template +std::enable_if_t< + is_thrust_tuple_of_arithmetic::value_type>::value && + is_thrust_tuple::value_type>::value, + void> +device_scatter(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + static_assert( + cuda::std::tuple_size::value_type>::value == + cuda::std::tuple_size::value_type>::value); + + size_t constexpr tuple_size = + cuda::std::tuple_size::value_type>::value; + + detail::device_scatter_tuple_iterator_element_impl() + .run(comm, input_first, output_first, recvcount, root, stream_view); +} + +template +std::enable_if_t< + std::is_arithmetic::value_type>::value, + void> +device_scatterv(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + raft::host_span sendcounts, + raft::host_span displacements, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + detail::device_scatterv_impl( + comm, input_first, output_first, sendcounts, displacements, recvcount, root, stream_view); +} + +template +std::enable_if_t< + is_thrust_tuple_of_arithmetic::value_type>::value && + is_thrust_tuple::value_type>::value, + void> +device_scatterv(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + raft::host_span sendcounts, + raft::host_span displacements, + size_t recvcount, + int root, + rmm::cuda_stream_view stream_view) +{ + static_assert( + cuda::std::tuple_size::value_type>::value == + cuda::std::tuple_size::value_type>::value); + + size_t constexpr tuple_size = + cuda::std::tuple_size::value_type>::value; + + detail::device_scatterv_tuple_iterator_element_impl() + .run(comm, input_first, output_first, sendcounts, displacements, recvcount, root, stream_view); +} + +template +std::enable_if_t< + std::is_arithmetic::value_type>::value, + void> +device_gather(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t sendcount, + int root, + rmm::cuda_stream_view stream_view) +{ + detail::device_gather_impl(comm, input_first, output_first, sendcount, root, stream_view); +} + +template +std::enable_if_t< + is_thrust_tuple_of_arithmetic::value_type>::value && + is_thrust_tuple::value_type>::value, + void> +device_gather(raft::comms::comms_t const& comm, + InputIterator input_first, + OutputIterator output_first, + size_t sendcount, + int root, + rmm::cuda_stream_view stream_view) +{ + static_assert( + cuda::std::tuple_size::value_type>::value == + cuda::std::tuple_size::value_type>::value); + + size_t constexpr tuple_size = + cuda::std::tuple_size::value_type>::value; + + detail::device_gather_tuple_iterator_element_impl() + .run(comm, input_first, output_first, sendcount, root, stream_view); +} + template std::enable_if_t< std::is_arithmetic::value_type>::value,