Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions cpp/include/raft/core/numpy_serializer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <raft/core/detail/mdspan_numpy_serializer.hpp>

namespace raft::numpy_serializer {

/**
* @defgroup numpy_serializer NumPy serialization helpers
* @{
*/

/**
* @brief Integer type used for NumPy array shape extents.
*/
using ndarray_len_t = detail::numpy_serializer::ndarray_len_t;

/**
* @brief NumPy dtype descriptor.
*/
using dtype_t = detail::numpy_serializer::dtype_t;

/**
* @brief Parsed NumPy header metadata.
*/
using header_t = detail::numpy_serializer::header_t;

/**
* @brief Return the NumPy dtype descriptor corresponding to a C++ element type.
*
* @tparam T C++ element type.
* @return NumPy dtype descriptor for T.
*/
template <typename T>
inline dtype_t get_numpy_dtype()
{
return detail::numpy_serializer::get_numpy_dtype<T>();
}

/**
* @brief Parse a NumPy dtype descriptor string.
*
* @param typestr NumPy dtype descriptor string.
* @return Parsed NumPy dtype descriptor.
*/
inline dtype_t parse_descr(std::string typestr)
{
return detail::numpy_serializer::parse_descr(typestr);
}

/**
* @brief Write a NumPy `.npy` header to an output stream.
*
* @param os Output stream.
* @param header Header metadata to write.
*/
inline void write_header(std::ostream& os, const header_t& header)
{
detail::numpy_serializer::write_header(os, header);
}

/**
* @brief Read and parse a NumPy `.npy` header from an input stream.
*
* The stream is left positioned immediately after the header.
*
* @param is Input stream.
* @return Parsed NumPy header metadata.
*/
inline header_t read_header(std::istream& is) { return detail::numpy_serializer::read_header(is); }

/** @} */

} // namespace raft::numpy_serializer
20 changes: 19 additions & 1 deletion cpp/tests/core/numpy_serializer.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <raft/core/host_mdarray.hpp>
#include <raft/core/managed_mdspan.hpp>
#include <raft/core/numpy_serializer.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>

Expand Down Expand Up @@ -98,6 +99,23 @@ TEST(NumPySerializerMDSpan, HeaderRoundTrip)
}
}

TEST(NumPySerializerMDSpan, PublicHeaderRoundTrip)
{
std::ostringstream oss;
numpy_serializer::header_t header{
numpy_serializer::get_numpy_dtype<std::uint32_t>(), false, {4, 8}};

numpy_serializer::write_header(oss, header);

std::istringstream iss(oss.str());
auto header2 = numpy_serializer::read_header(iss);

EXPECT_EQ(header, header2);
EXPECT_EQ(numpy_serializer::parse_descr(header2.dtype.to_string()), header2.dtype);
EXPECT_EQ(header2.dtype.to_string(),
numpy_serializer::get_numpy_dtype<std::uint32_t>().to_string());
}

TEST(NumPySerializerMDSpan, ManagedMDSpan)
{
raft::resources handle{};
Expand Down
1 change: 1 addition & 0 deletions docs/source/cpp_api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ expose in public APIs.
core_interruptible.rst
core_operators.rst
core_math.rst
core_numpy_serializer.rst
core_bitset.rst
core_bitmap.rst
15 changes: 15 additions & 0 deletions docs/source/cpp_api/core_numpy_serializer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
NumPy Serialization Helpers
===========================

.. role:: py(code)
:language: c++
:class: highlight

``#include <raft/core/numpy_serializer.hpp>``

namespace *raft::numpy_serializer*

.. doxygengroup:: numpy_serializer
:project: RAFT
:members:
:content-only:
Loading