diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index 37b8375197..51a1ec1d56 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -164,6 +165,8 @@ void bitset_view::repeat(const raft::resources& res, index_t times, bitset_t* output_device_ptr) const { + // Only a copy and kernel run below this point. + if (resource::get_dry_run_flag(res)) { return; } constexpr index_t bits_per_element = sizeof(bitset_t) * 8; if (bitset_len_ % bits_per_element == 0) { diff --git a/cpp/include/raft/core/bitset.hpp b/cpp/include/raft/core/bitset.hpp index d6b3fb7b63..8b6f8ab70c 100644 --- a/cpp/include/raft/core/bitset.hpp +++ b/cpp/include/raft/core/bitset.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -131,9 +132,11 @@ struct bitset_view { auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); count(res, count_gpu_scalar.view()); index_t count_cpu = 0; - raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); + if (!resource::get_dry_run_flag(res)) { + raft::update_host( + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + } return count_cpu; } @@ -406,9 +409,11 @@ struct bitset { auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); count(res, count_gpu_scalar.view()); index_t count_cpu = 0; - raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); + if (!resource::get_dry_run_flag(res)) { + raft::update_host( + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + } return count_cpu; } /** diff --git a/cpp/include/raft/core/coo_matrix.hpp b/cpp/include/raft/core/coo_matrix.hpp index 62ad6fda0a..0c8dff62c4 100644 --- a/cpp/include/raft/core/coo_matrix.hpp +++ b/cpp/include/raft/core/coo_matrix.hpp @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -179,8 +179,8 @@ class coordinate_structure : public coordinate_structure_tget_n_rows() + 1); - c_indices_.resize(nnz); + c_indptr_.reallocate(this->get_n_rows() + 1); + c_indices_.reallocate(nnz); } protected: diff --git a/cpp/include/raft/core/detail/copy.hpp b/cpp/include/raft/core/detail/copy.hpp index 8905f4c29b..058fcaac34 100644 --- a/cpp/include/raft/core/detail/copy.hpp +++ b/cpp/include/raft/core/detail/copy.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -398,6 +399,10 @@ mdspan_copyable_t copy(resources const& res, DstType&& dst, Sr RAFT_EXPECTS(src.extent(i) == dst.extent(i), "Must copy between mdspans of the same shape"); } + // Dry-run guard: raft::copy is a pure data-movement utility with no + // allocations that callers would need tracked. + if (resource::get_dry_run_flag(res)) { return; } + if constexpr (config::use_intermediate_src) { #ifndef RAFT_DISABLE_CUDA // Copy to intermediate source on device, then perform necessary diff --git a/cpp/include/raft/core/device_container_policy.hpp b/cpp/include/raft/core/device_container_policy.hpp index 9a9871a3ab..ff60d99c10 100644 --- a/cpp/include/raft/core/device_container_policy.hpp +++ b/cpp/include/raft/core/device_container_policy.hpp @@ -126,6 +126,29 @@ class device_uvector { void resize(size_type size) { data_.resize(size, data_.stream()); } + /** + * @brief Resize the internal buffer without copying old data. + * + * Unlike resize(), this never copies old data. + * Thus, unlike in resize(), there's no point in time where the old and the new buffers are both + * alive, and the peak memory usage is lower. + * + * Unlike resize(), this deallocates the old buffer even if the new size is smaller. + * This ensures the memory is released promptly. + */ + void reallocate(size_type size) + { + if (size != data_.size()) { + auto stream = data_.stream(); + auto mr = data_.memory_resource(); + // Resize and shrink rmm::device_uvector: force deallocation without copying old data + data_.resize(0, data_.stream()); + data_.shrink_to_fit(data_.stream()); + // Assign a new value after the old one is deallocated + data_ = rmm::device_uvector(size, stream, mr); + } + } + [[nodiscard]] auto data() noexcept -> pointer { return data_.data(); } [[nodiscard]] auto data() const noexcept -> const_pointer { return data_.data(); } }; diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index c575546be2..40b48f0f6e 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -163,7 +164,7 @@ auto make_device_scalar(raft::resources const& handle, ElementType const& v) using policy_t = typename device_scalar::container_policy_type; policy_t policy{}; auto scalar = device_scalar{handle, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(handle)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/host_container_policy.hpp b/cpp/include/raft/core/host_container_policy.hpp index 47db081771..87a0acea77 100644 --- a/cpp/include/raft/core/host_container_policy.hpp +++ b/cpp/include/raft/core/host_container_policy.hpp @@ -104,6 +104,27 @@ requires cuda::mr::synchronous_resource_with *this = std::move(new_container); } + /** + * @brief Resize the internal buffer without copying old data. + * + * Unlike resize(), this never copies old data. + * Thus, unlike in resize(), there's no point in time where the old and the new buffers are both + * alive, and the peak memory usage is lower. + * + * Unlike resize(), this deallocates the old buffer even if the new size is smaller. + * This ensures the memory is released promptly. + */ + void reallocate(size_type count) + { + if (bytesize_ == sizeof(value_type) * count) { return; } + if (data_ != nullptr) { + mr_.deallocate_sync(data_, bytesize_); + data_ = nullptr; + } + auto tmp = host_container{count, mr_}; + std::swap(tmp, *this); + } + [[nodiscard]] auto data() noexcept -> pointer { return data_; } [[nodiscard]] auto data() const noexcept -> const_pointer { return data_; } }; diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp index 535d4f47bf..b0751b28d3 100644 --- a/cpp/include/raft/core/host_mdarray.hpp +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -223,7 +224,7 @@ auto make_host_scalar(raft::resources const& res, ElementType const& v) using policy_t = typename host_scalar::container_policy_type; policy_t policy; auto scalar = host_scalar{res, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(res)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/managed_mdarray.hpp b/cpp/include/raft/core/managed_mdarray.hpp index 21db4b52aa..b52d7c6eba 100644 --- a/cpp/include/raft/core/managed_mdarray.hpp +++ b/cpp/include/raft/core/managed_mdarray.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -117,7 +118,7 @@ auto make_managed_scalar(raft::resources const& handle, ElementType const& v) using policy_t = typename managed_scalar::container_policy_type; policy_t policy{}; auto scalar = managed_scalar{handle, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(handle)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/pinned_mdarray.hpp b/cpp/include/raft/core/pinned_mdarray.hpp index f01f00f897..3f1ae81244 100644 --- a/cpp/include/raft/core/pinned_mdarray.hpp +++ b/cpp/include/raft/core/pinned_mdarray.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -117,7 +118,7 @@ auto make_pinned_scalar(raft::resources const& handle, ElementType const& v) using policy_t = typename pinned_scalar::container_policy_type; policy_t policy{}; auto scalar = pinned_scalar{handle, extents, policy}; - scalar(0) = v; + if (!resource::get_dry_run_flag(handle)) { scalar(0) = v; } return scalar; } diff --git a/cpp/include/raft/core/resource/cuda_stream.hpp b/cpp/include/raft/core/resource/cuda_stream.hpp index 690bd610f9..454082d7c3 100644 --- a/cpp/include/raft/core/resource/cuda_stream.hpp +++ b/cpp/include/raft/core/resource/cuda_stream.hpp @@ -1,10 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include #include #include #include @@ -82,13 +83,18 @@ inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_v */ inline void sync_stream(const resources& res, rmm::cuda_stream_view stream) { + if (raft::resource::get_dry_run_flag(res)) { return; } interruptible::synchronize(stream); } /** * @brief synchronize main stream on the resources instance */ -inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream(res)); } +inline void sync_stream(const resources& res) +{ + if (raft::resource::get_dry_run_flag(res)) { return; } + sync_stream(res, get_cuda_stream(res)); +} /** * @} diff --git a/cpp/include/raft/core/resource/dry_run_flag.hpp b/cpp/include/raft/core/resource/dry_run_flag.hpp new file mode 100644 index 0000000000..4d0c9e27b5 --- /dev/null +++ b/cpp/include/raft/core/resource/dry_run_flag.hpp @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +#include + +namespace raft::resource { + +/** + * @defgroup dry_run_flag Dry-run flag resource + * @{ + */ + +/** + * @brief Resource that holds a boolean dry-run flag. + * + * When the dry-run flag is set, algorithms should skip kernel execution + * and only perform allocations to measure memory usage. + */ +class dry_run_flag_resource : public resource { + public: + dry_run_flag_resource() = default; + explicit dry_run_flag_resource(bool value) : flag_(value) {} + ~dry_run_flag_resource() override = default; + + auto get_resource() -> void* override { return &flag_; } + + void set(bool value) { flag_ = value; } + [[nodiscard]] auto get() const -> bool { return flag_; } + + private: + bool flag_{false}; +}; + +/** + * @brief Factory that creates a dry_run_flag_resource. + */ +class dry_run_flag_resource_factory : public resource_factory { + public: + explicit dry_run_flag_resource_factory(bool initial_value = false) : initial_value_(initial_value) + { + } + + auto get_resource_type() -> resource_type override { return resource_type::DRY_RUN_FLAG; } + auto make_resource() -> resource* override { return new dry_run_flag_resource(initial_value_); } + + private: + bool initial_value_; +}; + +/** + * @brief Get the dry-run flag from a resources handle. + * + * @param res raft resources object + * @return true if dry-run mode is active + */ +inline auto get_dry_run_flag(resources const& res) -> bool +{ + if (!res.has_resource_factory(resource_type::DRY_RUN_FLAG)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::DRY_RUN_FLAG); +} + +/** + * @brief Set the dry-run flag on a resources handle. + * + * @param res raft resources object + * @param value true to enable dry-run mode, false to disable + */ +inline void set_dry_run_flag(resources const& res, bool value) +{ + if (!res.has_resource_factory(resource_type::DRY_RUN_FLAG)) { + res.add_resource_factory(std::make_shared(value)); + } else { + // The resource may already be instantiated; update it directly + auto* flag = res.get_resource(resource_type::DRY_RUN_FLAG); + *flag = value; + } +} + +/** @} */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index cda3c8ecae..105adc4018 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -40,6 +40,7 @@ enum resource_type { MULTI_GPU, // resource that tracks resource of each device in multi-gpu world PINNED_MEMORY_RESOURCE, // memory resource for pinned (page-locked) host allocations MANAGED_MEMORY_RESOURCE, // resource for managed (unified) allocations + DRY_RUN_FLAG, // dry-run mode flag for allocation profiling LAST_KEY // reserved for the last key }; diff --git a/cpp/include/raft/core/sparse_types.hpp b/cpp/include/raft/core/sparse_types.hpp index c0de7ca673..acdf14cf2a 100644 --- a/cpp/include/raft/core/sparse_types.hpp +++ b/cpp/include/raft/core/sparse_types.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -177,7 +177,7 @@ class sparse_matrix { ~sparse_matrix() noexcept(std::is_nothrow_destructible::value) = default; - void initialize_sparsity(nnz_type nnz) { c_elements_.resize(nnz); }; + void initialize_sparsity(nnz_type nnz) { c_elements_.reallocate(nnz); }; raft::span get_elements() { diff --git a/cpp/include/raft/label/classlabels.cuh b/cpp/include/raft/label/classlabels.cuh index 6e299182da..66a3af3b52 100644 --- a/cpp/include/raft/label/classlabels.cuh +++ b/cpp/include/raft/label/classlabels.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2022, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __CLASS_LABELS_H @@ -7,11 +7,37 @@ #pragma once +#include +#include +#include #include namespace raft { namespace label { +/** + * Get unique class labels. + * + * The y array is assumed to store class labels. The unique values are selected + * from this array. + * + * @tparam value_t numeric type of the arrays with class labels + * @param [in] handle raft resources handle (dry-run aware) + * @param [inout] unique output unique labels + * @param [in] y device array of labels, size [n] + * @param [in] n number of labels + * @returns number of unique labels (upper bound in dry-run mode) + */ +template +int getUniquelabels(raft::resources const& handle, + rmm::device_uvector& unique, + value_t* y, + size_t n) +{ + return detail::getUniquelabels( + resource::get_dry_run_flag(handle), unique, y, n, resource::get_cuda_stream(handle)); +} + /** * Get unique class labels. * diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index 8b0a296eb3..ab66d82a32 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -29,15 +29,17 @@ namespace detail { * from this array. * * \tparam value_t numeric type of the arrays with class labels - * \param [in] y device array of labels, size [n] - * \param [in] n number of labels + * \param [in] dry_run if true, perform allocations but skip CUDA work * \param [out] unique device array of unique labels, unallocated on entry, * on exit it has size [n_unique] - * \param [out] n_unique number of unique labels + * \param [in] y device array of labels, size [n] + * \param [in] n number of labels * \param [in] stream cuda stream + * \return number of unique labels (upper bound when dry_run is true) */ template -int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, cudaStream_t stream) +int getUniquelabels( + bool dry_run, rmm::device_uvector& unique, value_t* y, size_t n, cudaStream_t stream) { rmm::device_scalar d_num_selected(stream); rmm::device_uvector workspace(n, stream); @@ -53,6 +55,11 @@ int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, bytes = std::max(bytes, bytes2); rmm::device_uvector cub_storage(bytes, stream); + if (dry_run) { + if (unique.size() < n) { unique = rmm::device_uvector(n, stream); } + return static_cast(n); + } + // Select Unique classes cub::DeviceRadixSort::SortKeys( cub_storage.data(), bytes, y, workspace.data(), n, 0, sizeof(value_t) * 8, stream); @@ -72,6 +79,26 @@ int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, return n_unique; } +/** + * Get unique class labels. + * + * The y array is assumed to store class labels. The unique values are selected + * from this array. + * + * \tparam value_t numeric type of the arrays with class labels + * \param [out] unique device array of unique labels, unallocated on entry, + * on exit it has size [n_unique] + * \param [in] y device array of labels, size [n] + * \param [in] n number of labels + * \param [in] stream cuda stream + * \return number of unique labels + */ +template +int getUniquelabels(rmm::device_uvector& unique, value_t* y, size_t n, cudaStream_t stream) +{ + return getUniquelabels(false, unique, y, n, stream); +} + /** * Assign one versus rest labels. * diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index b1953470b0..c0e086f43f 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __ADD_H @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace raft { @@ -102,6 +103,7 @@ template > void add(raft::resources const& handle, InType in1, InType in2, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -139,6 +141,7 @@ void add_scalar(raft::resources const& handle, OutType out, raft::device_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -174,6 +177,7 @@ void add_scalar(raft::resources const& handle, OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 3ed5ed7736..ca6548f28b 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __COALESCED_REDUCTION_H @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace raft { @@ -62,7 +63,7 @@ void coalescedReduction(OutType* dots, FinalLambda final_op = raft::identity_op()) { detail::coalescedReduction( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + false, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } /** @@ -120,30 +121,32 @@ void coalesced_reduction(raft::resources const& handle, RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), "Output should be equal to number of rows in Input"); - coalescedReduction(dots.data_handle(), - data.data_handle(), - data.extent(1), - data.extent(0), - init, - resource::get_cuda_stream(handle), - inplace, - main_op, - reduce_op, - final_op); + detail::coalescedReduction(resource::get_dry_run_flag(handle), + dots.data_handle(), + data.data_handle(), + data.extent(1), + data.extent(0), + init, + resource::get_cuda_stream(handle), + inplace, + main_op, + reduce_op, + final_op); } else if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), "Output should be equal to number of columns in Input"); - coalescedReduction(dots.data_handle(), - data.data_handle(), - data.extent(0), - data.extent(1), - init, - resource::get_cuda_stream(handle), - inplace, - main_op, - reduce_op, - final_op); + detail::coalescedReduction(resource::get_dry_run_flag(handle), + dots.data_handle(), + data.data_handle(), + data.extent(0), + data.extent(1), + init, + resource::get_cuda_stream(handle), + inplace, + main_op, + reduce_op, + final_op); } } diff --git a/cpp/include/raft/linalg/detail/axpy.cuh b/cpp/include/raft/linalg/detail/axpy.cuh index 1ab690937d..40634b6428 100644 --- a/cpp/include/raft/linalg/detail/axpy.cuh +++ b/cpp/include/raft/linalg/detail/axpy.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,7 @@ #include "cublas_wrappers.hpp" #include +#include #include #include @@ -24,6 +25,7 @@ void axpy(raft::resources const& handle, const int incy, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } auto cublas_h = resource::get_cublas_handle(handle); cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(cublasaxpy(cublas_h, n, alpha, x, incx, y, incy, stream)); diff --git a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh index b05449f90a..d997377d54 100644 --- a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -53,6 +54,7 @@ void choleskyRank1Update(raft::resources const& handle, *n_bytes = offset + 1 * sizeof(math_t); return; } + if (resource::get_dry_run_flag(handle)) { return; } math_t* s = reinterpret_cast(((char*)workspace) + offset); math_t* L_22 = L + (n - 1) * ld + n - 1; diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index 2d513b433d..f44aa48cfb 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -498,7 +498,8 @@ template -void coalescedReductionThick(OutType* dots, +void coalescedReductionThick(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -517,6 +518,8 @@ void coalescedReductionThick(OutType* dots, rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); + if (dry_run) { return; } + /* We apply a two-step reduction: * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It * applies the main_op but not the final op. @@ -550,7 +553,8 @@ template -void coalescedReductionThickDispatcher(OutType* dots, +void coalescedReductionThickDispatcher(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -564,7 +568,7 @@ void coalescedReductionThickDispatcher(OutType* dots, // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + dry_run, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } // Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along @@ -579,7 +583,8 @@ template -void coalescedReduction(OutType* dots, +void coalescedReduction(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -600,12 +605,16 @@ void coalescedReduction(OutType* dots, */ const IdxType numSMs = raft::getMultiProcessorCount(); if (D <= IdxType(512) || (N >= IdxType(16) * numSMs && D < IdxType(2048))) { + if (dry_run) { return; } coalescedReductionThinDispatcher( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (N < numSMs && D >= IdxType(1 << 17)) { + // Must call through to coalescedReductionThick even in dry-run so workspace + // allocations are recorded (coalescedReductionThick allocates before guarding). coalescedReductionThickDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + dry_run, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { + if (dry_run) { return; } coalescedReductionMediumDispatcher( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } diff --git a/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp b/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp index 469780ba1f..3ffa4ded84 100644 --- a/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cublaslt_wrappers.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -282,6 +283,8 @@ template batch_scope( "linalg::matmul(m = %d, n = %d, k = %d)", m, n, k); std::shared_ptr mm_desc{nullptr}; diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index 5b64add128..d8d31fc411 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -44,9 +45,13 @@ void eigDC_legacy(raft::resources const& handle, eig_vals, &lwork)); + // TODO(achirkin): Consider using the workspace resource for these temporary allocations. rmm::device_uvector d_work(lwork, stream); rmm::device_scalar d_dev_info(stream); + // The workspace is already allocated, no more allocation are foreseeable. + if (resource::get_dry_run_flag(handle)) { return; } + raft::matrix::copy(handle, make_device_matrix_view(in, n_rows, n_cols), make_device_matrix_view(eig_vectors, n_rows, n_cols)); @@ -115,6 +120,12 @@ void eigDC(raft::resources const& handle, rmm::device_scalar d_dev_info(stream_new); std::vector h_work(workspaceHost / sizeof(math_t)); + if (resource::get_dry_run_flag(handle)) { + // No more allocations beyond this points, but need to cleanup. + RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params)); + return; + } + raft::copy(eig_vectors, in, n_rows * n_cols, stream_new); RAFT_CUSOLVER_TRY(cusolverDnxsyevd(cusolverH, @@ -181,7 +192,9 @@ void eigSelDC(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar d_dev_info(stream); - rmm::device_uvector d_eig_vectors(0, stream); + rmm::device_uvector d_eig_vectors(memUsage == COPY_INPUT ? n_rows * n_cols : 0, stream); + + if (resource::get_dry_run_flag(handle)) { return; } if (memUsage == OVERWRITE_INPUT) { RAFT_CUSOLVER_TRY(cusolverDnsyevdx(cusolverH, @@ -202,7 +215,6 @@ void eigSelDC(raft::resources const& handle, d_dev_info.data(), stream)); } else if (memUsage == COPY_INPUT) { - d_eig_vectors.resize(n_rows * n_cols, stream); raft::matrix::copy(handle, make_device_matrix_view(in, n_rows, n_cols), make_device_matrix_view(eig_vectors, n_rows, n_cols)); @@ -279,6 +291,12 @@ void eigJacobi(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar dev_info(stream); + if (resource::get_dry_run_flag(handle)) { + // No more allocations beyond this points, but need to cleanup. + RAFT_CUSOLVER_TRY(cusolverDnDestroySyevjInfo(syevj_params)); + return; + } + raft::matrix::copy(handle, make_device_matrix_view(in, n_rows, n_cols), make_device_matrix_view(eig_vectors, n_rows, n_cols)); diff --git a/cpp/include/raft/linalg/detail/gemv.hpp b/cpp/include/raft/linalg/detail/gemv.hpp index 3233940a66..905ecab0c5 100644 --- a/cpp/include/raft/linalg/detail/gemv.hpp +++ b/cpp/include/raft/linalg/detail/gemv.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,7 @@ #include "cublas_wrappers.hpp" #include +#include #include #include @@ -31,6 +32,7 @@ void gemv(raft::resources const& handle, const int incy, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } cublasHandle_t cublas_h = resource::get_cublas_handle(handle); detail::cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(detail::cublasgemv(cublas_h, @@ -109,6 +111,7 @@ void gemv(raft::resources const& handle, const math_t beta, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } cublasHandle_t cublas_h = resource::get_cublas_handle(handle); cublasOperation_t op_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; RAFT_CUBLAS_TRY( diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index 176c7763ba..1d37119a45 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -130,6 +131,9 @@ void lstsqSvdQR(raft::resources const& handle, + 1 // devInfo , stream); + + if (resource::get_dry_run_flag(handle)) { return; } + math_t* cusolverWorkSet = workset.data(); math_t* U = cusolverWorkSet + cusolverWorkSetSize; math_t* Vt = U + n_rows * minmn; @@ -204,6 +208,12 @@ void lstsqSvdJacobi(raft::resources const& handle, + 1 // devInfo , stream); + + if (resource::get_dry_run_flag(handle)) { + RAFT_CUSOLVER_TRY(cusolverDnDestroyGesvdjInfo(gesvdj_params)); + return; + } + math_t* cusolverWorkSet = workset.data(); math_t* U = cusolverWorkSet + cusolverWorkSetSize; math_t* V = U + n_rows * minmn; @@ -248,21 +258,27 @@ void lstsqEig(raft::resources const& handle, { rmm::cuda_stream_view mainStream = rmm::cuda_stream_view(stream); rmm::cuda_stream_view multAbStream = resource::get_next_usable_stream(handle); + bool dry_run = resource::get_dry_run_flag(handle); bool concurrent; - // Check if the two streams can run concurrently. This is needed because a legacy default stream - // would synchronize with other blocking streams. To avoid synchronization in such case, we try to - // use an additional stream from the pool. - if (!are_implicitly_synchronized(mainStream, multAbStream)) { - concurrent = true; - } else if (resource::get_stream_pool_size(handle) > 1) { - mainStream = resource::get_next_usable_stream(handle); - concurrent = true; + if (dry_run) { + concurrent = false; } else { - multAbStream = mainStream; - concurrent = false; + // Check if the two streams can run concurrently. This is needed because a legacy default stream + // would synchronize with other blocking streams. To avoid synchronization in such case, we try + // to use an additional stream from the pool. + if (!are_implicitly_synchronized(mainStream, multAbStream)) { + concurrent = true; + } else if (resource::get_stream_pool_size(handle) > 1) { + mainStream = resource::get_next_usable_stream(handle); + concurrent = true; + } else { + multAbStream = mainStream; + concurrent = false; + } } rmm::device_uvector workset(n_cols * n_cols * 3 + n_cols * 2, mainStream); + // the event is created only if the given raft handle is capable of running // at least two CUDA streams without implicit synchronization. DeviceEvent worksetDone(concurrent); @@ -302,8 +318,8 @@ void lstsqEig(raft::resources const& handle, raft::common::nvtx::pop_range(); // QS <- Q invS - raft::linalg::matrixVectorOp( - QS, Q, S, n_cols, n_cols, DivideByNonZero(), mainStream); + raft::linalg::detail::matrixVectorOp( + dry_run, QS, Q, S, n_cols, n_cols, DivideByNonZero(), mainStream); // covA <- QS Q* == Q invS Q* == inv(A* A) raft::linalg::gemm(handle, QS, @@ -392,6 +408,8 @@ void lstsqQR(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); + if (resource::get_dry_run_flag(handle)) { return; } + // #TODO: Call from public API when ready RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngeqrf( cusolverH, m, n, A, lda, d_tau.data(), d_work.data(), lwork, d_info.data(), stream)); diff --git a/cpp/include/raft/linalg/detail/map.cuh b/cpp/include/raft/linalg/detail/map.cuh index 3153de5396..5678f8e39b 100644 --- a/cpp/include/raft/linalg/detail/map.cuh +++ b/cpp/include/raft/linalg/detail/map.cuh @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -206,6 +207,7 @@ template > void map(const raft::resources& res, OutType out, Func f, InTypes... ins) { + if (resource::get_dry_run_flag(res)) { return; } RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); (map_check_shape(out, ins), ...); diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index 64de01a3fe..3275410bac 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -1,11 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include #include namespace raft { @@ -19,7 +20,8 @@ template -void matrixVectorOp(MatT* out, +void matrixVectorOp(bool dry_run, + MatT* out, const MatT* matrix, const VecT* vec, IdxType D, @@ -27,6 +29,7 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { + if (dry_run) { return; } raft::resources handle; resource::set_cuda_stream(handle, stream); constexpr raft::Apply apply = @@ -56,7 +59,8 @@ template -void matrixVectorOp(MatT* out, +void matrixVectorOp(bool dry_run, + MatT* out, const MatT* matrix, const Vec1T* vec1, const Vec2T* vec2, @@ -65,6 +69,7 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { + if (dry_run) { return; } raft::resources handle; resource::set_cuda_stream(handle, stream); constexpr raft::Apply apply = diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index ea7f5c8d28..549ecda0f5 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -19,18 +19,23 @@ template -void rowNormCaller( - OutType* dots, const Type* data, IdxType D, IdxType N, cudaStream_t stream, Lambda fin_op) +void rowNormCaller(bool dry_run, + OutType* dots, + const Type* data, + IdxType D, + IdxType N, + cudaStream_t stream, + Lambda fin_op) { if constexpr (norm_type == L1Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == L2Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == LinfNorm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); } else { THROW("Unsupported norm type: %d", norm_type); } @@ -42,18 +47,23 @@ template -void colNormCaller( - OutType* dots, const Type* data, IdxType D, IdxType N, cudaStream_t stream, Lambda fin_op) +void colNormCaller(bool dry_run, + OutType* dots, + const Type* data, + IdxType D, + IdxType N, + cudaStream_t stream, + Lambda fin_op) { if constexpr (norm_type == L1Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == L2Norm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); + reduce( + dry_run, dots, data, D, N, (OutType)0, stream, false, raft::sq_op(), raft::add_op(), fin_op); } else if constexpr (norm_type == LinfNorm) { - raft::linalg::reduce( - dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); + reduce( + false, dots, data, D, N, (OutType)0, stream, false, raft::abs_op(), raft::max_op(), fin_op); } else { THROW("Unsupported norm type: %d", norm_type); } diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 63cba5d73c..14b453203c 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,6 +9,7 @@ #include "cusolver_wrappers.hpp" #include +#include #include #include @@ -39,15 +40,26 @@ void qrGetQ_inplace( { RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols."); cusolverDnHandle_t cusolver = resource::get_cusolver_dn_handle(handle); + auto is_dry_run = resource::get_dry_run_flag(handle); rmm::device_uvector tau(n_cols, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream)); + } rmm::device_scalar dev_info(stream); - int ws_size; + int ws_size_Dngeqrf; + int ws_size_Dnorgqr; + + RAFT_CUSOLVER_TRY( + cusolverDngeqrf_bufferSize(cusolver, n_rows, n_cols, Q, n_rows, &ws_size_Dngeqrf)); + RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize( + cusolver, n_rows, n_cols, n_cols, Q, n_rows, tau.data(), &ws_size_Dnorgqr)); + + rmm::device_uvector workspace(std::max(ws_size_Dngeqrf, ws_size_Dnorgqr), stream); + + if (is_dry_run) { return; } - RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize(cusolver, n_rows, n_cols, Q, n_rows, &ws_size)); - rmm::device_uvector workspace(ws_size, stream); RAFT_CUSOLVER_TRY(cusolverDngeqrf(cusolver, n_rows, n_cols, @@ -55,13 +67,10 @@ void qrGetQ_inplace( n_rows, tau.data(), workspace.data(), - ws_size, + ws_size_Dngeqrf, dev_info.data(), stream)); - RAFT_CUSOLVER_TRY( - cusolverDnorgqr_bufferSize(cusolver, n_rows, n_cols, n_cols, Q, n_rows, tau.data(), &ws_size)); - workspace.resize(ws_size, stream); RAFT_CUSOLVER_TRY(cusolverDnorgqr(cusolver, n_rows, n_cols, @@ -70,7 +79,7 @@ void qrGetQ_inplace( n_rows, tau.data(), workspace.data(), - ws_size, + ws_size_Dnorgqr, dev_info.data(), stream)); } @@ -83,7 +92,7 @@ void qrGetQ(raft::resources const& handle, int n_cols, cudaStream_t stream) { - raft::copy(Q, M, n_rows * n_cols, stream); + if (!resource::get_dry_run_flag(handle)) { raft::copy(Q, M, n_rows * n_cols, stream); } qrGetQ_inplace(handle, Q, n_rows, n_cols, stream); } @@ -99,19 +108,32 @@ void qrGetQR(raft::resources const& handle, cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int m = n_rows, n = n_cols; + int R_full_nrows = m, R_full_ncols = n; + int Q_nrows = m, Q_ncols = n; + int Lwork_Dngeqrf, Lwork_Dnorgqr; rmm::device_uvector R_full(m * n, stream); rmm::device_uvector tau(std::min(m, n), stream); + rmm::device_scalar devInfo(stream); + + RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize( + cusolverH, R_full_nrows, R_full_ncols, R_full.data(), R_full_nrows, &Lwork_Dngeqrf)); + RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize(cusolverH, + Q_nrows, + Q_ncols, + std::min(Q_ncols, Q_nrows), + Q, + Q_nrows, + tau.data(), + &Lwork_Dnorgqr)); + + rmm::device_uvector workspace(std::max(Lwork_Dngeqrf, Lwork_Dnorgqr), stream); + + if (resource::get_dry_run_flag(handle)) { return; } + RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * std::min(m, n), stream)); - int R_full_nrows = m, R_full_ncols = n; RAFT_CUDA_TRY( cudaMemcpyAsync(R_full.data(), M, sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); - int Lwork; - rmm::device_scalar devInfo(stream); - - RAFT_CUSOLVER_TRY(cusolverDngeqrf_bufferSize( - cusolverH, R_full_nrows, R_full_ncols, R_full.data(), R_full_nrows, &Lwork)); - rmm::device_uvector workspace(Lwork, stream); RAFT_CUSOLVER_TRY(cusolverDngeqrf(cusolverH, R_full_nrows, R_full_ncols, @@ -119,7 +141,7 @@ void qrGetQR(raft::resources const& handle, R_full_nrows, tau.data(), workspace.data(), - Lwork, + Lwork_Dngeqrf, devInfo.data(), stream)); @@ -130,11 +152,7 @@ void qrGetQR(raft::resources const& handle, RAFT_CUDA_TRY( cudaMemcpyAsync(Q, R_full.data(), sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); - int Q_nrows = m, Q_ncols = n; - RAFT_CUSOLVER_TRY(cusolverDnorgqr_bufferSize( - cusolverH, Q_nrows, Q_ncols, std::min(Q_ncols, Q_nrows), Q, Q_nrows, tau.data(), &Lwork)); - workspace.resize(Lwork, stream); RAFT_CUSOLVER_TRY(cusolverDnorgqr(cusolverH, Q_nrows, Q_ncols, @@ -143,7 +161,7 @@ void qrGetQR(raft::resources const& handle, Q_nrows, tau.data(), workspace.data(), - Lwork, + Lwork_Dnorgqr, devInfo.data(), stream)); } diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index 4d90e32e99..a9ea95ca28 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -21,7 +21,8 @@ template -void reduce(OutType* dots, +void reduce(bool dry_run, + OutType* dots, const InType* data, IdxType D, IdxType N, @@ -33,17 +34,19 @@ void reduce(OutType* dots, FinalLambda final_op = raft::identity_op()) { if constexpr (rowMajor && alongRows) { - raft::linalg::coalescedReduction( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + coalescedReduction( + dry_run, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if constexpr (rowMajor && !alongRows) { + if (dry_run) { return; } // no allocations in strided reduction raft::linalg::stridedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if constexpr (!rowMajor && alongRows) { + if (dry_run) { return; } // no allocations in strided reduction raft::linalg::stridedReduction( dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); } else { - raft::linalg::coalescedReduction( - dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); + coalescedReduction( + dry_run, dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); } } diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 9dcdd1ed14..85b3a0cbcc 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -85,6 +86,8 @@ void randomized_svd(const raft::resources& handle, auto h_workspace = raft::make_host_vector(workspaceHost); auto devInfo = raft::make_device_scalar(handle, 0); + if (resource::get_dry_run_flag(handle)) { return; } + RAFT_CUSOLVER_TRY(cusolverDnxgesvdr(cusolverH, jobu, jobv, @@ -154,6 +157,7 @@ void rsvdFixedRank(raft::resources const& handle, int max_sweeps, cudaStream_t stream) { + bool is_dry_run = resource::get_dry_run_flag(handle); cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); cublasHandle_t cublasH = resource::get_cublas_handle(handle); @@ -171,7 +175,9 @@ void rsvdFixedRank(raft::resources const& handle, // Build temporary U, S, V matrices rmm::device_uvector S_vec_tmp(l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(S_vec_tmp.data(), 0, sizeof(math_t) * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(S_vec_tmp.data(), 0, sizeof(math_t) * l, stream)); + } // build random matrix rmm::device_uvector RN(n * l, stream); @@ -187,9 +193,11 @@ void rsvdFixedRank(raft::resources const& handle, rmm::device_uvector Z(n * l, stream); rmm::device_uvector Yorth(m * l, stream); rmm::device_uvector Zorth(n * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Z.data(), 0, sizeof(math_t) * n * l, stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(Yorth.data(), 0, sizeof(math_t) * m * l, stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(Zorth.data(), 0, sizeof(math_t) * n * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Z.data(), 0, sizeof(math_t) * n * l, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(Yorth.data(), 0, sizeof(math_t) * m * l, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(Zorth.data(), 0, sizeof(math_t) * n * l, stream)); + } // power sampling scheme for (int j = 1; j < q; j++) { @@ -236,30 +244,40 @@ void rsvdFixedRank(raft::resources const& handle, // orthogonalize on exit from loop to get Q rmm::device_uvector Q(m * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Q.data(), 0, sizeof(math_t) * m * l, stream)); + if (!is_dry_run) { RAFT_CUDA_TRY(cudaMemsetAsync(Q.data(), 0, sizeof(math_t) * m * l, stream)); } raft::linalg::qrGetQ(handle, Y.data(), Q.data(), m, l, stream); // either QR of B^T method, or eigendecompose BB^T method if (!use_bbt) { // form Bt = Mt*Q : nxm * mxl = nxl rmm::device_uvector Bt(n * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Bt.data(), 0, sizeof(math_t) * n * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Bt.data(), 0, sizeof(math_t) * n * l, stream)); + } raft::linalg::gemm( handle, M, m, n, Q.data(), Bt.data(), n, l, CUBLAS_OP_T, CUBLAS_OP_N, alpha, beta, stream); // compute QR factorization of Bt // M is mxn ; Q is mxn ; R is min(m,n) x min(m,n) */ rmm::device_uvector Qhat(n * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Qhat.data(), 0, sizeof(math_t) * n * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Qhat.data(), 0, sizeof(math_t) * n * l, stream)); + } rmm::device_uvector Rhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Rhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Rhat.data(), 0, sizeof(math_t) * l * l, stream)); + } raft::linalg::qrGetQR(handle, Bt.data(), Qhat.data(), Rhat.data(), n, l, stream); // compute SVD of Rhat (lxl) rmm::device_uvector Uhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + } rmm::device_uvector Vhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Vhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Vhat.data(), 0, sizeof(math_t) * l * l, stream)); + } if (use_jacobi) raft::linalg::svdJacobi(handle, Rhat.data(), @@ -350,9 +368,13 @@ void rsvdFixedRank(raft::resources const& handle, // compute eigendecomposition of BBt rmm::device_uvector Uhat(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); + } rmm::device_uvector Uhat_dup(l * l, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Uhat_dup.data(), 0, sizeof(math_t) * l * l, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Uhat_dup.data(), 0, sizeof(math_t) * l * l, stream)); + } raft::matrix::upper_triangular( handle, @@ -397,9 +419,13 @@ void rsvdFixedRank(raft::resources const& handle, // Sigma^{-1}[(p+1):l, (p+1):l] nxl * lxk * kxk = nxk if (gen_right_vec) { rmm::device_uvector Sinv(k * k, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(Sinv.data(), 0, sizeof(math_t) * k * k, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(Sinv.data(), 0, sizeof(math_t) * k * k, stream)); + } rmm::device_uvector UhatSinv(l * k, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(UhatSinv.data(), 0, sizeof(math_t) * l * k, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(UhatSinv.data(), 0, sizeof(math_t) * l * k, stream)); + } math_t scalar = 1.0; raft::matrix::reciprocal( handle, diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index ba831822d7..d4100cf473 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -59,6 +60,8 @@ void svdQR(raft::resources const& handle, RAFT_CUSOLVER_TRY(cusolverDngesvd_bufferSize(cusolverH, n_rows, n_cols, &lwork)); rmm::device_uvector d_work(lwork, stream); + if (resource::get_dry_run_flag(handle)) { return; } + char jobu = 'S'; char jobvt = 'A'; @@ -216,6 +219,11 @@ void svdJacobi(raft::resources const& handle, rmm::device_uvector d_work(lwork, stream); + if (resource::get_dry_run_flag(handle)) { + RAFT_CUSOLVER_TRY(cusolverDnDestroyGesvdjInfo(gesvdj_params)); + return; + } + RAFT_CUSOLVER_TRY(cusolverDngesvdj(cusolverH, CUSOLVER_EIG_MODE_VECTOR, econ, @@ -280,16 +288,19 @@ bool evaluateSVDByL2Norm(raft::resources const& handle, math_t tol, cudaStream_t stream) { - cublasHandle_t cublasH = resource::get_cublas_handle(handle); - int m = n_rows, n = n_cols; + bool is_dry_run = resource::get_dry_run_flag(handle); // form product matrix rmm::device_uvector P_d(m * n, stream); rmm::device_uvector S_mat(k * k, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(P_d.data(), 0, sizeof(math_t) * m * n, stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(S_mat.data(), 0, sizeof(math_t) * k * k, stream)); + if (!is_dry_run) { + RAFT_CUDA_TRY(cudaMemsetAsync(P_d.data(), 0, sizeof(math_t) * m * n, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(S_mat.data(), 0, sizeof(math_t) * k * k, stream)); + } + + // These RAFT functions have their own dry-run guards at the leaf level raft::matrix::set_diagonal(handle, make_device_vector_view(S_vec, k), make_device_matrix_view(S_mat.data(), k, k)); @@ -307,8 +318,12 @@ bool evaluateSVDByL2Norm(raft::resources const& handle, // calculate percent error const math_t alpha = 1.0, beta = -1.0; rmm::device_uvector A_minus_P(m * n, stream); + + if (is_dry_run) { return false; } + RAFT_CUDA_TRY(cudaMemsetAsync(A_minus_P.data(), 0, sizeof(math_t) * m * n, stream)); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); RAFT_CUBLAS_TRY(cublasgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 9efac50763..beda4aaa04 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -87,6 +88,7 @@ void transpose_half(raft::resources const& handle, const IndexType stride_out = 1) { if (n_cols == 0 || n_rows == 0) return; + if (resource::get_dry_run_flag(handle)) { return; } auto stream = resource::get_cuda_stream(handle); int dev_id, sm_count; @@ -134,6 +136,7 @@ void transpose(raft::resources const& handle, int n_cols, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return; } int out_n_rows = n_cols; int out_n_cols = n_rows; @@ -188,6 +191,7 @@ void transpose_row_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { + if (resource::get_dry_run_flag(handle)) { return; } auto out_n_rows = in.extent(1); auto out_n_cols = in.extent(0); T constexpr kOne = 1; @@ -230,6 +234,7 @@ void transpose_col_major_impl( raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { + if (resource::get_dry_run_flag(handle)) { return; } auto out_n_rows = in.extent(1); auto out_n_cols = in.extent(0); T constexpr kOne = 1; diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 69600f016c..0a64b8db55 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __DIVIDE_H @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -61,6 +62,7 @@ void divide_scalar(raft::resources const& handle, OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index af40c07459..086633745b 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __DOT_H @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -40,6 +41,7 @@ void dot(raft::resources const& handle, { RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); + if (resource::get_dry_run_flag(handle)) { return; } RAFT_CUBLAS_TRY(detail::cublasdot(resource::get_cublas_handle(handle), x.size(), @@ -70,6 +72,7 @@ void dot(raft::resources const& handle, { RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); + if (resource::get_dry_run_flag(handle)) { return; } RAFT_CUBLAS_TRY(detail::cublasdot(resource::get_cublas_handle(handle), x.size(), diff --git a/cpp/include/raft/linalg/map_reduce.cuh b/cpp/include/raft/linalg/map_reduce.cuh index e5176dda01..3c206bc11b 100644 --- a/cpp/include/raft/linalg/map_reduce.cuh +++ b/cpp/include/raft/linalg/map_reduce.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __MAP_REDUCE_H @@ -11,6 +11,7 @@ #include #include +#include namespace raft::linalg { @@ -89,6 +90,7 @@ void map_reduce(raft::resources const& handle, ReduceLambda op, Args... args) { + if (resource::get_dry_run_flag(handle)) { return; } mapReduce( out.data_handle(), in.extent(0), diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 47a3cd9ce8..abd437ab91 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __MATRIX_VECTOR_OP_H @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -56,7 +57,7 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { - detail::matrixVectorOp(out, matrix, vec, D, N, op, stream); + detail::matrixVectorOp(false, out, matrix, vec, D, N, op, stream); } /** @@ -100,7 +101,8 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { - detail::matrixVectorOp(out, matrix, vec1, vec2, D, N, op, stream); + detail::matrixVectorOp( + false, out, matrix, vec1, vec2, D, N, op, stream); } /** @@ -156,13 +158,14 @@ void matrix_vector_op(raft::resources const& handle, "Size mismatch between matrix and vector"); } - matrixVectorOp(out.data_handle(), - matrix.data_handle(), - vec.data_handle(), - out.extent(1), - out.extent(0), - op, - resource::get_cuda_stream(handle)); + detail::matrixVectorOp(resource::get_dry_run_flag(handle), + out.data_handle(), + matrix.data_handle(), + vec.data_handle(), + out.extent(1), + out.extent(0), + op, + resource::get_cuda_stream(handle)); } /** @@ -221,14 +224,15 @@ void matrix_vector_op(raft::resources const& handle, "Size mismatch between matrix and vector"); } - matrixVectorOp(out.data_handle(), - matrix.data_handle(), - vec1.data_handle(), - vec2.data_handle(), - out.extent(1), - out.extent(0), - op, - resource::get_cuda_stream(handle)); + detail::matrixVectorOp(resource::get_dry_run_flag(handle), + out.data_handle(), + matrix.data_handle(), + vec1.data_handle(), + vec2.data_handle(), + out.extent(1), + out.extent(0), + op, + resource::get_cuda_stream(handle)); } /** @} */ // end of group matrix_vector_op diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index 70c04ccc6b..85ca248cf5 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __MSE_H @@ -11,6 +11,7 @@ #include #include +#include namespace raft { namespace linalg { @@ -57,6 +58,7 @@ void mean_squared_error(raft::resources const& handle, raft::device_scalar_view out, OutValueType weight) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(A.size() == B.size(), "Size mismatch between inputs"); meanSquaredError(out.data_handle(), diff --git a/cpp/include/raft/linalg/multiply.cuh b/cpp/include/raft/linalg/multiply.cuh index 22c89a5883..325918868e 100644 --- a/cpp/include/raft/linalg/multiply.cuh +++ b/cpp/include/raft/linalg/multiply.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __MULTIPLY_H @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace raft { @@ -63,6 +64,7 @@ void multiply_scalar( OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index e16fbf4353..c0839aca44 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __NORM_H @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -54,7 +55,7 @@ void rowNorm(OutType* dots, cudaStream_t stream, Lambda fin_op = raft::identity_op()) { - detail::rowNormCaller(dots, data, D, N, stream, fin_op); + detail::rowNormCaller(false, dots, data, D, N, stream, fin_op); } /** @@ -85,7 +86,7 @@ void colNorm(OutType* dots, cudaStream_t stream, Lambda fin_op = raft::identity_op()) { - detail::colNormCaller(dots, data, D, N, stream, fin_op); + detail::colNormCaller(false, dots, data, D, N, stream, fin_op); } /** @@ -128,21 +129,23 @@ void norm(raft::resources const& handle, if constexpr (along_rows) { RAFT_EXPECTS(static_cast(out.size()) == in.extent(0), "Output should be equal to number of rows in Input"); - rowNorm(out.data_handle(), - in.data_handle(), - in.extent(1), - in.extent(0), - resource::get_cuda_stream(handle), - fin_op); + detail::rowNormCaller(resource::get_dry_run_flag(handle), + out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + resource::get_cuda_stream(handle), + fin_op); } else { RAFT_EXPECTS(static_cast(out.size()) == in.extent(1), "Output should be equal to number of columns in Input"); - colNorm(out.data_handle(), - in.data_handle(), - in.extent(1), - in.extent(0), - resource::get_cuda_stream(handle), - fin_op); + detail::colNormCaller(resource::get_dry_run_flag(handle), + out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + resource::get_cuda_stream(handle), + fin_op); } } diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index 730d5aff25..86b59751f5 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -53,6 +54,7 @@ void row_normalize(raft::resources const& handle, FinalLambda fin_op, ElementType eps = ElementType(1e-8)) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(in.extent(0) == out.extent(0), diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index ae4820cda3..de6461bc83 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __POWER_H @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -74,6 +75,7 @@ template > void power(raft::resources const& handle, InType in1, InType in2, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -112,6 +114,7 @@ void power_scalar( OutType out, const raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index ce2c324f24..e3650469df 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __REDUCE_H @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -71,7 +72,7 @@ void reduce(OutType* dots, FinalLambda final_op = raft::identity_op()) { detail::reduce( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + false, dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } /** @@ -166,16 +167,18 @@ void reduce(raft::resources const& handle, "Output should be equal to number of columns in Input"); } - reduce(dots.data_handle(), - data.data_handle(), - data.extent(1), - data.extent(0), - init, - resource::get_cuda_stream(handle), - inplace, - main_op, - reduce_op, - final_op); + detail::reduce( + resource::get_dry_run_flag(handle), + dots.data_handle(), + data.data_handle(), + data.extent(1), + data.extent(0), + init, + resource::get_cuda_stream(handle), + inplace, + main_op, + reduce_op, + final_op); } /** @} */ // end of group reduction diff --git a/cpp/include/raft/linalg/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/reduce_cols_by_key.cuh index e0ac2d6544..eb90244cc3 100644 --- a/cpp/include/raft/linalg/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_cols_by_key.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __REDUCE_COLS_BY_KEY @@ -11,6 +11,7 @@ #include #include +#include #include namespace raft { @@ -81,6 +82,7 @@ void reduce_cols_by_key( IndexType nkeys = 0, bool reset_sums = true) { + if (resource::get_dry_run_flag(handle)) { return; } if (nkeys > 0) { RAFT_EXPECTS(out.extent(1) == nkeys, "Output doesn't have nkeys columns"); } else { diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index 7e7e91bcb9..685f8fb962 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __REDUCE_ROWS_BY_KEY @@ -11,6 +11,7 @@ #include #include +#include #include namespace raft { @@ -147,6 +148,7 @@ void reduce_rows_by_key( std::optional> d_weights = std::nullopt, bool reset_sums = true) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(d_A.extent(0) == d_A.extent(0) && d_sums.extent(1) == n_unique_keys, "Output is not of size ncols * n_unique_keys"); RAFT_EXPECTS(d_keys.extent(0) == d_A.extent(1), "Keys is not of size nrows"); diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index e0c232e62a..abf19e765e 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __SQRT_H @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace raft { @@ -51,6 +52,7 @@ template > void sqrt(raft::resources const& handle, InType in, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index efbd80126e..eb34a99452 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -127,6 +128,7 @@ void strided_reduction(raft::resources const& handle, ReduceLambda reduce_op = raft::add_op(), FinalLambda final_op = raft::identity_op()) { + if (resource::get_dry_run_flag(handle)) { return; } if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), "Output should be equal to number of columns in Input"); diff --git a/cpp/include/raft/linalg/subtract.cuh b/cpp/include/raft/linalg/subtract.cuh index 1aba864100..08e5f38fbe 100644 --- a/cpp/include/raft/linalg/subtract.cuh +++ b/cpp/include/raft/linalg/subtract.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace raft { @@ -98,6 +99,7 @@ template > void subtract(raft::resources const& handle, InType in1, InType in2, OutType out) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -136,6 +138,7 @@ void subtract_scalar( OutType out, raft::device_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -172,6 +175,7 @@ void subtract_scalar( OutType out, raft::host_scalar_view scalar) { + if (resource::get_dry_run_flag(handle)) { return; } using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index 69e2130adb..6cf4b3a266 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __UNARY_OP_H @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -109,6 +110,7 @@ template > void write_only_unary_op(const raft::resources& handle, OutType out, Lambda op) { + if (resource::get_dry_run_flag(handle)) { return; } return writeOnlyUnaryOp(out.data_handle(), out.size(), op, resource::get_cuda_stream(handle)); } diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index 36a8999b64..caa477fa8e 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,6 +7,7 @@ #include #include +#include #include namespace raft::matrix { @@ -27,6 +28,7 @@ void argmax(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); detail::argmax(in.data_handle(), diff --git a/cpp/include/raft/matrix/argmin.cuh b/cpp/include/raft/matrix/argmin.cuh index a168d3969a..9531b6a426 100644 --- a/cpp/include/raft/matrix/argmin.cuh +++ b/cpp/include/raft/matrix/argmin.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,6 +7,7 @@ #include #include +#include #include namespace raft::matrix { @@ -27,6 +28,7 @@ void argmin(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); detail::argmin(in.data_handle(), diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 0347797a4c..7e5d95f3eb 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #ifndef __COL_WISE_SORT_H @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace raft::matrix { @@ -38,8 +39,16 @@ void sort_cols_per_row(const InType* in, cudaStream_t stream, InType* sortedKeys = nullptr) { - detail::sortColumnsPerRow( - in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); + detail::sortColumnsPerRow(false, + in, + out, + n_rows, + n_columns, + bAllocWorkspace, + workspacePtr, + workspaceSize, + stream, + sortedKeys); } /** @@ -78,12 +87,14 @@ void sort_cols_per_row(raft::resources const& handle, "Input and `sorted_keys` matrices must have the same shape."); } + bool dry_run = resource::get_dry_run_flag(handle); size_t workspace_size = 0; bool alloc_workspace = false; in_t* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; - detail::sortColumnsPerRow(in.data_handle(), + detail::sortColumnsPerRow(dry_run, + in.data_handle(), out.data_handle(), in.extent(0), in.extent(1), @@ -96,7 +107,10 @@ void sort_cols_per_row(raft::resources const& handle, if (alloc_workspace) { auto workspace = raft::make_device_vector(handle, workspace_size); - detail::sortColumnsPerRow(in.data_handle(), + if (dry_run) { return; } + + detail::sortColumnsPerRow(dry_run, + in.data_handle(), out.data_handle(), in.extent(0), in.extent(1), diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 8c3f00eca5..0aca60483a 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -34,6 +35,7 @@ void copy_rows(raft::resources const& handle, raft::device_matrix_view out, raft::device_vector_view indices) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(in.extent(1) == out.extent(1), "Input and output matrices must have same number of columns"); RAFT_EXPECTS(indices.extent(0) == out.extent(0), @@ -59,6 +61,7 @@ void copy(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); @@ -79,6 +82,7 @@ void copy(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); @@ -100,6 +104,7 @@ void trunc_zero_origin(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { + if (resource::get_dry_run_flag(handle)) { return; } RAFT_EXPECTS(out.extent(0) <= in.extent(0) && out.extent(1) <= in.extent(1), "Output matrix must have less or equal number of rows and columns"); diff --git a/cpp/include/raft/matrix/detail/columnWiseSort.cuh b/cpp/include/raft/matrix/detail/columnWiseSort.cuh index a8f654557d..a36e9ee4da 100644 --- a/cpp/include/raft/matrix/detail/columnWiseSort.cuh +++ b/cpp/include/raft/matrix/detail/columnWiseSort.cuh @@ -163,7 +163,8 @@ cudaError_t layoutSortOffset(T* in, T value, int n_times, cudaStream_t stream) * @param sortedKeys: Optional, output matrix for sorted keys (input) */ template -void sortColumnsPerRow(const InType* in, +void sortColumnsPerRow(bool dry_run, + const InType* in, OutType* out, int n_rows, int n_columns, @@ -203,6 +204,8 @@ void sortColumnsPerRow(const InType* in, // more elements per thread --> more register pressure // 512(blockSize) * 8 elements per thread = 71 register / thread + if (dry_run) { return; } + // instantiate some kernel combinations if (n_columns <= 512) INST_BLOCK_SORT(in, sortedKeys, out, n_rows, n_columns, 128, 4, stream); @@ -250,6 +253,8 @@ void sortColumnsPerRow(const InType* in, // for segment offsets workspaceSize += raft::alignTo(sizeof(int) * (size_t)numSegments, memAlignWidth); } else { + if (dry_run) { return; } + size_t workspaceOffset = 0; if (!sortedKeys) { @@ -301,6 +306,8 @@ void sortColumnsPerRow(const InType* in, workspaceSize += raft::alignTo(sizeof(OutType) * (size_t)n_columns, memAlignWidth); } else { + if (dry_run) { return; } + size_t workspaceOffset = 0; bool userKeyOutputBuffer = true; diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 20bbc4271d..d931f433f4 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -550,13 +551,15 @@ void gather(raft::resources const& res, device_vector_view indices, raft::device_matrix_view output) { + auto dry_run = resource::get_dry_run_flag(res); raft::common::nvtx::range fun_scope("gather"); IdxT n_dim = output.extent(1); IdxT n_train = output.extent(0); auto indices_host = raft::make_host_vector(n_train); - raft::copy( - indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); - resource::sync_stream(res); + if (!dry_run) { + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + } const size_t buffer_size = 32768 * 1024; // bytes const size_t max_batch_size = @@ -568,6 +571,10 @@ void gather(raft::resources const& res, auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + if (dry_run) { return; } + + resource::sync_stream(res); + // Usually a limited number of threads provide sufficient bandwidth for gathering data. #if defined(_OPENMP) int n_threads = std::min(omp_get_max_threads(), 32); diff --git a/cpp/include/raft/matrix/detail/gather_inplace.cuh b/cpp/include/raft/matrix/detail/gather_inplace.cuh index beaad13657..ac9105b1cc 100644 --- a/cpp/include/raft/matrix/detail/gather_inplace.cuh +++ b/cpp/include/raft/matrix/detail/gather_inplace.cuh @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include #include @@ -38,12 +39,14 @@ void gatherInplaceImpl(raft::resources const& handle, // re-assign batch_size for default case if (batch_size == 0 || batch_size > n) batch_size = n; + auto scratch_space = raft::make_device_vector(handle, map_length * batch_size); + + if (resource::get_dry_run_flag(handle)) { return; } + auto exec_policy = resource::get_thrust_policy(handle); IndexT n_batches = raft::ceildiv(n, batch_size); - auto scratch_space = raft::make_device_vector(handle, map_length * batch_size); - for (IndexT bid = 0; bid < n_batches; bid++) { IndexT batch_offset = bid * batch_size; IndexT cols_per_batch = min(batch_size, n - batch_offset); diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 05416d16be..9eefcf547e 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include #include @@ -186,10 +187,10 @@ template void ratio( raft::resources const& handle, const math_t* src, math_t* dest, IdxType len, cudaStream_t stream) { - auto d_src = src; - auto d_dest = dest; - rmm::device_scalar d_sum(stream); + if (resource::get_dry_run_flag(handle)) { return; } + auto d_src = src; + auto d_dest = dest; auto* d_sum_ptr = d_sum.data(); raft::linalg::mapThenSumReduce(d_sum_ptr, len, raft::identity_op{}, stream, src); raft::linalg::unaryOp( @@ -200,15 +201,16 @@ template ( - data, data, vec, n_col, n_row, raft::mul_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::mul_op(), stream); } template void matrixVectorBinaryMultSkipZero( Type* data, const Type* vec, IdxType n_row, IdxType n_col, cudaStream_t stream) { - raft::linalg::matrixVectorOp( + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, @@ -227,8 +229,8 @@ template ( - data, data, vec, n_col, n_row, raft::div_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::div_op(), stream); } template @@ -240,7 +242,8 @@ void matrixVectorBinaryDivSkipZero(Type* data, bool return_zero = false) { if (return_zero) { - raft::linalg::matrixVectorOp( + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, @@ -254,7 +257,8 @@ void matrixVectorBinaryDivSkipZero(Type* data, }, stream); } else { - raft::linalg::matrixVectorOp( + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, @@ -274,16 +278,16 @@ template ( - data, data, vec, n_col, n_row, raft::add_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::add_op(), stream); } template void matrixVectorBinarySub( Type* data, const Type* vec, IdxType n_row, IdxType n_col, cudaStream_t stream) { - raft::linalg::matrixVectorOp( - data, data, vec, n_col, n_row, raft::sub_op(), stream); + raft::linalg::detail::matrixVectorOp( + false, data, data, vec, n_col, n_row, raft::sub_op(), stream); } // Computes an argmin/argmax column-wise in a DxN matrix diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index af42e12037..9e2989bee8 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include #include @@ -296,6 +297,7 @@ void getDiagonalInverseMatrix(m_t* in, idx_t len, cudaStream_t stream) template m_t getL2Norm(raft::resources const& handle, const m_t* in, idx_t size, cudaStream_t stream) { + if (resource::get_dry_run_flag(handle)) { return m_t{0}; } cublasHandle_t cublasH = resource::get_cublas_handle(handle); m_t normval = 0; RAFT_EXPECTS( diff --git a/cpp/include/raft/matrix/detail/scatter_inplace.cuh b/cpp/include/raft/matrix/detail/scatter_inplace.cuh index 7ffa697f71..0c3ea275b7 100644 --- a/cpp/include/raft/matrix/detail/scatter_inplace.cuh +++ b/cpp/include/raft/matrix/detail/scatter_inplace.cuh @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include #include @@ -63,12 +64,14 @@ void scatterInplaceImpl( // re-assign batch_size for default case if (batch_size == 0 || batch_size > n) batch_size = n; + auto scratch_space = raft::make_device_vector(handle, m * batch_size); + + if (resource::get_dry_run_flag(handle)) { return; } + auto exec_policy = resource::get_thrust_policy(handle); IndexT n_batches = raft::ceildiv(n, batch_size); - auto scratch_space = raft::make_device_vector(handle, m * batch_size); - for (IndexT bid = 0; bid < n_batches; bid++) { IndexT batch_offset = bid * batch_size; IndexT cols_per_batch = min(batch_size, n - batch_offset); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 37411ba0bd..f7b13d9977 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -125,6 +126,8 @@ void segmented_sort_by_key(raft::resources const& handle, auto d_temp_storage = raft::make_device_mdarray( handle, mr, raft::make_extents(temp_storage_bytes)); + if (resource::get_dry_run_flag(handle)) { return; } + if (asc) { // Run sorting operation cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(), diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index a6dd7e0ce5..28eb3b411c 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -873,7 +874,8 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) } template -void radix_topk(const T* in, +void radix_topk(bool dry_run, + const T* in, const IdxT* in_idx, int batch_size, IdxT len, @@ -907,6 +909,8 @@ void radix_topk(const T* in, rmm::device_buffer bufs(max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); + if (dry_run) { return; } + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); RAFT_CUDA_TRY( @@ -1148,7 +1152,8 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // used. It's used when len is relatively small or when the number of blocks per row calculated by // `calc_grid_dim()` is 1. template -void radix_topk_one_block(const T* in, +void radix_topk_one_block(bool dry_run, + const T* in, const IdxT* in_idx, int batch_size, IdxT len, @@ -1170,6 +1175,8 @@ void radix_topk_one_block(const T* in, rmm::device_buffer bufs(max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); + if (dry_run) { return; } + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; @@ -1266,9 +1273,11 @@ void select_k(raft::resources const& res, RAFT_EXPECTS(RowLayout::is_uniform || len_i != nullptr, "CSR layout requires a non-null indptr array (len_i)!"); - auto stream = resource::get_cuda_stream(res); - auto mr = resource::get_workspace_resource_ref(res); + bool dry_run = resource::get_dry_run_flag(res); + auto stream = resource::get_cuda_stream(res); + auto mr = resource::get_workspace_resource_ref(res); if (k == len && RowLayout::is_uniform) { + if (dry_run) { return; } RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { @@ -1288,15 +1297,27 @@ void select_k(raft::resources const& res, if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); + dry_run, in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); + impl::radix_topk_one_block(dry_run, + in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + len_i, + sm_cnt, + stream, + mr); } else { - impl::radix_topk(in, + impl::radix_topk(dry_run, + in, in_idx, batch_size, len, diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index a480743664..5860de1c8b 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -1042,7 +1043,8 @@ template