diff --git a/cpp/include/raft/core/compressed_mdarray.hpp b/cpp/include/raft/core/compressed_mdarray.hpp new file mode 100644 index 0000000000..4492860094 --- /dev/null +++ b/cpp/include/raft/core/compressed_mdarray.hpp @@ -0,0 +1,528 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { + +// ============================================================================ +// compressed_data_handle: the data_handle_type for compressed mdspan/mdarray +// ============================================================================ + +template +struct compressed_data_handle { + CodebookView codebook; + CodesView codes; + size_t base_offset = 0; +}; + +// ============================================================================ +// Reconstruction specs: the only compression-type-specific code +// ============================================================================ + +/** + * @brief Scalar quantization reconstruction spec (matches cuvs global min/max SQ). + * + * Codebook is a [2] vector holding {min, max}. Scale and offset are derived + * at construction time. Codes are int8_t in [-128, 127]. + */ +template +struct sq_spec { + using math_type = std::remove_const_t; + using codebook_view_type = CodebookView; + using codes_view_type = CodesView; + + math_type scale; + math_type offset; + + RAFT_INLINE_FUNCTION + static uint32_t dim(compressed_data_handle const& h) + { + return static_cast(h.codes.extent(1)); + } + + RAFT_INLINE_FUNCTION + math_type reconstruct(compressed_data_handle const& h, + size_t row, + size_t col) const + { + auto code = h.codes(row, col); + return static_cast((static_cast(code) - offset) / scale); + } +}; + +template +auto make_sq_spec(MathT min_val, MathT max_val) -> sq_spec +{ + using T = std::remove_const_t; + T q_min = T(-128); + T q_max = T(127); + T range = static_cast(max_val) - static_cast(min_val); + T scale = (range > T(0)) ? ((q_max - q_min) / range) : T(1); + T offset = q_min - static_cast(min_val) * scale; + return {scale, offset}; +} + +/** + * @brief PQ reconstruction spec with a single global codebook (stateless). + * + * Codebook shape: [n_centers, pq_len]. Codes shape: [n_rows, pq_dim]. + * All parameters are derived from the handle views at access time. + */ +template +struct pq_spec { + using math_type = std::remove_const_t; + using codebook_view_type = CodebookView; + using codes_view_type = CodesView; + + RAFT_INLINE_FUNCTION + static uint32_t dim(compressed_data_handle const& h) + { + return static_cast(h.codes.extent(1)) * static_cast(h.codebook.extent(1)); + } + + RAFT_INLINE_FUNCTION + static math_type reconstruct(compressed_data_handle const& h, + size_t row, + size_t col) + { + uint32_t pq_len = static_cast(h.codebook.extent(1)); + uint32_t subspace = static_cast(col) / pq_len; + uint32_t component = static_cast(col) % pq_len; + uint32_t code = static_cast(h.codes(row, subspace)); + return static_cast(h.codebook(code, component)); + } +}; + +/** + * @brief PQ reconstruction spec with per-subspace codebooks (rank-3, stateless). + * + * Codebook shape: [pq_dim, pq_len, n_centers] (IVF-PQ convention). + * Codes shape: [n_rows, pq_dim]. + * All parameters are derived from the handle views at access time. + */ +template +struct pq_subspace_spec { + using math_type = std::remove_const_t; + using codebook_view_type = CodebookView; + using codes_view_type = CodesView; + + RAFT_INLINE_FUNCTION + static uint32_t dim(compressed_data_handle const& h) + { + return static_cast(h.codes.extent(1)) * static_cast(h.codebook.extent(1)); + } + + RAFT_INLINE_FUNCTION + static math_type reconstruct(compressed_data_handle const& h, + size_t row, + size_t col) + { + uint32_t pq_len = static_cast(h.codebook.extent(1)); + uint32_t subspace = static_cast(col) / pq_len; + uint32_t component = static_cast(col) % pq_len; + uint32_t code = static_cast(h.codes(row, subspace)); + return static_cast(h.codebook(subspace, component, code)); + } +}; + +// ============================================================================ +// compressed_accessor: the accessor_policy for compressed mdspan +// ============================================================================ + +template +struct compressed_accessor { + using element_type = MathT; + using codebook_view_type = typename ReconstructSpec::codebook_view_type; + using codes_view_type = typename ReconstructSpec::codes_view_type; + using data_handle_type = compressed_data_handle; + using reference = std::remove_const_t; + using offset_policy = compressed_accessor; + + ReconstructSpec spec; + + compressed_accessor() = default; + constexpr explicit compressed_accessor(ReconstructSpec s) : spec(s) {} + + RAFT_INLINE_FUNCTION + reference access(data_handle_type h, size_t off) const + { + size_t d = spec.dim(h); + size_t total = h.base_offset + off; + size_t row = total / d; + size_t col = total % d; + return spec.reconstruct(h, row, col); + } + + RAFT_INLINE_FUNCTION + data_handle_type offset(data_handle_type h, size_t i) const noexcept + { + return {h.codebook, h.codes, h.base_offset + i}; + } +}; + +// ============================================================================ +// compressed_container: owns codebook and codes as inner mdarrays +// ============================================================================ + +/** + * @brief Owning container that stores a codebook mdarray and a codes mdarray. + * + * The container is a simple pair — no reconstruction logic here. + * The container_policy's access() handles reconstruction via the spec. + */ +template +class compressed_container { + public: + using codebook_view_type = typename CodebookMdarray::const_view_type; + using codes_view_type = typename CodesMdarray::const_view_type; + + using value_type = typename CodebookMdarray::element_type; + using size_type = size_t; + using reference = std::remove_const_t; + using const_reference = reference; + using pointer = compressed_data_handle; + using const_pointer = pointer; + + private: + CodebookMdarray codebook_; + CodesMdarray codes_; + + public: + compressed_container(CodebookMdarray&& codebook, CodesMdarray&& codes) + : codebook_(std::move(codebook)), codes_(std::move(codes)) + { + } + + [[nodiscard]] auto data() const noexcept -> pointer + { + return {codebook_.view(), codes_.view(), 0}; + } + + [[nodiscard]] auto data() noexcept -> pointer + { + return const_cast(this)->data(); + } + + [[nodiscard]] auto codebook() const noexcept -> CodebookMdarray const& { return codebook_; } + [[nodiscard]] auto codes() const noexcept -> CodesMdarray const& { return codes_; } +}; + +// ============================================================================ +// compressed_container_policy: bridges container and accessor for mdarray +// ============================================================================ + +template +class compressed_container_policy { + public: + using element_type = MathT; + using spec_type = ReconstructSpec; + using container_type = compressed_container; + + using pointer = typename container_type::pointer; + using const_pointer = typename container_type::const_pointer; + using reference = typename container_type::reference; + using const_reference = typename container_type::const_reference; + + using accessor_policy = compressed_accessor; + using const_accessor_policy = compressed_accessor; + + private: + spec_type spec_{}; + + public: + compressed_container_policy() = default; + + explicit compressed_container_policy(spec_type spec) : spec_(spec) {} + + [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference + { + auto acc = accessor_policy{spec_}; + return acc.access(c.data(), n); + } + + [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept + -> const_reference + { + auto acc = const_accessor_policy{spec_}; + return acc.access(c.data(), n); + } + + [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{spec_}; } + [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{spec_}; } +}; + +// ============================================================================ +// Type aliases: mdspan views +// ============================================================================ + +// ---- SQ ---- + +template +using sq_host_matrix_view = + host_mdspan, + layout_c_contiguous, + compressed_accessor, + host_matrix_view>>>; + +template +using sq_device_matrix_view = + device_mdspan, + layout_c_contiguous, + compressed_accessor, + device_matrix_view>>>; + +// ---- PQ global ---- + +template +using pq_host_matrix_view = + host_mdspan, + layout_c_contiguous, + compressed_accessor, + host_matrix_view>>>; + +template +using pq_device_matrix_view = + device_mdspan, + layout_c_contiguous, + compressed_accessor, + device_matrix_view>>>; + +// ---- PQ per-subspace (rank-3 codebook) ---- + +template +using host_pq_subspace_codebook_view = + host_mdspan>; + +template +using device_pq_subspace_codebook_view = + device_mdspan>; + +template +using pq_subspace_host_matrix_view = + host_mdspan, + layout_c_contiguous, + compressed_accessor, + host_matrix_view>>>; + +template +using pq_subspace_device_matrix_view = device_mdspan< + MathT const, + matrix_extent, + layout_c_contiguous, + compressed_accessor, + device_matrix_view>>>; + +// ============================================================================ +// Type aliases: inner codebook mdarrays (rank-3 for PQ-subspace) +// ============================================================================ + +template +using host_pq_subspace_codebook = + host_mdarray>; + +template +using device_pq_subspace_codebook = + device_mdarray>; + +// ============================================================================ +// Type aliases: owning mdarray +// ============================================================================ + +// ---- SQ ---- + +template +using sq_host_matrix = + host_mdarray, + layout_c_contiguous, + compressed_container_policy, + host_matrix_view>, + host_vector, + host_matrix>>; + +template +using sq_device_matrix = + device_mdarray, + layout_c_contiguous, + compressed_container_policy, + device_matrix_view>, + device_vector, + device_matrix>>; + +// ---- PQ global ---- + +template +using pq_host_matrix = + host_mdarray, + layout_c_contiguous, + compressed_container_policy, + host_matrix_view>, + host_matrix, + host_matrix>>; + +template +using pq_device_matrix = + device_mdarray, + layout_c_contiguous, + compressed_container_policy, + device_matrix_view>, + device_matrix, + device_matrix>>; + +// ---- PQ per-subspace ---- + +template +using pq_subspace_host_matrix = host_mdarray< + MathT, + matrix_extent, + layout_c_contiguous, + compressed_container_policy, + host_matrix_view>, + host_pq_subspace_codebook, + host_matrix>>; + +template +using pq_subspace_device_matrix = device_mdarray< + MathT, + matrix_extent, + layout_c_contiguous, + compressed_container_policy, + device_matrix_view>, + device_pq_subspace_codebook, + device_matrix>>; + +// ============================================================================ +// Factory functions: non-owning views +// ============================================================================ + +template +auto make_sq_host_matrix_view(host_vector_view codebook, + host_matrix_view codes, + MathT min_val, + MathT max_val) -> sq_host_matrix_view +{ + auto spec = make_sq_spec, + host_matrix_view>(min_val, max_val); + + using accessor_t = typename sq_host_matrix_view::accessor_type; + using handle_t = typename accessor_t::data_handle_type; + + handle_t handle{codebook, codes, 0}; + auto dim = static_cast(codes.extent(1)); + auto n_rows = static_cast(codes.extent(0)); + auto mapping = + typename sq_host_matrix_view::mapping_type{make_extents(n_rows, dim)}; + + return sq_host_matrix_view(handle, mapping, accessor_t{spec}); +} + +template +auto make_pq_host_matrix_view(host_matrix_view codebook, + host_matrix_view codes, + uint32_t dim) -> pq_host_matrix_view +{ + using accessor_t = typename pq_host_matrix_view::accessor_type; + using handle_t = typename accessor_t::data_handle_type; + + handle_t handle{codebook, codes, 0}; + auto n_rows = static_cast(codes.extent(0)); + auto mapping = + typename pq_host_matrix_view::mapping_type{make_extents(n_rows, IdxT(dim))}; + + return pq_host_matrix_view(handle, mapping, accessor_t{}); +} + +template +auto make_pq_subspace_host_matrix_view(host_pq_subspace_codebook_view codebook, + host_matrix_view codes, + uint32_t dim) -> pq_subspace_host_matrix_view +{ + using accessor_t = typename pq_subspace_host_matrix_view::accessor_type; + using handle_t = typename accessor_t::data_handle_type; + + handle_t handle{codebook, codes, 0}; + auto n_rows = static_cast(codes.extent(0)); + auto mapping = typename pq_subspace_host_matrix_view::mapping_type{ + make_extents(n_rows, IdxT(dim))}; + + return pq_subspace_host_matrix_view(handle, mapping, accessor_t{}); +} + +// ============================================================================ +// Factory functions: owning mdarray from existing component mdarrays +// ============================================================================ + +/** + * @brief Create a PQ host matrix from existing codebook and codes mdarrays. + * + * All parameters (dim, pq_len, n_centers) are derived from the component shapes. + */ +template +auto make_pq_host_matrix(host_matrix&& codebook, + host_matrix&& codes) -> pq_host_matrix +{ + uint32_t pq_len = static_cast(codebook.extent(1)); + uint32_t pq_dim = static_cast(codes.extent(1)); + uint32_t dim = pq_dim * pq_len; + auto n_rows = static_cast(codes.extent(0)); + + using container_t = typename pq_host_matrix::container_type; + using policy_t = typename pq_host_matrix::container_policy_type; + using mapping_t = typename pq_host_matrix::mapping_type; + + auto container = container_t(std::move(codebook), std::move(codes)); + auto mapping = mapping_t{make_extents(n_rows, IdxT(dim))}; + return pq_host_matrix(mapping, std::move(container), policy_t{}); +} + +} // namespace raft diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index fc1cd87b7a..b301257c5e 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -1,6 +1,6 @@ /* * SPDX-FileCopyrightText: Copyright (2019) Sandia Corporation - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause */ /* @@ -193,6 +193,17 @@ class mdarray { } + /** + * @brief Construct from a pre-built container, bypassing cp_.create(). + * + * Use this when the container has already been constructed externally + * (e.g. by compressed_mdarray or shared_mdarray factories). + */ + constexpr mdarray(mapping_type const& m, container_type&& c, container_policy_type const& cp) + : cp_(cp), map_(m), c_(std::move(c)) + { + } + /** * @brief Get an mdspan */ @@ -210,6 +221,11 @@ class mdarray [[nodiscard]] auto data_handle() noexcept -> pointer { return c_.data(); } [[nodiscard]] constexpr auto data_handle() const noexcept -> const_pointer { return c_.data(); } + /** + * @brief Extract the underlying container (only callable on rvalues). + */ + [[nodiscard]] constexpr auto release_container() && -> container_type { return std::move(c_); } + /** * @brief Indexing operator, use it sparingly since it triggers a device<->host copy. */ diff --git a/cpp/include/raft/core/shared_mdarray.hpp b/cpp/include/raft/core/shared_mdarray.hpp new file mode 100644 index 0000000000..f28e1fd05f --- /dev/null +++ b/cpp/include/raft/core/shared_mdarray.hpp @@ -0,0 +1,288 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { + +/** + * @brief A copyable container that wraps an inner (move-only) container in a shared_ptr. + * + * All type aliases (pointer, reference, etc.) are forwarded from the inner container, + * so the shared_container is a drop-in replacement that adds copy semantics via + * reference-counted shared ownership. + */ +template +class shared_container { + std::shared_ptr owner_; + + public: + using value_type = typename InnerContainer::value_type; + using size_type = typename InnerContainer::size_type; + using reference = typename InnerContainer::reference; + using const_reference = typename InnerContainer::const_reference; + using pointer = typename InnerContainer::pointer; + using const_pointer = typename InnerContainer::const_pointer; + using iterator = typename InnerContainer::iterator; + using const_iterator = typename InnerContainer::const_iterator; + + shared_container() = default; + + explicit shared_container(InnerContainer&& c) + : owner_(std::make_shared(std::move(c))) + { + } + + shared_container(shared_container const&) = default; + shared_container(shared_container&&) = default; + shared_container& operator=(shared_container const&) = default; + shared_container& operator=(shared_container&&) = default; + + [[nodiscard]] auto data() noexcept -> pointer { return owner_->data(); } + [[nodiscard]] auto data() const noexcept -> const_pointer { return owner_->data(); } + + template + auto operator[](Index i) noexcept -> reference + { + return (*owner_)[i]; + } + template + auto operator[](Index i) const noexcept -> const_reference + { + return (*owner_)[i]; + } +}; + +/** + * @brief A container policy that wraps any inner container policy, replacing its + * container_type with shared_container. + * + * All other type aliases (pointer, reference, accessor_policy, etc.) are forwarded + * from the inner policy, preserving type identity with the corresponding non-shared mdarray. + */ +template +class shared_container_policy { + InnerPolicy inner_; + + public: + using element_type = typename InnerPolicy::element_type; + using container_type = shared_container; + using pointer = typename InnerPolicy::pointer; + using const_pointer = typename InnerPolicy::const_pointer; + using reference = typename InnerPolicy::reference; + using const_reference = typename InnerPolicy::const_reference; + using accessor_policy = typename InnerPolicy::accessor_policy; + using const_accessor_policy = typename InnerPolicy::const_accessor_policy; + + shared_container_policy() = default; + explicit shared_container_policy(InnerPolicy inner) : inner_(std::move(inner)) {} + + auto create(raft::resources const& res, size_t n) -> container_type + { + return container_type(inner_.create(res, n)); + } + + [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference + { + return c[n]; + } + [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept + -> const_reference + { + return c[n]; + } + + [[nodiscard]] auto make_accessor_policy() noexcept { return inner_.make_accessor_policy(); } + [[nodiscard]] auto make_accessor_policy() const noexcept { return inner_.make_accessor_policy(); } +}; + +/** + * @defgroup shared_mdarray_aliases Shared mdarray type aliases + * @{ + */ + +template > +using shared_device_mdarray = mdarray>>; + +template +using shared_device_scalar = shared_device_mdarray>; + +template +using shared_device_vector = + shared_device_mdarray, LayoutPolicy>; + +template +using shared_device_matrix = + shared_device_mdarray, LayoutPolicy>; + +template > +using shared_host_mdarray = mdarray>>; + +template +using shared_host_scalar = shared_host_mdarray>; + +template +using shared_host_vector = shared_host_mdarray, LayoutPolicy>; + +template +using shared_host_matrix = shared_host_mdarray, LayoutPolicy>; + +/** @} */ + +/** + * @defgroup shared_mdarray_factories Shared mdarray factory functions + * @{ + */ + +/** + * @brief Move a regular mdarray into a shared (reference-counted, copyable) mdarray. + * + * This is a zero-copy operation: the underlying storage is moved into a shared_ptr. + * The source mdarray is left in a moved-from state. + * + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy + * @param src the source mdarray (consumed via move) + * @return a shared mdarray with the same data, shape, and layout + */ +template +auto make_shared_mdarray(mdarray&& src) +{ + using inner_policy_type = typename ContainerPolicy::accessor_type; + using shared_policy_type = shared_container_policy; + using shared_cp_type = host_device_accessor; + using shared_mdarray_type = mdarray; + + using shared_container_t = typename shared_mdarray_type::container_type; + + auto mapping = src.mapping(); + shared_container_t sc{std::move(src).release_container()}; + return shared_mdarray_type(mapping, std::move(sc), shared_cp_type{}); +} + +/** + * @brief Create a shared device mdarray. + * @tparam ElementType the data type of the elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::resources + * @param exts dimensionality of the array (series of integers) + * @return raft::shared_device_mdarray + */ +template +auto make_shared_device_mdarray(raft::resources const& handle, extents exts) +{ + using mdarray_t = shared_device_mdarray; + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{}; + return mdarray_t{handle, layout, policy}; +} + +/** + * @brief Create a shared host mdarray. + * @tparam ElementType the data type of the elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::resources + * @param exts dimensionality of the array (series of integers) + * @return raft::shared_host_mdarray + */ +template +auto make_shared_host_mdarray(raft::resources const& handle, extents exts) +{ + using mdarray_t = shared_host_mdarray; + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{}; + return mdarray_t{handle, layout, policy}; +} + +/** + * @brief Create a 2-dim shared device matrix. + */ +template +auto make_shared_device_matrix(raft::resources const& handle, IndexType n_rows, IndexType n_cols) +{ + return make_shared_device_mdarray( + handle, make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a 1-dim shared device vector. + */ +template +auto make_shared_device_vector(raft::resources const& handle, IndexType n) +{ + return make_shared_device_mdarray( + handle, make_extents(n)); +} + +/** + * @brief Create a 2-dim shared host matrix. + */ +template +auto make_shared_host_matrix(raft::resources const& handle, IndexType n_rows, IndexType n_cols) +{ + return make_shared_host_mdarray( + handle, make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a 1-dim shared host vector. + */ +template +auto make_shared_host_vector(raft::resources const& handle, IndexType n) +{ + return make_shared_host_mdarray(handle, + make_extents(n)); +} + +/** @} */ + +} // namespace raft diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 47ac6fc286..7eece762a7 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -102,6 +102,8 @@ if(BUILD_TESTS) core/interruptible.cu core/nvtx.cpp core/mdarray.cu + core/shared_mdarray.cu + core/compressed_mdarray.cu core/mdbuffer.cu core/mdspan_copy.cpp core/mdspan_copy.cu diff --git a/cpp/tests/core/compressed_mdarray.cu b/cpp/tests/core/compressed_mdarray.cu new file mode 100644 index 0000000000..08f34f31d2 --- /dev/null +++ b/cpp/tests/core/compressed_mdarray.cu @@ -0,0 +1,191 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include + +#include + +#include +#include + +namespace raft { + +TEST(CompressedMDArray, PQGlobalReconstruct) +{ + std::vector codebook_data = { + 0.0f, + 0.0f, + 1.0f, + 0.0f, + 0.0f, + 1.0f, + 1.0f, + 1.0f, + }; + + std::vector codes_data = { + 0, + 1, + 2, + 3, + 0, + 1, + }; + + auto codebook_view = make_host_matrix_view(codebook_data.data(), 4, 2); + auto codes_view = make_host_matrix_view(codes_data.data(), 2, 3); + + auto view = make_pq_host_matrix_view(codebook_view, codes_view, 6); + + ASSERT_EQ(view.extent(0), 2); + ASSERT_EQ(view.extent(1), 6); + + EXPECT_FLOAT_EQ(view(0, 0), 0.0f); + EXPECT_FLOAT_EQ(view(0, 1), 0.0f); + EXPECT_FLOAT_EQ(view(0, 2), 1.0f); + EXPECT_FLOAT_EQ(view(0, 3), 0.0f); + EXPECT_FLOAT_EQ(view(0, 4), 0.0f); + EXPECT_FLOAT_EQ(view(0, 5), 1.0f); + + EXPECT_FLOAT_EQ(view(1, 0), 1.0f); + EXPECT_FLOAT_EQ(view(1, 1), 1.0f); + EXPECT_FLOAT_EQ(view(1, 2), 0.0f); + EXPECT_FLOAT_EQ(view(1, 3), 0.0f); + EXPECT_FLOAT_EQ(view(1, 4), 1.0f); + EXPECT_FLOAT_EQ(view(1, 5), 0.0f); +} + +TEST(CompressedMDArray, PQSubspaceReconstruct) +{ + // codebook(subspace, component, center) with shape [2, 2, 3] + std::vector codebook_data = { + 10.0f, + 20.0f, + 30.0f, + 11.0f, + 21.0f, + 31.0f, + 100.0f, + 200.0f, + 300.0f, + 101.0f, + 201.0f, + 301.0f, + }; + + std::vector codes_data = { + 0, + 2, + 1, + 0, + }; + + using cb_view_t = host_pq_subspace_codebook_view; + auto codebook_view = + cb_view_t(codebook_data.data(), make_extents(uint32_t(2), uint32_t(2), uint32_t(3))); + auto codes_view = make_host_matrix_view(codes_data.data(), 2, 2); + + auto view = make_pq_subspace_host_matrix_view(codebook_view, codes_view, 4); + + ASSERT_EQ(view.extent(0), 2); + ASSERT_EQ(view.extent(1), 4); + + EXPECT_FLOAT_EQ(view(0, 0), 10.0f); + EXPECT_FLOAT_EQ(view(0, 1), 11.0f); + EXPECT_FLOAT_EQ(view(0, 2), 300.0f); + EXPECT_FLOAT_EQ(view(0, 3), 301.0f); + + EXPECT_FLOAT_EQ(view(1, 0), 20.0f); + EXPECT_FLOAT_EQ(view(1, 1), 21.0f); + EXPECT_FLOAT_EQ(view(1, 2), 100.0f); + EXPECT_FLOAT_EQ(view(1, 3), 101.0f); +} + +TEST(CompressedMDArray, SQReconstruct) +{ + float min_val = -1.0f; + float max_val = 1.0f; + + std::vector codes_data = { + 0, + 127, + -128, + 64, + -64, + 0, + }; + + std::vector codebook_data = {min_val, max_val}; + + auto codebook_view = + make_host_vector_view(codebook_data.data(), uint32_t(2)); + auto codes_view = make_host_matrix_view(codes_data.data(), 2, 3); + + auto view = + make_sq_host_matrix_view(codebook_view, codes_view, min_val, max_val); + + ASSERT_EQ(view.extent(0), 2); + ASSERT_EQ(view.extent(1), 3); + + float scale = 127.5f; + float offset = -0.5f; + + EXPECT_NEAR(view(0, 0), (0.0f - offset) / scale, 1e-5f); + EXPECT_NEAR(view(0, 1), (127.0f - offset) / scale, 1e-5f); + EXPECT_NEAR(view(0, 2), (-128.0f - offset) / scale, 1e-5f); + + EXPECT_NEAR(view(1, 0), (64.0f - offset) / scale, 1e-5f); + EXPECT_NEAR(view(1, 1), (-64.0f - offset) / scale, 1e-5f); + EXPECT_NEAR(view(1, 2), (0.0f - offset) / scale, 1e-5f); +} + +TEST(CompressedMDArray, PQHostMdarrayFactory) +{ + raft::resources handle; + + // Create codebook and codes as regular host mdarrays + auto codebook = raft::make_host_matrix(handle, std::uint32_t(4), std::uint32_t(2)); + float cb_values[] = {0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f}; + std::copy(cb_values, cb_values + 8, codebook.data_handle()); + + auto codes = + raft::make_host_matrix(handle, std::uint32_t(2), std::uint32_t(3)); + uint8_t code_values[] = {0, 1, 2, 3, 0, 1}; + std::copy(code_values, code_values + 6, codes.data_handle()); + + // Wrap into a compressed mdarray — dim is derived from codebook/codes shapes + auto pq_mat = make_pq_host_matrix(std::move(codebook), std::move(codes)); + + ASSERT_EQ(pq_mat.extent(0), 2); + ASSERT_EQ(pq_mat.extent(1), 6); + + EXPECT_FLOAT_EQ(pq_mat(0, 0), 0.0f); + EXPECT_FLOAT_EQ(pq_mat(0, 2), 1.0f); + EXPECT_FLOAT_EQ(pq_mat(1, 0), 1.0f); + EXPECT_FLOAT_EQ(pq_mat(1, 1), 1.0f); + + auto v = pq_mat.view(); + EXPECT_FLOAT_EQ(v(0, 4), 0.0f); + EXPECT_FLOAT_EQ(v(0, 5), 1.0f); +} + +TEST(CompressedMDArray, ViewTypeIsConst) +{ + using pq_view_t = pq_host_matrix_view; + static_assert(std::is_const_v, + "pq_host_matrix_view element_type must be const"); + + using sq_view_t = sq_host_matrix_view; + static_assert(std::is_const_v, + "sq_host_matrix_view element_type must be const"); + + using pqs_view_t = pq_subspace_host_matrix_view; + static_assert(std::is_const_v, + "pq_subspace_host_matrix_view element_type must be const"); +} + +} // namespace raft diff --git a/cpp/tests/core/shared_mdarray.cu b/cpp/tests/core/shared_mdarray.cu new file mode 100644 index 0000000000..c41f2948d3 --- /dev/null +++ b/cpp/tests/core/shared_mdarray.cu @@ -0,0 +1,107 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include + +#include + +namespace raft { + +TEST(SharedMDArray, MakeSharedFromDeviceMdarray) +{ + raft::resources handle; + + auto src = raft::make_device_matrix(handle, 4, 4); + auto src_ptr = src.data_handle(); + + auto shared = raft::make_shared_mdarray(std::move(src)); + + ASSERT_EQ(shared.data_handle(), src_ptr); + ASSERT_EQ(shared.extent(0), 4); + ASSERT_EQ(shared.extent(1), 4); + ASSERT_EQ(shared.size(), 16u); +} + +TEST(SharedMDArray, CopySharesOwnership) +{ + raft::resources handle; + + auto shared1 = raft::make_shared_mdarray(raft::make_device_matrix(handle, 3, 5)); + + auto shared2 = shared1; + + ASSERT_EQ(shared1.data_handle(), shared2.data_handle()); + + raft::device_matrix_view v1 = shared1.view(); + raft::device_matrix_view v2 = shared2.view(); + ASSERT_EQ(v1.data_handle(), v2.data_handle()); + ASSERT_EQ(v1.extent(0), v2.extent(0)); + ASSERT_EQ(v1.extent(1), v2.extent(1)); +} + +TEST(SharedMDArray, SharedOutlivesOriginal) +{ + raft::resources handle; + float* ptr = nullptr; + + auto shared1 = raft::make_shared_mdarray( + raft::make_device_vector(handle, std::uint32_t{128})); + ptr = shared1.data_handle(); + + decltype(shared1) survivor(handle); + survivor = shared1; + shared1 = decltype(shared1)(handle); + + ASSERT_EQ(survivor.data_handle(), ptr); + ASSERT_EQ(survivor.extent(0), 128); +} + +TEST(SharedMDArray, AllocateDirectly) +{ + raft::resources handle; + + auto shared = raft::make_shared_device_matrix(handle, 10, 20); + + ASSERT_EQ(shared.extent(0), 10); + ASSERT_EQ(shared.extent(1), 20); + ASSERT_NE(shared.data_handle(), nullptr); + + auto copy = shared; + ASSERT_EQ(copy.data_handle(), shared.data_handle()); +} + +TEST(SharedMDArray, HostSharedMdarray) +{ + raft::resources handle; + + auto src = raft::make_host_vector(handle, 10); + for (int i = 0; i < 10; i++) { + src(i) = static_cast(i); + } + + auto shared = raft::make_shared_mdarray(std::move(src)); + auto copy = shared; + + ASSERT_EQ(copy.data_handle(), shared.data_handle()); + for (int i = 0; i < 10; i++) { + ASSERT_EQ(shared(i), static_cast(i)); + } +} + +TEST(SharedMDArray, ViewTypeIdentity) +{ + static_assert(std::is_same_v::view_type, + raft::device_matrix::view_type>, + "shared and regular device_matrix must produce the same view_type"); + + static_assert(std::is_same_v::const_view_type, + raft::device_matrix::const_view_type>, + "shared and regular device_matrix must produce the same const_view_type"); +} + +} // namespace raft