Skip to content

Conversation

@zcbenz
Copy link
Collaborator

@zcbenz zcbenz commented Dec 12, 2025

This PR implements gather_mm for the cases that can be transferred into a grouped GEMM.

The grouped GEMM code uses CUTLASS 2.x API which allows us to choose kernels that do not require large alignment. The performance is not good, running benchmarks/python/gather_mm_bench.py shows that it takes 7x time than equivalent matmul (for Metal kernel it takes 1.5x time).

There are a lot of things remaining to be done:

  • Implement the cases that lhs indices are passed.
  • Implement the cases when indices are broadcasted.
  • Implement the cases when indices are not sorted.
  • Tune the GEMM tile sizes.
  • Enable tensor core for sm80 and later.
  • Pad the group sizes so we can use much faster kernels.

But current work can serve as a baseline and a good foundation for progressive improvements.

@zcbenz zcbenz force-pushed the cuda-grouped-mm branch 2 times, most recently from 28e5ca7 to 3b2a857 Compare December 15, 2025 23:49
Comment on lines +327 to +339
array zero(0, a.dtype());
encoder.add_temporary(zero);
fill_gpu(zero, out, s);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note, not necessary for this PR but we should probably do these with cudaMemsetAsync

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would need to go in the graph.. but I think that should be fairly straight-forward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cuda graph has a memset node we can use, that would be much better than running a kernel.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome start!

I left a comment regarding a small bug in the current code. Very weird that no tests caught that. I verified with the following small test

import mlx.core as mx
x = mx.random.normal((1024, 1, 1024))
w = mx.random.normal((16, 1024, 1024))
indices = mx.sort((mx.random.uniform(shape=(1024,)) * 16).astype(mx.int32))
y = mx.gather_mm(x, w.swapaxes(-1, -2), rhs_indices=indices, sorted_indices=True)
z = []
for i in range(1024):
    z.append(x[i] @ w[indices[i]].T)
z = mx.stack(z)
mx.eval(y, z)

cutlass::ComplexTransform::kNone,
kAlignment,
T,
cutlass::layout::RowMajor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should depend on the passed in matrix. Basically when b_transposed in matmul.cpp is true this should be ColMajor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing this! I have fixed it and added tests.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks.

@zcbenz zcbenz merged commit 1d21d0e into ml-explore:main Dec 24, 2025
15 checks passed
@zcbenz zcbenz deleted the cuda-grouped-mm branch December 24, 2025 00:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants