From fb488bce763423f42c34fd1d52f67ab481f482d7 Mon Sep 17 00:00:00 2001 From: Vedant Date: Fri, 22 May 2026 14:44:51 +0530 Subject: [PATCH 1/3] Route large 1D dot products through batched matmul --- mlx/ops.cpp | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6ad41e2e38..410978bd50 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -58,6 +58,28 @@ Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } +array inner_1d_gpu_chunked(const array& a, const array& b, Stream s) { + constexpr int kChunkSize = 16384; + + int main_size = (a.size() / kChunkSize) * kChunkSize; + int batches = main_size / kChunkSize; + + auto a_main = + reshape(slice(a, {0}, {main_size}, s), {batches, 1, kChunkSize}, s); + auto b_main = + reshape(slice(b, {0}, {main_size}, s), {batches, kChunkSize, 1}, s); + + // Route large 1D dot products through batched matmul so gemv parallelizes + array total = sum(reshape(matmul(a_main, b_main, s), {batches}, s), false, s); + if (main_size == a.size()) { + return total; + } + + auto a_tail = slice(a, {main_size}, {static_cast(a.size())}, s); + auto b_tail = slice(b, {main_size}, {static_cast(b.size())}, s); + return add(total, sum(multiply(a_tail, b_tail, s), false, s), s); +} + array indices_or_default( std::optional indices, const array& x, @@ -5406,6 +5428,18 @@ array tensordot( } } + auto stream = to_stream(s); + auto device = stream.device; + if (a.has_primitive()) { + device = a.primitive().stream().device; + } else if (b.has_primitive()) { + device = b.primitive().stream().device; + } + if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && + axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 16384) { + return inner_1d_gpu_chunked(a, b, stream); + } + std::vector cdims1(x.ndim(), false); std::vector cdims2(y.ndim(), false); for (const auto n : axes_a) { From 24d639d4e14e1f8ec39ad95e2ac3273a7279cbe5 Mon Sep 17 00:00:00 2001 From: Vedant Date: Fri, 22 May 2026 15:24:57 +0530 Subject: [PATCH 2/3] Change chunk size --- mlx/ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 410978bd50..9c0b35cbf9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -59,7 +59,7 @@ Dtype at_least_float(const Dtype& d) { } array inner_1d_gpu_chunked(const array& a, const array& b, Stream s) { - constexpr int kChunkSize = 16384; + constexpr int kChunkSize = 4096; int main_size = (a.size() / kChunkSize) * kChunkSize; int batches = main_size / kChunkSize; @@ -5436,7 +5436,7 @@ array tensordot( device = b.primitive().stream().device; } if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && - axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 16384) { + axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 4096) { return inner_1d_gpu_chunked(a, b, stream); } From 0d2e6209d55e942256b45fb01a0674796f3bd1e5 Mon Sep 17 00:00:00 2001 From: Vedant Date: Sun, 24 May 2026 10:44:58 +0530 Subject: [PATCH 3/3] Remove operand level checking --- mlx/ops.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9c0b35cbf9..b95c9422fe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -5429,14 +5429,9 @@ array tensordot( } auto stream = to_stream(s); - auto device = stream.device; - if (a.has_primitive()) { - device = a.primitive().stream().device; - } else if (b.has_primitive()) { - device = b.primitive().stream().device; - } if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && - axes_b.size() == 1 && device == Device::gpu && csize >= 32 * 4096) { + axes_b.size() == 1 && stream.device == Device::gpu && + csize >= 32 * 4096) { return inner_1d_gpu_chunked(a, b, stream); }