Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/solver/lanczos_solver_int64_float.cu
src/raft_runtime/solver/lanczos_solver_int_double.cu
src/raft_runtime/solver/lanczos_solver_int_float.cu
src/raft_runtime/solver/randomized_svds_float.cu
src/raft_runtime/solver/randomized_svds_double.cu
)
set_target_properties(
raft_objs
Expand Down
159 changes: 159 additions & 0 deletions cpp/include/raft/sparse/solver/detail/cholesky_qr.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/qr.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

namespace raft::sparse::solver::detail {

/**
* @brief Single pass of CholeskyQR: orthogonalize Q in-place via Cholesky factorization
* of the Gram matrix W = Q^T @ Q.
*
* @return true on success, false if Cholesky factorization failed (matrix not SPD / rank deficient)
*/
template <typename ValueTypeT>
bool cholesky_qr_pass(raft::resources const& handle,
ValueTypeT* Q,
int m,
int k,
ValueTypeT* W,
ValueTypeT* workspace,
int workspace_size,
int* dev_info)
{
auto stream = raft::resource::get_cuda_stream(handle);
auto cublas_h = raft::resource::get_cublas_handle(handle);
auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle);

const ValueTypeT one = 1;
const ValueTypeT zero = 0;

// W = Q^T @ Q (k x k)
// Q is col-major (m x k), so: W = Q^T * Q via gemm(TRANS, NOTRANS, k, k, m)
raft::linalg::gemm(handle,
true, // trans_a
false, // trans_b
k,
k,
m,
&one,
Q,
m,
Q,
m,
&zero,
W,
k,
stream);

// L = cholesky(W, LOWER) — W is overwritten with L in lower triangle
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(
cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W, k, workspace, workspace_size, dev_info, stream));

// Check if Cholesky succeeded
int h_dev_info = 0;
raft::update_host(&h_dev_info, dev_info, 1, stream);
raft::resource::sync_stream(handle);
if (h_dev_info != 0) { return false; }

// Q = Q @ L^{-T}
// This is equivalent to solving X * L^T = Q for X, i.e. trsm with RIGHT, LOWER, TRANS
RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(cublas_h,
CUBLAS_SIDE_RIGHT,
CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_T,
CUBLAS_DIAG_NON_UNIT,
m,
k,
&one,
W,
k,
Q,
m,
stream));

return true;
}

/**
* @brief CholeskyQR2 orthogonalization: two passes of CholeskyQR for numerical stability.
*
* This is the GPU-optimized orthogonalization from Tomás, Quintana-Ortí, Anzt (2024),
* "Fast Truncated SVD of Sparse and Dense Matrices on Graphics Processors".
* It uses GEMM + Cholesky + TRSM operations which are highly efficient on GPU,
* providing ~3x speedup over standard Householder QR.
*
* If Cholesky factorization fails (input is rank-deficient), falls back to standard QR.
*
* @param handle raft resources handle
* @param Q matrix to orthogonalize of shape (m, k), col-major, modified in-place
* @return true if CholeskyQR2 succeeded, false if fell back to QR
*/
template <typename ValueTypeT>
bool cholesky_qr2(raft::resources const& handle,
Comment thread
Intron7 marked this conversation as resolved.
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Q)
{
int m = Q.extent(0);
int k = Q.extent(1);

auto stream = raft::resource::get_cuda_stream(handle);
auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle);

// Allocate workspace for Gram matrix and Cholesky
rmm::device_uvector<ValueTypeT> W(k * k, stream);
rmm::device_scalar<int> dev_info(stream);

// Query workspace size for potrf
int potrf_workspace_size = 0;
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize(
cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W.data(), k, &potrf_workspace_size));
rmm::device_uvector<ValueTypeT> potrf_workspace(potrf_workspace_size, stream);

// First pass
if (!cholesky_qr_pass(handle,
Q.data_handle(),
m,
k,
W.data(),
potrf_workspace.data(),
potrf_workspace_size,
dev_info.data())) {
// Fallback to standard QR (qrGetQ handles src==dst via internal copy)
raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream);
return false;
}

// Second pass for improved numerical stability
if (!cholesky_qr_pass(handle,
Q.data_handle(),
m,
k,
W.data(),
potrf_workspace.data(),
potrf_workspace_size,
dev_info.data())) {
// Fallback to standard QR (qrGetQ handles src==dst via internal copy)
raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream);
return false;
}

return true;
}

} // namespace raft::sparse::solver::detail
85 changes: 85 additions & 0 deletions cpp/include/raft/sparse/solver/detail/csr_linear_operator.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/linalg/spmm.hpp>

namespace raft::sparse::solver::detail {

/**
* @brief CSR matrix wrapper providing apply() / apply_transpose() for the sparse SVD solver.
*
* This is NOT a general-purpose LinearOperator (in the sense of scipy's
* `scipy.sparse.linalg.LinearOperator`): it only exposes the two products the
* randomized SVD inner loop needs (Y = A @ X and Z = A^T @ X) via cuSPARSE SpMM.
* It intentionally lives in `detail/` and is not part of the public API.
*
* @note The cuSPARSE spmm wrapper requires `int` for indptr/indices types, so this wrapper
* is currently limited to int-indexed CSR matrices.
*
* @tparam ValueTypeT Data type of matrix values
* @tparam NNZTypeT Type for number of non-zeros
*/
template <typename ValueTypeT, typename NNZTypeT = int>
struct csr_linear_operator {
/**
* @brief Construct from a const CSR matrix view
*
* @note `m_`/`n_` are cached at construction because
* `raft::device_csr_matrix_view::structure_view()` is not const-qualified
* and cannot be invoked from a const member function.
*/
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see a LinearOperator implementation! It could be useful in the lanczos solver as well (#2705). Is this one similar to https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html

I see this is in a detail file, but it might be nice to have a public interface.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a full Linear Operator. Its just a very limited version for raft to do PCA with the SVDs

explicit csr_linear_operator(raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A)
: A_(A), m_(A.structure_view().get_n_rows()), n_(A.structure_view().get_n_cols())
{
}

int rows() const { return m_; }
int cols() const { return n_; }

/** @brief Access the underlying const CSR matrix view */
raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> csr_view() const { return A_; }

/**
* @brief Compute Y = A @ X
* @param[in] handle raft resources handle
* @param[in] X input dense matrix of shape (n, k) col-major
* @param[out] Y output dense matrix of shape (m, k) col-major
*/
void apply(raft::resources const& handle,
raft::device_matrix_view<const ValueTypeT, uint32_t, raft::col_major> X,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Y) const
{
ValueTypeT alpha = 1;
ValueTypeT beta = 0;
raft::sparse::linalg::spmm(handle, false, false, &alpha, A_, X, &beta, Y);
}

/**
* @brief Compute Z = A^T @ X
* @param[in] handle raft resources handle
* @param[in] X input dense matrix of shape (m, k) col-major
* @param[out] Z output dense matrix of shape (n, k) col-major
*/
void apply_transpose(raft::resources const& handle,
raft::device_matrix_view<const ValueTypeT, uint32_t, raft::col_major> X,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Z) const
{
ValueTypeT alpha = 1;
ValueTypeT beta = 0;
raft::sparse::linalg::spmm(handle, true, false, &alpha, A_, X, &beta, Z);
}

private:
raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A_;
int m_;
int n_;
};

} // namespace raft::sparse::solver::detail
Loading
Loading