diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..313ded3545 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1700,6 +1700,21 @@ class GatherQMM : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(GatherQMM) bool is_equivalent(const Primitive& other) const override; + + // inputs layout: Affine → {x, w, scales, biases, lhs_idx, rhs_idx} + // other → {x, w, scales, lhs_idx, rhs_idx} + std::vector output_shapes(const std::vector& inputs) override { + const auto& x = inputs[0]; + const auto& w = inputs[1]; + const auto& lhs_idx = + (mode_ == QuantizationMode::Affine) ? inputs[4] : inputs[3]; + int w_outer = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_; + auto out_shape = lhs_idx.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w_outer); + return {out_shape}; + } + auto state() const { return std::make_tuple( group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 20f1145223..0c3e49fa56 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -504,6 +504,34 @@ def ones_fun(x): self.assertEqual(compiled_zero_like(y).shape, y_shape) self.assertEqual(compiled_ones_like(y).shape, y_shape) + def test_shapeless_compile_gather_qmm(self): + # GatherQMM must implement output_shapes() so shapeless compile can + # re-trace without throwing "GatherQMM cannot infer output shapes". + K, N, num_experts = 64, 32, 4 + + w = mx.random.normal((num_experts, N, K)) + qw, s, b = mx.quantize(w) + mx.eval(qw, s, b) + + # x has shape (num_experts, M, K): the batch dim is indexed by idx, + # which stays fixed so that lhs_indices and rhs_indices (auto-generated + # from w's batch shape) always broadcast. Only M changes between calls. + idx = mx.array([0, 1, 2, 3]) + x4 = mx.ones((num_experts, 4, K)) + x8 = mx.ones((num_experts, 8, K)) + + def fn(x): + return mx.gather_qmm( + x, qw, s, b, lhs_indices=idx, rhs_indices=idx, transpose=True + ) + + cfn = mx.compile(fn, shapeless=True) + + self.assertEqual(cfn(x4).shape, fn(x4).shape) + + # Different M — must reuse compiled graph without throwing. + self.assertEqual(cfn(x8).shape, fn(x8).shape) + def test_compile_with_constant(self): # Test float @partial(mx.compile)