Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions mlx/backend/cpu/encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,29 @@

namespace mlx::core::cpu {

CommandEncoder& get_command_encoder(Stream s) {
auto& encoders = get_command_encoders();
auto it = encoders.find(s.index);
if (it == encoders.end()) {
auto& global_encoders = get_global_command_encoders();
it = global_encoders.find(s.index);
if (it == global_encoders.end()) {
throw std::runtime_error(
fmt::format(
"There is no Stream(cpu, {}) in current thread.", s.index));
}
}
return it->second;
}

std::unordered_map<int, CommandEncoder>& get_command_encoders() {
static thread_local std::unordered_map<int, CommandEncoder> encoders;
return encoders;
}

CommandEncoder& get_command_encoder(Stream stream) {
auto& encoders = get_command_encoders();
auto it = encoders.find(stream.index);
if (it == encoders.end()) {
throw std::runtime_error(
fmt::format(
"There is no Stream(cpu, {}) in current thread.", stream.index));
}
return it->second;
std::unordered_map<int, CommandEncoder>& get_global_command_encoders() {
static std::unordered_map<int, CommandEncoder> encoders;
return encoders;
}

} // namespace mlx::core::cpu
4 changes: 3 additions & 1 deletion mlx/backend/cpu/encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ struct MLX_API CommandEncoder {
int num_ops_{0};
};

MLX_API CommandEncoder& get_command_encoder(Stream stream);
MLX_API CommandEncoder& get_command_encoder(Stream s);

std::unordered_map<int, CommandEncoder>& get_command_encoders();
std::unordered_map<int, CommandEncoder>& get_global_command_encoders();

} // namespace mlx::core::cpu
5 changes: 5 additions & 0 deletions mlx/backend/cpu/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ void new_stream(Stream s) {
encoders.try_emplace(s.index, s);
}

void new_thread_unsafe_stream(Stream s) {
auto& encoders = get_global_command_encoders();
encoders.try_emplace(s.index, s);
}

void clear_streams() {
get_command_encoders().clear();
}
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/cpu/eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
namespace mlx::core::cpu {

void new_stream(Stream s);
void new_thread_unsafe_stream(Stream s);
void eval(array& arr);
void clear_streams();

Expand Down
14 changes: 12 additions & 2 deletions mlx/backend/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,13 @@ CommandEncoder& get_command_encoder(Stream s) {
auto& encoders = get_command_encoders();
auto it = encoders.find(s.index);
if (it == encoders.end()) {
throw std::runtime_error(
fmt::format("There is no Stream(gpu, {}) in current thread.", s.index));
auto& global_encoders = get_global_command_encoders();
it = global_encoders.find(s.index);
if (it == global_encoders.end()) {
throw std::runtime_error(
fmt::format(
"There is no Stream(gpu, {}) in current thread.", s.index));
}
}
return it->second;
}
Expand All @@ -577,4 +582,9 @@ std::unordered_map<int, CommandEncoder>& get_command_encoders() {
return encoders;
}

std::unordered_map<int, CommandEncoder>& get_global_command_encoders() {
static std::unordered_map<int, CommandEncoder> encoders;
return encoders;
}

} // namespace mlx::core::cu
1 change: 1 addition & 0 deletions mlx/backend/cuda/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,6 @@ MLX_API Device& device(mlx::core::Device d);
MLX_API CommandEncoder& get_command_encoder(Stream s);

std::unordered_map<int, CommandEncoder>& get_command_encoders();
std::unordered_map<int, CommandEncoder>& get_global_command_encoders();

} // namespace mlx::core::cu
21 changes: 15 additions & 6 deletions mlx/backend/cuda/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlx/backend/cuda/event.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"

#include <nvtx3/nvtx3.hpp>

Expand All @@ -15,23 +16,28 @@ namespace mlx::core::gpu {
void init() {
// Force initalization of CUDA, so CUDA runtime get destroyed last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
// Make sure native resources get destroyed after CommandEncoder.
mlx::core::cu::CudaEvent::init_pool();
}

void new_stream(Stream s) {
// Make sure the handles get destroyed after CommandEncoder.
init_cublas_handles_cache();
init_cudnn_handles_cache();
init_cudnn_conv_cache();
init_cudnn_sdpa_cache();
// Create CommandEncoder.
}

void new_stream(Stream s) {
assert(s.device == Device::gpu);
auto& encoders = cu::get_command_encoders();
auto& d = cu::device(s.device);
encoders.try_emplace(s.index, d);
}

void new_thread_unsafe_stream(Stream s) {
assert(s.device == Device::gpu);
auto& encoders = cu::get_global_command_encoders();
auto& d = cu::device(s.device);
encoders.try_emplace(s.index, d);
}

void eval(array& arr) {
nvtx3::scoped_range r("gpu::eval");
// Ensure CUDA context is active on this thread. Required when MLX is called
Expand Down Expand Up @@ -82,6 +88,9 @@ void synchronize(Stream s) {

void clear_streams() {
cu::get_command_encoders().clear();
if (is_main_thread()) {
cu::get_global_command_encoders().clear();
}
}

} // namespace mlx::core::gpu
3 changes: 2 additions & 1 deletion mlx/backend/gpu/eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
namespace mlx::core::gpu {

void init();
void new_stream(Stream stream);
void new_stream(Stream s);
void new_thread_unsafe_stream(Stream s);
void eval(array& arr);
void finalize(Stream s);
void synchronize(Stream s);
Expand Down
14 changes: 12 additions & 2 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,13 @@ CommandEncoder& get_command_encoder(Stream s) {
auto& encoders = get_command_encoders();
auto it = encoders.find(s.index);
if (it == encoders.end()) {
throw std::runtime_error(
fmt::format("There is no Stream(gpu, {}) in current thread.", s.index));
auto& global_encoders = get_global_command_encoders();
it = global_encoders.find(s.index);
if (it == global_encoders.end()) {
throw std::runtime_error(
fmt::format(
"There is no Stream(gpu, {}) in current thread.", s.index));
}
}
return it->second;
}
Expand All @@ -821,6 +826,11 @@ std::unordered_map<int, CommandEncoder>& get_command_encoders() {
return encoders;
}

std::unordered_map<int, CommandEncoder>& get_global_command_encoders() {
static std::unordered_map<int, CommandEncoder> encoders;
return encoders;
}

NS::SharedPtr<NS::AutoreleasePool> new_scoped_memory_pool() {
return NS::TransferPtr(NS::AutoreleasePool::alloc()->init());
}
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ MLX_API Device& device(mlx::core::Device);
MLX_API CommandEncoder& get_command_encoder(Stream s);

std::unordered_map<int, CommandEncoder>& get_command_encoders();
std::unordered_map<int, CommandEncoder>& get_global_command_encoders();
NS::SharedPtr<NS::AutoreleasePool> new_scoped_memory_pool();

bool is_nax_available();
Expand Down
11 changes: 11 additions & 0 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"

namespace mlx::core::gpu {

Expand All @@ -18,6 +19,13 @@ void new_stream(Stream s) {
encoders.try_emplace(s.index, d, s.index, d.residency_set());
}

void new_thread_unsafe_stream(Stream s) {
assert(s.device == Device::gpu);
auto& encoders = metal::get_global_command_encoders();
auto& d = metal::device(s.device);
encoders.try_emplace(s.index, d, s.index, d.residency_set());
}

inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
Expand Down Expand Up @@ -89,6 +97,9 @@ void synchronize(Stream s) {

void clear_streams() {
metal::get_command_encoders().clear();
if (is_main_thread()) {
metal::get_global_command_encoders().clear();
}
}

} // namespace mlx::core::gpu
5 changes: 5 additions & 0 deletions mlx/backend/no_gpu/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ void new_stream(Stream) {
"[new_stream] Cannot make gpu stream without gpu backend.");
}

void new_thread_unsafe_stream(Stream) {
throw std::invalid_argument(
"[new_thread_unsafe_stream] Cannot make gpu stream without gpu backend.");
}

void eval(array&) {
throw std::runtime_error("[gpu::eval] GPU backend is not available");
}
Expand Down
2 changes: 2 additions & 0 deletions mlx/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlx/scheduler.h"
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/utils.h"

namespace mlx::core {

Expand Down Expand Up @@ -33,6 +34,7 @@ void clear_streams() {
namespace scheduler {

Scheduler::Scheduler() {
is_main_thread();
gpu::init();
}

Expand Down
13 changes: 13 additions & 0 deletions mlx/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ Stream new_stream(Device d) {
return s;
}

Stream new_thread_unsafe_stream(Device d) {
auto& [streams, mtx] = all_streams();
std::unique_lock lock(mtx);
int index = streams.size();
auto& s = streams.emplace_back(index, d);
if (d == Device::gpu) {
gpu::new_thread_unsafe_stream(s);
} else {
cpu::new_thread_unsafe_stream(s);
}
return s;
}

ThreadLocalStream new_thread_local_stream(Device d) {
auto& [streams, mtx] = thread_local_streams();
std::lock_guard lock(mtx);
Expand Down
3 changes: 3 additions & 0 deletions mlx/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ MLX_API void set_default_stream(Stream s);
/** Make a new stream on the given device. */
MLX_API Stream new_stream(Device d);

/** Make a new stream that can be used in any thread. */
MLX_API Stream new_thread_unsafe_stream(Device d);

/** Make a new stream that will be unique per thread. */
MLX_API ThreadLocalStream new_thread_local_stream(Device d);

Expand Down
9 changes: 4 additions & 5 deletions mlx/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <iomanip>
#include <iostream>
#include <sstream>
#include <thread>
#include <vector>

#include "mlx/dtype_utils.h"
Expand Down Expand Up @@ -118,11 +119,9 @@ void set_printoptions(PrintOptions options) {
formatter.format_options = options;
}

void abort_with_exception(const std::exception& error) {
std::ostringstream msg;
msg << "Terminating due to uncaught exception: " << error.what();
std::cerr << msg.str() << std::endl;
std::abort();
bool is_main_thread() {
static auto main_thread_id = std::this_thread::get_id();
return main_thread_id == std::this_thread::get_id();
}

Dtype result_type(const std::vector<array>& arrays) {
Expand Down
4 changes: 2 additions & 2 deletions mlx/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ MLX_API void set_printoptions(PrintOptions options);

MLX_API PrintFormatter& get_global_formatter();

/** Print the exception and then abort. */
MLX_API void abort_with_exception(const std::exception& error);
/** Return whether current thread is the first one that called this function. */
bool is_main_thread();

/** Holds information about floating-point types. */
struct MLX_API finfo {
Expand Down
19 changes: 18 additions & 1 deletion python/src/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,24 @@ void init_stream(nb::module_& m) {
"new_stream",
&mx::new_stream,
"device"_a,
R"pbdoc(Make a new stream on the given device.)pbdoc");
R"pbdoc(
Make a new stream on the given device.

The stream can only be used on the thread where it was created on, using
it in any other thread would result in errors.
)pbdoc");
m.def(
"new_thread_unsafe_stream",
&mx::new_thread_unsafe_stream,
"device"_a,
R"pbdoc(
Make a new stream that can be used in any thread.

Unlike :func:`new_stream` which can only work on the thread of creation,
streams created by this API can be passed to and evaluated anywhere, but
note that currently all nodes in a graph must be evaluated in sequence
and it is user's responsibilty to ensure there is no race condition.
)pbdoc");
m.def(
"new_thread_local_stream",
&mx::new_thread_local_stream,
Expand Down
20 changes: 15 additions & 5 deletions tests/scheduler_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,7 @@ TEST_CASE("test default stream in threads") {
}

TEST_CASE("test access stream in other thread") {
if (!gpu::is_available()) {
return;
}

auto main_thread_stream = new_stream(Device::gpu);
auto main_thread_stream = new_stream(default_device());
eval(arange(10, main_thread_stream));

bool error_caught = false;
Expand Down Expand Up @@ -104,6 +100,20 @@ TEST_CASE("test new stream in threads") {
}
}

TEST_CASE("test thread unsafe stream") {
auto s = new_thread_unsafe_stream(default_device());
int expected = sum(arange(10, s)).item<int>();

int actual = 0;
std::thread t([&] {
actual = sum(arange(10, s)).item<int>();
clear_streams();
});
t.join();

CHECK_EQ(expected, actual);
}

TEST_CASE("test thread local stream") {
auto s = new_thread_local_stream(default_device());
int result = sum(arange(10, s)).item<int>();
Expand Down
Loading