-
Notifications
You must be signed in to change notification settings - Fork 231
Add Randomized SVDs #2999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Randomized SVDs #2999
Changes from 3 commits
d20ea68
307d18e
e52ad48
8a08ffc
9dcc910
a803923
5dc8079
9b44f7a
c5fd1a4
b601ea2
15b6285
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,163 @@ | ||||||||||
| /* | ||||||||||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||||||||||
| * SPDX-License-Identifier: Apache-2.0 | ||||||||||
| */ | ||||||||||
|
|
||||||||||
| #pragma once | ||||||||||
|
|
||||||||||
| #include <raft/core/resource/cublas_handle.hpp> | ||||||||||
| #include <raft/core/resource/cuda_stream.hpp> | ||||||||||
| #include <raft/core/resource/cusolver_dn_handle.hpp> | ||||||||||
| #include <raft/core/resources.hpp> | ||||||||||
| #include <raft/linalg/detail/cublas_wrappers.hpp> | ||||||||||
| #include <raft/linalg/detail/cusolver_wrappers.hpp> | ||||||||||
| #include <raft/linalg/gemm.cuh> | ||||||||||
| #include <raft/linalg/qr.cuh> | ||||||||||
| #include <raft/util/cuda_utils.cuh> | ||||||||||
| #include <raft/util/cudart_utils.hpp> | ||||||||||
|
|
||||||||||
| #include <rmm/device_scalar.hpp> | ||||||||||
| #include <rmm/device_uvector.hpp> | ||||||||||
|
|
||||||||||
| namespace raft::sparse::solver::detail { | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * @brief Single pass of CholeskyQR: orthogonalize Q in-place via Cholesky factorization | ||||||||||
| * of the Gram matrix W = Q^T @ Q. | ||||||||||
| * | ||||||||||
| * @return true on success, false if Cholesky factorization failed (matrix not SPD / rank deficient) | ||||||||||
| */ | ||||||||||
| template <typename ValueTypeT> | ||||||||||
| bool cholesky_qr_pass(raft::resources const& handle, | ||||||||||
| ValueTypeT* Q, | ||||||||||
| int m, | ||||||||||
| int k, | ||||||||||
| ValueTypeT* W, | ||||||||||
| ValueTypeT* workspace, | ||||||||||
| int workspace_size, | ||||||||||
| int* dev_info) | ||||||||||
| { | ||||||||||
| auto stream = raft::resource::get_cuda_stream(handle); | ||||||||||
| auto cublas_h = raft::resource::get_cublas_handle(handle); | ||||||||||
| auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle); | ||||||||||
|
|
||||||||||
| const ValueTypeT one = 1; | ||||||||||
| const ValueTypeT zero = 0; | ||||||||||
|
|
||||||||||
| // W = Q^T @ Q (k x k) | ||||||||||
| // Q is col-major (m x k), so: W = Q^T * Q via gemm(TRANS, NOTRANS, k, k, m) | ||||||||||
| RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_h, | ||||||||||
| CUBLAS_OP_T, | ||||||||||
| CUBLAS_OP_N, | ||||||||||
| k, | ||||||||||
| k, | ||||||||||
| m, | ||||||||||
| &one, | ||||||||||
| Q, | ||||||||||
| m, | ||||||||||
| Q, | ||||||||||
| m, | ||||||||||
| &zero, | ||||||||||
| W, | ||||||||||
| k, | ||||||||||
| stream)); | ||||||||||
|
|
||||||||||
| // L = cholesky(W, LOWER) — W is overwritten with L in lower triangle | ||||||||||
| RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf( | ||||||||||
| cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W, k, workspace, workspace_size, dev_info, stream)); | ||||||||||
|
|
||||||||||
| // Check if Cholesky succeeded | ||||||||||
| int h_dev_info = 0; | ||||||||||
| raft::update_host(&h_dev_info, dev_info, 1, stream); | ||||||||||
| raft::resource::sync_stream(handle); | ||||||||||
| if (h_dev_info != 0) { return false; } | ||||||||||
|
|
||||||||||
| // Q = Q @ L^{-T} | ||||||||||
| // This is equivalent to solving X * L^T = Q for X, i.e. trsm with RIGHT, LOWER, TRANS | ||||||||||
| RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(cublas_h, | ||||||||||
| CUBLAS_SIDE_RIGHT, | ||||||||||
| CUBLAS_FILL_MODE_LOWER, | ||||||||||
| CUBLAS_OP_T, | ||||||||||
| CUBLAS_DIAG_NON_UNIT, | ||||||||||
| m, | ||||||||||
| k, | ||||||||||
| &one, | ||||||||||
| W, | ||||||||||
| k, | ||||||||||
| Q, | ||||||||||
| m, | ||||||||||
| stream)); | ||||||||||
|
|
||||||||||
| return true; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * @brief CholeskyQR2 orthogonalization: two passes of CholeskyQR for numerical stability. | ||||||||||
| * | ||||||||||
| * This is the GPU-optimized orthogonalization from Tomás, Quintana-Ortí, Anzt (2024), | ||||||||||
| * "Fast Truncated SVD of Sparse and Dense Matrices on Graphics Processors". | ||||||||||
| * It uses GEMM + Cholesky + TRSM operations which are highly efficient on GPU, | ||||||||||
| * providing ~3x speedup over standard Householder QR. | ||||||||||
| * | ||||||||||
| * If Cholesky factorization fails (input is rank-deficient), falls back to standard QR. | ||||||||||
| * | ||||||||||
| * @param handle raft resources handle | ||||||||||
| * @param Q matrix to orthogonalize of shape (m, k), col-major, modified in-place | ||||||||||
| * @return true if CholeskyQR2 succeeded, false if fell back to QR | ||||||||||
| */ | ||||||||||
| template <typename ValueTypeT> | ||||||||||
| bool cholesky_qr2(raft::resources const& handle, | ||||||||||
|
Intron7 marked this conversation as resolved.
|
||||||||||
| raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Q) | ||||||||||
| { | ||||||||||
| int m = Q.extent(0); | ||||||||||
| int k = Q.extent(1); | ||||||||||
|
|
||||||||||
| auto stream = raft::resource::get_cuda_stream(handle); | ||||||||||
| auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle); | ||||||||||
|
|
||||||||||
| // Allocate workspace for Gram matrix and Cholesky | ||||||||||
| rmm::device_uvector<ValueTypeT> W(k * k, stream); | ||||||||||
| rmm::device_scalar<int> dev_info(stream); | ||||||||||
|
|
||||||||||
| // Query workspace size for potrf | ||||||||||
| int potrf_workspace_size = 0; | ||||||||||
| RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize( | ||||||||||
| cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W.data(), k, &potrf_workspace_size)); | ||||||||||
| rmm::device_uvector<ValueTypeT> potrf_workspace(potrf_workspace_size, stream); | ||||||||||
|
|
||||||||||
| // First pass | ||||||||||
| if (!cholesky_qr_pass(handle, | ||||||||||
| Q.data_handle(), | ||||||||||
| m, | ||||||||||
| k, | ||||||||||
| W.data(), | ||||||||||
| potrf_workspace.data(), | ||||||||||
| potrf_workspace_size, | ||||||||||
| dev_info.data())) { | ||||||||||
| // Fallback to standard QR | ||||||||||
| rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream); | ||||||||||
| raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream); | ||||||||||
| raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream); | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Additionally, |
||||||||||
| return false; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Second pass for improved numerical stability | ||||||||||
| if (!cholesky_qr_pass(handle, | ||||||||||
| Q.data_handle(), | ||||||||||
| m, | ||||||||||
| k, | ||||||||||
| W.data(), | ||||||||||
| potrf_workspace.data(), | ||||||||||
| potrf_workspace_size, | ||||||||||
| dev_info.data())) { | ||||||||||
| // Fallback to standard QR | ||||||||||
| rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream); | ||||||||||
| raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream); | ||||||||||
| raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream); | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment here. |
||||||||||
| return false; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| return true; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| } // namespace raft::sparse::solver::detail | ||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,209 @@ | ||||||
| /* | ||||||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||||||
| * SPDX-License-Identifier: Apache-2.0 | ||||||
| */ | ||||||
|
|
||||||
| #pragma once | ||||||
|
|
||||||
| #include <raft/core/device_mdarray.hpp> | ||||||
| #include <raft/core/device_mdspan.hpp> | ||||||
| #include <raft/core/nvtx.hpp> | ||||||
| #include <raft/core/resource/cuda_stream.hpp> | ||||||
| #include <raft/core/resources.hpp> | ||||||
| #include <raft/linalg/gemm.cuh> | ||||||
| #include <raft/linalg/svd.cuh> | ||||||
| #include <raft/linalg/transpose.cuh> | ||||||
| #include <raft/random/rng.cuh> | ||||||
| #include <raft/random/rng_state.hpp> | ||||||
| #include <raft/sparse/solver/detail/cholesky_qr.cuh> | ||||||
| #include <raft/sparse/solver/detail/svds_sign_correction.cuh> | ||||||
| #include <raft/sparse/solver/svds_types.hpp> | ||||||
| #include <raft/util/cuda_utils.cuh> | ||||||
| #include <raft/util/cudart_utils.hpp> | ||||||
|
|
||||||
| #include <rmm/device_uvector.hpp> | ||||||
|
|
||||||
| #include <algorithm> | ||||||
|
|
||||||
| namespace raft::sparse::solver::detail { | ||||||
|
|
||||||
| /** | ||||||
| * @brief Randomized SVD for sparse matrices using block power iteration with CholeskyQR2. | ||||||
| * | ||||||
| * Implements randomized SVD (Halko et al. 2009) with GPU-optimized CholeskyQR2 | ||||||
| * orthogonalization (Tomás et al. 2024). | ||||||
| * | ||||||
| * The operator interface allows implicit operators (e.g. mean-centered sparse matrices) | ||||||
| * without materializing the dense matrix. | ||||||
| * | ||||||
| * @tparam ValueTypeT Data type (float or double) | ||||||
| * @tparam OperatorT Linear operator type providing apply() and apply_transpose() | ||||||
| * | ||||||
| * @param handle raft resources handle | ||||||
| * @param config SVD configuration (n_components, n_oversamples, n_power_iters, seed) | ||||||
| * @param op linear operator representing the matrix to decompose | ||||||
| * @param singular_values output singular values of shape (k,) in descending order | ||||||
| * @param U output left singular vectors of shape (m, k), col-major | ||||||
| * @param Vt output right singular vectors of shape (k, n), col-major | ||||||
| */ | ||||||
| template <typename ValueTypeT, typename OperatorT> | ||||||
| void sparse_randomized_svd( | ||||||
| raft::resources const& handle, | ||||||
| sparse_svd_config<ValueTypeT> const& config, | ||||||
| OperatorT const& op, | ||||||
| raft::device_vector_view<ValueTypeT, uint32_t> singular_values, | ||||||
| raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> U, | ||||||
| raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Vt) | ||||||
| { | ||||||
| common::nvtx::range<common::nvtx::domain::raft> fun_scope( | ||||||
| "raft::sparse::solver::sparse_randomized_svd(%d, %d, %d)", | ||||||
| op.rows(), | ||||||
| op.cols(), | ||||||
| config.n_components); | ||||||
|
|
||||||
| int m = op.rows(); | ||||||
| int n = op.cols(); | ||||||
| int k = config.n_components; | ||||||
| int p = config.n_oversamples; | ||||||
|
|
||||||
| RAFT_EXPECTS(k > 0, "n_components must be positive"); | ||||||
| RAFT_EXPECTS(k < std::min(m, n), "n_components must be less than min(m, n)"); | ||||||
| RAFT_EXPECTS(p >= 0, "n_oversamples must be non-negative"); | ||||||
| RAFT_EXPECTS(config.n_power_iters >= 0, "n_power_iters must be non-negative"); | ||||||
| RAFT_EXPECTS(singular_values.extent(0) == static_cast<uint32_t>(k), | ||||||
| "singular_values must have size n_components"); | ||||||
| RAFT_EXPECTS(U.extent(0) == static_cast<uint32_t>(m) && U.extent(1) == static_cast<uint32_t>(k), | ||||||
| "U must have shape (m, n_components)"); | ||||||
| RAFT_EXPECTS( | ||||||
| Vt.extent(0) == static_cast<uint32_t>(k) && Vt.extent(1) == static_cast<uint32_t>(n), | ||||||
| "Vt must have shape (n_components, n)"); | ||||||
|
|
||||||
| auto stream = raft::resource::get_cuda_stream(handle); | ||||||
|
|
||||||
| int block_size = std::min(k + p, std::min(m, n)); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should avoid silently clamping as it affects approximation quality. Please add a warning when |
||||||
| RAFT_EXPECTS(block_size >= k, "block_size (n_components + n_oversamples) must be >= n_components"); | ||||||
|
|
||||||
| // Initialize RNG | ||||||
| uint64_t seed = config.seed.value_or(0); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Setting |
||||||
| raft::random::RngState rng_state(seed); | ||||||
|
|
||||||
| // Step 1-3: Y = A @ Omega, orthogonalize | ||||||
| auto Y = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>( | ||||||
| handle, static_cast<uint32_t>(m), static_cast<uint32_t>(block_size)); | ||||||
| { | ||||||
| auto Omega = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>( | ||||||
| handle, static_cast<uint32_t>(n), static_cast<uint32_t>(block_size)); | ||||||
| raft::random::normal(handle, | ||||||
| rng_state, | ||||||
| Omega.data_handle(), | ||||||
| static_cast<std::size_t>(n) * block_size, | ||||||
| ValueTypeT(0), | ||||||
| ValueTypeT(1)); | ||||||
| op.apply(handle, | ||||||
| raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>( | ||||||
| Omega.data_handle(), n, block_size), | ||||||
| Y.view()); | ||||||
| } // Omega freed here | ||||||
| cholesky_qr2(handle, Y.view()); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should maybe check the return value of |
||||||
|
|
||||||
| // Step 4: Power iterations | ||||||
| auto Z = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>( | ||||||
| handle, static_cast<uint32_t>(n), static_cast<uint32_t>(block_size)); | ||||||
|
|
||||||
| for (int iter = 0; iter < config.n_power_iters; ++iter) { | ||||||
| // Z = A^T @ Q -> (n, block_size) | ||||||
| op.apply_transpose( | ||||||
| handle, | ||||||
| raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>( | ||||||
| Y.data_handle(), m, block_size), | ||||||
| Z.view()); | ||||||
| cholesky_qr2(handle, Z.view()); | ||||||
|
|
||||||
| // Y = A @ Z -> (m, block_size) | ||||||
| op.apply(handle, | ||||||
| raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>( | ||||||
| Z.data_handle(), n, block_size), | ||||||
| Y.view()); | ||||||
| cholesky_qr2(handle, Y.view()); | ||||||
| } | ||||||
|
|
||||||
| // Q = Y after power iterations (already orthogonal) | ||||||
| // Q is (m, block_size) | ||||||
|
|
||||||
| // Step 5: Bt = A^T @ Q -> (n, block_size) | ||||||
| op.apply_transpose( | ||||||
| handle, | ||||||
| raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>( | ||||||
| Y.data_handle(), m, block_size), | ||||||
| Z.view()); | ||||||
|
|
||||||
| // Step 6-7: SVD of B = Bt^T where Bt = Z (n x block_size, tall matrix) | ||||||
| // We compute SVD(Bt) directly to avoid cuSOLVER gesvd issues with wide matrices. | ||||||
| // SVD(Bt) = U_bt * S * Vt_bt → SVD(B) has U_b = V_bt and Vt_b = U_bt^T | ||||||
| auto S_full = raft::make_device_vector<ValueTypeT, uint32_t>(handle, block_size); | ||||||
|
|
||||||
| // Bt = B^T is (n x block_size), we already have Bt = Z from step 5 | ||||||
| // Z is (n, block_size) = A^T @ Q = B^T. We can reuse Z directly! | ||||||
| // SVD(Bt) with Bt being (n x block_size): jobu='S' gives U_bt (n x block_size), | ||||||
| // jobvt='A' gives Vt_bt (block_size x block_size) | ||||||
| auto U_bt = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>( | ||||||
| handle, static_cast<uint32_t>(n), static_cast<uint32_t>(block_size)); | ||||||
| auto Vt_bt = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>( | ||||||
| handle, static_cast<uint32_t>(block_size), static_cast<uint32_t>(block_size)); | ||||||
|
|
||||||
| // Z is consumed by svdQR (modifies input in-place) — this is fine since Z is not used after | ||||||
| raft::linalg::svdQR(handle, | ||||||
| Z.data_handle(), | ||||||
| n, | ||||||
| block_size, | ||||||
| S_full.data_handle(), | ||||||
| U_bt.data_handle(), | ||||||
| Vt_bt.data_handle(), | ||||||
| true, // transpose right vectors: Vt_bt -> V_bt (block_size x block_size) | ||||||
| true, // generate left vectors | ||||||
| true, // generate right vectors | ||||||
| stream); | ||||||
| // After svdQR with trans_right=true: | ||||||
| // U_bt is (n, block_size) — left singular vectors of Bt | ||||||
| // Vt_bt is now V_bt (block_size x block_size) — right singular vectors of Bt (transposed) | ||||||
| // S_full has block_size singular values | ||||||
| // | ||||||
| // For B = Bt^T: U_b = V_bt = Vt_bt (after transpose), Vt_b = U_bt^T | ||||||
| // So: U_b[:, :k] = Vt_bt[:, :k] and Vt_b[:k, :] = U_bt[:, :k]^T | ||||||
|
|
||||||
| // Step 8: U = Q @ U_b[:, :k] = Q @ V_bt[:, :k] | ||||||
| // Q is Y (m, block_size), V_bt is (block_size, block_size) | ||||||
| // U = Y @ V_bt[:, :k] → (m, block_size) * (block_size, k) → (m, k) | ||||||
| const ValueTypeT one = 1; | ||||||
| const ValueTypeT zero = 0; | ||||||
| raft::linalg::gemm(handle, | ||||||
| Y.data_handle(), | ||||||
| m, | ||||||
| block_size, | ||||||
| Vt_bt.data_handle(), // This is V_bt after trans_right=true | ||||||
| U.data_handle(), | ||||||
| m, | ||||||
| k, | ||||||
| CUBLAS_OP_N, | ||||||
| CUBLAS_OP_N, | ||||||
| one, | ||||||
| zero, | ||||||
| stream); | ||||||
|
|
||||||
| // Step 9: Truncate S and Vt | ||||||
| raft::copy(singular_values.data_handle(), S_full.data_handle(), k, stream); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A nitpick
Suggested change
|
||||||
|
|
||||||
| // Vt[:k, :] = U_bt[:, :k]^T | ||||||
| // U_bt is col-major (n, block_size), we need the first k columns transposed to (k, n) | ||||||
| // Vt(i,j) = U_bt(j,i) for i < k | ||||||
| // This is: Vt = (U_bt[:, :k])^T | ||||||
| // Use GEMM with identity: Vt = I_k @ U_bt[:, :k]^T doesn't help | ||||||
| // Just use transpose: transpose the first k columns of U_bt | ||||||
| // U_bt[:, :k] is (n, k) col-major → transpose to (k, n) col-major = Vt | ||||||
| raft::linalg::transpose(handle, U_bt.data_handle(), Vt.data_handle(), n, k, stream); | ||||||
|
|
||||||
| // Step 10: Sign correction | ||||||
| svd_sign_correction(handle, U, Vt); | ||||||
| } | ||||||
|
|
||||||
| } // namespace raft::sparse::solver::detail | ||||||
Uh oh!
There was an error while loading. Please reload this page.