-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[CUDA] Implement gather_mm_rhs #2902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
28e5ca7 to
3b2a857
Compare
| array zero(0, a.dtype()); | ||
| encoder.add_temporary(zero); | ||
| fill_gpu(zero, out, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note, not necessary for this PR but we should probably do these with cudaMemsetAsync
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would need to go in the graph.. but I think that should be fairly straight-forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cuda graph has a memset node we can use, that would be much better than running a kernel.
3b2a857 to
ce6070d
Compare
angeloskath
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome start!
I left a comment regarding a small bug in the current code. Very weird that no tests caught that. I verified with the following small test
import mlx.core as mx
x = mx.random.normal((1024, 1, 1024))
w = mx.random.normal((16, 1024, 1024))
indices = mx.sort((mx.random.uniform(shape=(1024,)) * 16).astype(mx.int32))
y = mx.gather_mm(x, w.swapaxes(-1, -2), rhs_indices=indices, sorted_indices=True)
z = []
for i in range(1024):
z.append(x[i] @ w[indices[i]].T)
z = mx.stack(z)
mx.eval(y, z)| cutlass::ComplexTransform::kNone, | ||
| kAlignment, | ||
| T, | ||
| cutlass::layout::RowMajor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should depend on the passed in matrix. Basically when b_transposed in matmul.cpp is true this should be ColMajor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for noticing this! I have fixed it and added tests.
ce6070d to
a6e2da6
Compare
a6e2da6 to
fb6740f
Compare
angeloskath
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks.
This PR implements
gather_mmfor the cases that can be transferred into a grouped GEMM.The grouped GEMM code uses CUTLASS 2.x API which allows us to choose kernels that do not require large alignment. The performance is not good, running
benchmarks/python/gather_mm_bench.pyshows that it takes 7x time than equivalent matmul (for Metal kernel it takes 1.5x time).There are a lot of things remaining to be done:
But current work can serve as a baseline and a good foundation for progressive improvements.