Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 63 additions & 10 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
Expand Down Expand Up @@ -442,12 +443,64 @@ void CommandEncoder::end_encoding() {
all_inputs_.clear();
}

void CommandEncoder::signal_event(
std::shared_ptr<EventImpl> event,
uint64_t value) {
end_encoding();
buffer_->encodeSignalEvent(event->mtl_event(), value);
signal_events_.push_back({std::move(event), value});
}

void CommandEncoder::wait_event(
std::shared_ptr<EventImpl> event,
uint64_t value) {
end_encoding();
buffer_->encodeWait(event->mtl_event(), value);
wait_events_.push_back(std::move(event));
}

bool CommandEncoder::needs_commit() const {
auto [max_ops, max_mb] = device_.get_max_ops_mb_per_buffer();
return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb);
}

void CommandEncoder::commit() {
void CommandEncoder::commit(std::function<void()> completion) {
buffer_->addCompletedHandler(
[&error_ = error_,
wait_events = std::move(wait_events_),
signal_events = std::move(signal_events_),
completion = std::move(completion)](MTL::CommandBuffer* cbuf) {
if (completion) {
completion();
}
// If any of the waited event has error in it, poison the encoder.
for (auto& event : wait_events) {
if (event->error()) {
error_ = event->error();
break;
}
}
// Set error only when no error happended before, to preserve the
// earliest error.
if (!error_ && cbuf->status() == MTL::CommandBufferStatusError) {
error_ = std::make_shared<std::string>(fmt::format(
"[METAL] Command buffer execution failed: {}.",
cbuf->error()->localizedDescription()->utf8String()));
}
// Poison all the signaled events when error happened.
if (error_) {
for (auto& [event, value] : signal_events) {
event->set_error(error_);
}
}
// Metal won't signal the events for us on error, manually signal them
// to avoid infinite waiting.
if (cbuf->status() == MTL::CommandBufferStatusError) {
for (auto& [event, value] : signal_events) {
event->signal(value);
}
}
});
buffer_->commit();
buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences());
buffer_ops_ = 0;
Expand All @@ -456,17 +509,14 @@ void CommandEncoder::commit() {

void CommandEncoder::synchronize() {
auto pool = new_scoped_memory_pool();
auto cb = NS::RetainPtr(get_command_buffer());
auto cbuf = buffer_; // retained
end_encoding();
commit();
cb->waitUntilCompleted();
if (!exiting_) {
if (cb->status() == MTL::CommandBufferStatusError) {
throw std::runtime_error(
fmt::format(
"[METAL] Command buffer execution failed: {}.",
cb->error()->localizedDescription()->utf8String()));
}
cbuf->waitUntilCompleted();

if (error_ && !exiting_) {
auto error = std::move(error_);
throw std::runtime_error(*error);
}
}

Expand All @@ -475,6 +525,9 @@ MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() {
encoder_ = NS::RetainPtr(
buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent));
fence_ = NS::TransferPtr(device_.mtl_device()->newFence());
// Reset error when user starts to encode new commands, they are supposed to
// have handled the error in synchronize() or Event::wait().
error_.reset();
}
return encoder_.get();
}
Expand Down
15 changes: 11 additions & 4 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;

class Device;
class EventImpl;

class MLX_API CommandEncoder {
public:
Expand Down Expand Up @@ -90,13 +91,12 @@ class MLX_API CommandEncoder {

void barrier();
void end_encoding();
void wait_event(std::shared_ptr<EventImpl> event, uint64_t value);
void signal_event(std::shared_ptr<EventImpl> event, uint64_t value);
bool needs_commit() const;
void commit();
void commit(std::function<void()> completion = nullptr);
void synchronize();

MTL::CommandQueue* get_command_queue() const {
return queue_.get();
}
MTL::CommandBuffer* get_command_buffer() const {
return buffer_.get();
}
Expand All @@ -113,6 +113,13 @@ class MLX_API CommandEncoder {
int buffer_ops_{0};
size_t buffer_sizes_{0};

// The events hooked to current command buffer.
std::vector<std::shared_ptr<EventImpl>> wait_events_;
std::vector<std::tuple<std::shared_ptr<EventImpl>, uint64_t>> signal_events_;

// Error from previous commited command buffer.
std::shared_ptr<std::string> error_;

// Encoder for issuing GPU commands.
// The members are used within a single ComputeCommandEncoder and will be
// reset after calling end_encoding().
Expand Down
23 changes: 4 additions & 19 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@ void new_stream(Stream s) {
encoders.try_emplace(s.index, d, s.index, d.residency_set());
}

inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}

void eval(array& arr) {
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
Expand Down Expand Up @@ -60,17 +51,12 @@ void eval(array& arr) {
if (encoder.needs_commit()) {
encoder.end_encoding();
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
encoder.commit();
encoder.commit([s, buffers = std::move(buffers)]() {
scheduler::notify_task_completion(s);
});
} else {
command_buffer->addCompletedHandler(
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {});
}
}

Expand All @@ -79,7 +65,6 @@ void finalize(Stream s) {
auto& encoder = metal::get_command_encoder(s);
auto* cb = encoder.get_command_buffer();
encoder.end_encoding();
cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
encoder.commit();
}

Expand Down
88 changes: 59 additions & 29 deletions mlx/backend/metal/event.cpp
Original file line number Diff line number Diff line change
@@ -1,62 +1,92 @@
// Copyright © 2024 Apple Inc.

#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/scheduler.h"

namespace mlx::core {

Event::Event(Stream stream) : stream_(stream) {
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(ptr)->release();
};
auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>(
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
if (event_ == nullptr) {
///////////////////////////////////////////////////////////////////////////////
// EventImpl implementations
///////////////////////////////////////////////////////////////////////////////

namespace metal {

EventImpl::EventImpl(Device& d) {
auto p = new_scoped_memory_pool();
mtl_event_ = NS::TransferPtr(d.mtl_device()->newSharedEvent());
if (!mtl_event_) {
throw std::runtime_error(
"[Event::Event] Failed to create Metal shared event.");
}
}

void Event::wait() {
if (!static_cast<MTL::SharedEvent*>(event_.get())
->waitUntilSignaledValue(value(), -1)) {
throw std::runtime_error("[Event::wait] Timed out");
EventImpl::~EventImpl() {
auto p = new_scoped_memory_pool();
mtl_event_.reset();
}

void EventImpl::wait(uint64_t value) {
check_error();
mtl_event_->waitUntilSignaledValue(value, -1); // never times out
check_error();
}

void EventImpl::signal(uint64_t value) {
mtl_event_->setSignaledValue(value);
}

void EventImpl::set_error(std::shared_ptr<std::string> error) {
std::atomic_store(&error_, std::move(error));
}

void EventImpl::check_error() {
auto error = std::atomic_exchange(&error_, {});
if (error) {
throw std::runtime_error(*error);
}
}

} // namespace metal

///////////////////////////////////////////////////////////////////////////////
// Event implementations
///////////////////////////////////////////////////////////////////////////////

Event::Event(Stream stream) : stream_(stream) {
event_ = std::make_shared<metal::EventImpl>(metal::device(stream.device));
}

void Event::wait() {
static_cast<metal::EventImpl*>(event_.get())->wait(value());
}

void Event::wait(Stream stream) {
auto impl = std::static_pointer_cast<metal::EventImpl>(event_);
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() {
impl->wait(value);
});
} else {
auto& encoder = metal::get_command_encoder(stream);
encoder.end_encoding();
auto* command_buffer = encoder.get_command_buffer();
command_buffer->encodeWait(static_cast<MTL::Event*>(event_.get()), value());
command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});
encoder.wait_event(std::move(impl), value());
}
}

void Event::signal(Stream stream) {
auto impl = std::static_pointer_cast<metal::EventImpl>(event_);
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
scheduler::enqueue(stream, [impl = std::move(impl), value = value()]() {
impl->signal(value);
});
} else {
auto& encoder = metal::get_command_encoder(stream);
encoder.end_encoding();
auto* command_buffer = encoder.get_command_buffer();
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(event_.get()), value());
command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});
encoder.signal_event(std::move(impl), value());
}
}

bool Event::is_signaled() const {
return static_cast<MTL::SharedEvent*>(event_.get())->signaledValue() >=
value();
auto* mtl_event = static_cast<metal::EventImpl*>(event_.get())->mtl_event();
return mtl_event->signaledValue() >= value();
}

} // namespace mlx::core
33 changes: 33 additions & 0 deletions mlx/backend/metal/event.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright © 2026 Apple Inc.

#include "mlx/backend/metal/device.h"
#include "mlx/event.h"

namespace mlx::core::metal {

class EventImpl {
public:
EventImpl(Device& d);
~EventImpl();

void wait(uint64_t value);
void signal(uint64_t value);
void set_error(std::shared_ptr<std::string> error);
void check_error();

const auto& error() const {
return error_;
}

auto* mtl_event() {
return mtl_event_.get();
}

private:
// TODO: Use std::atomic<std::shared_ptr> when it gets supported in Xcode.
std::shared_ptr<std::string> error_;

NS::SharedPtr<MTL::SharedEvent> mtl_event_;
};

} // namespace mlx::core::metal
Loading
Loading