Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mlx/backend/cuda/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ CustomKernelFunction cuda_kernel(
<< "```" << std::endl;
}

auto output_shapes_copy = output_shapes;
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
Expand All @@ -236,7 +237,8 @@ CustomKernelFunction cuda_kernel(
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
shared_memory,
std::move(output_shapes_copy)),
Comment thread
zcbenz marked this conversation as resolved.
Outdated
std::move(inputs));
};
}
Expand Down Expand Up @@ -270,7 +272,8 @@ std::vector<array> precompiled_cuda_kernel(
init_value,
scalars,
true,
shared_memory),
shared_memory,
output_shapes),
inputs);
}

Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/metal/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ CustomKernelFunction metal_kernel(
<< "```" << std::endl;
}

auto output_shapes_copy = output_shapes;
return array::make_arrays(
std::move(output_shapes),
std::move(output_dtypes),
Expand All @@ -319,7 +320,8 @@ CustomKernelFunction metal_kernel(
init_value,
std::vector<ScalarArg>{},
false,
0),
0,
std::move(output_shapes_copy)),
std::move(inputs));
};
}
Expand Down
14 changes: 12 additions & 2 deletions mlx/fast_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ class CustomKernel : public Primitive {
std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
int shared_memory,
std::vector<Shape> output_shapes = {})
: Primitive(stream),
name_(std::move(name)),
source_(std::move(source)),
Expand All @@ -386,7 +387,8 @@ class CustomKernel : public Primitive {
init_value_(init_value),
scalar_arguments_(std::move(scalar_arguments)),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
shared_memory_(shared_memory),
output_shapes_(std::move(output_shapes)) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
Expand All @@ -397,6 +399,13 @@ class CustomKernel : public Primitive {
override;

DEFINE_NAME(CustomKernel);

std::vector<Shape> output_shapes(const std::vector<array>&) override {
if (output_shapes_.empty())
return Primitive::output_shapes({});
return output_shapes_;
}

auto state() const {
return std::make_tuple(
name_,
Expand All @@ -422,6 +431,7 @@ class CustomKernel : public Primitive {
std::vector<ScalarArg> scalar_arguments_;
bool is_precompiled_;
int shared_memory_;
std::vector<Shape> output_shapes_;
};

} // namespace mlx::core::fast
15 changes: 15 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,21 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_NAME(GatherQMM)
bool is_equivalent(const Primitive& other) const override;

// inputs layout: Affine → {x, w, scales, biases, lhs_idx, rhs_idx}
// other → {x, w, scales, lhs_idx, rhs_idx}
Comment thread
zcbenz marked this conversation as resolved.
Outdated
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
const auto& x = inputs[0];
const auto& w = inputs[1];
const auto& lhs_idx =
(mode_ == QuantizationMode::Affine) ? inputs[4] : inputs[3];
int w_outer = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_;
auto out_shape = lhs_idx.shape();
out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer);
return {out_shape};
}

auto state() const {
return std::make_tuple(
group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_);
Expand Down