-
Notifications
You must be signed in to change notification settings - Fork 231
Add Randomized SVDs #2999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Randomized SVDs #2999
Changes from all commits
d20ea68
307d18e
e52ad48
8a08ffc
9dcc910
a803923
5dc8079
9b44f7a
c5fd1a4
b601ea2
15b6285
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <raft/core/resource/cublas_handle.hpp> | ||
| #include <raft/core/resource/cuda_stream.hpp> | ||
| #include <raft/core/resource/cusolver_dn_handle.hpp> | ||
| #include <raft/core/resources.hpp> | ||
| #include <raft/linalg/detail/cublas_wrappers.hpp> | ||
| #include <raft/linalg/detail/cusolver_wrappers.hpp> | ||
| #include <raft/linalg/gemm.cuh> | ||
| #include <raft/linalg/qr.cuh> | ||
| #include <raft/util/cuda_utils.cuh> | ||
| #include <raft/util/cudart_utils.hpp> | ||
|
|
||
| #include <rmm/device_scalar.hpp> | ||
| #include <rmm/device_uvector.hpp> | ||
|
|
||
| namespace raft::sparse::solver::detail { | ||
|
|
||
| /** | ||
| * @brief Single pass of CholeskyQR: orthogonalize Q in-place via Cholesky factorization | ||
| * of the Gram matrix W = Q^T @ Q. | ||
| * | ||
| * @return true on success, false if Cholesky factorization failed (matrix not SPD / rank deficient) | ||
| */ | ||
| template <typename ValueTypeT> | ||
| bool cholesky_qr_pass(raft::resources const& handle, | ||
| ValueTypeT* Q, | ||
| int m, | ||
| int k, | ||
| ValueTypeT* W, | ||
| ValueTypeT* workspace, | ||
| int workspace_size, | ||
| int* dev_info) | ||
| { | ||
| auto stream = raft::resource::get_cuda_stream(handle); | ||
| auto cublas_h = raft::resource::get_cublas_handle(handle); | ||
| auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle); | ||
|
|
||
| const ValueTypeT one = 1; | ||
| const ValueTypeT zero = 0; | ||
|
|
||
| // W = Q^T @ Q (k x k) | ||
| // Q is col-major (m x k), so: W = Q^T * Q via gemm(TRANS, NOTRANS, k, k, m) | ||
| raft::linalg::gemm(handle, | ||
| true, // trans_a | ||
| false, // trans_b | ||
| k, | ||
| k, | ||
| m, | ||
| &one, | ||
| Q, | ||
| m, | ||
| Q, | ||
| m, | ||
| &zero, | ||
| W, | ||
| k, | ||
| stream); | ||
|
|
||
| // L = cholesky(W, LOWER) — W is overwritten with L in lower triangle | ||
| RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf( | ||
| cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W, k, workspace, workspace_size, dev_info, stream)); | ||
|
|
||
| // Check if Cholesky succeeded | ||
| int h_dev_info = 0; | ||
| raft::update_host(&h_dev_info, dev_info, 1, stream); | ||
| raft::resource::sync_stream(handle); | ||
| if (h_dev_info != 0) { return false; } | ||
|
|
||
| // Q = Q @ L^{-T} | ||
| // This is equivalent to solving X * L^T = Q for X, i.e. trsm with RIGHT, LOWER, TRANS | ||
| RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(cublas_h, | ||
| CUBLAS_SIDE_RIGHT, | ||
| CUBLAS_FILL_MODE_LOWER, | ||
| CUBLAS_OP_T, | ||
| CUBLAS_DIAG_NON_UNIT, | ||
| m, | ||
| k, | ||
| &one, | ||
| W, | ||
| k, | ||
| Q, | ||
| m, | ||
| stream)); | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| /** | ||
| * @brief CholeskyQR2 orthogonalization: two passes of CholeskyQR for numerical stability. | ||
| * | ||
| * This is the GPU-optimized orthogonalization from Tomás, Quintana-Ortí, Anzt (2024), | ||
| * "Fast Truncated SVD of Sparse and Dense Matrices on Graphics Processors". | ||
| * It uses GEMM + Cholesky + TRSM operations which are highly efficient on GPU, | ||
| * providing ~3x speedup over standard Householder QR. | ||
| * | ||
| * If Cholesky factorization fails (input is rank-deficient), falls back to standard QR. | ||
| * | ||
| * @param handle raft resources handle | ||
| * @param Q matrix to orthogonalize of shape (m, k), col-major, modified in-place | ||
| * @return true if CholeskyQR2 succeeded, false if fell back to QR | ||
| */ | ||
| template <typename ValueTypeT> | ||
| bool cholesky_qr2(raft::resources const& handle, | ||
| raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Q) | ||
| { | ||
| int m = Q.extent(0); | ||
| int k = Q.extent(1); | ||
|
|
||
| auto stream = raft::resource::get_cuda_stream(handle); | ||
| auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle); | ||
|
|
||
| // Allocate workspace for Gram matrix and Cholesky | ||
| rmm::device_uvector<ValueTypeT> W(k * k, stream); | ||
| rmm::device_scalar<int> dev_info(stream); | ||
|
|
||
| // Query workspace size for potrf | ||
| int potrf_workspace_size = 0; | ||
| RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize( | ||
| cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W.data(), k, &potrf_workspace_size)); | ||
| rmm::device_uvector<ValueTypeT> potrf_workspace(potrf_workspace_size, stream); | ||
|
|
||
| // First pass | ||
| if (!cholesky_qr_pass(handle, | ||
| Q.data_handle(), | ||
| m, | ||
| k, | ||
| W.data(), | ||
| potrf_workspace.data(), | ||
| potrf_workspace_size, | ||
| dev_info.data())) { | ||
| // Fallback to standard QR (qrGetQ handles src==dst via internal copy) | ||
| raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream); | ||
| return false; | ||
| } | ||
|
|
||
| // Second pass for improved numerical stability | ||
| if (!cholesky_qr_pass(handle, | ||
| Q.data_handle(), | ||
| m, | ||
| k, | ||
| W.data(), | ||
| potrf_workspace.data(), | ||
| potrf_workspace_size, | ||
| dev_info.data())) { | ||
| // Fallback to standard QR (qrGetQ handles src==dst via internal copy) | ||
| raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream); | ||
| return false; | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| } // namespace raft::sparse::solver::detail | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <raft/core/device_csr_matrix.hpp> | ||
| #include <raft/core/device_mdspan.hpp> | ||
| #include <raft/core/resources.hpp> | ||
| #include <raft/sparse/linalg/spmm.hpp> | ||
|
|
||
| namespace raft::sparse::solver::detail { | ||
|
|
||
| /** | ||
| * @brief CSR matrix wrapper providing apply() / apply_transpose() for the sparse SVD solver. | ||
| * | ||
| * This is NOT a general-purpose LinearOperator (in the sense of scipy's | ||
| * `scipy.sparse.linalg.LinearOperator`): it only exposes the two products the | ||
| * randomized SVD inner loop needs (Y = A @ X and Z = A^T @ X) via cuSPARSE SpMM. | ||
| * It intentionally lives in `detail/` and is not part of the public API. | ||
| * | ||
| * @note The cuSPARSE spmm wrapper requires `int` for indptr/indices types, so this wrapper | ||
| * is currently limited to int-indexed CSR matrices. | ||
| * | ||
| * @tparam ValueTypeT Data type of matrix values | ||
| * @tparam NNZTypeT Type for number of non-zeros | ||
| */ | ||
| template <typename ValueTypeT, typename NNZTypeT = int> | ||
| struct csr_linear_operator { | ||
| /** | ||
| * @brief Construct from a const CSR matrix view | ||
| * | ||
| * @note `m_`/`n_` are cached at construction because | ||
| * `raft::device_csr_matrix_view::structure_view()` is not const-qualified | ||
| * and cannot be invoked from a const member function. | ||
| */ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great to see a LinearOperator implementation! It could be useful in the lanczos solver as well (#2705). Is this one similar to https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html I see this is in a detail file, but it might be nice to have a public interface.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a full Linear Operator. Its just a very limited version for raft to do PCA with the SVDs |
||
| explicit csr_linear_operator(raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A) | ||
| : A_(A), m_(A.structure_view().get_n_rows()), n_(A.structure_view().get_n_cols()) | ||
| { | ||
| } | ||
|
|
||
| int rows() const { return m_; } | ||
| int cols() const { return n_; } | ||
|
|
||
| /** @brief Access the underlying const CSR matrix view */ | ||
| raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> csr_view() const { return A_; } | ||
|
|
||
| /** | ||
| * @brief Compute Y = A @ X | ||
| * @param[in] handle raft resources handle | ||
| * @param[in] X input dense matrix of shape (n, k) col-major | ||
| * @param[out] Y output dense matrix of shape (m, k) col-major | ||
| */ | ||
| void apply(raft::resources const& handle, | ||
| raft::device_matrix_view<const ValueTypeT, uint32_t, raft::col_major> X, | ||
| raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Y) const | ||
| { | ||
| ValueTypeT alpha = 1; | ||
| ValueTypeT beta = 0; | ||
| raft::sparse::linalg::spmm(handle, false, false, &alpha, A_, X, &beta, Y); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Compute Z = A^T @ X | ||
| * @param[in] handle raft resources handle | ||
| * @param[in] X input dense matrix of shape (m, k) col-major | ||
| * @param[out] Z output dense matrix of shape (n, k) col-major | ||
| */ | ||
| void apply_transpose(raft::resources const& handle, | ||
| raft::device_matrix_view<const ValueTypeT, uint32_t, raft::col_major> X, | ||
| raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Z) const | ||
| { | ||
| ValueTypeT alpha = 1; | ||
| ValueTypeT beta = 0; | ||
| raft::sparse::linalg::spmm(handle, true, false, &alpha, A_, X, &beta, Z); | ||
| } | ||
|
|
||
| private: | ||
| raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A_; | ||
| int m_; | ||
| int n_; | ||
| }; | ||
|
|
||
| } // namespace raft::sparse::solver::detail | ||
Uh oh!
There was an error while loading. Please reload this page.