Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/solver/lanczos_solver_int64_float.cu
src/raft_runtime/solver/lanczos_solver_int_double.cu
src/raft_runtime/solver/lanczos_solver_int_float.cu
src/raft_runtime/solver/randomized_svds_float.cu
src/raft_runtime/solver/randomized_svds_double.cu
)
set_target_properties(
raft_objs
Expand Down
163 changes: 163 additions & 0 deletions cpp/include/raft/sparse/solver/detail/cholesky_qr.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/detail/cusolver_wrappers.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/qr.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

namespace raft::sparse::solver::detail {

/**
* @brief Single pass of CholeskyQR: orthogonalize Q in-place via Cholesky factorization
* of the Gram matrix W = Q^T @ Q.
*
* @return true on success, false if Cholesky factorization failed (matrix not SPD / rank deficient)
*/
template <typename ValueTypeT>
bool cholesky_qr_pass(raft::resources const& handle,
ValueTypeT* Q,
int m,
int k,
ValueTypeT* W,
ValueTypeT* workspace,
int workspace_size,
int* dev_info)
{
auto stream = raft::resource::get_cuda_stream(handle);
auto cublas_h = raft::resource::get_cublas_handle(handle);
auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle);

const ValueTypeT one = 1;
const ValueTypeT zero = 0;

// W = Q^T @ Q (k x k)
// Q is col-major (m x k), so: W = Q^T * Q via gemm(TRANS, NOTRANS, k, k, m)
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_h,
CUBLAS_OP_T,
CUBLAS_OP_N,
k,
k,
m,
&one,
Q,
m,
Q,
m,
&zero,
W,
k,
stream));
Comment thread
achirkin marked this conversation as resolved.
Outdated

// L = cholesky(W, LOWER) — W is overwritten with L in lower triangle
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(
cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W, k, workspace, workspace_size, dev_info, stream));

// Check if Cholesky succeeded
int h_dev_info = 0;
raft::update_host(&h_dev_info, dev_info, 1, stream);
raft::resource::sync_stream(handle);
if (h_dev_info != 0) { return false; }

// Q = Q @ L^{-T}
// This is equivalent to solving X * L^T = Q for X, i.e. trsm with RIGHT, LOWER, TRANS
RAFT_CUBLAS_TRY(raft::linalg::detail::cublastrsm(cublas_h,
CUBLAS_SIDE_RIGHT,
CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_T,
CUBLAS_DIAG_NON_UNIT,
m,
k,
&one,
W,
k,
Q,
m,
stream));

return true;
}

/**
* @brief CholeskyQR2 orthogonalization: two passes of CholeskyQR for numerical stability.
*
* This is the GPU-optimized orthogonalization from Tomás, Quintana-Ortí, Anzt (2024),
* "Fast Truncated SVD of Sparse and Dense Matrices on Graphics Processors".
* It uses GEMM + Cholesky + TRSM operations which are highly efficient on GPU,
* providing ~3x speedup over standard Householder QR.
*
* If Cholesky factorization fails (input is rank-deficient), falls back to standard QR.
*
* @param handle raft resources handle
* @param Q matrix to orthogonalize of shape (m, k), col-major, modified in-place
* @return true if CholeskyQR2 succeeded, false if fell back to QR
*/
template <typename ValueTypeT>
bool cholesky_qr2(raft::resources const& handle,
Comment thread
Intron7 marked this conversation as resolved.
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Q)
{
int m = Q.extent(0);
int k = Q.extent(1);

auto stream = raft::resource::get_cuda_stream(handle);
auto cusolver_h = raft::resource::get_cusolver_dn_handle(handle);

// Allocate workspace for Gram matrix and Cholesky
rmm::device_uvector<ValueTypeT> W(k * k, stream);
rmm::device_scalar<int> dev_info(stream);

// Query workspace size for potrf
int potrf_workspace_size = 0;
RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize(
cusolver_h, CUBLAS_FILL_MODE_LOWER, k, W.data(), k, &potrf_workspace_size));
rmm::device_uvector<ValueTypeT> potrf_workspace(potrf_workspace_size, stream);

// First pass
if (!cholesky_qr_pass(handle,
Q.data_handle(),
m,
k,
W.data(),
potrf_workspace.data(),
potrf_workspace_size,
dev_info.data())) {
// Fallback to standard QR
rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream);
raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream);
raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream);
raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream);
raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream);
raft::linalg::qrGetQ(handle, Q.data_handle(), Q.data_handle(), m, k, stream);

qrGetQ already does a copy internally. Using Q.data_handle() twice would allow the operation to work inplace (even the copy could be avoided as src==dst). To double check though.

Additionally, m * k has an integer overflow risk.

return false;
}

// Second pass for improved numerical stability
if (!cholesky_qr_pass(handle,
Q.data_handle(),
m,
k,
W.data(),
potrf_workspace.data(),
potrf_workspace_size,
dev_info.data())) {
// Fallback to standard QR
rmm::device_uvector<ValueTypeT> Q_copy(m * k, stream);
raft::copy(Q_copy.data(), Q.data_handle(), m * k, stream);
raft::linalg::qrGetQ(handle, Q_copy.data(), Q.data_handle(), m, k, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here.

return false;
}

return true;
}

} // namespace raft::sparse::solver::detail
209 changes: 209 additions & 0 deletions cpp/include/raft/sparse/solver/detail/randomized_svds.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/svd.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/sparse/solver/detail/cholesky_qr.cuh>
#include <raft/sparse/solver/detail/svds_sign_correction.cuh>
#include <raft/sparse/solver/svds_types.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>

#include <algorithm>

namespace raft::sparse::solver::detail {

/**
* @brief Randomized SVD for sparse matrices using block power iteration with CholeskyQR2.
*
* Implements randomized SVD (Halko et al. 2009) with GPU-optimized CholeskyQR2
* orthogonalization (Tomás et al. 2024).
*
* The operator interface allows implicit operators (e.g. mean-centered sparse matrices)
* without materializing the dense matrix.
*
* @tparam ValueTypeT Data type (float or double)
* @tparam OperatorT Linear operator type providing apply() and apply_transpose()
*
* @param handle raft resources handle
* @param config SVD configuration (n_components, n_oversamples, n_power_iters, seed)
* @param op linear operator representing the matrix to decompose
* @param singular_values output singular values of shape (k,) in descending order
* @param U output left singular vectors of shape (m, k), col-major
* @param Vt output right singular vectors of shape (k, n), col-major
*/
template <typename ValueTypeT, typename OperatorT>
void sparse_randomized_svd(
raft::resources const& handle,
sparse_svd_config<ValueTypeT> const& config,
OperatorT const& op,
raft::device_vector_view<ValueTypeT, uint32_t> singular_values,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> U,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> Vt)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::sparse::solver::sparse_randomized_svd(%d, %d, %d)",
op.rows(),
op.cols(),
config.n_components);

int m = op.rows();
int n = op.cols();
int k = config.n_components;
int p = config.n_oversamples;

RAFT_EXPECTS(k > 0, "n_components must be positive");
RAFT_EXPECTS(k < std::min(m, n), "n_components must be less than min(m, n)");
RAFT_EXPECTS(p >= 0, "n_oversamples must be non-negative");
RAFT_EXPECTS(config.n_power_iters >= 0, "n_power_iters must be non-negative");
RAFT_EXPECTS(singular_values.extent(0) == static_cast<uint32_t>(k),
"singular_values must have size n_components");
RAFT_EXPECTS(U.extent(0) == static_cast<uint32_t>(m) && U.extent(1) == static_cast<uint32_t>(k),
"U must have shape (m, n_components)");
RAFT_EXPECTS(
Vt.extent(0) == static_cast<uint32_t>(k) && Vt.extent(1) == static_cast<uint32_t>(n),
"Vt must have shape (n_components, n)");

auto stream = raft::resource::get_cuda_stream(handle);

int block_size = std::min(k + p, std::min(m, n));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid silently clamping as it affects approximation quality. Please add a warning when n_components + n_oversamples < min(n_rows, n_features).

RAFT_EXPECTS(block_size >= k, "block_size (n_components + n_oversamples) must be >= n_components");

// Initialize RNG
uint64_t seed = config.seed.value_or(0);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uint64_t seed = config.seed.value_or(0);
uint64_t seed = config.seed.value_or(std::random_device{}());

Setting seed to std::nullopt is still deterministic and do not produce different runs. We should fix this.

raft::random::RngState rng_state(seed);

// Step 1-3: Y = A @ Omega, orthogonalize
auto Y = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(
handle, static_cast<uint32_t>(m), static_cast<uint32_t>(block_size));
{
auto Omega = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(
handle, static_cast<uint32_t>(n), static_cast<uint32_t>(block_size));
raft::random::normal(handle,
rng_state,
Omega.data_handle(),
static_cast<std::size_t>(n) * block_size,
ValueTypeT(0),
ValueTypeT(1));
op.apply(handle,
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>(
Omega.data_handle(), n, block_size),
Y.view());
} // Omega freed here
cholesky_qr2(handle, Y.view());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should maybe check the return value of cholesky_qr2 calls and emit a warning when there's a fallback to standard QR.


// Step 4: Power iterations
auto Z = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(
handle, static_cast<uint32_t>(n), static_cast<uint32_t>(block_size));

for (int iter = 0; iter < config.n_power_iters; ++iter) {
// Z = A^T @ Q -> (n, block_size)
op.apply_transpose(
handle,
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>(
Y.data_handle(), m, block_size),
Z.view());
cholesky_qr2(handle, Z.view());

// Y = A @ Z -> (m, block_size)
op.apply(handle,
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>(
Z.data_handle(), n, block_size),
Y.view());
cholesky_qr2(handle, Y.view());
}

// Q = Y after power iterations (already orthogonal)
// Q is (m, block_size)

// Step 5: Bt = A^T @ Q -> (n, block_size)
op.apply_transpose(
handle,
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>(
Y.data_handle(), m, block_size),
Z.view());

// Step 6-7: SVD of B = Bt^T where Bt = Z (n x block_size, tall matrix)
// We compute SVD(Bt) directly to avoid cuSOLVER gesvd issues with wide matrices.
// SVD(Bt) = U_bt * S * Vt_bt → SVD(B) has U_b = V_bt and Vt_b = U_bt^T
auto S_full = raft::make_device_vector<ValueTypeT, uint32_t>(handle, block_size);

// Bt = B^T is (n x block_size), we already have Bt = Z from step 5
// Z is (n, block_size) = A^T @ Q = B^T. We can reuse Z directly!
// SVD(Bt) with Bt being (n x block_size): jobu='S' gives U_bt (n x block_size),
// jobvt='A' gives Vt_bt (block_size x block_size)
auto U_bt = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(
handle, static_cast<uint32_t>(n), static_cast<uint32_t>(block_size));
auto Vt_bt = raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(
handle, static_cast<uint32_t>(block_size), static_cast<uint32_t>(block_size));

// Z is consumed by svdQR (modifies input in-place) — this is fine since Z is not used after
raft::linalg::svdQR(handle,
Z.data_handle(),
n,
block_size,
S_full.data_handle(),
U_bt.data_handle(),
Vt_bt.data_handle(),
true, // transpose right vectors: Vt_bt -> V_bt (block_size x block_size)
true, // generate left vectors
true, // generate right vectors
stream);
// After svdQR with trans_right=true:
// U_bt is (n, block_size) — left singular vectors of Bt
// Vt_bt is now V_bt (block_size x block_size) — right singular vectors of Bt (transposed)
// S_full has block_size singular values
//
// For B = Bt^T: U_b = V_bt = Vt_bt (after transpose), Vt_b = U_bt^T
// So: U_b[:, :k] = Vt_bt[:, :k] and Vt_b[:k, :] = U_bt[:, :k]^T

// Step 8: U = Q @ U_b[:, :k] = Q @ V_bt[:, :k]
// Q is Y (m, block_size), V_bt is (block_size, block_size)
// U = Y @ V_bt[:, :k] → (m, block_size) * (block_size, k) → (m, k)
const ValueTypeT one = 1;
const ValueTypeT zero = 0;
raft::linalg::gemm(handle,
Y.data_handle(),
m,
block_size,
Vt_bt.data_handle(), // This is V_bt after trans_right=true
U.data_handle(),
m,
k,
CUBLAS_OP_N,
CUBLAS_OP_N,
one,
zero,
stream);

// Step 9: Truncate S and Vt
raft::copy(singular_values.data_handle(), S_full.data_handle(), k, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nitpick

Suggested change
raft::copy(singular_values.data_handle(), S_full.data_handle(), k, stream);
raft::copy(handle, singular_values, raft::make_const_mdspan(S_full));


// Vt[:k, :] = U_bt[:, :k]^T
// U_bt is col-major (n, block_size), we need the first k columns transposed to (k, n)
// Vt(i,j) = U_bt(j,i) for i < k
// This is: Vt = (U_bt[:, :k])^T
// Use GEMM with identity: Vt = I_k @ U_bt[:, :k]^T doesn't help
// Just use transpose: transpose the first k columns of U_bt
// U_bt[:, :k] is (n, k) col-major → transpose to (k, n) col-major = Vt
raft::linalg::transpose(handle, U_bt.data_handle(), Vt.data_handle(), n, k, stream);

// Step 10: Sign correction
svd_sign_correction(handle, U, Vt);
}

} // namespace raft::sparse::solver::detail
Loading
Loading