Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 9 additions & 26 deletions cpp/include/rapidsmpf/memory/host_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include <rapidsmpf/error.hpp>
#include <rapidsmpf/memory/host_memory_resource.hpp>
#include <rapidsmpf/memory/pinned_memory_resource.hpp>

namespace rapidsmpf {

Expand All @@ -29,17 +28,6 @@ namespace rapidsmpf {
*/
class HostBuffer {
public:
/**
* @brief Type-erased deleter for owned storage.
*
* This deleter holds a callable that releases the underlying storage when invoked.
* It enables `HostBuffer` to take ownership of different storage types
* (e.g., `rmm::device_buffer`, `std::vector<std::uint8_t>`) without exposing their
* types. The deleter captures the owned object and destroys it when the deleter
* itself is destroyed (the `void*` parameter is ignored).
*/
using OwnedStorageDeleter = std::function<void(void*)>;

/**
* @brief Allocate a new host buffer.
*
Expand Down Expand Up @@ -176,14 +164,11 @@ class HostBuffer {
*
* @param data Vector to take ownership of (will be moved).
* @param stream CUDA stream to associate with this buffer.
* @param mr Host memory resource used to allocate the buffer.
*
* @return A new `HostBuffer` owning the vector's memory.
*/
static HostBuffer from_owned_vector(
std::vector<std::uint8_t>&& data,
rmm::cuda_stream_view stream,
rmm::host_async_resource_ref mr
std::vector<std::uint8_t>&& data, rmm::cuda_stream_view stream
);

/**
Expand All @@ -199,7 +184,6 @@ class HostBuffer {
*
* @param pinned_host_buffer Device buffer to take ownership of.
* @param stream CUDA stream to associate with this buffer.
* @param mr Pinned host memory resource used to allocate the buffer.
*
* @return A new `HostBuffer` owning the device buffer's memory.
*
Expand All @@ -212,8 +196,7 @@ class HostBuffer {
*/
static HostBuffer from_rmm_device_buffer(
std::unique_ptr<rmm::device_buffer> pinned_host_buffer,
rmm::cuda_stream_view stream,
PinnedMemoryResource& mr
rmm::cuda_stream_view stream
);

private:
Expand All @@ -222,21 +205,21 @@ class HostBuffer {
*
* @param span View of the owned memory.
* @param stream CUDA stream associated with this buffer.
* @param mr Dummy memory resource (not used for deallocation).
* @param owned_storage Unique pointer managing the owned storage lifetime.
* @param deallocate_fn Callable invoked with the current stream to release the
* underlying memory. It captures all resources needed for deallocation (e.g.
* memory resource, raw pointer, size).
*/
HostBuffer(
std::span<std::byte> span,
rmm::cuda_stream_view stream,
rmm::host_async_resource_ref mr,
std::unique_ptr<void, OwnedStorageDeleter> owned_storage
std::function<void(rmm::cuda_stream_view)> deallocate_fn
);

rmm::cuda_stream_view stream_;
rmm::host_async_resource_ref mr_;
std::span<std::byte> span_{};
/// @brief Optional owned storage that will be released when the buffer is destroyed.
std::unique_ptr<void, OwnedStorageDeleter> owned_storage_{nullptr, [](void*) {}};
/// @brief Callable that releases the underlying memory when invoked with the current
/// stream. Null when the buffer is empty.
std::function<void(rmm::cuda_stream_view)> deallocate_fn_{};
};

} // namespace rapidsmpf
71 changes: 32 additions & 39 deletions cpp/src/memory/host_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,39 @@ namespace rapidsmpf {
HostBuffer::HostBuffer(
std::size_t size, rmm::cuda_stream_view stream, rmm::host_async_resource_ref mr
)
: stream_{stream}, mr_{std::move(mr)} {
: stream_{stream} {
if (size > 0) {
auto owned_mr = cuda::mr::any_resource<cuda::mr::host_accessible>{mr};
Copy link
Copy Markdown
Member

@madsbk madsbk May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think owned_mr actually owns the memory resource. AFAIK, cuda::mr::any_resource<P> only owns the rmm::host_async_resource_ref mr, which is itself a non-owning reference. So owned_mr ends up owning a copy of a non-owning reference, not the underlying memory resource?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a reasonable thought but incorrect. any_resource(resource_ref) magically converts the ref into an owning type.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wence-, good to know! @nirandaperera could you add a comment noting this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, any_resource takes in a copy of the underlying resource in the ref. Idea is to increment the refcount if that resource is shared.

auto* ptr = static_cast<std::byte*>(
mr_.allocate(stream_, size, alignof(::cuda::std::max_align_t))
owned_mr.allocate(stream_, size, alignof(::cuda::std::max_align_t))
);
span_ = std::span<std::byte>{ptr, size};
deallocate_fn_ =
[owned_mr = std::move(owned_mr), ptr, size](rmm::cuda_stream_view s) mutable {
owned_mr.deallocate(s, ptr, size, alignof(::cuda::std::max_align_t));
};
}
}

HostBuffer::HostBuffer(
std::span<std::byte> span,
rmm::cuda_stream_view stream,
rmm::host_async_resource_ref mr,
std::unique_ptr<void, OwnedStorageDeleter> owned_storage
std::function<void(rmm::cuda_stream_view)> deallocate_fn
)
: stream_{stream},
mr_{std::move(mr)},
span_{span},
owned_storage_{std::move(owned_storage)} {}
: stream_{stream}, span_{span}, deallocate_fn_{std::move(deallocate_fn)} {}

void HostBuffer::deallocate_async() noexcept {
if (!span_.empty()) {
// If we have owned storage, release it; otherwise deallocate via mr_.
if (owned_storage_) {
owned_storage_.reset();
} else {
mr_.deallocate(
stream_, span_.data(), span_.size(), alignof(::cuda::std::max_align_t)
);
}
if (!span_.empty() && deallocate_fn_) {
deallocate_fn_(stream_);
deallocate_fn_ = nullptr;
}
span_ = {};
}

HostBuffer::HostBuffer(HostBuffer&& other) noexcept
: stream_{other.stream_},
mr_{other.mr_},
span_{std::exchange(other.span_, {})},
owned_storage_{std::move(other.owned_storage_)} {}
deallocate_fn_{std::exchange(other.deallocate_fn_, {})} {}

HostBuffer& HostBuffer::operator=(HostBuffer&& other) {
if (this != &other) {
Expand All @@ -64,9 +58,8 @@ HostBuffer& HostBuffer::operator=(HostBuffer&& other) {
std::invalid_argument
);
stream_ = other.stream_;
mr_ = other.mr_;
span_ = std::exchange(other.span_, {});
owned_storage_ = std::move(other.owned_storage_);
deallocate_fn_ = std::exchange(other.deallocate_fn_, {});
}
return *this;
}
Expand Down Expand Up @@ -125,27 +118,24 @@ HostBuffer HostBuffer::from_uint8_vector(
}

HostBuffer HostBuffer::from_owned_vector(
std::vector<std::uint8_t>&& data,
rmm::cuda_stream_view stream,
rmm::host_async_resource_ref mr
std::vector<std::uint8_t>&& data, rmm::cuda_stream_view stream
) {
// Wrap in shared_ptr so the lambda is copyable (required by std::function).
auto shared_vec = std::make_shared<std::vector<std::uint8_t>>(std::move(data));
auto* ptr = reinterpret_cast<std::byte*>(shared_vec->data());
auto size = shared_vec->size();
std::span<std::byte> span{ptr, size};
std::span<std::byte> span{ptr, shared_vec->size()};

std::unique_ptr<void, OwnedStorageDeleter> owned_storage{
ptr, [shared_vec_ = std::move(shared_vec)](void*) mutable { shared_vec_.reset(); }
return HostBuffer{
span,
stream,
[shared_vec_ = std::move(shared_vec)](rmm::cuda_stream_view) mutable {
shared_vec_.reset();
}
};

return HostBuffer{span, stream, std::move(mr), std::move(owned_storage)};
}

HostBuffer HostBuffer::from_rmm_device_buffer(
std::unique_ptr<rmm::device_buffer> pinned_host_buffer,
rmm::cuda_stream_view stream,
PinnedMemoryResource& mr
std::unique_ptr<rmm::device_buffer> pinned_host_buffer, rmm::cuda_stream_view stream
) {
RAPIDSMPF_EXPECTS(
pinned_host_buffer != nullptr,
Expand All @@ -161,14 +151,17 @@ HostBuffer HostBuffer::from_rmm_device_buffer(

// Wrap in shared_ptr so the lambda is copyable (required by std::function).
auto shared_db = std::make_shared<rmm::device_buffer>(std::move(*pinned_host_buffer));
auto* ptr = static_cast<std::byte*>(shared_db->data());
std::span<std::byte> span{ptr, shared_db->size()};

std::unique_ptr<void, OwnedStorageDeleter> owned_storage{
ptr, [shared_db_ = std::move(shared_db)](void*) mutable { shared_db_.reset(); }
std::span<std::byte> span{
static_cast<std::byte*>(shared_db->data()), shared_db->size()
};

return HostBuffer{std::move(span), stream, mr, std::move(owned_storage)};
return HostBuffer{
std::move(span),
stream,
[shared_db_ = std::move(shared_db)](rmm::cuda_stream_view) mutable {
shared_db_.reset();
}
};
}

} // namespace rapidsmpf
6 changes: 3 additions & 3 deletions cpp/tests/test_host_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ TEST_P(HostMemoryResource, from_owned_vector) {

// Create a host buffer by taking ownership of a vector
auto buffer = rapidsmpf::HostBuffer::from_owned_vector(
std::vector<std::uint8_t>(source_data), stream, mr
std::vector<std::uint8_t>(source_data), stream
);

EXPECT_NO_THROW(test_buffer(std::move(buffer), source_data));
Expand Down Expand Up @@ -178,7 +178,7 @@ TEST_P(PinnedResource, from_owned_vector) {

// Create a host buffer by taking ownership of a vector
auto buffer = rapidsmpf::HostBuffer::from_owned_vector(
std::vector<std::uint8_t>(source_data), stream, *mr
std::vector<std::uint8_t>(source_data), stream
);

EXPECT_NO_THROW(test_buffer(std::move(buffer), source_data));
Expand All @@ -197,7 +197,7 @@ TEST_P(PinnedResource, from_rmm_device_buffer) {

// Create a host buffer by taking ownership of an rmm::device_buffer
auto buffer = rapidsmpf::HostBuffer::from_rmm_device_buffer(
std::move(pinned_host_buffer), stream, *mr
std::move(pinned_host_buffer), stream
);

EXPECT_NO_THROW(test_buffer(std::move(buffer), source_data));
Expand Down
Loading