Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/solver/lanczos_solver_int64_float.cu
src/raft_runtime/solver/lanczos_solver_int_double.cu
src/raft_runtime/solver/lanczos_solver_int_float.cu
src/raft_runtime/solver/randomized_svds_float.cu
src/raft_runtime/solver/randomized_svds_double.cu
)
set_target_properties(
raft_objs
Expand Down
159 changes: 159 additions & 0 deletions cpp/include/raft/sparse/solver/detail/cholesky_qr.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/qr.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

namespace raft::sparse::solver::detail {

/**
* @brief Single pass of CholeskyQR: orthogonalize Q in-place via Cholesky factorization
* of the Gram matrix W = Q^T @ Q.
*
* @return true on success, false if Cholesky factorization failed (matrix not SPD / rank deficient)
*/
template <typename ValueTypeT>
bool cholesky_qr_pass(raft::resources const& handle,
ValueTypeT* Q,
int m,
int k,
ValueTypeT* W,
ValueTypeT* workspace,
int workspace_size,
int* dev_info)
{
auto stream = raft::resource::get_cuda_stream(handle);
auto cublas_h = raft::resource::get_cublas_handle(handle);
auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle);

const ValueTypeT one = 1;
const ValueTypeT zero = 0;

// W = Q^T @ Q (k x k)
// Q is col-major (m x k), so: W = Q^T * Q via gemm(TRANS, NOTRANS, k, k, m)
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_h,
CUBLAS_OP_T,
CUBLAS_OP_N,
k,
k,
m,
&one,
Q,
m,
Q,
m,
&zero,
W,
k,
stream));
Comment thread
achirkin marked this conversation as resolved.
Outdated

// L = cholesky(W, LOWER) — W is overwritten with L in lower triangle
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(
cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W, k, workspace, workspace_size, dev_info, stream));

// Check if Cholesky succeeded
int h_dev_info = 0;
raft::update_host(&h_dev_info, dev_info, 1, stream);
raft::resource::sync_stream(handle);
if (h_dev_info != 0) { return false; }

// Q = Q @ L^{-T}
// This is equivalent to solving X * L^T = Q for X, i.e. trsm with RIGHT, LOWER, TRANS
RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(cublas_h,
CUBLAS_SIDE_RIGHT,
CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_T,
CUBLAS_DIAG_NON_UNIT,
m,
k,
&one,
W,
k,
Q,
m,
stream));

return true;
}

/**
* @brief CholeskyQR2 orthogonalization: two passes of CholeskyQR for numerical stability.
*
* This is the GPU-optimized orthogonalization from Tomás, Quintana-Ortí, Anzt (2024),
* "Fast Truncated SVD of Sparse and Dense Matrices on Graphics Processors".
* It uses GEMM + Cholesky + TRSM operations which are highly efficient on GPU,
* providing ~3x speedup over standard Householder QR.
*
* If Cholesky factorization fails (input is rank-deficient), falls back to standard QR.
*
* @param handle raft resources handle
* @param Q matrix to orthogonalize of shape (m, k), col-major, modified in-place
* @return true if CholeskyQR2 succeeded, false if fell back to QR
*/
template <typename ValueTypeT>
bool cholesky_qr2(raft::resources const& handle,
Comment thread
Intron7 marked this conversation as resolved.
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Q)
{
int m = Q.extent(0);
int k = Q.extent(1);

auto stream = raft::resource::get_cuda_stream(handle);
auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle);

// Allocate workspace for Gram matrix and Cholesky
rmm::device_uvector<ValueTypeT> W(k * k, stream);
rmm::device_scalar<int> dev_info(stream);

// Query workspace size for potrf
int potrf_workspace_size = 0;
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize(
cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W.data(), k, &potrf_workspace_size));
rmm::device_uvector<ValueTypeT> potrf_workspace(potrf_workspace_size, stream);

// First pass
if (!cholesky_qr_pass(handle,
Q.data_handle(),
m,
k,
W.data(),
potrf_workspace.data(),
potrf_workspace_size,
dev_info.data())) {
// Fallback to standard QR (qrGetQ handles src==dst via internal copy)
raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream);
return false;
}

// Second pass for improved numerical stability
if (!cholesky_qr_pass(handle,
Q.data_handle(),
m,
k,
W.data(),
potrf_workspace.data(),
potrf_workspace_size,
dev_info.data())) {
// Fallback to standard QR (qrGetQ handles src==dst via internal copy)
raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream);
return false;
}

return true;
}

} // namespace raft::sparse::solver::detail
96 changes: 96 additions & 0 deletions cpp/include/raft/sparse/solver/detail/csr_linear_operator.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/linalg/spmm.hpp>

namespace raft::sparse::solver::detail {

/**
* @brief Linear operator wrapping a CSR sparse matrix for use with sparse SVD solvers.
*
* Provides apply() (Y = A @ X) and apply_transpose() (Z = A^T @ X) using cuSPARSE SpMM.
*
* @note The cuSPARSE spmm wrapper requires int for indptr/indices types, so this operator
* is currently limited to int-indexed CSR matrices.
*
* @tparam ValueTypeT Data type of matrix values
* @tparam NNZTypeT Type for number of non-zeros
*/
template <typename ValueTypeT, typename NNZTypeT = int>
struct csr_linear_operator {
/**
* @brief Construct from a const CSR matrix view
*/
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see a LinearOperator implementation! It could be useful in the lanczos solver as well (#2705). Is this one similar to https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html

I see this is in a detail file, but it might be nice to have a public interface.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a full Linear Operator. Its just a very limited version for raft to do PCA with the SVDs

explicit csr_linear_operator(
raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A)
: A_(A),
m_(A.structure_view().get_n_rows()),
n_(A.structure_view().get_n_cols())
{
}

/**
* @brief Construct from a mutable CSR matrix view (converts to const)
*/
explicit csr_linear_operator(
raft::device_csr_matrix_view<ValueTypeT, int, int, NNZTypeT> A)
: A_(raft::make_device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT>(
A.get_elements().data(), A.structure_view())),
m_(A.structure_view().get_n_rows()),
n_(A.structure_view().get_n_cols())
{
}

int rows() const { return m_; }
int cols() const { return n_; }

/** @brief Access the underlying const CSR matrix view (for SpMV operations) */
raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> csr_view() const
{
return A_;
}

/**
* @brief Compute Y = A @ X
* @param[in] handle raft resources handle
* @param[in] X input dense matrix of shape (n, k) col-major
* @param[out] Y output dense matrix of shape (m, k) col-major
*/
void apply(raft::resources const& handle,
raft::device_matrix_view<const ValueTypeT, uint32_t, raft::col_major> X,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Y) const
{
ValueTypeT alpha = 1;
ValueTypeT beta = 0;
raft::sparse::linalg::spmm(handle, false, false, &alpha, A_, X, &beta, Y);
}

/**
* @brief Compute Z = A^T @ X
* @param[in] handle raft resources handle
* @param[in] X input dense matrix of shape (m, k) col-major
* @param[out] Z output dense matrix of shape (n, k) col-major
*/
void apply_transpose(raft::resources const& handle,
raft::device_matrix_view<const ValueTypeT, uint32_t, raft::col_major> X,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Z) const
{
ValueTypeT alpha = 1;
ValueTypeT beta = 0;
raft::sparse::linalg::spmm(handle, true, false, &alpha, A_, X, &beta, Z);
}

private:
raft::device_csr_matrix_view<const ValueTypeT, int, int, NNZTypeT> A_;
int m_;
int n_;
};

} // namespace raft::sparse::solver::detail
Loading
Loading