From 63fbc244ec0196b7e900c87ca75a3582a349dddc Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 28 Oct 2025 17:53:11 +0800 Subject: [PATCH] Add `matmul` operator --- src/ntops/torch/__init__.py | 2 ++ src/ntops/torch/matmul.py | 18 ++++++++++++++++++ tests/test_matmul.py | 22 ++++++++++++++++++++++ 3 files changed, 42 insertions(+) create mode 100644 src/ntops/torch/matmul.py create mode 100644 tests/test_matmul.py diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 056c0eb..702877e 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -19,6 +19,7 @@ from ntops.torch.layer_norm import layer_norm from ntops.torch.le import le from ntops.torch.lt import lt +from ntops.torch.matmul import matmul from ntops.torch.mm import mm from ntops.torch.mul import mul from ntops.torch.ne import ne @@ -58,6 +59,7 @@ "layer_norm", "le", "lt", + "matmul", "mm", "mul", "ne", diff --git a/src/ntops/torch/matmul.py b/src/ntops/torch/matmul.py new file mode 100644 index 0000000..9cf7b26 --- /dev/null +++ b/src/ntops/torch/matmul.py @@ -0,0 +1,18 @@ +import ntops + + +def matmul(input, other, *, out=None): + assert input.ndim in (2, 3) and other.ndim in (2, 3), ( + "Currently, only 2D and 3D tensors are supported." + ) + + if input.ndim == 2 and other.ndim == 2: + return ntops.torch.mm(input, other, out=out) + + if input.ndim < 3: + input = input.unsqueeze(0) + + if other.ndim < 3: + other = other.unsqueeze(0) + + return ntops.torch.bmm(input, other, out=out) diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000..e800f8f --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,22 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.test_mm import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("b", (None, 1, 2, 3)) +def test_matmul(b, m, n, k, dtype, device, rtol, atol): + input_shape = (b, m, k) if b is not None else (m, k) + other_shape = (b, k, n) if b is not None else (k, n) + + input = torch.randn(input_shape, dtype=dtype, device=device) + other = torch.randn(other_shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.matmul(input, other) + reference_output = torch.matmul(input, other) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)