Skip to content

Validate segmented_mm inner dimensions#3572

Closed
fallintoplace wants to merge 1 commit into
ml-explore:mainfrom
fallintoplace:fix-segmented-mm-shape-check
Closed

Validate segmented_mm inner dimensions#3572
fallintoplace wants to merge 1 commit into
ml-explore:mainfrom
fallintoplace:fix-segmented-mm-shape-check

Conversation

@fallintoplace
Copy link
Copy Markdown

Summary

  • Add an inner-dimension check to segmented_mm before creating the primitive.
  • Add a regression test for mismatched MxK and KxN inputs.

Why

segmented_mm expects a to have shape MxK and b to have shape KxN, but it did not reject incompatible inner dimensions. This now matches the validation style used by neighboring matmul APIs.

Validation

  • git diff --check HEAD~1..HEAD
  • python3 -m py_compile python/tests/test_blas.py

I could not run the full MLX Python test locally because this machine does not have the Metal compiler available through xcrun -sdk macosx metal.

Fixes #3571

@fallintoplace fallintoplace marked this pull request as ready for review May 21, 2026 10:08
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented May 22, 2026

This is not how segmented_mm is supposed to work, see reference implementation for example:

mlx/python/tests/test_blas.py

Lines 1295 to 1300 in 2d2d59e

def segmented_mm_ref(a, b, s):
s = s.tolist()
c = []
for s1, s2 in s:
c.append(a[:, s1:s2] @ b[s1:s2, :])
return mx.stack(c, axis=0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

segmented_mm should validate inner matrix dimensions

2 participants