Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ Dtype at_least_float(const Dtype& d) {
return issubdtype(d, inexact) ? d : promote_types(d, float32);
}

array inner_1d_gpu_chunked(const array& a, const array& b, Stream s) {
constexpr int kChunkSize = 4096;

int main_size = (a.size() / kChunkSize) * kChunkSize;
int batches = main_size / kChunkSize;

auto a_main =
reshape(slice(a, {0}, {main_size}, s), {batches, 1, kChunkSize}, s);
auto b_main =
reshape(slice(b, {0}, {main_size}, s), {batches, kChunkSize, 1}, s);

// Route large 1D dot products through batched matmul so gemv parallelizes
array total = sum(reshape(matmul(a_main, b_main, s), {batches}, s), false, s);
if (main_size == a.size()) {
return total;
}

auto a_tail = slice(a, {main_size}, {static_cast<int>(a.size())}, s);
auto b_tail = slice(b, {main_size}, {static_cast<int>(b.size())}, s);
return add(total, sum(multiply(a_tail, b_tail, s), false, s), s);
}

array indices_or_default(
std::optional<array> indices,
const array& x,
Expand Down Expand Up @@ -5406,6 +5428,13 @@ array tensordot(
}
}

auto stream = to_stream(s);
if (a.ndim() == 1 && b.ndim() == 1 && axes_a.size() == 1 &&
axes_b.size() == 1 && stream.device == Device::gpu &&
csize >= 32 * 4096) {
return inner_1d_gpu_chunked(a, b, stream);
}

std::vector<bool> cdims1(x.ndim(), false);
std::vector<bool> cdims2(y.ndim(), false);
for (const auto n : axes_a) {
Expand Down
Loading