diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6ad41e2e38..b95c9422fe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -58,6 +58,28 @@ Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } +array inner_1d_gpu_chunked(const array& a, const array& b, Stream s) { + constexpr int kChunkSize = 4096; + + int main_size = (a.size() / kChunkSize) * kChunkSize; + int batches = main_size / kChunkSize; + + auto a_main = + reshape(slice(a, {0}, {main_size}, s), {batches, 1, kChunkSize}, s); + auto b_main = + reshape(slice(b, {0}, {main_size}, s), {batches, kChunkSize, 1}, s); + + // Route large 1D dot products through batched matmul so gemv parallelizes + array total = sum(reshape(matmul(a_main, b_main, s), {batches}, s), false, s); + if (main_size == a.size()) { + return total; + } + + auto a_tail = slice(a, {main_size}, {static_cast(a.size())}, s); + auto b_tail = slice(b, {main_size}, {static_cast(b.size())}, s); + return add(total, sum(multiply(a_tail, b_tail, s), false, s), s); +} + array indices_or_default( std::optional indices, const array& x, @@ -5406,6 +5428,13 @@ array tensordot( } } + auto stream = to_stream(s); + if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 && + axes_b.size() == 1 && stream.device == Device::gpu && + csize >= 32 * 4096) { + return inner_1d_gpu_chunked(a, b, stream); + } + std::vector cdims1(x.ndim(), false); std::vector cdims2(y.ndim(), false); for (const auto n : axes_a) {