diff --git a/cpp/include/raft/mr/host_memory_resource.hpp b/cpp/include/raft/mr/host_memory_resource.hpp index 03baee661c..0c0c3e66ce 100644 --- a/cpp/include/raft/mr/host_memory_resource.hpp +++ b/cpp/include/raft/mr/host_memory_resource.hpp @@ -31,18 +31,18 @@ namespace detail { struct default_host_resource_holder { private: std::mutex lock_; - raft::mr::host_resource_ref ref_{raft::mr::new_delete_resource()}; + raft::mr::host_resource res_{raft::mr::new_delete_resource()}; public: - inline auto set(raft::mr::host_resource_ref ref) -> raft::mr::host_resource_ref + inline auto set(raft::mr::host_resource res) -> raft::mr::host_resource { std::unique_lock guard(lock_); - return std::exchange(ref_, ref); + return std::exchange(res_, res); } inline auto get() -> raft::mr::host_resource_ref { std::unique_lock guard(lock_); - return ref_; + return raft::mr::host_resource_ref{res_}; } }; @@ -64,16 +64,14 @@ inline auto get_default_host_resource() -> raft::mr::host_resource_ref /** * @brief Set the default host memory resource. * - * The caller must keep the underlying resource alive while it is set as the default * (same contract as rmm::mr::set_current_device_resource). * - * @param ref Non-owning reference to the resource to install. - * @return The previous default host resource ref. + * @param res The resource to install. + * @return The previous default host resource. */ -inline auto set_default_host_resource(raft::mr::host_resource_ref ref) - -> raft::mr::host_resource_ref +inline auto set_default_host_resource(raft::mr::host_resource res) -> raft::mr::host_resource { - return detail::default_host_resource_holder_.set(ref); + return detail::default_host_resource_holder_.set(res); } } // namespace mr diff --git a/cpp/include/raft/mr/statistics_adaptor.hpp b/cpp/include/raft/mr/statistics_adaptor.hpp index 7993ff3c0c..1553f730a5 100644 --- a/cpp/include/raft/mr/statistics_adaptor.hpp +++ b/cpp/include/raft/mr/statistics_adaptor.hpp @@ -75,6 +75,32 @@ class statistics_adaptor : public cuda::forward_property +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Snapshot of memory usage across the six tracked resource types. + * + * Returned by accessor methods on dry_run_resources and + * memory_stats_resources (e.g. get_bytes_peak(), get_bytes_current()). + */ +struct memory_stats { + std::size_t device_workspace{0}; + std::size_t device_large_workspace{0}; + std::size_t device_global{0}; + std::size_t device_managed{0}; + std::size_t host{0}; + std::size_t host_pinned{0}; + + /** + * @brief Sum of all memory stats across the six tracked categories. + * + * The three resource wrapper classes (dry_run_resources, memory_stats_resources, + * memory_tracking_resources) guarantee that every category is tracked by its own + * independent adaptor: each wrapper force-initializes all resources, captures their + * upstream refs *before* replacing the global device resource, and wraps those + * originals. Workspace and large-workspace allocations therefore bypass the + * device-global tracking adaptor and are counted exactly once, making this sum + * an accurate total when used with stats produced by any of the three wrappers. + */ + [[nodiscard]] inline constexpr auto total() const -> std::size_t + { + return device_workspace + device_large_workspace + device_global + device_managed + host + + host_pinned; + } +}; + +/** + * @brief Resources handle that wraps all reachable memory resources with + * statistics adaptors to track actual allocation usage. + * + * Inherits from raft::resources, so it can be passed anywhere a + * raft::resources& is expected. On construction the handle: + * - Materializes all tracked resource types (host, device, pinned, + * managed, workspace, large_workspace). + * - Takes a snapshot of the original resources to keep them alive. + * - Wraps each with statistics_adaptor. + * - Replaces global host and device resources with tracked versions. + * + * On destruction the handle restores global resources. + */ +class memory_stats_resources : public resources { + public: + explicit memory_stats_resources(const resources& existing) + : resources(existing), + old_host_(mr::get_default_host_resource()), + old_device_(rmm::mr::get_current_device_resource_ref()) + { + init(); + } + + ~memory_stats_resources() override + { + mr::set_default_host_resource(old_host_); + rmm::mr::set_current_device_resource(std::move(old_device_)); + } + + memory_stats_resources(memory_stats_resources const&) = delete; + memory_stats_resources& operator=(memory_stats_resources const&) = delete; + memory_stats_resources(memory_stats_resources&&) = delete; + memory_stats_resources& operator=(memory_stats_resources&&) = delete; + + [[nodiscard]] auto get_bytes_current() const -> memory_stats + { + return read_field(&mr::resource_stats::bytes_current); + } + + [[nodiscard]] auto get_bytes_peak() const -> memory_stats + { + return read_field(&mr::resource_stats::bytes_peak); + } + + [[nodiscard]] auto get_bytes_total_allocated() const -> memory_stats + { + return read_field(&mr::resource_stats::bytes_total_allocated); + } + + [[nodiscard]] auto get_bytes_total_deallocated() const -> memory_stats + { + return read_field(&mr::resource_stats::bytes_total_deallocated); + } + + [[nodiscard]] auto get_num_allocations() const -> memory_stats + { + return read_field(&mr::resource_stats::num_allocations); + } + + [[nodiscard]] auto get_num_deallocations() const -> memory_stats + { + return read_field(&mr::resource_stats::num_deallocations); + } + + private: + using field_ptr = std::atomic mr::resource_stats::*; + + [[nodiscard]] auto read_field(field_ptr field) const -> memory_stats + { + auto load = [&](const std::shared_ptr& s) -> std::size_t { + return static_cast((s.get()->*field).load(std::memory_order_relaxed)); + }; + return { + .device_workspace = load(ws_stats_), + .device_large_workspace = load(lws_stats_), + .device_global = load(device_stats_), + .device_managed = load(managed_stats_), + .host = load(host_stats_), + .host_pinned = load(pinned_stats_), + }; + } + + std::vector snapshot_; + + raft::mr::host_resource old_host_; + raft::mr::device_resource old_device_; + + using host_stats_adaptor_t = mr::statistics_adaptor; + std::unique_ptr host_adaptor_; + + using device_stats_adaptor_t = mr::statistics_adaptor; + std::unique_ptr device_adaptor_; + + std::shared_ptr host_stats_; + std::shared_ptr pinned_stats_; + std::shared_ptr managed_stats_; + std::shared_ptr ws_stats_; + std::shared_ptr lws_stats_; + std::shared_ptr device_stats_; + + void init() + { + // Independent-counting invariant + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // 1. Force-initialize all lazily-created resources (workspace, large workspace, + // pinned, managed) so that their factories resolve against the *original* + // global device MR, not a tracking wrapper we install later. + // 2. Capture every upstream ref while it still points to the original resource. + // 3. Snapshot the resource map to keep the originals alive. + // 4. Only *then* replace the global device resource with the tracking bridge. + // 5. Wrap each captured upstream with a separate statistics_adaptor. + // + // Because step 2 happens before step 4, workspace/lws allocations flow through + // their own adaptor directly to the original device MR, bypassing the device adaptor. + // Each allocation is therefore counted in exactly one category, and + // memory_stats::total() returns an accurate, non-overlapping sum. + auto* ws = resource::get_workspace_resource(*this); + auto ws_free = resource::get_workspace_free_bytes(*this); + auto ws_upstream = ws->get_upstream_resource(); + auto lws_ref = resource::get_large_workspace_resource_ref(*this); + auto pinned_ref = resource::get_pinned_memory_resource_ref(*this); + auto managed_ref = resource::get_managed_memory_resource_ref(*this); + + snapshot_ = resources_; + + // --- Host (global) --- + { + host_adaptor_ = std::make_unique(mr::host_resource_ref{old_host_}); + host_stats_ = host_adaptor_->get_stats(); + mr::set_default_host_resource(mr::host_resource_ref{*host_adaptor_}); + } + + // --- Pinned --- + { + mr::statistics_adaptor sa{pinned_ref}; + pinned_stats_ = sa.get_stats(); + resource::set_pinned_memory_resource(*this, std::move(sa)); + } + + // --- Managed --- + { + mr::statistics_adaptor sa{managed_ref}; + managed_stats_ = sa.get_stats(); + resource::set_managed_memory_resource(*this, std::move(sa)); + } + + // --- Device (global) --- + // Invalidate the cached thrust policy (the resource_ref it captured + // will be stale once we replace the global device resource). + factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( + resource::resource_type::LAST_KEY, std::make_shared()); + resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( + resource::resource_type::LAST_KEY, std::make_shared()); + { + device_stats_adaptor_t sa{rmm::device_async_resource_ref{old_device_}}; + device_stats_ = sa.get_stats(); + device_adaptor_ = std::make_unique(std::move(sa)); + rmm::mr::set_current_device_resource(*device_adaptor_); + } + // --- Workspace --- + { + mr::statistics_adaptor sa{ws_upstream}; + ws_stats_ = sa.get_stats(); + resource::set_workspace_resource(*this, std::move(sa), ws_free); + } + + // --- Large workspace --- + { + mr::statistics_adaptor sa{lws_ref}; + lws_stats_ = sa.get_stats(); + resource::set_large_workspace_resource(*this, std::move(sa)); + } + } +}; + +} // namespace raft diff --git a/cpp/include/raft/util/memory_tracking_resources.hpp b/cpp/include/raft/util/memory_tracking_resources.hpp index 6994deda01..65c0b70f41 100644 --- a/cpp/include/raft/util/memory_tracking_resources.hpp +++ b/cpp/include/raft/util/memory_tracking_resources.hpp @@ -107,8 +107,8 @@ class memory_tracking_resources : public resources { ~memory_tracking_resources() override { report_.stop(); - raft::mr::set_default_host_resource(old_host_ref_); - rmm::mr::set_current_device_resource(old_device_ref_); + raft::mr::set_default_host_resource(old_host_); + rmm::mr::set_current_device_resource(old_device_); } memory_tracking_resources(memory_tracking_resources const&) = delete; @@ -127,8 +127,8 @@ class memory_tracking_resources : public resources { : resources(existing ? *existing : resources{}), owned_stream_(std::move(owned_stream)), report_(out_override ? *out_override : *owned_stream_, sample_interval), - old_host_ref_(raft::mr::get_default_host_resource()), - old_device_ref_(rmm::mr::get_current_device_resource_ref()) + old_host_(raft::mr::get_default_host_resource()), + old_device_(rmm::mr::get_current_device_resource_ref()) { init(); } @@ -141,9 +141,8 @@ class memory_tracking_resources : public resources { std::unique_ptr owned_stream_; raft::mr::resource_monitor report_; - raft::mr::host_resource_ref old_host_ref_; - rmm::device_async_resource_ref old_device_ref_; - std::size_t saved_ws_limit_{}; + raft::mr::host_resource old_host_; + raft::mr::device_resource old_device_; using host_stats_t = raft::mr::statistics_adaptor; using host_notify_t = raft::mr::notifying_adaptor; @@ -156,8 +155,22 @@ class memory_tracking_resources : public resources { void init() { + // Independent-counting invariant + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // 1. Force-initialize all lazily-created resources (workspace, large workspace, + // pinned, managed) so that their factories resolve against the *original* + // global device MR, not a tracking wrapper we install later. + // 2. Capture every upstream ref while it still points to the original resource. + // 3. Snapshot the resource map to keep the originals alive. + // 4. Only *then* replace the global device resource with the tracking bridge. + // 5. Wrap each captured upstream with a separate statistics/notifying adaptor. + // + // Because step 2 happens before step 4, workspace/lws allocations flow through + // their own adaptor directly to the original device MR, bypassing the device adaptor. + // Each allocation is therefore counted in exactly one category, and + // memory_stats::total() returns an accurate, non-overlapping sum. auto* ws = raft::resource::get_workspace_resource(*this); - saved_ws_limit_ = ws->get_allocation_limit(); + auto ws_free = raft::resource::get_workspace_free_bytes(*this); auto upstream_ref = ws->get_upstream_resource(); auto lws_ref = raft::resource::get_large_workspace_resource_ref(*this); auto pinned_ref = raft::resource::get_pinned_memory_resource_ref(*this); @@ -168,7 +181,7 @@ class memory_tracking_resources : public resources { // --- Host (global) --- { - host_stats_t sa{old_host_ref_}; + host_stats_t sa{raft::mr::host_resource_ref{old_host_}}; report_.register_source("host", sa.get_stats()); host_adaptor_ = std::make_unique(std::move(sa), report_.get_notifier()); raft::mr::set_default_host_resource(*host_adaptor_); @@ -195,8 +208,14 @@ class memory_tracking_resources : public resources { } // --- Device (global) --- + // Invalidate the cached thrust policy (the resource_ref it captured + // will be stale once we replace the global device resource). + factories_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( + resource::resource_type::LAST_KEY, std::make_shared()); + resources_.at(resource::resource_type::THRUST_POLICY) = std::make_pair( + resource::resource_type::LAST_KEY, std::make_shared()); { - device_stats_t sa{old_device_ref_}; + device_stats_t sa{rmm::device_async_resource_ref{old_device_}}; report_.register_source("device", sa.get_stats()); device_adaptor_ = std::make_unique(std::move(sa), report_.get_notifier()); rmm::mr::set_current_device_resource(*device_adaptor_); @@ -209,7 +228,7 @@ class memory_tracking_resources : public resources { ws_stats_t sa{upstream_ref}; report_.register_source("workspace", sa.get_stats()); raft::resource::set_workspace_resource( - *this, ws_notify_t{std::move(sa), report_.get_notifier()}, saved_ws_limit_); + *this, ws_notify_t{std::move(sa), report_.get_notifier()}, ws_free); } // --- Large workspace --- diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 47ac6fc286..c5fb50fdbd 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -318,6 +318,7 @@ if(BUILD_TESTS) util/device_atomics.cu util/integer_utils.cpp util/integer_utils.cu + util/memory_stats_resources.cpp util/memory_type_dispatcher.cu util/popc.cu util/pow2_utils.cu diff --git a/cpp/tests/util/memory_stats_resources.cpp b/cpp/tests/util/memory_stats_resources.cpp new file mode 100644 index 0000000000..fb8a06b765 --- /dev/null +++ b/cpp/tests/util/memory_stats_resources.cpp @@ -0,0 +1,96 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +namespace raft { + +TEST(MemoryStatsResources, IndependentCounting_DefaultWorkspace) +{ + raft::resources res; + + memory_stats_resources stat_res(res); + + constexpr std::size_t kWsSize = 1024; + constexpr std::size_t kGlobalSize = 2048; + + auto ws_ref = resource::get_workspace_resource_ref(stat_res); + void* ws_ptr = ws_ref.allocate(cuda::stream_ref{cudaStreamLegacy}, kWsSize); + + auto dev_mr = rmm::mr::get_current_device_resource_ref(); + void* dev_ptr = dev_mr.allocate(cuda::stream_ref{cudaStreamLegacy}, kGlobalSize); + + auto peak = stat_res.get_bytes_peak(); + EXPECT_EQ(peak.device_workspace, kWsSize); + EXPECT_EQ(peak.device_global, kGlobalSize); + EXPECT_EQ(peak.total(), kWsSize + kGlobalSize); + + ws_ref.deallocate(cuda::stream_ref{cudaStreamLegacy}, ws_ptr, kWsSize); + dev_mr.deallocate(cuda::stream_ref{cudaStreamLegacy}, dev_ptr, kGlobalSize); +} + +TEST(MemoryStatsResources, IndependentCounting_WorkspaceSetToGlobal) +{ + raft::resources res; + resource::set_workspace_to_global_resource(res); + + memory_stats_resources stat_res(res); + + constexpr std::size_t kWsSize = 1024; + constexpr std::size_t kGlobalSize = 2048; + + auto ws_ref = resource::get_workspace_resource_ref(stat_res); + void* ws_ptr = ws_ref.allocate(cuda::stream_ref{cudaStreamLegacy}, kWsSize); + + auto dev_mr = rmm::mr::get_current_device_resource_ref(); + void* dev_ptr = dev_mr.allocate(cuda::stream_ref{cudaStreamLegacy}, kGlobalSize); + + auto peak = stat_res.get_bytes_peak(); + EXPECT_EQ(peak.device_workspace, kWsSize); + EXPECT_EQ(peak.device_global, kGlobalSize); + EXPECT_EQ(peak.total(), kWsSize + kGlobalSize); + + ws_ref.deallocate(cuda::stream_ref{cudaStreamLegacy}, ws_ptr, kWsSize); + dev_mr.deallocate(cuda::stream_ref{cudaStreamLegacy}, dev_ptr, kGlobalSize); +} + +TEST(MemoryStatsResources, IndependentCounting_PoolWorkspace) +{ + raft::resources res; + constexpr std::size_t kPoolLimit = 64UL * 1024UL * 1024UL; + resource::set_workspace_to_pool_resource(res, kPoolLimit); + + memory_stats_resources stat_res(res); + + constexpr std::size_t kWsSize = 1024; + constexpr std::size_t kGlobalSize = 2048; + + auto ws_ref = resource::get_workspace_resource_ref(stat_res); + void* ws_ptr = ws_ref.allocate(cuda::stream_ref{cudaStreamLegacy}, kWsSize); + + auto dev_mr = rmm::mr::get_current_device_resource_ref(); + void* dev_ptr = dev_mr.allocate(cuda::stream_ref{cudaStreamLegacy}, kGlobalSize); + + auto peak = stat_res.get_bytes_peak(); + EXPECT_EQ(peak.device_workspace, kWsSize); + EXPECT_EQ(peak.device_global, kGlobalSize); + EXPECT_EQ(peak.total(), kWsSize + kGlobalSize); + + ws_ref.deallocate(cuda::stream_ref{cudaStreamLegacy}, ws_ptr, kWsSize); + dev_mr.deallocate(cuda::stream_ref{cudaStreamLegacy}, dev_ptr, kGlobalSize); +} + +} // namespace raft