diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index d678461e3a..a0af6674de 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -11,6 +11,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/event.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/utils.h" #include "mlx/utils.h" @@ -442,12 +443,64 @@ void CommandEncoder::end_encoding() { all_inputs_.clear(); } +void CommandEncoder::signal_event( + std::shared_ptr event, + uint64_t value) { + end_encoding(); + buffer_->encodeSignalEvent(event->mtl_event(), value); + signal_events_.push_back({std::move(event), value}); +} + +void CommandEncoder::wait_event( + std::shared_ptr event, + uint64_t value) { + end_encoding(); + buffer_->encodeWait(event->mtl_event(), value); + wait_events_.push_back(std::move(event)); +} + bool CommandEncoder::needs_commit() const { auto [max_ops, max_mb] = device_.get_max_ops_mb_per_buffer(); return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb); } -void CommandEncoder::commit() { +void CommandEncoder::commit(std::function completion) { + buffer_->addCompletedHandler( + [&error_ = error_, + wait_events = std::move(wait_events_), + signal_events = std::move(signal_events_), + completion = std::move(completion)](MTL::CommandBuffer* cbuf) { + if (completion) { + completion(); + } + // If any of the waited event has error in it, poison the encoder. + for (auto& event : wait_events) { + if (event->error()) { + error_ = event->error(); + break; + } + } + // Set error only when no error happended before, to preserve the + // earliest error. + if (!error_ && cbuf->status() == MTL::CommandBufferStatusError) { + error_ = std::make_shared(fmt::format( + "[METAL] Command buffer execution failed: {}.", + cbuf->error()->localizedDescription()->utf8String())); + } + // Poison all the signaled events when error happened. + if (error_) { + for (auto& [event, value] : signal_events) { + event->set_error(error_); + } + } + // Metal won't signal the events for us on error, manually signal them + // to avoid infinite waiting. + if (cbuf->status() == MTL::CommandBufferStatusError) { + for (auto& [event, value] : signal_events) { + event->signal(value); + } + } + }); buffer_->commit(); buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); buffer_ops_ = 0; @@ -456,17 +509,14 @@ void CommandEncoder::commit() { void CommandEncoder::synchronize() { auto pool = new_scoped_memory_pool(); - auto cb = NS::RetainPtr(get_command_buffer()); + auto cbuf = buffer_; // retained end_encoding(); commit(); - cb->waitUntilCompleted(); - if (!exiting_) { - if (cb->status() == MTL::CommandBufferStatusError) { - throw std::runtime_error( - fmt::format( - "[METAL] Command buffer execution failed: {}.", - cb->error()->localizedDescription()->utf8String())); - } + cbuf->waitUntilCompleted(); + + if (error_ && !exiting_) { + auto error = std::move(error_); + throw std::runtime_error(*error); } } @@ -475,6 +525,9 @@ MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() { encoder_ = NS::RetainPtr( buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent)); fence_ = NS::TransferPtr(device_.mtl_device()->newFence()); + // Reset error when user starts to encode new commands, they are supposed to + // have handled the error in synchronize() or Event::wait(). + error_.reset(); } return encoder_.get(); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5f2e72f915..99e98076c3 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -20,6 +20,7 @@ using MTLFCList = std::vector>; class Device; +class EventImpl; class MLX_API CommandEncoder { public: @@ -90,13 +91,12 @@ class MLX_API CommandEncoder { void barrier(); void end_encoding(); + void wait_event(std::shared_ptr event, uint64_t value); + void signal_event(std::shared_ptr event, uint64_t value); bool needs_commit() const; - void commit(); + void commit(std::function completion = nullptr); void synchronize(); - MTL::CommandQueue* get_command_queue() const { - return queue_.get(); - } MTL::CommandBuffer* get_command_buffer() const { return buffer_.get(); } @@ -113,6 +113,13 @@ class MLX_API CommandEncoder { int buffer_ops_{0}; size_t buffer_sizes_{0}; + // The events hooked to current command buffer. + std::vector> wait_events_; + std::vector, uint64_t>> signal_events_; + + // Error from previous commited command buffer. + std::shared_ptr error_; + // Encoder for issuing GPU commands. // The members are used within a single ComputeCommandEncoder and will be // reset after calling end_encoding(). diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 6f55976efe..8d1200d71f 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -18,15 +18,6 @@ void new_stream(Stream s) { encoders.try_emplace(s.index, d, s.index, d.residency_set()); } -inline void check_error(MTL::CommandBuffer* cbuf) { - if (cbuf->status() == MTL::CommandBufferStatusError) { - std::ostringstream msg; - msg << "[METAL] Command buffer execution failed: " - << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); - } -} - void eval(array& arr) { auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); @@ -60,17 +51,12 @@ void eval(array& arr) { if (encoder.needs_commit()) { encoder.end_encoding(); scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); - encoder.commit(); + encoder.commit([s, buffers = std::move(buffers)]() { + scheduler::notify_task_completion(s); + }); } else { command_buffer->addCompletedHandler( - [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); - }); + [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {}); } } @@ -79,7 +65,6 @@ void finalize(Stream s) { auto& encoder = metal::get_command_encoder(s); auto* cb = encoder.get_command_buffer(); encoder.end_encoding(); - cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); encoder.commit(); } diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 78ed4fafe2..77f48f0838 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -1,62 +1,92 @@ // Copyright © 2024 Apple Inc. -#include "mlx/event.h" -#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/event.h" #include "mlx/scheduler.h" namespace mlx::core { -Event::Event(Stream stream) : stream_(stream) { - auto dtor = [](void* ptr) { - auto p = metal::new_scoped_memory_pool(); - static_cast(ptr)->release(); - }; - auto p = metal::new_scoped_memory_pool(); - event_ = std::shared_ptr( - metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor); - if (event_ == nullptr) { +/////////////////////////////////////////////////////////////////////////////// +// EventImpl implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace metal { + +EventImpl::EventImpl(Device& d) { + auto p = new_scoped_memory_pool(); + mtl_event_ = NS::TransferPtr(d.mtl_device()->newSharedEvent()); + if (!mtl_event_) { throw std::runtime_error( "[Event::Event] Failed to create Metal shared event."); } } -void Event::wait() { - if (!static_cast(event_.get()) - ->waitUntilSignaledValue(value(), -1)) { - throw std::runtime_error("[Event::wait] Timed out"); +EventImpl::~EventImpl() { + auto p = new_scoped_memory_pool(); + mtl_event_.reset(); +} + +void EventImpl::wait(uint64_t value) { + check_error(); + mtl_event_->waitUntilSignaledValue(value, -1); // never times out + check_error(); +} + +void EventImpl::signal(uint64_t value) { + mtl_event_->setSignaledValue(value); +} + +void EventImpl::set_error(std::shared_ptr error) { + std::atomic_store(&error_, std::move(error)); +} + +void EventImpl::check_error() { + auto error = std::atomic_exchange(&error_, {}); + if (error) { + throw std::runtime_error(*error); } } +} // namespace metal + +/////////////////////////////////////////////////////////////////////////////// +// Event implementations +/////////////////////////////////////////////////////////////////////////////// + +Event::Event(Stream stream) : stream_(stream) { + event_ = std::make_shared(metal::device(stream.device)); +} + +void Event::wait() { + static_cast(event_.get())->wait(value()); +} + void Event::wait(Stream stream) { + auto impl = std::static_pointer_cast(event_); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [*this]() mutable { wait(); }); + scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() { + impl->wait(value); + }); } else { auto& encoder = metal::get_command_encoder(stream); - encoder.end_encoding(); - auto* command_buffer = encoder.get_command_buffer(); - command_buffer->encodeWait(static_cast(event_.get()), value()); - command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); + encoder.wait_event(std::move(impl), value()); } } void Event::signal(Stream stream) { + auto impl = std::static_pointer_cast(event_); if (stream.device == Device::cpu) { - scheduler::enqueue(stream, [*this]() mutable { - static_cast(event_.get())->setSignaledValue(value()); + scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() { + impl->signal(value); }); } else { auto& encoder = metal::get_command_encoder(stream); - encoder.end_encoding(); - auto* command_buffer = encoder.get_command_buffer(); - command_buffer->encodeSignalEvent( - static_cast(event_.get()), value()); - command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {}); + encoder.signal_event(std::move(impl), value()); } } bool Event::is_signaled() const { - return static_cast(event_.get())->signaledValue() >= - value(); + auto* mtl_event = static_cast(event_.get())->mtl_event(); + return mtl_event->signaledValue() >= value(); } } // namespace mlx::core diff --git a/mlx/backend/metal/event.h b/mlx/backend/metal/event.h new file mode 100644 index 0000000000..c5c82a7cd3 --- /dev/null +++ b/mlx/backend/metal/event.h @@ -0,0 +1,33 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/metal/device.h" +#include "mlx/event.h" + +namespace mlx::core::metal { + +class EventImpl { + public: + EventImpl(Device& d); + ~EventImpl(); + + void wait(uint64_t value); + void signal(uint64_t value); + void set_error(std::shared_ptr error); + void check_error(); + + const auto& error() const { + return error_; + } + + auto* mtl_event() { + return mtl_event_.get(); + } + + private: + // TODO: Use std::atomic when it gets supported in Xcode. + std::shared_ptr error_; + + NS::SharedPtr mtl_event_; +}; + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 0ff7e7f3b4..6fdd57a5f6 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -7,8 +7,8 @@ namespace mlx::core { struct FenceImpl { - FenceImpl() { - auto d = metal::device(Device::gpu).mtl_device(); + FenceImpl(Stream stream) { + auto d = metal::device(stream.device).mtl_device(); if (!d->supportsFamily(MTL::GPUFamilyMetal3)) { use_fast = false; } else if (__builtin_available(macOS 15, iOS 18, *)) { @@ -16,8 +16,7 @@ struct FenceImpl { } if (!use_fast) { - auto p = metal::new_scoped_memory_pool(); - fence = static_cast(d->newSharedEvent()); + event = std::make_unique(stream); } else { auto buf = allocator::malloc(sizeof(uint32_t)).ptr(); fence = static_cast(buf); @@ -26,17 +25,14 @@ struct FenceImpl { } ~FenceImpl() { - if (!use_fast) { - // Wraps Metal SharedEvent - auto p = metal::new_scoped_memory_pool(); - static_cast(fence)->release(); - } else { + if (use_fast) { allocator::free(allocator::Buffer{static_cast(fence)}); } } bool use_fast{false}; uint32_t count{0}; void* fence; + std::unique_ptr event; std::atomic_uint* cpu_value() { return static_cast( @@ -44,24 +40,22 @@ struct FenceImpl { } }; -Fence::Fence(Stream) { +Fence::Fence(Stream stream) { auto dtor = [](void* ptr) { delete static_cast(ptr); }; - fence_ = std::shared_ptr(new FenceImpl{}, dtor); + fence_ = std::shared_ptr(new FenceImpl(stream), dtor); } void Fence::wait(Stream stream, const array& x) { auto& f = *static_cast(fence_.get()); + if (!f.use_fast) { + f.event->wait(stream); + return; + } + if (stream.device == Device::cpu) { scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { auto& f = *static_cast(fence_.get()); - if (!f.use_fast) { - if (!static_cast(f.fence)->waitUntilSignaledValue( - count, -1)) { - throw std::runtime_error("[Fence::wait] Timed out"); - } - return; - } while (f.cpu_value()[0] < count) { } }); @@ -71,15 +65,6 @@ void Fence::wait(Stream stream, const array& x) { auto& d = metal::device(stream.device); auto& compute_encoder = metal::get_command_encoder(stream); - if (!f.use_fast) { - compute_encoder.end_encoding(); - auto* command_buffer = compute_encoder.get_command_buffer(); - command_buffer->encodeWait(static_cast(f.fence), f.count); - command_buffer->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); - return; - } - // Register outputs to ensure that no kernels which depends on the // output starts before this one is done compute_encoder.register_output_array(x); @@ -101,14 +86,15 @@ void Fence::update(Stream stream, const array& x, bool cross_device) { auto& f = *static_cast(fence_.get()); f.count++; + if (!f.use_fast) { + f.event->set_value(f.count); + f.event->signal(stream); + return; + } + if (stream.device == Device::cpu) { scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { auto& f = *static_cast(fence_.get()); - if (!f.use_fast) { - static_cast(f.fence)->setSignaledValue(count); - return; - } - f.cpu_value()[0] = count; }); return; @@ -117,16 +103,6 @@ void Fence::update(Stream stream, const array& x, bool cross_device) { auto& d = metal::device(stream.device); auto& compute_encoder = metal::get_command_encoder(stream); - if (!f.use_fast) { - compute_encoder.end_encoding(); - auto* command_buffer = compute_encoder.get_command_buffer(); - command_buffer->encodeSignalEvent( - static_cast(f.fence), f.count); - command_buffer->addCompletedHandler( - [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); - return; - } - // Launch input visibility kernels if (cross_device) { auto kernel = d.get_kernel("input_coherent"); diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index dab40e48dd..31db2ed5f7 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -14,8 +14,7 @@ def setUpClass(cls): def test_groups(self): world = mx.distributed.init() - self.assertEqual(world.size(), 8) - self.assertTrue(0 <= world.rank() < 8) + self.assertTrue(0 <= world.rank() < world.size()) world2 = mx.distributed.init() self.assertEqual(world.size(), world2.size())