From ddd7b9502f7c5b848a3f60136457a23e9bc5bd17 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 22 May 2026 10:39:09 +0900 Subject: [PATCH] Add new_thread_unsafe_stream API --- mlx/backend/cpu/encoder.cpp | 27 ++++++++++++++++++--------- mlx/backend/cpu/encoder.h | 4 +++- mlx/backend/cpu/eval.cpp | 5 +++++ mlx/backend/cpu/eval.h | 1 + mlx/backend/cuda/device.cpp | 14 ++++++++++++-- mlx/backend/cuda/device.h | 1 + mlx/backend/cuda/eval.cpp | 21 +++++++++++++++------ mlx/backend/gpu/eval.h | 3 ++- mlx/backend/metal/device.cpp | 14 ++++++++++++-- mlx/backend/metal/device.h | 1 + mlx/backend/metal/eval.cpp | 11 +++++++++++ mlx/backend/no_gpu/eval.cpp | 5 +++++ mlx/scheduler.cpp | 2 ++ mlx/stream.cpp | 13 +++++++++++++ mlx/stream.h | 3 +++ mlx/utils.cpp | 9 ++++----- mlx/utils.h | 4 ++-- python/src/stream.cpp | 19 ++++++++++++++++++- tests/scheduler_tests.cpp | 20 +++++++++++++++----- 19 files changed, 143 insertions(+), 34 deletions(-) diff --git a/mlx/backend/cpu/encoder.cpp b/mlx/backend/cpu/encoder.cpp index 292957a548..bb2c06ea0a 100644 --- a/mlx/backend/cpu/encoder.cpp +++ b/mlx/backend/cpu/encoder.cpp @@ -6,20 +6,29 @@ namespace mlx::core::cpu { +CommandEncoder& get_command_encoder(Stream s) { + auto& encoders = get_command_encoders(); + auto it = encoders.find(s.index); + if (it == encoders.end()) { + auto& global_encoders = get_global_command_encoders(); + it = global_encoders.find(s.index); + if (it == global_encoders.end()) { + throw std::runtime_error( + fmt::format( + "There is no Stream(cpu, {}) in current thread.", s.index)); + } + } + return it->second; +} + std::unordered_map& get_command_encoders() { static thread_local std::unordered_map encoders; return encoders; } -CommandEncoder& get_command_encoder(Stream stream) { - auto& encoders = get_command_encoders(); - auto it = encoders.find(stream.index); - if (it == encoders.end()) { - throw std::runtime_error( - fmt::format( - "There is no Stream(cpu, {}) in current thread.", stream.index)); - } - return it->second; +std::unordered_map& get_global_command_encoders() { + static std::unordered_map encoders; + return encoders; } } // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/encoder.h b/mlx/backend/cpu/encoder.h index 9a2c35a932..cd015623f6 100644 --- a/mlx/backend/cpu/encoder.h +++ b/mlx/backend/cpu/encoder.h @@ -62,7 +62,9 @@ struct MLX_API CommandEncoder { int num_ops_{0}; }; -MLX_API CommandEncoder& get_command_encoder(Stream stream); +MLX_API CommandEncoder& get_command_encoder(Stream s); + std::unordered_map& get_command_encoders(); +std::unordered_map& get_global_command_encoders(); } // namespace mlx::core::cpu diff --git a/mlx/backend/cpu/eval.cpp b/mlx/backend/cpu/eval.cpp index 23d055287a..354820f0fe 100644 --- a/mlx/backend/cpu/eval.cpp +++ b/mlx/backend/cpu/eval.cpp @@ -12,6 +12,11 @@ void new_stream(Stream s) { encoders.try_emplace(s.index, s); } +void new_thread_unsafe_stream(Stream s) { + auto& encoders = get_global_command_encoders(); + encoders.try_emplace(s.index, s); +} + void clear_streams() { get_command_encoders().clear(); } diff --git a/mlx/backend/cpu/eval.h b/mlx/backend/cpu/eval.h index 0b46663280..775f3e46f0 100644 --- a/mlx/backend/cpu/eval.h +++ b/mlx/backend/cpu/eval.h @@ -8,6 +8,7 @@ namespace mlx::core::cpu { void new_stream(Stream s); +void new_thread_unsafe_stream(Stream s); void eval(array& arr); void clear_streams(); diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 812a77277c..30248f5568 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -566,8 +566,13 @@ CommandEncoder& get_command_encoder(Stream s) { auto& encoders = get_command_encoders(); auto it = encoders.find(s.index); if (it == encoders.end()) { - throw std::runtime_error( - fmt::format("There is no Stream(gpu, {}) in current thread.", s.index)); + auto& global_encoders = get_global_command_encoders(); + it = global_encoders.find(s.index); + if (it == global_encoders.end()) { + throw std::runtime_error( + fmt::format( + "There is no Stream(gpu, {}) in current thread.", s.index)); + } } return it->second; } @@ -577,4 +582,9 @@ std::unordered_map& get_command_encoders() { return encoders; } +std::unordered_map& get_global_command_encoders() { + static std::unordered_map encoders; + return encoders; +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 99830e3cf1..78527315e8 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -208,5 +208,6 @@ MLX_API Device& device(mlx::core::Device d); MLX_API CommandEncoder& get_command_encoder(Stream s); std::unordered_map& get_command_encoders(); +std::unordered_map& get_global_command_encoders(); } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index ef9ee20cfa..e21a57bea6 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -7,6 +7,7 @@ #include "mlx/backend/cuda/event.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" +#include "mlx/utils.h" #include @@ -15,23 +16,28 @@ namespace mlx::core::gpu { void init() { // Force initalization of CUDA, so CUDA runtime get destroyed last. cudaFree(nullptr); - // Make sure CUDA event pool get destroyed after device and stream. + // Make sure native resources get destroyed after CommandEncoder. mlx::core::cu::CudaEvent::init_pool(); -} - -void new_stream(Stream s) { - // Make sure the handles get destroyed after CommandEncoder. init_cublas_handles_cache(); init_cudnn_handles_cache(); init_cudnn_conv_cache(); init_cudnn_sdpa_cache(); - // Create CommandEncoder. +} + +void new_stream(Stream s) { assert(s.device == Device::gpu); auto& encoders = cu::get_command_encoders(); auto& d = cu::device(s.device); encoders.try_emplace(s.index, d); } +void new_thread_unsafe_stream(Stream s) { + assert(s.device == Device::gpu); + auto& encoders = cu::get_global_command_encoders(); + auto& d = cu::device(s.device); + encoders.try_emplace(s.index, d); +} + void eval(array& arr) { nvtx3::scoped_range r("gpu::eval"); // Ensure CUDA context is active on this thread. Required when MLX is called @@ -82,6 +88,9 @@ void synchronize(Stream s) { void clear_streams() { cu::get_command_encoders().clear(); + if (is_main_thread()) { + cu::get_global_command_encoders().clear(); + } } } // namespace mlx::core::gpu diff --git a/mlx/backend/gpu/eval.h b/mlx/backend/gpu/eval.h index ca86687dd4..d75b30abb5 100644 --- a/mlx/backend/gpu/eval.h +++ b/mlx/backend/gpu/eval.h @@ -11,7 +11,8 @@ namespace mlx::core::gpu { void init(); -void new_stream(Stream stream); +void new_stream(Stream s); +void new_thread_unsafe_stream(Stream s); void eval(array& arr); void finalize(Stream s); void synchronize(Stream s); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index d678461e3a..e5cb88f652 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -810,8 +810,13 @@ CommandEncoder& get_command_encoder(Stream s) { auto& encoders = get_command_encoders(); auto it = encoders.find(s.index); if (it == encoders.end()) { - throw std::runtime_error( - fmt::format("There is no Stream(gpu, {}) in current thread.", s.index)); + auto& global_encoders = get_global_command_encoders(); + it = global_encoders.find(s.index); + if (it == global_encoders.end()) { + throw std::runtime_error( + fmt::format( + "There is no Stream(gpu, {}) in current thread.", s.index)); + } } return it->second; } @@ -821,6 +826,11 @@ std::unordered_map& get_command_encoders() { return encoders; } +std::unordered_map& get_global_command_encoders() { + static std::unordered_map encoders; + return encoders; +} + NS::SharedPtr new_scoped_memory_pool() { return NS::TransferPtr(NS::AutoreleasePool::alloc()->init()); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5f2e72f915..95d54a2519 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -231,6 +231,7 @@ MLX_API Device& device(mlx::core::Device); MLX_API CommandEncoder& get_command_encoder(Stream s); std::unordered_map& get_command_encoders(); +std::unordered_map& get_global_command_encoders(); NS::SharedPtr new_scoped_memory_pool(); bool is_nax_available(); diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 6f55976efe..1188b2a7bd 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" +#include "mlx/utils.h" namespace mlx::core::gpu { @@ -18,6 +19,13 @@ void new_stream(Stream s) { encoders.try_emplace(s.index, d, s.index, d.residency_set()); } +void new_thread_unsafe_stream(Stream s) { + assert(s.device == Device::gpu); + auto& encoders = metal::get_global_command_encoders(); + auto& d = metal::device(s.device); + encoders.try_emplace(s.index, d, s.index, d.residency_set()); +} + inline void check_error(MTL::CommandBuffer* cbuf) { if (cbuf->status() == MTL::CommandBufferStatusError) { std::ostringstream msg; @@ -89,6 +97,9 @@ void synchronize(Stream s) { void clear_streams() { metal::get_command_encoders().clear(); + if (is_main_thread()) { + metal::get_global_command_encoders().clear(); + } } } // namespace mlx::core::gpu diff --git a/mlx/backend/no_gpu/eval.cpp b/mlx/backend/no_gpu/eval.cpp index fa93a23382..3966754679 100644 --- a/mlx/backend/no_gpu/eval.cpp +++ b/mlx/backend/no_gpu/eval.cpp @@ -14,6 +14,11 @@ void new_stream(Stream) { "[new_stream] Cannot make gpu stream without gpu backend."); } +void new_thread_unsafe_stream(Stream) { + throw std::invalid_argument( + "[new_thread_unsafe_stream] Cannot make gpu stream without gpu backend."); +} + void eval(array&) { throw std::runtime_error("[gpu::eval] GPU backend is not available"); } diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index e0b467abd8..7507917f5b 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -3,6 +3,7 @@ #include "mlx/scheduler.h" #include "mlx/backend/cpu/eval.h" #include "mlx/backend/gpu/eval.h" +#include "mlx/utils.h" namespace mlx::core { @@ -33,6 +34,7 @@ void clear_streams() { namespace scheduler { Scheduler::Scheduler() { + is_main_thread(); gpu::init(); } diff --git a/mlx/stream.cpp b/mlx/stream.cpp index 9f09596f90..b78ee67d67 100644 --- a/mlx/stream.cpp +++ b/mlx/stream.cpp @@ -77,6 +77,19 @@ Stream new_stream(Device d) { return s; } +Stream new_thread_unsafe_stream(Device d) { + auto& [streams, mtx] = all_streams(); + std::unique_lock lock(mtx); + int index = streams.size(); + auto& s = streams.emplace_back(index, d); + if (d == Device::gpu) { + gpu::new_thread_unsafe_stream(s); + } else { + cpu::new_thread_unsafe_stream(s); + } + return s; +} + ThreadLocalStream new_thread_local_stream(Device d) { auto& [streams, mtx] = thread_local_streams(); std::lock_guard lock(mtx); diff --git a/mlx/stream.h b/mlx/stream.h index fd938955ee..8243369876 100644 --- a/mlx/stream.h +++ b/mlx/stream.h @@ -34,6 +34,9 @@ MLX_API void set_default_stream(Stream s); /** Make a new stream on the given device. */ MLX_API Stream new_stream(Device d); +/** Make a new stream that can be used in any thread. */ +MLX_API Stream new_thread_unsafe_stream(Device d); + /** Make a new stream that will be unique per thread. */ MLX_API ThreadLocalStream new_thread_local_stream(Device d); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 239e6603dd..046d849fe5 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "mlx/dtype_utils.h" @@ -118,11 +119,9 @@ void set_printoptions(PrintOptions options) { formatter.format_options = options; } -void abort_with_exception(const std::exception& error) { - std::ostringstream msg; - msg << "Terminating due to uncaught exception: " << error.what(); - std::cerr << msg.str() << std::endl; - std::abort(); +bool is_main_thread() { + static auto main_thread_id = std::this_thread::get_id(); + return main_thread_id == std::this_thread::get_id(); } Dtype result_type(const std::vector& arrays) { diff --git a/mlx/utils.h b/mlx/utils.h index d8b4c7ac99..7fde562efc 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -69,8 +69,8 @@ MLX_API void set_printoptions(PrintOptions options); MLX_API PrintFormatter& get_global_formatter(); -/** Print the exception and then abort. */ -MLX_API void abort_with_exception(const std::exception& error); +/** Return whether current thread is the first one that called this function. */ +bool is_main_thread(); /** Holds information about floating-point types. */ struct MLX_API finfo { diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 6202983914..76fe427696 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -106,7 +106,24 @@ void init_stream(nb::module_& m) { "new_stream", &mx::new_stream, "device"_a, - R"pbdoc(Make a new stream on the given device.)pbdoc"); + R"pbdoc( + Make a new stream on the given device. + + The stream can only be used on the thread where it was created on, using + it in any other thread would result in errors. + )pbdoc"); + m.def( + "new_thread_unsafe_stream", + &mx::new_thread_unsafe_stream, + "device"_a, + R"pbdoc( + Make a new stream that can be used in any thread. + + Unlike :func:`new_stream` which can only work on the thread of creation, + streams created by this API can be passed to and evaluated anywhere, but + note that currently all nodes in a graph must be evaluated in sequence + and it is user's responsibilty to ensure there is no race condition. + )pbdoc"); m.def( "new_thread_local_stream", &mx::new_thread_local_stream, diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index 93b6818e0e..3a8f7e86b7 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -69,11 +69,7 @@ TEST_CASE("test default stream in threads") { } TEST_CASE("test access stream in other thread") { - if (!gpu::is_available()) { - return; - } - - auto main_thread_stream = new_stream(Device::gpu); + auto main_thread_stream = new_stream(default_device()); eval(arange(10, main_thread_stream)); bool error_caught = false; @@ -104,6 +100,20 @@ TEST_CASE("test new stream in threads") { } } +TEST_CASE("test thread unsafe stream") { + auto s = new_thread_unsafe_stream(default_device()); + int expected = sum(arange(10, s)).item(); + + int actual = 0; + std::thread t([&] { + actual = sum(arange(10, s)).item(); + clear_streams(); + }); + t.join(); + + CHECK_EQ(expected, actual); +} + TEST_CASE("test thread local stream") { auto s = new_thread_local_stream(default_device()); int result = sum(arange(10, s)).item();