diff --git a/cpp/include/raft/core/numpy_serializer.hpp b/cpp/include/raft/core/numpy_serializer.hpp new file mode 100644 index 0000000000..eb9632fecf --- /dev/null +++ b/cpp/include/raft/core/numpy_serializer.hpp @@ -0,0 +1,78 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace raft::numpy_serializer { + +/** + * @defgroup numpy_serializer NumPy serialization helpers + * @{ + */ + +/** + * @brief Integer type used for NumPy array shape extents. + */ +using ndarray_len_t = detail::numpy_serializer::ndarray_len_t; + +/** + * @brief NumPy dtype descriptor. + */ +using dtype_t = detail::numpy_serializer::dtype_t; + +/** + * @brief Parsed NumPy header metadata. + */ +using header_t = detail::numpy_serializer::header_t; + +/** + * @brief Return the NumPy dtype descriptor corresponding to a C++ element type. + * + * @tparam T C++ element type. + * @return NumPy dtype descriptor for T. + */ +template +inline dtype_t get_numpy_dtype() +{ + return detail::numpy_serializer::get_numpy_dtype(); +} + +/** + * @brief Parse a NumPy dtype descriptor string. + * + * @param typestr NumPy dtype descriptor string. + * @return Parsed NumPy dtype descriptor. + */ +inline dtype_t parse_descr(std::string typestr) +{ + return detail::numpy_serializer::parse_descr(typestr); +} + +/** + * @brief Write a NumPy `.npy` header to an output stream. + * + * @param os Output stream. + * @param header Header metadata to write. + */ +inline void write_header(std::ostream& os, const header_t& header) +{ + detail::numpy_serializer::write_header(os, header); +} + +/** + * @brief Read and parse a NumPy `.npy` header from an input stream. + * + * The stream is left positioned immediately after the header. + * + * @param is Input stream. + * @return Parsed NumPy header metadata. + */ +inline header_t read_header(std::istream& is) { return detail::numpy_serializer::read_header(is); } + +/** @} */ + +} // namespace raft::numpy_serializer diff --git a/cpp/tests/core/numpy_serializer.cu b/cpp/tests/core/numpy_serializer.cu index cc327aa886..76c478ed9b 100644 --- a/cpp/tests/core/numpy_serializer.cu +++ b/cpp/tests/core/numpy_serializer.cu @@ -1,10 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include #include +#include #include #include @@ -98,6 +99,23 @@ TEST(NumPySerializerMDSpan, HeaderRoundTrip) } } +TEST(NumPySerializerMDSpan, PublicHeaderRoundTrip) +{ + std::ostringstream oss; + numpy_serializer::header_t header{ + numpy_serializer::get_numpy_dtype(), false, {4, 8}}; + + numpy_serializer::write_header(oss, header); + + std::istringstream iss(oss.str()); + auto header2 = numpy_serializer::read_header(iss); + + EXPECT_EQ(header, header2); + EXPECT_EQ(numpy_serializer::parse_descr(header2.dtype.to_string()), header2.dtype); + EXPECT_EQ(header2.dtype.to_string(), + numpy_serializer::get_numpy_dtype().to_string()); +} + TEST(NumPySerializerMDSpan, ManagedMDSpan) { raft::resources handle{}; diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index f159c85af8..d365abb866 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -21,5 +21,6 @@ expose in public APIs. core_interruptible.rst core_operators.rst core_math.rst + core_numpy_serializer.rst core_bitset.rst core_bitmap.rst diff --git a/docs/source/cpp_api/core_numpy_serializer.rst b/docs/source/cpp_api/core_numpy_serializer.rst new file mode 100644 index 0000000000..13715ef122 --- /dev/null +++ b/docs/source/cpp_api/core_numpy_serializer.rst @@ -0,0 +1,15 @@ +NumPy Serialization Helpers +=========================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::numpy_serializer* + + .. doxygengroup:: numpy_serializer + :project: RAFT + :members: + :content-only: